Source code for enc.component.frame

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md


"""A frame encoder is composed of a CoolChicEncoder and a InterCodingModule."""

import typing
from dataclasses import dataclass, field
from io import BytesIO
from typing import Any, Dict, List, Optional, OrderedDict, Union

import torch
from enc.component.coolchic import (
    CoolChicEncoder,
    CoolChicEncoderParameter,
)
from enc.component.core.quantizer import (
    POSSIBLE_QUANTIZATION_NOISE_TYPE,
    POSSIBLE_QUANTIZER_TYPE,
)
from enc.component.intercoding import InterCodingModule
from torch import Tensor, nn
from enc.utils.codingstructure import (
    FRAME_DATA_TYPE,
    FRAME_TYPE,
    POSSIBLE_BITDEPTH,
    DictTensorYUV,
    convert_444_to_420,
)
from enc.utils.misc import POSSIBLE_DEVICE
from enc.utils.yuv import yuv_dict_clamp


[docs] @dataclass class FrameEncoderOutput: """Dataclass representing the output of FrameEncoder forward.""" # Either a [B, 3, H, W] tensor representing the decoded image or a # dictionary with the following keys for yuv420: # { # 'y': [B, 1, H, W], # 'u': [B, 1, H / 2, W / 2], # 'v': [B, 1, H / 2, W / 2], # } # Note: yuv444 data are represented as a simple [B, 3, H, W] tensor decoded_image: Union[Tensor, DictTensorYUV] rate: Tensor # Rate associated to each latent [total_latent_value] # Any other data required to compute some logs, stored inside a dictionary additional_data: Dict[str, Any] = field(default_factory=lambda: {})
[docs] class FrameEncoder(nn.Module): """A ``FrameEncoder`` is the object containing everything required to encode a video frame or an image. It is composed of a ``CoolChicEncoder`` and an ``ÌnterCodingModule``. """
[docs] def __init__( self, coolchic_encoder_param: CoolChicEncoderParameter, frame_type: FRAME_TYPE = "I", frame_data_type: FRAME_DATA_TYPE = "rgb", bitdepth: POSSIBLE_BITDEPTH = 8, ): """ Args: coolchic_encoder_param: Parameters for the underlying CoolChicEncoder frame_type: More info in :doc:`coding_structure.py <../utils/codingstructure>`. Defaults to "I". frame_data_type: More info in :doc:`coding_structure.py <../utils/codingstructure>`. Defaults to "rgb" bitdepth: More info in :doc:`coding_structure.py <../utils/codingstructure>`. Defaults to 8. """ super().__init__() # ----- Copy the parameters self.coolchic_encoder_param = coolchic_encoder_param self.frame_type = frame_type self.frame_data_type = frame_data_type self.bitdepth = bitdepth # "Core" CoolChic codec. This will be reset by the warm-up function self.coolchic_encoder = CoolChicEncoder(self.coolchic_encoder_param) self.inter_coding_module = InterCodingModule(self.frame_type)
[docs] def forward( self, reference_frames: Optional[List[Tensor]] = None, quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy", quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround", soft_round_temperature: Optional[float] = 0.3, noise_parameter: Optional[float] = 1.0, AC_MAX_VAL: int = -1, flag_additional_outputs: bool = False, ) -> FrameEncoderOutput: """Perform the entire forward pass of a video frame / image. 1. **Simulate Cool-chic decoding** to obtain both the decoded image :math:`\\hat{\\mathbf{x}}` as a :math:`(B, 3, H, W)` tensor and its associated rate :math:`\\mathrm{R}(\\hat{\\mathbf{x}})` as as :math:`(N)` tensor`, where :math:`N` is the number of latent pixels. The rate is given in bits. 2. **Simulate the saving of the image to a file (Optional)**. *Only if the model has been set in test mode* e.g. ``self.set_to_eval()`` . Take into account that :math:`\\hat{\\mathbf{x}}` is a float Tensor, which is gonna be saved as integer values in a file. .. math:: \\hat{\\mathbf{x}}_{saved} = \\mathtt{round}(\Delta_q \\ \\hat{\\mathbf{x}}) / \\Delta_q, \\text{ with } \\Delta_q = 2^{bitdepth} - 1 3. **Downscale to YUV 420 (Optional)**. *Only if the required output format is YUV420*. The current output is a dense Tensor. Downscale the last two channels to obtain a YUV420-like representation. This is done with a nearest neighbor downsampling. 4. **Clamp the output** to be in :math:`[0, 1]`. Args: reference_frames: List of tensors representing the reference frames. Can be set to None if no reference frame is available. Default to None. quantizer_noise_type: Defaults to ``"kumaraswamy"``. quantizer_type: Defaults to ``"softround"``. soft_round_temperature: Soft round temperature. This is used for softround modes as well as the ste mode to simulate the derivative in the backward. Defaults to 0.3. noise_parameter: noise distribution parameter. Defaults to 1.0. AC_MAX_VAL: If different from -1, clamp the value to be in :math:`[-AC\\_MAX\\_VAL; AC\\_MAX\\_VAL + 1]` to write the actual bitstream. Defaults to -1. flag_additional_outputs: True to fill ``CoolChicEncoderOutput['additional_data']`` with many different quantities which can be used to analyze Cool-chic behavior. Defaults to False. Returns: Output of the FrameEncoder for the forward pass. """ # CoolChic forward pass coolchic_encoder_output = self.coolchic_encoder.forward( quantizer_noise_type=quantizer_noise_type, quantizer_type=quantizer_type, soft_round_temperature=soft_round_temperature, noise_parameter=noise_parameter, AC_MAX_VAL=AC_MAX_VAL, flag_additional_outputs=flag_additional_outputs, ) # Combine CoolChic output and reference frames through the inter coding modules inter_coding_output = self.inter_coding_module.forward( coolchic_output=coolchic_encoder_output, references=[] if reference_frames is None else reference_frames, flag_additional_outputs=flag_additional_outputs, ) # Clamp decoded image & down sample YUV channel if needed if self.training: decoded_image = inter_coding_output.decoded_image else: max_dynamic = 2 ** (self.bitdepth) - 1 decoded_image = ( torch.round(inter_coding_output.decoded_image * max_dynamic) / max_dynamic ) if self.frame_data_type == "yuv420": decoded_image = convert_444_to_420(decoded_image) decoded_image = yuv_dict_clamp(decoded_image, min_val=0.0, max_val=1.0) else: decoded_image = torch.clamp(decoded_image, 0.0, 1.0) additional_data = {} if flag_additional_outputs: additional_data.update(coolchic_encoder_output.get("additional_data")) additional_data.update(inter_coding_output.additional_data) results = FrameEncoderOutput( decoded_image=decoded_image, rate=coolchic_encoder_output.get("rate"), additional_data=additional_data, ) return results
# ------- Getter / Setter and Initializer
[docs] def get_param(self) -> OrderedDict[str, Tensor]: """Return **a copy** of the weights and biases inside the module. Returns: OrderedDict[str, Tensor]: A copy of all weights & biases in the module. """ param = OrderedDict({}) param.update( { f"coolchic_encoder.{k}": v for k, v in self.coolchic_encoder.get_param().items() } ) return param
[docs] def set_param(self, param: OrderedDict[str, Tensor]): """Replace the current parameters of the module with param. Args: param (OrderedDict[str, Tensor]): Parameters to be set. """ self.load_state_dict(param)
[docs] def reinitialize_parameters(self) -> None: """Reinitialize in place the different parameters of a FrameEncoder.""" self.coolchic_encoder.reinitialize_parameters()
[docs] def set_to_train(self) -> None: """Set the current model to training mode, in place. This only affects the quantization. """ self = self.train() self.coolchic_encoder = self.coolchic_encoder.train() self.inter_coding_module = self.inter_coding_module.train()
[docs] def set_to_eval(self) -> None: """Set the current model to test mode, in place. This only affects the quantization. """ self = self.eval() self.coolchic_encoder = self.coolchic_encoder.eval() self.inter_coding_module = self.inter_coding_module.eval()
[docs] def to_device(self, device: POSSIBLE_DEVICE) -> None: """Push a model to a given device. Args: device: The device on which the model should run. """ assert device in typing.get_args( POSSIBLE_DEVICE ), f"Unknown device {device}, should be in {typing.get_args(POSSIBLE_DEVICE)}" self = self.to(device) self.coolchic_encoder.to_device(device)
[docs] def save(self) -> BytesIO: """Save the FrameEncoder into a bytes buffer and return it. Returns: Bytes representing the saved coolchic model """ buffer = BytesIO() data_to_save = { "bitdepth": self.bitdepth, "frame_type": self.frame_type, "frame_data_type": self.frame_data_type, "coolchic_encoder_param": self.coolchic_encoder_param, "coolchic_encoder": self.coolchic_encoder.get_param(), "coolchic_nn_q_step": self.coolchic_encoder.get_network_quantization_step(), "coolchic_nn_expgol_cnt": self.coolchic_encoder.get_network_expgol_count(), } if self.coolchic_encoder.full_precision_param is not None: data_to_save["coolchic_full_precision_param"] = self.coolchic_encoder.full_precision_param torch.save(data_to_save, buffer) # for k, v in self.coolchic_encoder.get_param().items(): # print(f"{k:>30}: {v.abs().sum().item()}") return buffer
[docs] def load_frame_encoder(raw_bytes: BytesIO) -> FrameEncoder: """From already loaded raw bytes, load & return a CoolChicEncoder Args: raw_bytes: Already loaded raw bytes from which we'll instantiate the CoolChicEncoder. Returns: Frame encoder loaded by the function """ # Reset the stream position to the beginning of the BytesIO object & load it raw_bytes.seek(0) loaded_data = torch.load(raw_bytes, map_location="cpu") # Create a frame encoder from the stored parameters frame_encoder = FrameEncoder( coolchic_encoder_param=loaded_data["coolchic_encoder_param"], frame_type=loaded_data["frame_type"], frame_data_type=loaded_data["frame_data_type"], bitdepth=loaded_data["bitdepth"], ) # Load the different submodules (only one cool-chic for now) frame_encoder.coolchic_encoder.set_param(loaded_data["coolchic_encoder"]) frame_encoder.coolchic_encoder.nn_q_step = loaded_data["coolchic_nn_q_step"] # Check if coolchic_nn_expgol_cnt is present in loaded data for backward # compatibility. Not meant to stay very long. if "coolchic_nn_expgol_cnt" in loaded_data: frame_encoder.coolchic_encoder.nn_expgol_cnt = loaded_data["coolchic_nn_expgol_cnt"] if "coolchic_full_precision_param" in loaded_data: frame_encoder.coolchic_encoder.full_precision_param = loaded_data["coolchic_full_precision_param"] return frame_encoder