# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
"""A frame encoder is composed of one or two CoolChicEncoder."""
import typing
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
from enc.component.types import DescriptorCoolChic, NAME_COOLCHIC_ENC
from enc.utils.termprint import center_str
import torch
import torch.nn.functional as F
from enc.component.coolchic import (
CoolChicEncoder,
CoolChicEncoderParameter,
)
from enc.component.core.quantizer import (
POSSIBLE_QUANTIZATION_NOISE_TYPE,
POSSIBLE_QUANTIZER_TYPE,
)
from enc.component.intercoding.warp import warp_fn
from enc.io.types import FRAME_DATA_TYPE, POSSIBLE_BITDEPTH
from enc.io.format.yuv import DictTensorYUV, convert_444_to_420, yuv_dict_clamp
from enc.utils.codingstructure import FRAME_TYPE
from enc.training.manager import FrameEncoderManager
from enc.utils.device import POSSIBLE_DEVICE
from torch import Tensor, nn
[docs]
@dataclass
class FrameEncoderOutput:
"""Dataclass representing the output of FrameEncoder forward."""
# Either a [B, 3, H, W] tensor representing the decoded image or a
# dictionary with the following keys for yuv420:
# {
# 'y': [B, 1, H, W],
# 'u': [B, 1, H / 2, W / 2],
# 'v': [B, 1, H / 2, W / 2],
# }
# Note: yuv444 data are represented as a simple [B, 3, H, W] tensor
decoded_image: Union[Tensor, DictTensorYUV]
# Rate associated to each cool-chic encoder
rate: Dict[NAME_COOLCHIC_ENC, Tensor]
# Any other data required to compute some logs, stored inside a dictionary
additional_data: Dict[str, Any] = field(default_factory=lambda: {})
[docs]
class FrameEncoder(nn.Module):
"""A ``FrameEncoder`` is the object containing everything
required to encode a video frame or an image. It is composed of
one or more ``CoolChicEncoder``.
"""
[docs]
def __init__(
self,
coolchic_enc_param: Dict[NAME_COOLCHIC_ENC, CoolChicEncoderParameter],
frame_type: FRAME_TYPE = "I",
frame_data_type: FRAME_DATA_TYPE = "rgb",
bitdepth: POSSIBLE_BITDEPTH = 8,
index_references: List[int] = [],
frame_display_index: int = 0,
):
"""
Args:
coolchic_enc_param: Parameters for the underlying CoolChicEncoders
frame_type: More info in
:doc:`coding_structure.py <../utils/codingstructure>`.
Defaults to "I".
frame_data_type: More info in
:doc:`coding_structure.py <../utils/codingstructure>`.
Defaults to "rgb"
bitdepth: More info in
:doc:`coding_structure.py <../utils/codingstructure>`.
Defaults to 8.
index_references: List of the display index of the references.
Defaults to []
frame_display_index: display index of the frame being encoded.
"""
super().__init__()
# ----- Copy the parameters
self.coolchic_enc_param = coolchic_enc_param
self.frame_type = frame_type
self.frame_data_type = frame_data_type
self.bitdepth = bitdepth
self.index_references = index_references
self.frame_display_index = frame_display_index
# Check we've passed the expected number of frames.
all_expected_n_ref = {"I": 0, "P": 1, "B": 2}
for frame_type, expected_n_ref in all_expected_n_ref.items():
if self.frame_type == frame_type:
assert len(self.index_references) == expected_n_ref, (
f"{frame_type} frame must have {expected_n_ref} references. "
f"Found {len(self.index_references)}: {self.index_references}."
)
# "Core" CoolChic codec. This will be reset by the warm-up function
self.coolchic_enc: Dict[NAME_COOLCHIC_ENC, CoolChicEncoder] = nn.ModuleDict()
for name, cc_enc_param in self.coolchic_enc_param.items():
self.coolchic_enc[name] = CoolChicEncoder(cc_enc_param)
# Global motion. Only here for saving purposes. Not used in the forward
# We shift the references instead!
# Global motion --> Shift the entire ref by a constant motion prior to
# using the optical flow recovered from the motion cool-chic.
# register_buffer for automatic device management. We set persistent to false
# to simply use the "automatically move to device" function, without
# considering global_flow_1 as a parameters (i.e. returned
# by self.parameters())
self.register_buffer("global_flow_1", torch.zeros(1, 2, 1, 1), persistent=False)
self.register_buffer("global_flow_2", torch.zeros(1, 2, 1, 1), persistent=False)
# self.global_flow_1 = nn.Parameter(torch.zeros(1, 2, 1, 1), requires_grad=True)
# self.global_flow_2 = nn.Parameter(torch.zeros(1, 2, 1, 1), requires_grad=True)
[docs]
def forward(
self,
reference_frames: Optional[List[Tensor]] = None,
quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy",
quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround",
soft_round_temperature: Optional[Tensor] = torch.tensor(0.3),
noise_parameter: Optional[Tensor] = torch.tensor(1.0),
AC_MAX_VAL: int = -1,
flag_additional_outputs: bool = False,
) -> FrameEncoderOutput:
"""Perform the entire forward pass of a video frame / image.
1. **Simulate Cool-chic decoding** to obtain both the decoded image
:math:`\\hat{\\mathbf{x}}` as a :math:`(B, 3, H, W)` tensor
and its associated rate :math:`\\mathrm{R}(\\hat{\\mathbf{x}})` as
as :math:`(N)` tensor`, where :math:`N` is the number of latent
pixels. The rate is given in bits.
2. **Simulate the saving of the image to a file (Optional)**.
*Only if the model has been set in test mode* e.g.
``self.set_to_eval()`` . Take into account that
:math:`\\hat{\\mathbf{x}}` is a float Tensor, which is
gonna be saved as integer values in a file.
.. math::
\\hat{\\mathbf{x}}_{saved} = \\mathtt{round}(\Delta_q \\
\\hat{\\mathbf{x}}) / \\Delta_q, \\text{ with }
\\Delta_q = 2^{bitdepth} - 1
3. **Downscale to YUV 420 (Optional)**. *Only if the required output
format is YUV420*. The current output is a dense Tensor. Downscale
the last two channels to obtain a YUV420-like representation. This
is done with a nearest neighbor downsampling.
4. **Clamp the output** to be in :math:`[0, 1]`.
Args:
reference_frames: List of tensors representing the reference
frames. Can be set to None if no reference frame is available.
Default to None.
quantizer_noise_type: Defaults to ``"kumaraswamy"``.
quantizer_type: Defaults to ``"softround"``.
soft_round_temperature: Soft round temperature.
This is used for softround modes as well as the
ste mode to simulate the derivative in the backward.
Defaults to 0.3.
noise_parameter: noise distribution parameter. Defaults to 1.0.
AC_MAX_VAL: If different from -1, clamp the value to be in
:math:`[-AC\\_MAX\\_VAL; AC\\_MAX\\_VAL + 1]` to write the actual bitstream.
Defaults to -1.
flag_additional_outputs: True to fill
``CoolChicEncoderOutput['additional_data']`` with many different
quantities which can be used to analyze Cool-chic behavior.
Defaults to False.
Returns:
Output of the FrameEncoder for the forward pass.
"""
# Common parameters for all cool-chic encoders
cc_forward_param = {
"quantizer_noise_type": quantizer_noise_type,
"quantizer_type": quantizer_type,
"soft_round_temperature": soft_round_temperature,
"noise_parameter": noise_parameter,
"AC_MAX_VAL": AC_MAX_VAL,
"flag_additional_outputs": flag_additional_outputs,
}
cc_enc_out = {
cc_name: cc_enc(**cc_forward_param)
for cc_name, cc_enc in self.coolchic_enc.items()
}
# Get the rate of each cool-chic encoder
rate = {
cc_name: cc_enc_out_i.get("rate")
for cc_name, cc_enc_out_i in cc_enc_out.items()
}
if self.frame_type == "I":
decoded_image = cc_enc_out["residue"].get("raw_out")
elif self.frame_type in ["P", "B"]:
residue = cc_enc_out["residue"].get("raw_out")[:, :3, :, :]
alpha = torch.clamp(
cc_enc_out["residue"].get("raw_out")[:, 3:4, :, :] + 0.5, 0.0, 1.0
)
flow_1 = cc_enc_out["motion"].get("raw_out")[:, 0:2, :, :]
# Apply each global flow on each reference.
# Upsample the global flow beforehand to obtain a constant [1, 2, H, W] optical flow.
shifted_ref = []
for ref_i, global_flow_i in zip(reference_frames, [self.global_flow_1, self.global_flow_2]):
ups_global_flow_i = F.interpolate(global_flow_i, size=ref_i.size()[-2:], mode="nearest")
shifted_ref.append(warp_fn(ref_i, ups_global_flow_i))
if self.frame_type == "P":
pred = warp_fn(shifted_ref[0], flow_1)
elif self.frame_type == "B":
flow_2 = cc_enc_out["motion"].get("raw_out")[:, 2:4, :, :]
beta = torch.clamp(
cc_enc_out["motion"].get("raw_out")[:, 4:5, :, :] + 0.5, 0.0, 1.0
)
pred = beta * warp_fn(shifted_ref[0], flow_1) \
+ (1 - beta) * warp_fn( shifted_ref[1], flow_2)
decoded_image = alpha * pred + residue
# Clamp decoded image & down sample YUV channel if needed
if not self.training:
max_dynamic = 2 ** (self.bitdepth) - 1
decoded_image = torch.round(decoded_image * max_dynamic) / max_dynamic
if self.frame_data_type == "yuv420":
decoded_image = convert_444_to_420(decoded_image)
decoded_image = yuv_dict_clamp(decoded_image, min_val=0.0, max_val=1.0)
elif self.frame_data_type != "flow":
decoded_image = torch.clamp(decoded_image, 0.0, 1.0)
additional_data = {}
if flag_additional_outputs:
# Browse all the cool-chic output to get their additional data
for cc_name, cc_enc_out_i in cc_enc_out.items():
additional_data.update(
{
# Append the cc_name (e.g. residue) in front of the key
f"{cc_name}{k}": v
for k, v in cc_enc_out_i.get("additional_data").items()
}
)
# Also add the residue, flow, pred and beta
if self.frame_type in ["P", "B"]:
additional_data["residue"] = residue
additional_data["alpha"] = alpha
additional_data["flow_1"] = flow_1
additional_data["pred"] = pred
if self.frame_type == "B":
additional_data["flow_2"] = flow_2
additional_data["beta"] = beta
results = FrameEncoderOutput(
decoded_image=decoded_image,
rate=rate,
additional_data=additional_data,
)
return results
# ------- Getter / Setter and Initializer
[docs]
def get_param(self) -> OrderedDict[NAME_COOLCHIC_ENC, Tensor]:
"""Return **a copy** of the weights and biases inside the module.
Returns:
OrderedDict[NAME_COOLCHIC_ENC, Tensor]: A copy of all weights & biases in the module.
"""
param = OrderedDict({})
for cc_name, cc_enc in self.coolchic_enc.items():
param.update(
{
f"coolchic_enc.{cc_name}.{k}": v
for k, v in cc_enc.get_param().items()
}
)
return param
[docs]
def set_param(self, param: OrderedDict[NAME_COOLCHIC_ENC, Tensor]):
"""Replace the current parameters of the module with param.
Args:
param (OrderedDict[NAME_COOLCHIC_ENC, Tensor]): Parameters to be set.
"""
self.load_state_dict(param)
[docs]
def reinitialize_parameters(self) -> None:
"""Reinitialize in place the different parameters of a FrameEncoder."""
for _, cc_enc in self.coolchic_enc.items():
print("CHECK THAT I DO RE-INIT THE PARAM!")
cc_enc.reinitialize_parameters()
def _store_full_precision_param(self) -> None:
"""For all the coolchic_enc, store their current parameters inside
self.full_precision_param.
This function checks that there is no self.nn_q_step and
self.nn_expgol_cnt already saved. This would mean that we no longer
have full precision parameters but quantized ones.
"""
for _, cc_enc in self.coolchic_enc.items():
cc_enc._store_full_precision_param()
[docs]
def set_to_train(self) -> None:
"""Set the current model to training mode, in place. This only
affects the quantization.
"""
self = self.train()
for _, cc_enc in self.coolchic_enc.items():
cc_enc.train()
[docs]
def set_global_flow(self, global_flow_1: Tensor, global_flow_2: Tensor) -> None:
"""Set the value of the global flows.
The global flows are 2-element tensors. The first one is the horizontal
displacement and the second one the vertical displacement.
Args:
global_flow_1 (Tensor): Value of global flow for reference 1. Must have 2 elements.
global_flow_2 (Tensor): Value of global flow for reference 2. Must have 2 elements.
"""
assert global_flow_1.numel() == 2, (
f"global_flow_1 must have 2 parameters. Found {global_flow_1.numel()} "
" parameters."
)
assert global_flow_2.numel() == 2, (
f"global_flow_2 must have 2 parameters. Found {global_flow_2.numel()} "
" parameters."
)
self.global_flow_1 = global_flow_1.view(self.global_flow_1.size())
self.global_flow_2 = global_flow_2.view(self.global_flow_2.size())
[docs]
def get_network_rate(self) -> Tuple[Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic], int]:
"""Return the rate (in bits) associated to the parameters
(weights and biases) of the different modules
Returns:
Tuple[Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic], int]: The rate (in bits)
associated with the weights and biases of each module of each
cool-chic decoder. Also return the overall rate in bits.
"""
detailed_rate_bit = {}
total_rate_bit = 0.0
for cc_name, cc_enc in self.coolchic_enc.items():
detailed_rate_bit[cc_name], sum_rate = cc_enc.get_network_rate()
total_rate_bit += sum_rate
return detailed_rate_bit, total_rate_bit
[docs]
def get_network_quantization_step(
self,
) -> Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]:
"""Return the quantization step associated to the parameters (weights
and biases) of the different modules of each cool-chic decoder. Those
quantization can be ``None`` if the model has not yet been quantized.
E.g. {"residue": {"arm": 4, "upsampling": 12, "synthesis": 1}}
Returns:
Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]: The quantization step
associated with the weights and biases of each module of each
cool-chic decoder.
"""
q_step = {}
for cc_name, cc_enc in self.coolchic_enc.items():
q_step[cc_name] = cc_enc.get_network_quantization_step()
return q_step
[docs]
def get_network_expgol_count(self) -> Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]:
"""Return the Exp-Golomb count parameter associated to the parameters
(weights and biases) of the different modules of each cool-chic decoder.
Those exp-golomb param can be ``None`` if the model has not yet
been quantized.
E.g. {"residue": {"arm": 4, "upsampling": 12, "synthesis": 1}}
Returns:
Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic]: The exp-golomb count
parameter associated with the weights and biases of each module of
each cool-chic decoder.
"""
expgol_cnt = {}
for cc_name, cc_enc in self.coolchic_enc.items():
expgol_cnt[cc_name] = cc_enc.get_network_expgol_count()
return expgol_cnt
[docs]
def get_total_mac_per_pixel(self) -> float:
"""Count the number of Multiplication-Accumulation (MAC) per decoded pixel
for this model.
Returns:
float: number of floating point operations per decoded pixel.
"""
mac_per_pixel = 0
for cc_name, cc_enc in self.coolchic_enc.items():
mac_per_pixel += cc_enc.get_total_mac_per_pixel()
return mac_per_pixel
[docs]
def set_to_eval(self) -> None:
"""Set the current model to test mode, in place. This only
affects the quantization.
"""
self = self.eval()
for _, cc_enc in self.coolchic_enc.items():
cc_enc.eval()
[docs]
def to_device(self, device: POSSIBLE_DEVICE) -> None:
"""Push a model to a given device.
Args:
device: The device on which the model should run.
"""
assert device in typing.get_args(
POSSIBLE_DEVICE
), f"Unknown device {device}, should be in {typing.get_args(POSSIBLE_DEVICE)}"
self = self.to(device)
for _, cc_enc in self.coolchic_enc.items():
cc_enc.to_device(device)
[docs]
def save(
self,
path_file: str,
frame_encoder_manager: Optional[FrameEncoderManager] = None,
) -> None:
"""Save the FrameEncoder into a bytes buffer and return it.
Optionally save a frame_encoder_manager alongside the current frame
encoder to keep track of the training time, record loss etc.
Args:
path_file: Where to save the FrameEncoder
frame_encoder_manager: Contains (among other things) the rate
constraint :math:`\\lambda` and description of the
warm-up preset. It is also used to track the total encoding time
and encoding iterations.
Returns:
Bytes representing the saved coolchic model
"""
data_to_save = {
"bitdepth": self.bitdepth,
"frame_type": self.frame_type,
"frame_data_type": self.frame_data_type,
"index_references": self.index_references,
"frame_display_index": self.frame_display_index,
# Name of the different cool-chic encoder
"keys_cc_enc": list(self.coolchic_enc.keys()),
"global_flow_1": self.global_flow_1,
"global_flow_2": self.global_flow_2,
}
for cc_name, cc_enc in self.coolchic_enc.items():
data_to_save[f"{cc_name}"] = cc_enc.get_param()
data_to_save[f"{cc_name}_nn_q_step"] = (
cc_enc.get_network_quantization_step()
)
data_to_save[f"{cc_name}_nn_expgol_cnt"] = cc_enc.get_network_expgol_count()
data_to_save[f"{cc_name}_param"] = self.coolchic_enc_param[cc_name]
if cc_enc.full_precision_param is not None:
data_to_save[f"{cc_name}_full_precision_param"] = (
cc_enc.full_precision_param
)
if frame_encoder_manager is not None:
data_to_save["frame_encoder_manager"] = frame_encoder_manager
torch.save(data_to_save, path_file)
[docs]
def pretty_string(self, print_detailed_archi: bool = False) -> str:
"""Get a pretty string representing the architectures of
the different ``CoolChicEncoder`` composing the current ``FrameEncoder``.
Args:
print_detailed_archi: True to print the detailed decoder architecture
Returns:
str: a pretty string ready to be printed out
"""
s = ""
for name, cc_enc in self.coolchic_enc.items():
total_mac_per_pix = cc_enc.get_total_mac_per_pixel()
title = (
"\n\n"
f"{name} decoder: {total_mac_per_pix:5.0f} MAC / pixel"
"\n"
f"{'-' * len(name)}---------------------------"
"\n"
)
s += title
s += cc_enc.pretty_string(print_detailed_archi=print_detailed_archi) + "\n"
return s
[docs]
def pretty_string_param(self) -> str:
"""Get a pretty string representing the parameters of
the different ``CoolChicEncoderParameters`` parameterising the current
``FrameEncoder``
"""
s = ""
for name, cc_enc_param in self.coolchic_enc_param.items():
title = (
"\n\n"
+ center_str(f"{name} parameters")
+ "\n"
+ center_str(f"{'-' * len(name)})-----------")
+ "\n\n"
)
s += title
s += cc_enc_param.pretty_string() + "\n"
return s
[docs]
def load_frame_encoder(
path_file: str,
) -> Tuple[FrameEncoder, Optional[FrameEncoderManager]]:
"""From already loaded raw bytes, load & return a CoolChicEncoder
Args:
path_file: Path of the FrameEncoder to be loaded
Returns:
Tuple with a FrameEncoder loaded by the function and an optional
FrameEncoderManager
"""
loaded_data = torch.load(path_file, map_location="cpu", weights_only=False)
# Something like ["residue", "motion"]
list_cc_name = loaded_data["keys_cc_enc"]
# Load first the CoolChicEncoderParameter of all the Cool-chic encoders
# for the frame
coolchic_enc_param = {}
for cc_name in list_cc_name:
coolchic_enc_param[cc_name] = loaded_data[f"{cc_name}_param"]
# Create a, empty frame encoder from the stored parameters
frame_encoder = FrameEncoder(
coolchic_enc_param=coolchic_enc_param,
frame_type=loaded_data["frame_type"],
frame_data_type=loaded_data["frame_data_type"],
bitdepth=loaded_data["bitdepth"],
index_references=loaded_data["index_references"],
frame_display_index=loaded_data["frame_display_index"],
)
# Load the parameters
for cc_name in list_cc_name:
frame_encoder.coolchic_enc[cc_name].set_param(loaded_data[cc_name])
frame_encoder.coolchic_enc[cc_name].nn_q_step = loaded_data[
f"{cc_name}_nn_q_step"
]
frame_encoder.coolchic_enc[cc_name].nn_expgol_cnt = loaded_data[
f"{cc_name}_nn_expgol_cnt"
]
if f"{cc_name}_full_precision_param" in loaded_data:
frame_encoder.coolchic_enc[cc_name].full_precision_parameter = loaded_data[
f"{cc_name}_full_precision_param"
]
frame_encoder_manager = loaded_data["frame_encoder_manager"]
if "global_flow_1" in loaded_data:
frame_encoder.global_flow_1 = loaded_data["global_flow_1"]
if "global_flow_2" in loaded_data:
frame_encoder.global_flow_2 = loaded_data["global_flow_2"]
return frame_encoder, frame_encoder_manager