# Software Name: Cool-Chic# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange# SPDX-License-Identifier: BSD 3-Clause "New"## This software is distributed under the BSD-3-Clause license.## Authors: see CONTRIBUTORS.mdfromtypingimportListimporttorchfromenc.component.intercoding.warpimportwarp_fnfromenc.io.format.yuvimportconvert_420_to_444,rgb2yuv,yuv2rgbfromenc.io.framedataimportFrameDatafromenc.training.lossimport_compute_msefromtorchvision.models.optical_flowimportRaft_Large_Weights,raft_large
[docs]@torch.no_grad()defget_raft_optical_flow(frame_data:FrameData,ref_data:List[FrameData])->List[torch.Tensor]:"""Apply a RAFT model on its frame and each of its reference. Return the raft-estimated optical flows. Args: frame_data: Data for the frame to be encoded. ref_data: Data of the references. Returns: List[Tensor]: List of optical flows. Each of them with a shape of [1, 2, H, W]. """print("\nRaft-based motion estimation")print("----------------------------\n")# Store the old device to move the flows etc. after the RAFTifframe_data.frame_data_type=="yuv420":old_device=frame_data.data.get("y").deviceelse:old_device=frame_data.data.deviceiftorch.cuda.is_available():total_gpu_mem_gbytes=torch.cuda.mem_get_info()[1]/1e9_RAFT_DEVICE="cpu"iftotal_gpu_mem_gbytes<24else"cuda:0"else:_RAFT_DEVICE="cpu"# Get a pre-trained raft model to guide the coolchic encoder training# with a good optical flowraft=raft_large(weights=Raft_Large_Weights.DEFAULT,progress=False)raft=raft.to(_RAFT_DEVICE)raft.eval()flows=[]fori,ref_data_iinenumerate(ref_data):raw_ref_data=ref_data_i.dataraw_target_data=frame_data.data# RAFT expects dense frame, not 420 format.ifref_data_i.frame_data_type=="yuv420":raw_ref_data=convert_420_to_444(raw_ref_data)ifframe_data.frame_data_type=="yuv420":raw_target_data=convert_420_to_444(raw_target_data)# RAFT expects RGB data in [-1., 1.]ifref_data_i.frame_data_type!="rgb":raw_ref_data=yuv2rgb(raw_ref_data*255.0)/255.0ifframe_data.frame_data_type!="rgb":raw_target_data=yuv2rgb(raw_target_data*255.0)/255.0# Convert the data from [0, 1] -> [-1, 1]raw_target_data=raw_target_data*2-1raw_ref_data=raw_ref_data*2-1# Send from old_device -> _RAFT_DEVICE prior to using raft.# After raft, send back from _RAFT_DEVICE -> old_device.raw_target_data=raw_target_data.to(_RAFT_DEVICE)raw_ref_data=raw_ref_data.to(_RAFT_DEVICE)raft_flow=raft(raw_target_data,raw_ref_data)[-1]raw_target_data=raw_target_data.to(old_device)raw_ref_data=raw_ref_data.to(old_device)raft_flow=raft_flow.to(old_device)# Send back the data from [-1, 1] -> [0, 1]raw_target_data=(raw_target_data+1)*0.5raw_ref_data=(raw_ref_data+1)*0.5flows.append(raft_flow)raft_pred=warp_fn(raw_ref_data,raft_flow)# Compute the PSNR of the prediction obtained by raft.# Note that we compute this prediction in the 444 domain even though# we have a 420 frame.raft_pred_psnr=-10*torch.log10(_compute_mse(rgb2yuv(raft_pred*255.0)/255.0,rgb2yuv(raw_target_data*255.0)/255.0,)+1e-10)# Just as a reference we compute the PSNR when the pred is the# reference i.e. optical flow = 0.dummy_pred_psnr=-10*torch.log10(_compute_mse(rgb2yuv(raw_ref_data*255.0)/255.0,rgb2yuv(raw_target_data*255.0)/255.0,)+1e-10)print(f"Reference {i} | "f"RAFT-based prediction PSNR = {raft_pred_psnr:6.3f} dB | "f"No motion prediction PSNR = {dummy_pred_psnr:6.3f} dB")returnflows