# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
"""A frame encoder is composed of one or two CoolChicEncoder."""
import typing
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
from enc.component.types import DescriptorCoolChic, NAME_COOLCHIC_ENC
from enc.utils.termprint import center_str
import torch
import torch.nn.functional as F
from enc.component.coolchic import (
    CoolChicEncoder,
    CoolChicEncoderParameter,
)
from enc.component.core.quantizer import (
    POSSIBLE_QUANTIZATION_NOISE_TYPE,
    POSSIBLE_QUANTIZER_TYPE,
)
from enc.component.intercoding.warp import vanilla_warp_fn, Warper, WarpParameter
from enc.io.types import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH
from enc.io.format.yuv import DictTensorYUV, convert_444_to_420, yuv_dict_clamp
from enc.utils.codingstructure import FRAME_TYPE
from enc.training.manager import FrameEncoderManager
from enc.utils.device import POSSIBLE_DEVICE
from torch import Tensor, nn
[docs]
@dataclass
class FrameEncoderOutput:
    """Dataclass representing the output of FrameEncoder forward."""
    # Either a [B, 3, H, W] tensor representing the decoded image or a
    # dictionary with the following keys for yuv420:
    #   {
    #         'y': [B, 1, H, W],
    #         'u': [B, 1, H / 2, W / 2],
    #         'v': [B, 1, H / 2, W / 2],
    #   }
    # Note: yuv444 data are represented as a simple [B, 3, H, W] tensor
    decoded_image: Union[Tensor, DictTensorYUV]
    # Rate associated to each cool-chic encoder
    rate: Dict[NAME_COOLCHIC_ENC, Tensor]
    # Any other data required to compute some logs, stored inside a dictionary
    additional_data: Dict[str, Any] = field(default_factory=lambda: {}) 
[docs]
class FrameEncoder(nn.Module):
    """A ``FrameEncoder`` is the object containing everything
    required to encode a video frame or an image. It is composed of
    one or more ``CoolChicEncoder``.
    """
[docs]
    def __init__(
        self,
        coolchic_enc_param: Dict[NAME_COOLCHIC_ENC, CoolChicEncoderParameter],
        warp_parameter: Optional[WarpParameter] = None,
        frame_type: FRAME_TYPE = "I",
        frame_data_type: FRAME_DATA_TYPE = "rgb",
        bitdepth: POSSIBLE_BITDEPTH = 8,
        index_references: List[int] = [],
        frame_display_index: int = 0,
    ):
        """
        Args:
            coolchic_enc_param: Parameters for the underlying CoolChicEncoders
            warp_parameter: Parameters for the Warper. Can be None for intra frame.
            frame_type: More info in
                :doc:`coding_structure.py <../utils/codingstructure>`.
                Defaults to "I".
            frame_data_type: More info in
                :doc:`coding_structure.py <../utils/codingstructure>`.
                Defaults to "rgb"
            bitdepth: More info in
                :doc:`coding_structure.py <../utils/codingstructure>`.
                Defaults to 8.
            index_references: List of the display index of the references.
                Defaults to []
            frame_display_index: display index of the frame being encoded.
        """
        super().__init__()
        # ----- Copy the parameters
        self.coolchic_enc_param = coolchic_enc_param
        self.warp_parameter = warp_parameter
        self.frame_type = frame_type
        self.frame_data_type = frame_data_type
        self.bitdepth = bitdepth
        self.index_references = index_references
        self.frame_display_index = frame_display_index
        # Check we've passed the expected number of frames.
        all_expected_n_ref = {"I": 0, "P": 1, "B": 2}
        for frame_type, expected_n_ref in all_expected_n_ref.items():
            if self.frame_type == frame_type:
                assert len(self.index_references) == expected_n_ref, (
                    f"{frame_type} frame must have {expected_n_ref} references. "
                    f"Found {len(self.index_references)}: {self.index_references}."
                )
        # "Core" CoolChic codec. This will be reset by the warm-up function
        self.coolchic_enc: Dict[NAME_COOLCHIC_ENC, CoolChicEncoder] = nn.ModuleDict()
        for name, cc_enc_param in self.coolchic_enc_param.items():
            self.coolchic_enc[name] = CoolChicEncoder(cc_enc_param)
        if self.frame_type != "I":
            self.warper = Warper(
                self.warp_parameter,
                img_size=self.coolchic_enc_param.get("residue").img_size,
            )
        # Global motion. Only here for saving purposes. Not used in the forward
        # We shift the references instead!
        # Global motion --> Shift the entire ref by a constant motion prior to
        # using the optical flow recovered from the motion cool-chic.
        # register_buffer for automatic device management. We set persistent to false
        # to simply use the "automatically move to device" function, without
        # considering global_flow_1 as a parameters (i.e. returned
        # by self.parameters())
        self.register_buffer("global_flow_1", torch.zeros(1, 2, 1, 1), persistent=False)
        self.register_buffer("global_flow_2", torch.zeros(1, 2, 1, 1), persistent=False) 
        # self.global_flow_1 = nn.Parameter(torch.zeros(1, 2, 1, 1), requires_grad=True)
        # self.global_flow_2 = nn.Parameter(torch.zeros(1, 2, 1, 1), requires_grad=True)
[docs]
    def forward(
        self,
        reference_frames: Optional[List[Tensor]] = None,
        quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy",
        quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround",
        soft_round_temperature: Optional[Tensor] = torch.tensor(0.3),
        noise_parameter: Optional[Tensor] = torch.tensor(1.0),
        AC_MAX_VAL: int = -1,
        flag_additional_outputs: bool = False,
        no_common_randomness: bool = False,
        only_common_randomness: bool = False,
    ) -> FrameEncoderOutput:
        """Perform the entire forward pass of a video frame / image.
        1. **Simulate Cool-chic decoding** to obtain both the decoded image
           :math:`\\hat{\\mathbf{x}}` as a :math:`(B, 3, H, W)` tensor
           and its associated rate :math:`\\mathrm{R}(\\hat{\\mathbf{x}})` as
           as :math:`(N)` tensor`, where :math:`N` is the number of latent
           pixels. The rate is given in bits.
        2. **Simulate the saving of the image to a file (Optional)**.
            *Only if the model has been set in test mode* e.g.
            ``self.set_to_eval()`` . Take into account that
            :math:`\\hat{\\mathbf{x}}` is a float Tensor, which is
            gonna be saved as integer values in a file.
            .. math::
                \\hat{\\mathbf{x}}_{saved} = \\mathtt{round}(\Delta_q \\
                \\hat{\\mathbf{x}}) / \\Delta_q, \\text{ with }
                \\Delta_q = 2^{bitdepth} - 1
        3. **Downscale to YUV 420 (Optional)**. *Only if the required output
           format is YUV420*. The current output is a dense Tensor. Downscale
           the last two channels to obtain a YUV420-like representation. This
           is done with a nearest neighbor downsampling.
        4. **Clamp the output** to be in :math:`[0, 1]`.
        Args:
            reference_frames: List of tensors representing the reference
                frames. Can be set to None if no reference frame is available.
                Default to None.
            quantizer_noise_type: Defaults to ``"kumaraswamy"``.
            quantizer_type: Defaults to ``"softround"``.
            soft_round_temperature: Soft round temperature.
                This is used for softround modes as well as the
                ste mode to simulate the derivative in the backward.
                Defaults to 0.3.
            noise_parameter: noise distribution parameter. Defaults to 1.0.
            AC_MAX_VAL: If different from -1, clamp the value to be in
                :math:`[-AC\\_MAX\\_VAL; AC\\_MAX\\_VAL + 1]` to write the actual bitstream.
                Defaults to -1.
            flag_additional_outputs: True to fill
                ``CoolChicEncoderOutput['additional_data']`` with many different
                quantities which can be used to analyze Cool-chic behavior.
                Defaults to False.
        Returns:
            Output of the FrameEncoder for the forward pass.
        """
        # Common parameters for all cool-chic encoders
        cc_forward_param = {
            "quantizer_noise_type": quantizer_noise_type,
            "quantizer_type": quantizer_type,
            "soft_round_temperature": soft_round_temperature,
            "noise_parameter": noise_parameter,
            "AC_MAX_VAL": AC_MAX_VAL,
            "flag_additional_outputs": flag_additional_outputs,
            "no_common_randomness": no_common_randomness,
            "only_common_randomness": only_common_randomness,
        }
        cc_enc_out = {
            cc_name: cc_enc(**cc_forward_param)
            for cc_name, cc_enc in self.coolchic_enc.items()
        }
        # Get the rate of each cool-chic encoder
        rate = {
            cc_name: cc_enc_out_i.get("rate")
            for cc_name, cc_enc_out_i in cc_enc_out.items()
        }
        if self.frame_type == "I":
            decoded_image = cc_enc_out["residue"].get("raw_out")
        elif self.frame_type in ["P", "B"]:
            residue = cc_enc_out["residue"].get("raw_out")[:, :3, :, :]
            alpha = torch.clamp(
                cc_enc_out["residue"].get("raw_out")[:, 3:4, :, :] + 0.5, 0.0, 1.0
            )
            flow_1 = cc_enc_out["motion"].get("raw_out")[:, 0:2, :, :]
            # Apply each global flow on each reference.
            # Upsample the global flow beforehand to obtain a constant [1, 2, H, W] optical flow.
            shifted_ref = []
            for ref_i, global_flow_i in zip(
                reference_frames, [self.global_flow_1, self.global_flow_2]
            ):
                ups_global_flow_i = F.interpolate(
                    global_flow_i, size=ref_i.size()[-2:], mode="nearest"
                )
                shifted_ref.append(
                    vanilla_warp_fn(ref_i, ups_global_flow_i, mode="nearest")
                )
            if self.frame_type == "P":
                pred = self.warper(shifted_ref[0], flow_1)
            elif self.frame_type == "B":
                flow_2 = cc_enc_out["motion"].get("raw_out")[:, 2:4, :, :]
                beta = torch.clamp(
                    cc_enc_out["motion"].get("raw_out")[:, 4:5, :, :] + 0.5, 0.0, 1.0
                )
                pred = beta * self.warper(shifted_ref[0], flow_1) + (
                    1 - beta
                ) * self.warper(shifted_ref[1], flow_2)
            decoded_image = alpha * pred + residue
        # Clamp decoded image & down sample YUV channel if needed
        if not self.training:
            max_dynamic = 2 ** (self.bitdepth) - 1
            decoded_image = torch.round(decoded_image * max_dynamic) / max_dynamic
        if self.frame_data_type == "yuv420":
            decoded_image = convert_444_to_420(decoded_image)
            decoded_image = yuv_dict_clamp(decoded_image, min_val=0.0, max_val=1.0)
        elif self.frame_data_type != "flow":
            decoded_image = torch.clamp(decoded_image, 0.0, 1.0)
        additional_data = {}
        if flag_additional_outputs:
            # Browse all the cool-chic output to get their additional data
            for cc_name, cc_enc_out_i in cc_enc_out.items():
                additional_data.update(
                    {
                        # Append the cc_name (e.g. residue) in front of the key
                        f"{cc_name}_{k}": v
                        for k, v in cc_enc_out_i.get("additional_data").items()
                    }
                )
            # Also add the residue, flow, pred and beta
            if self.frame_type in ["P", "B"]:
                additional_data["residue"] = residue
                additional_data["alpha"] = alpha
                additional_data["flow_1"] = flow_1
                additional_data["pred"] = pred
            if self.frame_type == "B":
                additional_data["flow_2"] = flow_2
                additional_data["beta"] = beta
        results = FrameEncoderOutput(
            decoded_image=decoded_image,
            rate=rate,
            additional_data=additional_data,
        )
        return results 
    # ------- Getter / Setter and Initializer
[docs]
    def get_param(self) -> OrderedDict[NAME_COOLCHIC_ENC, Tensor]:
        """Return **a copy** of the weights and biases inside the module.
        Returns:
            OrderedDict[NAME_COOLCHIC_ENC, Tensor]: A copy of all weights & biases in the module.
        """
        param = OrderedDict({})
        for cc_name, cc_enc in self.coolchic_enc.items():
            param.update(
                {
                    f"coolchic_enc.{cc_name}.{k}": v
                    for k, v in cc_enc.get_param().items()
                }
            )
        return param 
[docs]
    def set_param(self, param: OrderedDict[NAME_COOLCHIC_ENC, Tensor]):
        """Replace the current parameters of the module with param.
        Args:
            param (OrderedDict[NAME_COOLCHIC_ENC, Tensor]): Parameters to be set.
        """
        self.load_state_dict(param) 
[docs]
    def reinitialize_parameters(self) -> None:
        """Reinitialize in place the different parameters of a FrameEncoder."""
        for _, cc_enc in self.coolchic_enc.items():
            print("CHECK THAT I DO RE-INIT THE PARAM!")
            cc_enc.reinitialize_parameters() 
    def _store_full_precision_param(self) -> None:
        """For all the coolchic_enc,  store their current parameters inside
        self.full_precision_param.
        This function checks that there is no self.nn_q_step and
        self.nn_expgol_cnt already saved. This would mean that we no longer
        have full precision parameters but quantized ones.
        """
        for _, cc_enc in self.coolchic_enc.items():
            cc_enc._store_full_precision_param()
[docs]
    def set_to_train(self) -> None:
        """Set the current model to training mode, in place."""
        self = self.train()
        for _, cc_enc in self.coolchic_enc.items():
            cc_enc.train() 
[docs]
    def set_to_eval(self) -> None:
        """
        Set the current model to test mode, in place. This affects latent
        quantization, forcing it to mode="hardround" in eval mode.
        For video coding, it also affects the optical flows value,
        quantizing them at a given subpixel accuracy, defined in
        self.warp_parameter.fractional_accuracy.
        """
        self = self.eval()
        for _, cc_enc in self.coolchic_enc.items():
            cc_enc.eval() 
[docs]
    def set_global_flow(self, global_flow_1: Tensor, global_flow_2: Tensor) -> None:
        """Set the value of the global flows.
        The global flows are 2-element tensors. The first one is the horizontal
        displacement and the second one the vertical displacement.
        Args:
            global_flow_1 (Tensor): Value of global flow for reference 1. Must have 2 elements.
            global_flow_2 (Tensor): Value of global flow for reference 2. Must have 2 elements.
        """
        assert global_flow_1.numel() == 2, (
            f"global_flow_1 must have 2 parameters. Found {global_flow_1.numel()} "
            " parameters."
        )
        assert global_flow_2.numel() == 2, (
            f"global_flow_2 must have 2 parameters. Found {global_flow_2.numel()} "
            " parameters."
        )
        self.global_flow_1 = global_flow_1.view(self.global_flow_1.size())
        self.global_flow_2 = global_flow_2.view(self.global_flow_2.size()) 
[docs]
    def get_network_rate(
        self,
    ) -> Tuple[Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic], int]:
        """Return the rate (in bits) associated to the parameters
        (weights and biases) of the different modules
        Returns:
            Tuple[Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic], int]: The rate (in bits)
            associated with the weights and biases of each module of each
            cool-chic decoder. Also return the overall rate in bits.
        """
        detailed_rate_bit = {}
        total_rate_bit = 0.0
        for cc_name, cc_enc in self.coolchic_enc.items():
            detailed_rate_bit[cc_name], sum_rate = cc_enc.get_network_rate()
            total_rate_bit += sum_rate
        return detailed_rate_bit, total_rate_bit 
[docs]
    def get_network_quantization_step(
        self,
    ) -> Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]:
        """Return the quantization step associated to the parameters (weights
        and biases) of the different modules of each cool-chic decoder. Those
        quantization can be ``None`` if the model has not yet been quantized.
        E.g. {"residue": {"arm": 4, "upsampling": 12, "synthesis": 1}}
        Returns:
            Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]: The quantization step
            associated with the weights and biases of each module of each
            cool-chic decoder.
        """
        q_step = {}
        for cc_name, cc_enc in self.coolchic_enc.items():
            q_step[cc_name] = cc_enc.get_network_quantization_step()
        return q_step 
[docs]
    def get_network_expgol_count(self) -> Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]:
        """Return the Exp-Golomb count parameter associated to the parameters
        (weights and biases) of the different modules of each cool-chic decoder.
        Those exp-golomb param can be ``None`` if the model has not yet
        been quantized.
        E.g. {"residue": {"arm": 4, "upsampling": 12, "synthesis": 1}}
        Returns:
            Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]: The exp-golomb count
            parameter associated with the weights and biases of each module of
            each cool-chic decoder.
        """
        expgol_cnt = {}
        for cc_name, cc_enc in self.coolchic_enc.items():
            expgol_cnt[cc_name] = cc_enc.get_network_expgol_count()
        return expgol_cnt 
[docs]
    def get_total_mac_per_pixel(self) -> float:
        """Count the number of Multiplication-Accumulation (MAC) per decoded pixel
        for this model.
        Returns:
            float: number of floating point operations per decoded pixel.
        """
        mac_per_pixel = 0
        for cc_name, cc_enc in self.coolchic_enc.items():
            mac_per_pixel += cc_enc.get_total_mac_per_pixel()
        mac_per_pixel += self.get_warp_mac_per_pixel()
        return mac_per_pixel 
[docs]
    def get_warp_mac_per_pixel(self) -> float:
        """Compute the mac per pixel of a warping. The formula is derived from
        the paper "Efficient Sub-pixel Motion Compensation in Learned Video
        Codecs" from Ladune et al.
        Coefficient are supposed pre-computed. Warping is applyed on blocks
        Returns:
            float: Mac per pixel of the warping
        """
        # No warping for intra frame
        if self.frame_type == "I":
            return 0.0
        # if motion n_ft_per_res is 0,0,0,1,1 then idx_first_lat = 3
        # This corresponds to motion block size of 2 ** 3 = 8 since
        # the motion is decoded at 1/8 of the resolution (corresponding
        # to the biggest latent resolution)
        idx_first_lat = self.coolchic_enc_param["motion"].n_ft_per_res.index(1)
        b = 2**idx_first_lat
        n = self.warp_parameter.filter_size
        warp_single_channel = (n**2 - n) / b + 2 * n
        warp_single_ref = 3 * warp_single_channel
        if self.frame_type == "P":
            return warp_single_ref
        # Two references for B-frame
        elif self.frame_type == "B":
            return 2 * warp_single_ref 
[docs]
    def to_device(self, device: POSSIBLE_DEVICE) -> None:
        """Push a model to a given device.
        Args:
            device: The device on which the model should run.
        """
        assert device in typing.get_args(POSSIBLE_DEVICE), (
            f"Unknown device {device}, should be in {typing.get_args(POSSIBLE_DEVICE)}"
        )
        self = self.to(device)
        for _, cc_enc in self.coolchic_enc.items():
            cc_enc.to_device(device) 
[docs]
    def save(
        self,
        path_file: str,
        frame_encoder_manager: Optional[FrameEncoderManager] = None,
    ) -> None:
        """Save the FrameEncoder into a bytes buffer and return it.
            Optionally save a frame_encoder_manager alongside the current frame
            encoder to keep track of the training time, record loss etc.
        Args:
            path_file: Where to save the FrameEncoder
            frame_encoder_manager: Contains (among other things) the rate
                constraint :math:`\\lambda` and description of the
                warm-up preset. It is also used to track the total encoding time
                and encoding iterations.
            Returns:
                Bytes representing the saved coolchic model
        """
        data_to_save = {
            "bitdepth": self.bitdepth,
            "frame_type": self.frame_type,
            "frame_data_type": self.frame_data_type,
            "index_references": self.index_references,
            "frame_display_index": self.frame_display_index,
            "warp_parameter": self.warp_parameter,
            # Name of the different cool-chic encoder
            "keys_cc_enc": list(self.coolchic_enc.keys()),
            "global_flow_1": self.global_flow_1,
            "global_flow_2": self.global_flow_2,
        }
        for cc_name, cc_enc in self.coolchic_enc.items():
            data_to_save[f"{cc_name}"] = cc_enc.get_param()
            data_to_save[f"{cc_name}_nn_q_step"] = (
                cc_enc.get_network_quantization_step()
            )
            data_to_save[f"{cc_name}_nn_expgol_cnt"] = cc_enc.get_network_expgol_count()
            data_to_save[f"{cc_name}_param"] = self.coolchic_enc_param[cc_name]
            if cc_enc.full_precision_param is not None:
                data_to_save[f"{cc_name}_full_precision_param"] = (
                    cc_enc.full_precision_param
                )
        if frame_encoder_manager is not None:
            data_to_save["frame_encoder_manager"] = frame_encoder_manager
        torch.save(data_to_save, path_file) 
[docs]
    def pretty_string(self, print_detailed_archi: bool = False) -> str:
        """Get a pretty string representing the architectures of
        the different ``CoolChicEncoder`` composing the current ``FrameEncoder``.
        Args:
            print_detailed_archi: True to print the detailed decoder architecture
        Returns:
            str: a pretty string ready to be printed out
        """
        s = ""
        for name, cc_enc in self.coolchic_enc.items():
            total_mac_per_pix = cc_enc.get_total_mac_per_pixel()
            title = (
                "\n\n"
                f"{name} decoder: {total_mac_per_pix:5.0f} MAC / pixel"
                "\n"
                f"{'-' * len(name)}---------------------------"
                "\n"
            )
            s += title
            s += cc_enc.pretty_string(print_detailed_archi=print_detailed_archi) + "\n"
        if self.frame_type != "I":
            s += (
                f"Warping:        {self.get_warp_mac_per_pixel():5.0f} MAC / pixel\n"
                "---------------------------------"
            )
        return s 
[docs]
    def pretty_string_param(self) -> str:
        """Get a pretty string representing the parameters of
        the different ``CoolChicEncoderParameters`` parameterising the current
        ``FrameEncoder``
        """
        s = ""
        for name, cc_enc_param in self.coolchic_enc_param.items():
            title = (
                "\n\n"
                + center_str(f"{name} parameters")
                + "\n"
                + center_str(f"{'-' * len(name)})-----------")
                + "\n\n"
            )
            s += title
            s += cc_enc_param.pretty_string() + "\n"
        if self.frame_type != "I":
            title = (
                "\n\n"
                + center_str("Warper parameters")
                + "\n"
                + center_str(f"{'-' * len(name)})-----------")
                + "\n\n"
            )
            s += title
            s += self.warper.pretty_string() + "\n"
        return s 
 
[docs]
def load_frame_encoder(
    path_file: str,
) -> Tuple[FrameEncoder, Optional[FrameEncoderManager]]:
    """From already loaded raw bytes, load & return a CoolChicEncoder
    Args:
        path_file: Path of the FrameEncoder to be loaded
    Returns:
        Tuple with a FrameEncoder loaded by the function and an optional
        FrameEncoderManager
    """
    loaded_data = torch.load(path_file, map_location="cpu", weights_only=False)
    # Something like ["residue", "motion"]
    list_cc_name = loaded_data["keys_cc_enc"]
    # Load first the CoolChicEncoderParameter of all the Cool-chic encoders
    # for the frame
    coolchic_enc_param = {}
    for cc_name in list_cc_name:
        coolchic_enc_param[cc_name] = loaded_data[f"{cc_name}_param"]
        # Add backward compatibility
        if not hasattr(coolchic_enc_param[cc_name], "n_ft_per_res_cr"):
            coolchic_enc_param[cc_name].n_ft_per_res_cr = []
    if "warp_parameter" in loaded_data:
        warp_parameter = loaded_data["warp_parameter"]
        # Backward compatibility
        warp_parameter = WarpParameter(filter_size=warp_parameter.filter_size)
    else:
        warp_parameter = None
    # Create a, empty frame encoder from the stored parameters
    frame_encoder = FrameEncoder(
        coolchic_enc_param=coolchic_enc_param,
        warp_parameter=warp_parameter,
        frame_type=loaded_data["frame_type"],
        frame_data_type=loaded_data["frame_data_type"],
        bitdepth=loaded_data["bitdepth"],
        index_references=loaded_data["index_references"],
        frame_display_index=loaded_data["frame_display_index"],
    )
    # Load the parameters
    for cc_name in list_cc_name:
        frame_encoder.coolchic_enc[cc_name].set_param(loaded_data[cc_name])
        frame_encoder.coolchic_enc[cc_name].nn_q_step = loaded_data[
            f"{cc_name}_nn_q_step"
        ]
        frame_encoder.coolchic_enc[cc_name].nn_expgol_cnt = loaded_data[
            f"{cc_name}_nn_expgol_cnt"
        ]
        if f"{cc_name}_full_precision_param" in loaded_data:
            frame_encoder.coolchic_enc[cc_name].full_precision_parameter = loaded_data[
                f"{cc_name}_full_precision_param"
            ]
    frame_encoder_manager = loaded_data["frame_encoder_manager"]
    if "global_flow_1" in loaded_data:
        frame_encoder.global_flow_1 = loaded_data["global_flow_1"]
    if "global_flow_2" in loaded_data:
        frame_encoder.global_flow_2 = loaded_data["global_flow_2"]
    return frame_encoder, frame_encoder_manager