Source code for enc.training.test

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md


import copy
import typing
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from enc.component.types import DescriptorCoolChic, DescriptorNN
import torch
import torch.nn.functional as F
from enc.component.frame import FrameEncoder, FrameEncoderOutput, NAME_COOLCHIC_ENC
from enc.io.format.yuv import convert_420_to_444
from enc.training.loss import (
    LossFunctionOutput,
    _compute_mse,
    loss_function,
)
from enc.utils.codingstructure import Frame
from enc.training.manager import FrameEncoderManager
from torch import Tensor


[docs] @dataclass class FrameEncoderLogs(LossFunctionOutput): """Output of the test function i.e. the actual results of the encoding of one frame by the frame encoder. It inherits from LossFunctionOutput, meaning that all attributes of LossFunctionOutput are also attributes of FrameEncoderLogs. A FrameEncoderLogs is thus initialized from a LossFunctionOutput, all attribute of the LossFunctionOutput will be copied as new attributes for the class. This is what is going to be saved to a log file. """ loss_function_output: LossFunctionOutput # All outputs from the loss function, will be copied is __post_init__ frame_encoder_output: FrameEncoderOutput # Output of frame encoder forward original_frame: Frame # Non coded frame detailed_rate_nn: Dict[ str, DescriptorCoolChic ] # Rate for each NN weights & bias [bit] quantization_param_nn: Dict[ str, DescriptorCoolChic ] # Quantization step for each NN weights & bias [ / ] expgol_count_nn: Dict[ str, DescriptorCoolChic ] # Exp-Golomb count parameter for each NN weights & bias [ / ] lmbda: float # Rate constraint in D + lambda * R [ / ] encoding_time_second: float # Duration of the encoding [sec] encoding_iterations_cnt: int # Number of encoding iterations [ / ] mac_decoded_pixel: float = 0.0 # Number of multiplication per decoded pixel # ==================== Not set by the init function ===================== # # Everything here is derived from frame_encoder_output and original_frame # ----- CoolChicEncoder outputs # Spatial distribution of the rate, obtained by summing the rate of the different features # for each spatial location (in bit). [1, 1, H, W] spatial_rate_bit: Optional[Tensor] = field(init=False) # Feature distribution of the rate, obtained by the summing all the spatial location # of a given feature. [Number of latent resolution] feature_rate_bpp: Optional[List[float]] = field( init=False, default_factory=lambda: [] ) # ----- Computed from loss_function_output.rate_latent_bpp rate_latent_residue_bpp: float = field(init=False) rate_latent_motion_bpp: float = field(init=False) # ----- Inter coding module outputs alpha: Optional[Tensor] = field(init=False, default=None) # Inter / intra switch beta: Optional[Tensor] = field( init=False, default=None ) # Bi-directional prediction weighting residue: Optional[Tensor] = field(init=False, default=None) # Residue flow_1: Optional[Tensor] = field( init=False, default=None ) # Optical flow for the first reference flow_2: Optional[Tensor] = field( init=False, default=None ) # Optical flow for the second reference pred: Optional[Tensor] = field(init=False, default=None) # Temporal prediction masked_pred: Optional[Tensor] = field( init=False, default=None ) # Temporal prediction * alpha # ----- Compute prediction performance alpha_mean: Optional[float] = field(init=False, default=None) # Mean value of alpha beta_mean: Optional[float] = field(init=False, default=None) # Mean value of beta pred_psnr_db: Optional[float] = field( init=False, default=None ) # PSNR of the prediction dummy_pred_psnr_db: Optional[float] = field( init=False, default=None ) # PSNR of a prediction if we had no motion # ----- Miscellaneous quantities recovered from self.frame img_size: Tuple[int, int] = field(init=False) # [Height, Width] n_pixels: int = field(init=False) # Height x Width display_order: int = field( init=False ) # Index of the current frame in display order coding_order: int = field(init=False) # Index of the current frame in coding order frame_offset: int = field( init=False, default=0 ) # Skip the first <frame_offset> frames of the video seq_name: str = field(init=False) # Name of the sequence to which this frame belong # ----- Neural network rate in bit per pixels detailed_rate_nn_bpp: DescriptorCoolChic = field( init=False ) # Rate for each NN weights & bias [bpp] def __post_init__(self): # ----- Copy all the attributes of loss_function_output for f in fields(self.loss_function_output): setattr(self, f.name, getattr(self.loss_function_output, f.name)) # ----- Retrieve info from the frame self.img_size = self.original_frame.data.img_size self.n_pixels = self.original_frame.data.n_pixels self.display_order = self.original_frame.display_order self.coding_order = self.original_frame.coding_order self.frame_offset = self.original_frame.frame_offset self.seq_name = self.original_frame.seq_name # ----- Information related to nn net rate / quantization & exp-golomb # # Iterates on all the possible names for a cool-chic encoder. If we find # something, we keep the value, otherwise we put a O so that all fields # are always filled. # Divide each entry of self.detailed_rate_nn by the number of pixel self.detailed_rate_nn_bpp: Dict[NAME_COOLCHIC_ENC, DescriptorCoolChic] = {} # Loop on all possible cool-chic encoder name for cc_name in typing.get_args(NAME_COOLCHIC_ENC): # Loop on all neural networks composing a Cool-chic Encoder self.detailed_rate_nn_bpp[cc_name] = {} for nn_name in [x.name for x in fields(DescriptorCoolChic)]: self.detailed_rate_nn_bpp[cc_name][nn_name] = {} # Loop on weight and biases for weight_or_bias in [x.name for x in fields(DescriptorNN)]: # Convert the value to bpp if we have one if cc_name in self.detailed_rate_nn: self.detailed_rate_nn_bpp[cc_name][nn_name][weight_or_bias] = ( self.detailed_rate_nn[cc_name][nn_name][weight_or_bias] / self.n_pixels ) # Else we set it to 0 else: self.detailed_rate_nn_bpp[cc_name][nn_name][weight_or_bias] = 0 # ----- Get sum of latent rate for each cool-chic encoder for cc_name in typing.get_args(NAME_COOLCHIC_ENC): setattr( self, f"rate_latent_{cc_name}_bpp", self.rate_latent_bpp.get(cc_name, 0) ) # ----- Copy all the quantities present in InterCodingModuleOutput quantities_from_inter_coding = [ "alpha", "beta", "residue", "flow_1", "flow_2", "pred", "masked_pred", ] for k in quantities_from_inter_coding: if k in self.frame_encoder_output.additional_data: setattr(self, k, self.frame_encoder_output.additional_data.get(k)) # ----- Compute several additional quantities if self.alpha is not None: self.alpha_mean = self.alpha.mean().item() # No alpha for intra frame else: self.alpha_mean = 0 if self.beta is not None: self.beta_mean = self.beta.mean().item() # No beta for I & P frames else: self.beta_mean = 0 if self.pred is not None: # Transform the reference to yuv 444 if needed if self.original_frame.data.frame_data_type == "yuv420": original_frame_data = convert_420_to_444(self.original_frame.data.data) else: original_frame_data = self.original_frame.data.data self.pred_psnr_db = -10 * torch.log10( _compute_mse(self.pred, original_frame_data) + 1e-10 ) # Compute the dumbest prediction i.e. the average of the reference dummy_pred = torch.zeros_like(self.pred) for ref in self.original_frame.refs_data: dummy_pred += ref.data dummy_pred /= len(self.original_frame.refs_data) self.dummy_pred_psnr_db = -10 * torch.log10( _compute_mse(dummy_pred, original_frame_data) + 1e-10 ) # No prediction for intra frame. else: self.pred_psnr_db = 0 self.dummy_pred_psnr_db = 0 # ------ Retrieve things related to the CoolChicEncoder from the additional # ------ outputs of the frame encoder. if "detailed_rate_bit" in self.frame_encoder_output.additional_data: detailed_rate_bit = self.frame_encoder_output.additional_data.get( "detailed_rate_bit" ) # Sum on the last three dimensions self.feature_rate_bpp = [ x.sum(dim=(-1, -2, -3)) / (self.img_size[0] * self.img_size[1]) for x in detailed_rate_bit ] upscaled_rate = [] for rate in detailed_rate_bit: cur_c, cur_h, cur_w = rate.size()[-3:] # Ignore tensor with no channel if cur_c == 0: continue # Rate is in bit, but since we're going to upsampling the rate values to match # the actual image size, we want to keep the total number of bit consistent. # To do so, we divide the rate by the upsampling ratio. # Example: # 2x2 feature maps with 8 bits for each sample gives a 4x4 visualisation # with 2 bits per sample. This make the total number of bits stay identical rate /= (self.img_size[0] * self.img_size[1]) / (cur_h * cur_w) upscaled_rate.append( F.interpolate(rate, size=self.img_size, mode="nearest") ) upscaled_rate = torch.cat(upscaled_rate, dim=1) self.spatial_rate_bit = upscaled_rate.sum(dim=1, keepdim=True)
[docs] def pretty_string( self, show_col_name: bool = False, mode: Literal["all", "short"] = "all", additional_data: Dict[str, Any] = {}, ) -> str: """Return a pretty string formatting the data within the class. Args: show_col_name (bool, optional): True to also display col name. Defaults to False. mode (str, optional): Either "short" or "all". Defaults to 'all'. Returns: str: The formatted results """ col_name = "" values = "" COL_WIDTH = 10 INTER_COLUMN_SPACE = " " for k in fields(self): if not self._should_be_printed(k.name, mode=mode): continue # ! Deep copying is needed but i don't know why? val = copy.deepcopy(getattr(self, k.name)) if val is None: continue if k.name == "feature_rate_bpp": for i in range(len(val)): col_name += f'{k.name + f"_{str(i).zfill(2)}":<{COL_WIDTH}}{INTER_COLUMN_SPACE}' values += f"{self._format_value(val[i], attribute_name=k.name):<{COL_WIDTH}}{INTER_COLUMN_SPACE}" # This concerns the different neural networks composing the # different cool-chic encoders. elif k.name in [ "detailed_rate_nn_bpp", "quantization_param_nn", "expgol_count_nn", ]: match k.name: case "detailed_rate_nn_bpp": col_name_sufix = "_rate_bpp" case "quantization_param_nn": col_name_sufix = "_q_step" case "expgol_count_nn": col_name_sufix = "_exp_cnt" case _: pass # Loop on all possible cool-chic encoder name for cc_name in typing.get_args(NAME_COOLCHIC_ENC): # Loop on all neural networks composing a Cool-chic Encoder for nn_name in [x.name for x in fields(DescriptorCoolChic)]: for weight_or_bias in [x.name for x in fields(DescriptorNN)]: tmp_col_name = ( cc_name + "_" + nn_name + "_" + weight_or_bias + col_name_sufix ) col_name += ( f"{tmp_col_name:<{COL_WIDTH}}{INTER_COLUMN_SPACE}" ) try: tmp_val = val[cc_name][nn_name][weight_or_bias] except KeyError: tmp_val = 0 tmp_val = self._format_value(tmp_val, attribute_name=k.name) values += f"{tmp_val:<{COL_WIDTH}}{INTER_COLUMN_SPACE}" # Default case else: col_name += f"{self._format_column_name(k.name):<{COL_WIDTH}}{INTER_COLUMN_SPACE}" values += f"{self._format_value(val, attribute_name=k.name):<{COL_WIDTH}}{INTER_COLUMN_SPACE}" for k, v in additional_data.items(): col_name += f"{k:<{COL_WIDTH}}{INTER_COLUMN_SPACE}" values += f"{v:<{COL_WIDTH}}{INTER_COLUMN_SPACE}" if show_col_name: return col_name + "\n" + values else: return values
def _should_be_printed(self, attribute_name: str, mode: str) -> bool: """Return True if the attribute named <attribute_name> should be printed in mode <mode>. Args: attribute_name (str): Candidate attribute to print mode (str): Either "short" or "all" Returns: bool: True if the attribute should be printed, False otherwise """ # Syntax: {'attribute': [printed in mode xxx]} ATTRIBUTES = { # ----- This is printed in every modes "loss": ["short", "all"], "psnr_db": ["short", "all"], "total_rate_bpp": ["short", "all"], "total_rate_latent_bpp": ["short", "all"], "total_rate_nn_bpp": ["short", "all"], "encoding_time_second": ["short", "all"], "encoding_iterations_cnt": ["short", "all"], # ----- This is only printed in mode all "alpha_mean": ["all"], "beta_mean": ["all"], "pred_psnr_db": ["all"], "dummy_pred_psnr_db": ["all"], "display_order": ["all"], "coding_order": ["all"], "frame_offset": ["all"], "lmbda": ["all"], "seq_name": ["all"], "feature_rate_bpp": ["all"], "detailed_rate_nn_bpp": ["all"], "ms_ssim_db": ["all"], "lpips_db": ["all"], "n_pixels": ["all"], "img_size": ["all"], "mac_decoded_pixel": ["all"], # Uncomment to log the quantization step and exp-golomb param # "quantization_param_nn": ["all"], # "expgol_count_nn": ["all"], } for cc_name in typing.get_args(NAME_COOLCHIC_ENC): ATTRIBUTES[f"rate_latent_{cc_name}_bpp"] = ["all"] # Add a few quantities for some particular frame type if self.original_frame.frame_type != "I": ATTRIBUTES["rate_latent_residue_bpp"] = ["all", "short"] ATTRIBUTES["rate_latent_motion_bpp"] = ["all", "short"] ATTRIBUTES["alpha_mean"].append("short") ATTRIBUTES["pred_psnr_db"].append("short") ATTRIBUTES["dummy_pred_psnr_db"].append("short") if self.original_frame.frame_type == "B": ATTRIBUTES["beta_mean"].append("short") if attribute_name not in ATTRIBUTES: return False if mode not in ATTRIBUTES.get(attribute_name): return False return True def _format_value( self, value: Union[str, int, float, Tensor], attribute_name: str = "" ) -> str: if attribute_name == "loss": value *= 1000 if attribute_name == "img_size": value = "x".join([str(tmp) for tmp in value]) # Number of digits after the decimal point DEFAULT_ACCURACY = "6" ACCURACY = { "alpha_mean": "3", "beta_mean": "3", "pred_psnr_db": "3", "dummy_pred_psnr_db": "3", "encoding_time_second": "1", } cur_accuracy = ACCURACY.get(attribute_name, DEFAULT_ACCURACY) if isinstance(value, str): return value elif isinstance(value, int): return str(value) elif isinstance(value, float): return f"{value:.{cur_accuracy}f}" elif isinstance(value, Tensor): return f"{value.item():.{cur_accuracy}f}" def _format_column_name(self, col_name: str) -> str: # Syntax: {'long_name': 'short_name'} LONG_TO_SHORT = { "total_rate_latent_bpp": "latent_bpp", "total_rate_nn_bpp": "nn_bpp", "total_rate_bpp": "rate_bpp", "encoding_time_second": "time_sec", "encoding_iterations_cnt": "itr", "alpha_mean": "alpha", "beta_mean": "beta", "pred_psnr_db": "pred_db", "dummy_pred_psnr_db": "dummy_pred", "rate_latent_residue_bpp": "residue_bpp", "rate_latent_motion_bpp": "motion_bpp", } if col_name not in LONG_TO_SHORT: return col_name else: return LONG_TO_SHORT.get(col_name)
[docs] @torch.no_grad() def test( frame_encoder: FrameEncoder, frame: Frame, frame_encoder_manager: FrameEncoderManager, ) -> FrameEncoderLogs: """Evaluate the performance of a ``FrameEncoder`` when encoding a ``Frame``. Args: frame_encoder: FrameEncoder to be evaluated. frame: The original frame to compress. It provides both the target (original non compressed frame) as well as the reference(s) (list of already decoded images) lambda: Rate constraint lambda. Only requires to compute a meaningfull loss :math:`\\mathcal{L} = \\mathrm{D} + \\lambda \\mathrm{R}` frame_encoder_manager: Contains (among other things) the rate constraint :math:`\\lambda`. It is also used to track the total encoding time and encoding iterations. Returns: Many logs on the performance of the FrameEncoder. See doc of ``FrameEncoderLogs``. """ # 1. Get the rate associated to the network ----------------------------- # # The rate associated with the network is zero if it has not been quantize # before calling the test functions detailed_rate_nn, total_rate_nn_bit = frame_encoder.get_network_rate() # 2. Measure performance ------------------------------------------------ # frame_encoder.set_to_eval() # flag_additional_outputs set to True to obtain more output frame_encoder_out = frame_encoder.forward( reference_frames=[ref_i.data for ref_i in frame.refs_data], quantizer_noise_type="none", quantizer_type="hardround", AC_MAX_VAL=-1, flag_additional_outputs=True, ) loss_fn_output = loss_function( frame_encoder_out.decoded_image, frame_encoder_out.rate, frame.data.data, lmbda=frame_encoder_manager.lmbda, total_rate_nn_bit=total_rate_nn_bit, compute_logs=True, ) encoder_logs = FrameEncoderLogs( loss_function_output=loss_fn_output, frame_encoder_output=frame_encoder_out, original_frame=frame, detailed_rate_nn=detailed_rate_nn, quantization_param_nn=frame_encoder.get_network_quantization_step(), expgol_count_nn=frame_encoder.get_network_expgol_count(), encoding_time_second=frame_encoder_manager.total_training_time_sec, encoding_iterations_cnt=frame_encoder_manager.iterations_counter, mac_decoded_pixel=frame_encoder.get_total_mac_per_pixel(), lmbda=frame_encoder_manager.lmbda, ) # 3. Restore training mode ---------------------------------------------- # frame_encoder.set_to_train() return encoder_logs