Source code for component.core.coolchic

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md

import math
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, OrderedDict, Tuple, TypedDict

import torch
import torch.nn.functional as F
from einops import rearrange
from fvcore.nn import FlopCountAnalysis, flop_count_table
from torch import Tensor, nn

from coolchic.component.core.arm import (
    Arm,
    Ifce,
    _get_mask_size_ctx,
    _get_neighbor,
    _get_non_zero_pixel_ctx_index,
    _laplace_cdf,
)
from coolchic.component.core.noise import CommonGaussianNoiseGenerator
from coolchic.component.core.quantizer import (
    POSSIBLE_QUANTIZATION_NOISE_TYPE,
    POSSIBLE_QUANTIZER_TYPE,
    quantize,
)
from coolchic.component.core.synthesis import Synthesis
from coolchic.component.core.types import DescriptorCoolChic, DescriptorNN
from coolchic.component.core.upsampling import Upsampling, fixed_upsampling

"""A cool-chic encoder is composed of:
    - A set of 2d hierarchical latent grids
    - An auto-regressive probability module + Inter-feature context extractor.
    - An upsampling module
    - A synthesis.

At its core, it is a tool to compress any spatially organized signal, with
one or more features by representing it as a set of 2d entropy coding-friendly
latent grids. After upsampling, these latent grids allows to synthesize the
desired signal.
"""


[docs] @dataclass class CoolChicEncoderParameter: """Dataclass storing the parameters of a ``CoolChicEncoder``. Args: layers_synthesis (List[str]): Describes the architecture of the synthesis transform. See the :doc:`synthesis documentation <synthesis>` for more information. linear_stabiliser_synth (bool): Flag indicating the usage of the linear stabiliser for the synthesis. ups_k_size (int): Upsampling kernel size for the transposed convolutions. See the :doc:`upsampling documentation <upsampling>` for more information. ups_preconcat_k_size (int): Upsampling kernel size for the pre-concatenation convolutions. See the :doc:`upsampling documentation <upsampling>` for more information. ifce_resolution (Optional[Tuple[int, int]]): Lowest and highest base two downsampling of the latent using the IFCEs. E.g., (0, 2) means latents between downsampling 1/2^0 and 1/2^2. Set to None to disable. output_feature_ifce (int): Number of output features of the IFCEs. Ignored if ifce_resolution is None. spatial_context_arm (int): Number of spatial contexts for the ARM. linear_stabiliser_arm (bool): Flag indicating the usage of the linear stabiliser for the ARM n_hidden_layers_arm (int): Number of hidden layers in the ARM. Set to zero for a linear ARM. latent_resolution (Tuple[int, int]): Lowest and highest base two downsampling of the latent grids. E.g., (0, 4) means 5 latent grids from downsampling 1/2^0 to 1/2^4. hyper_latent_resolution (Optional[Tuple[int, int]]): Identical to latent_resolution but for hyperlatent *i.e.,* additional latent grids which are used only for the entropy modeling and not by the synthesis. Set to None to disable flag_common_randomness (bool). Flag indicating the usage of common randomness latent grids, with resolution identical to the latent_resolution parameters. img_size (Tuple[int, int]): Height and width :math:`(H, W)` of the frame to be coded encoder_gain (int): Multiply the latent by this value before quantization. Defaults to 16. final_upsampling_type (Literal["nearest", "bilinear", "bicubic"]). If the resolution of the biggest latent grid is smaller than the input image, upsample it using the specified filter to the image size. """ # ---- Synthesis layers_synthesis: List[str] linear_stabiliser_synth: bool input_feature_synthesis: int = field(init=False) # ---- Upsampling ups_k_size: int ups_preconcat_k_size: int # ---- Entropy model ifce_resolution: Optional[Tuple[int, int]] output_feature_ifce: int spatial_context_arm: int linear_stabiliser_arm: bool n_hidden_layers_arm: int total_context_arm: int = field(init=False) # ---- Latent grids and hyper latent grids latent_resolution: Tuple[int, int] hyperlatent_resolution: Optional[Tuple[int, int]] flag_common_randomness: bool # ---- Others img_size: Tuple[int, int] # If the synthesis output is smaller than img_size i.e. when the highest latent resolution # is not 1/1, there is a final upsampling. final_upsampling_type: Literal["nearest", "bilinear", "bicubic"] encoder_gain: int = 16 # ==================== Not set by the init function ===================== # # Set to true if there is at least one feature of common randomness requested flag_ifce: bool = field(init=False) flag_hyperlatent: bool = field(init=False) cr_latent_resolution: Optional[Tuple[int, int]] = field(init=False) # size_per_latent[i] = Dimension of the i-th latent: (1, 1, H_i, W_i) # size_per_latent[0] is the biggest size_per_latent: List[Tuple[int, int, int, int]] = field(init=False, default_factory=lambda: []) # Same thing but for the common randomness latent # size_per_latent_cr[i] = Dimension of the i-th common randomness latent: (1, 1, H_i, W_i) # size_per_latent_cr[0] is the biggest size_per_latent_cr: List[Tuple[int, int, int, int]] = field( init=False, default_factory=lambda: [] ) # flag_is_hyperlatent[i] = True --> i-th latent is only used for the entropy # decoding and discarded before the upsampling/synthesis flag_is_hyperlatent: List[bool] = field(init=False, default_factory=lambda: []) # input_features_ifce[i] = number of input feature for the ifce associated to the # i-th latent grid. Set to zero if the i-th latent does not have an IFCE associated input_features_ifce: List[int] = field(init=False, default_factory=lambda: []) # Total number of latent transmitted n_latent_grids: int = field(init=False) # ==================== Not set by the init function ===================== # def __post_init__(self): # Order is important. Some parameters set in post_init_latent() are reused # in the initialization of the synthesis or IFCE for instance. self.post_init_latent() self.post_init_arm() self.post_init_common_randomness() self.post_init_synthesis() self.post_init_ifce() def post_init_latent(self) -> None: self.flag_hyperlatent = self.hyperlatent_resolution is not None # Compute all latent spatial dimension and fill the flag_is_hyperlatent list to indicate # which grids are to be discarded before the synthesis if self.flag_hyperlatent: min_downsampling = min(self.latent_resolution + self.hyperlatent_resolution) max_downsampling = max(self.latent_resolution + self.hyperlatent_resolution) else: min_downsampling, max_downsampling = self.latent_resolution for i in range(min_downsampling, max_downsampling + 1): h_grid, w_grid = [int(math.ceil(x / (2**i))) for x in self.img_size] cur_size = (1, 1, h_grid, w_grid) # Add the grid if it falls inside the required latent resolution if self.latent_resolution[0] <= i <= self.latent_resolution[1]: self.size_per_latent.append(cur_size) self.flag_is_hyperlatent.append(False) if self.flag_hyperlatent: # Add the grid if it falls inside the required hyperlatent resolution if self.hyperlatent_resolution[0] <= i <= self.hyperlatent_resolution[1]: self.size_per_latent.append(cur_size) self.flag_is_hyperlatent.append(True) self.n_latent_grids = len(self.size_per_latent) if self.flag_common_randomness: for i in range(self.latent_resolution[0], self.latent_resolution[1] + 1): h_grid, w_grid = [int(math.ceil(x / (2**i))) for x in self.img_size] cur_size = (1, 1, h_grid, w_grid) self.size_per_latent_cr.append(cur_size) def post_init_arm(self) -> None: self.total_context_arm = self.spatial_context_arm + self.output_feature_ifce def post_init_common_randomness(self) -> None: if self.flag_common_randomness: # Common randomness has the same resolution than the latent variables self.cr_latent_resolution = (self.latent_resolution[0], self.latent_resolution[1]) else: self.cr_latent_resolution = None def post_init_synthesis(self) -> None: # latent_resolution = (0, 6) --> 7 = 6 - 0 + 1 latent features self.input_feature_synthesis = self.latent_resolution[1] - self.latent_resolution[0] + 1 if self.flag_common_randomness: self.input_feature_synthesis *= 2 def post_init_ifce(self) -> None: self.flag_ifce = self.ifce_resolution is not None for i, size_latent_i in enumerate(self.size_per_latent): # We assume identical downsampling ratio for height and width downsampling_ratio = int(math.ceil(math.log2(self.img_size[0] / size_latent_i[-2]))) if not self.flag_ifce: self.input_features_ifce.append(0) # We do have an IFCE elif self.ifce_resolution[0] <= downsampling_ratio <= self.ifce_resolution[1]: # How many latents are already decoded when we're decoding latent_i # max(X, 1) because we always have at least one input feature. Padding if need be self.input_features_ifce.append(max(self.n_latent_grids - 1 - i, 1)) else: self.input_features_ifce.append(0)
[docs] def pretty_string(self) -> str: """Return a pretty string presenting the CoolChicEncoderParameter.""" ATTRIBUTE_WIDTH = 35 VALUE_WIDTH = 80 s = "" for k in fields(self): v = getattr(self, k.name) # Print only height and width if k.name.startswith("size_per_latent"): v = [v_i[-2:] for v_i in v] s += f"{k.name:<{ATTRIBUTE_WIDTH}}: {str(v):<{VALUE_WIDTH}}\n" s += "\n" return s
class CoolChicEncoderOutput(TypedDict): """``TypedDict`` representing the output of CoolChicEncoder forward. Args: raw_out (Tensor): Output of the synthesis :math:`([B, C, H, W])`. rate (Tensor): rate associated to each latent (in bits). Shape is :math:`(N)`, with :math:`N` the total number of latent variables. additional_data (Dict[str, Any]): Any other data required to compute some logs, stored inside a dictionary """ raw_out: Tensor rate: Tensor additional_data: Dict[str, Any]
[docs] class CoolChicEncoder(nn.Module): """CoolChicEncoder for a single frame."""
[docs] def __init__(self, param: CoolChicEncoderParameter): """Instantiate a cool-chic encoder for one frame. Args: param (CoolChicEncoderParameter): Architecture of the `CoolChicEncoder`. See the documentation of `CoolChicEncoderParameter` for more information """ super().__init__() # Everything is stored inside param self.param = param assert self.param.img_size is not None, ( "You are trying to instantiate a CoolChicEncoder from a " "CoolChicEncoderParameter with a field img_size set to None. Use " "the function coolchic_encoder_param.set_img_size((H, W)) before " "instantiating the CoolChicEncoder." ) # ================== Latent related stuff ================= # # Encoder-side latent gain applied prior to quantization, one per feature self.encoder_gains = param.encoder_gain self.latent_grids = instantiate_latent_grids_from_cc_param(self.param) self.initialize_latent_grids() self.cr = instantiate_common_randomness_from_cc_param(self.param) # ================== Latent related stuff ================= # # ===================== ARM related stuff ==================== # # All context pixels are centered in a mask_size x mask_size window centered # on the pixel to be entropy coded self.mask_size = _get_mask_size_ctx(self.param.spatial_context_arm) # 1D tensor containing the indices of the selected context pixels. # register_buffer for automatic device management. We set persistent to false # to simply use the "automatically move to device" function, without # considering non_zero_pixel_ctx_index as a parameters (i.e. returned # by self.parameters()) self.register_buffer( "non_zero_pixel_ctx_index", _get_non_zero_pixel_ctx_index(self.param.spatial_context_arm), persistent=False, ) self.arm = instantiate_arm_from_cc_param(self.param) self.synthesis = instantiate_syn_from_cc_param(self.param) self.upsampling = instantiate_ups_from_cc_param(self.param) self.ifce = instantiate_ifce_from_cc_param(self.param) # ===================== ARM related stuff ==================== # # Something like ['arm', 'synthesis', 'upsampling', 'ifce'] self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)] if self.ifce is None: self.modules_to_send.remove("ifce") # ======================== Monitoring ======================== # # Pretty string representing the decoder complexity self.flops_str = "" # Total number of multiplications to decode the image self.total_flops = 0.0 self.flops_per_module = {k: 0 for k in self.modules_to_send} # Fill the two attributes aboves self.get_flops() # ======================== Monitoring ======================== # # Track the quantization step of each neural network, None if the # module is not yet quantized. Default initialization of DescriptorCoolChic # is None everywhere self.nn_q_step = DescriptorCoolChic() # Track the exponent of the exp-golomb code used for the NN parameters. # None if module is not yet quantized. Default initialization of DescriptorCoolChic # is None everywhere self.nn_expgol_cnt = DescriptorCoolChic() # Copy of the full precision parameters, set just before calling the # quantize_model() function. This is done through the # self._store_full_precision_param() function self.full_precision_param = None
# ------- Actual forward
[docs] def forward( self, quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "gaussian", quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround", soft_round_temperature: Optional[Tensor] = torch.tensor(0.35), noise_parameter: Optional[Tensor] = torch.tensor(0.22), AC_MAX_VAL: int = -1, flag_additional_outputs: bool = False, no_common_randomness: bool = False, only_common_randomness: bool = False, ) -> CoolChicEncoderOutput: """Perform CoolChicEncoder forward pass, to be used during the training. The main step are as follows: 1. **Scale & quantize the encoder-side latent** :math:`\\mathbf{y}` to get the decoder-side latent .. math:: \\hat{\\mathbf{y}} = \\mathrm{Q}(\\Gamma_{enc}\\ \\mathbf{y}), with :math:`\\Gamma_{enc} \\in \\mathbb{R}` a scalar encoder gain defined in ``self.param.encoder_gains`` and :math:`\\mathrm{Q}` the :doc:`quantization operation <quantizer>`. 2. **Measure the rate** of the decoder-side latent with the :doc:`ARM and IFCE <arm>`: .. math:: \\mathrm{R}(\\hat{\\mathbf{y}}) = -\\log_2 p_{\\psi}(\\hat{\\mathbf{y}}), where :math:`p_{\\psi}` is given by the :doc:`Auto-Regressive Module (ARM) <arm>`. 3. **Upsample and synthesize** the latent to get the output .. math:: \\hat{\\mathbf{x}} = f_{\\theta}(f_{\\upsilon}(\\hat{\\mathbf{y}})), with :math:`f_{\\psi}` the :doc:`Upsampling <upsampling>` and :math:`f_{\\theta}` the :doc:`Synthesis <synthesis>`. Args: quantizer_noise_type: Defaults to ``"gaussian"``. quantizer_type: Defaults to ``"softround"``. soft_round_temperature: Soft round temperature. This is used for softround modes as well as the ste mode to simulate the derivative in the backward. Defaults to 0.35. noise_parameter: noise distribution parameter. Defaults to 0.22. AC_MAX_VAL: If different from -1, clamp the value to be in :math:`[-AC\\_MAX\\_VAL; AC\\_MAX\\_VAL + 1]` to write the actual bitstream. Defaults to -1. flag_additional_outputs: True to fill ``CoolChicEncoderOutput['additional_data']`` with many different quantities which can be used to analyze Cool-chic behavior. Defaults to False. Returns: Output of Cool-chic training forward pass. """ decoder_side_latent = self.get_quantize_latent( quantizer_noise_type=quantizer_noise_type, quantizer_type=quantizer_type, soft_round_temperature=soft_round_temperature, noise_parameter=noise_parameter, AC_MAX_VAL=AC_MAX_VAL, ) # ----- ARM to estimate the distribution and the rate of each latent # As for the quantization, we flatten all the latent and their context # so that the ARM network is only called once. # flat_latent: [N, 1] tensor describing N latents # flat_context: [N, context_size] tensor describing each latent context # Get all the context as a single 2D vector of size [B, context size] flat_context_spatial = [] flat_latent = [] flat_context_inter_ft = [] if self.param.flag_ifce: _, intermediate_latent_ups = fixed_upsampling(decoder_side_latent, mode="nearest") for idx_latent, spatial_latent_i in enumerate(decoder_side_latent): if spatial_latent_i.numel() == 0: continue flat_latent.append(spatial_latent_i.view(-1)) cur_context_spatial = _get_neighbor( spatial_latent_i, self.non_zero_pixel_ctx_index, self.mask_size ) cur_context_spatial = rearrange(cur_context_spatial, "b 1 n_context -> b n_context") flat_context_spatial.append(cur_context_spatial) if self.param.flag_ifce and self.param.input_features_ifce[idx_latent] > 0: already_decoded_latent = intermediate_latent_ups[ len(self.latent_grids) - 1 - idx_latent ] cur_context_inter_ft = self.ifce( # Flatten for the ARM forward rearrange(already_decoded_latent, "1 c h w -> (h w) c"), idx_latent, ) cur_context_inter_ft = rearrange( cur_context_inter_ft, "(h w) c -> 1 c h w", h=already_decoded_latent.size()[2], w=already_decoded_latent.size()[3], ) # Interpolate one last time to reach the resolution of spatial_latent i h_i, w_i = spatial_latent_i.size()[-2:] cur_context_inter_ft = F.interpolate( cur_context_inter_ft, scale_factor=2.0, mode="nearest" )[:, :, :h_i, :w_i] inter_neighbors = rearrange(cur_context_inter_ft, "b c h w -> (b h w) c") # Pad with zeros # padded_inter_neighbors = torch.zeros(h_i*w_i, self.param.n_out_ifce, device=inter_neighbors.device) # padded_inter_neighbors[:, :inter_neighbors.shape[1]] = inter_neighbors flat_context_inter_ft.append(inter_neighbors) # No inter feature ARM for this level elif self.param.flag_ifce: h_i, w_i = spatial_latent_i.size()[-2:] padded_inter_neighbors = torch.zeros( h_i * w_i, self.param.output_feature_ifce, device=spatial_latent_i.device ) flat_context_inter_ft.append(padded_inter_neighbors) flat_context_spatial = torch.cat(flat_context_spatial, dim=0) flat_latent = torch.cat(flat_latent, dim=0) if self.param.flag_ifce: flat_context_inter_ft = torch.cat(flat_context_inter_ft, dim=0) flat_context = torch.cat((flat_context_spatial, flat_context_inter_ft), dim=1) else: flat_context = flat_context_spatial # Feed the spatial context to the arm MLP and get mu and scale flat_mu, flat_scale = self.arm.reparameterize_output(self.arm(flat_context)) # Compute the rate (i.e. the entropy of flat latent knowing mu and scale) proba = torch.clamp_min( _laplace_cdf(flat_latent + 0.5, flat_mu, flat_scale) - _laplace_cdf(flat_latent - 0.5, flat_mu, flat_scale), min=2**-16, # No value can cost more than 16 bits. ) flat_rate = -torch.log2(proba) # Discard the hyperlatent # Get only feature map assigned for the image reconstruction decoder_side_latent_syn = [ x for x, m in zip(decoder_side_latent, self.param.flag_is_hyperlatent) if not m ] ups_latent = self.upsampling(decoder_side_latent_syn) # Upsampling and synthesis to get the output if self.param.flag_common_randomness: # ups_noise is [1, C, H, W] where C = len(self.cr) and H, W is the # spatial resolution of the highest resolution in self.cr e.g. # self.cr[0].size()[-2:]. # If needed we interpolate once more to reach the resolution of the # image to be decoded. ups_noise, _ = fixed_upsampling(self.cr) ups_noise = F.interpolate(ups_noise, size=self.param.img_size, mode="bicubic") if no_common_randomness: ups_noise = ups_noise * 0 if only_common_randomness: ups_latent = ups_latent * 0 syn_in = torch.cat([ups_latent, ups_noise], dim=1) else: syn_in = ups_latent synth_out = self.synthesis(syn_in) # Upsample the output of the synthesis with a bicubic if required synthesis_output = F.interpolate( synth_out, size=self.param.img_size, mode=self.param.final_upsampling_type ) # Trim out additional pixels due to the final upsampling synthesis_output = synthesis_output[ :, :, : self.param.img_size[0], : self.param.img_size[1] ] additional_data = {} if flag_additional_outputs: # Prepare list to accommodate the visualisations additional_data["detailed_sent_latent"] = [] additional_data["detailed_mu"] = [] additional_data["detailed_scale"] = [] additional_data["detailed_log_scale"] = [] additional_data["detailed_rate_bit"] = [] additional_data["detailed_centered_latent"] = [] # "Pointer" for the reading of the 1D scale, mu and rate cnt = 0 # for i, _ in enumerate(filtered_latent): # print(torch.cat([x for x in self.cr], dim=0).view(-1).exp().detach().cpu()) for index_latent_res, _ in enumerate(self.latent_grids): c_i, h_i, w_i = decoder_side_latent[index_latent_res].size()[-3:] additional_data["detailed_sent_latent"].append( decoder_side_latent[index_latent_res].view((1, c_i, h_i, w_i)) ) # Scale, mu and rate are 1D tensors where the N latent grids # are flattened together. As such we have to read the appropriate # number of values in this 1D vector to reconstruct the i-th grid in 2D mu_i, scale_i, rate_i = [ # Read h_i * w_i values starting from cnt tmp[cnt : cnt + (c_i * h_i * w_i)].view((1, c_i, h_i, w_i)) for tmp in [flat_mu, flat_scale, flat_rate] ] cnt += c_i * h_i * w_i additional_data["detailed_mu"].append(mu_i) additional_data["detailed_scale"].append(scale_i) additional_data["detailed_rate_bit"].append(rate_i) additional_data["detailed_centered_latent"].append( additional_data["detailed_sent_latent"][-1] - mu_i ) additional_data["detailed_ups_latent"] = ups_latent additional_data["synthesis"] = synth_out # additional_data["fixer"] = fixer_out if self.param.flag_common_randomness: additional_data["detailed_ups_noise"] = ups_noise res: CoolChicEncoderOutput = { "raw_out": synthesis_output, "rate": flat_rate, "additional_data": additional_data, } return res
def get_quantize_latent( self, quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy", quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround", soft_round_temperature: Optional[Tensor] = torch.tensor(0.3), noise_parameter: Optional[Tensor] = torch.tensor(1.0), AC_MAX_VAL: int = -1, ): # ------ Encoder-side: quantize the latent # Convert the N [1, C, H_i, W_i] 4d latents with different resolutions # to a single flat vector. This allows to call the quantization # only once, which is faster encoder_side_flat_latent = torch.cat([latent_i.view(-1) for latent_i in self.latent_grids]) flat_decoder_side_latent = quantize( encoder_side_flat_latent * self.encoder_gains, quantizer_noise_type if self.training else "none", quantizer_type if self.training else "hardround", soft_round_temperature, noise_parameter, ) # Clamp latent if we need to write a bitstream if AC_MAX_VAL != -1: flat_decoder_side_latent = torch.clamp( flat_decoder_side_latent, -AC_MAX_VAL, AC_MAX_VAL - 1 ) # Convert back the 1d tensor to a list of N [1, C, H_i, W_i] 4d latents. # This require a few additional information about each individual # latent dimension, stored in self.param.size_per_latent decoder_side_latent = [] cnt = 0 for latent_size in self.param.size_per_latent: b, c, h, w = latent_size # b should be one latent_numel = b * c * h * w decoder_side_latent.append( flat_decoder_side_latent[cnt : cnt + latent_numel].view(latent_size) ) cnt += latent_numel return decoder_side_latent # ------- Getter / Setter and Initializer
[docs] def get_param(self) -> OrderedDict[str, Tensor]: """Return **a copy** of the weights and biases inside the module. Returns: OrderedDict[str, Tensor]: A copy of all weights & biases in the module. """ param = OrderedDict({}) param.update( { # Detach & clone to create a copy f"latent_grids.{k}": v.detach().clone() for k, v in self.latent_grids.named_parameters() } ) param.update({f"arm.{k}": v for k, v in self.arm.get_param().items()}) param.update({f"upsampling.{k}": v for k, v in self.upsampling.get_param().items()}) param.update({f"synthesis.{k}": v for k, v in self.synthesis.get_param().items()}) if self.ifce is not None: param.update({f"ifce.{k}": v for k, v in self.ifce.get_param().items()}) return param
[docs] def set_param(self, param: OrderedDict[str, Tensor]): """Replace the current parameters of the module with param. Args: param (OrderedDict[str, Tensor]): Parameters to be set. """ self.load_state_dict(param)
[docs] def initialize_latent_grids(self) -> None: """Initialize the latent grids. The different tensors composing the latent grids must have already been created e.g. through ``torch.empty()``. """ for latent_index, latent_value in enumerate(self.latent_grids): self.latent_grids[latent_index] = nn.Parameter( torch.zeros_like(latent_value), requires_grad=True )
[docs] def reinitialize_parameters(self): """Reinitialize in place the different parameters of a CoolChicEncoder namely the latent grids, the arm, the upsampling and the weights. """ self.arm.reinitialize_parameters() self.upsampling.reinitialize_parameters() self.synthesis.reinitialize_parameters() self.initialize_latent_grids() # Reset the quantization steps and exp-golomb count of the neural # network to None since we are resetting the parameters. # Default DescriptorCoolChic initialization is None. self.nn_q_step = DescriptorCoolChic() self.nn_expgol_cnt = DescriptorCoolChic()
def _store_full_precision_param(self) -> None: """Store the current parameters inside self.full_precision_param This function checks that there is no self.nn_q_step and self.nn_expgol_cnt already saved. This would mean that we no longer have full precision parameters but quantized ones. """ if self.full_precision_param is not None: print( "Warning: overwriting already saved full-precision parameters" " in CoolChicEncoder _store_full_precision_param()." ) # Check that we haven't already quantized the network by looking at # the nn_expgol_cnt and nn_q_step dictionaries no_q_step = True no_expgol_cnt = True for field_nn in fields(DescriptorCoolChic): for field_wb in fields(DescriptorNN): q_step = self.nn_q_step.get_value(field_nn.name, field_wb.name) expgol_cnt = self.nn_expgol_cnt.get_value(field_nn.name, field_wb.name) if q_step is not None: no_q_step = False if expgol_cnt is not None: no_expgol_cnt = False assert no_q_step and no_expgol_cnt, ( "Trying to store full precision parameters, while CoolChicEncoder " "nn_q_step or nn_expgol_cnt attributes are not full of None. This means that the " "parameters have already been quantized... aborting!" ) # All good, simply save the parameters self.full_precision_param = self.get_param() def _load_full_precision_param(self) -> None: assert self.full_precision_param is not None, ( "Trying to load full precision parameters but self.full_precision_param is None" ) self.set_param(self.full_precision_param) # Reset the side information about the quantization step and expgol cnt # so that the rate is no longer computed by the test() function. # Default init --> None self.nn_q_step = DescriptorCoolChic() self.nn_expgol_cnt = DescriptorCoolChic() # ------- Get flops, neural network rates and quantization step
[docs] def get_flops(self) -> None: """Compute the number of MAC & parameters for the model. Update ``self.total_flops`` (integer describing the number of total MAC) and ``self.flops_str``, a pretty string allowing to print the model complexity somewhere. .. attention:: ``fvcore`` measures MAC (multiplication & accumulation) but calls it FLOP (floating point operation)... We do the same here and call everything FLOP even though it would be more accurate to use MAC. """ # print("Ignoring get_flops") # Count the number of floating point operations here. It must be done before # torch scripting the different modules. self = self.train(mode=False) flops = FlopCountAnalysis( self, ( "none", # Quantization noise "hardround", # Quantizer type 0.3, # Soft round temperature 0.1, # Noise parameter -1, # AC_MAX_VAL False, # Flag additional outputs ), ) flops.unsupported_ops_warnings(False) flops.uncalled_modules_warnings(False) self.total_flops = flops.total() for k in self.flops_per_module: self.flops_per_module[k] = flops.by_module()[k] self.flops_str = flop_count_table(flops) del flops self = self.train(mode=True)
[docs] def get_network_rate(self) -> Tuple[DescriptorCoolChic, int]: """Return the rate (in bits) associated to the parameters (weights and biases) of the different modules Returns: Tuple[DescriptorCoolChic, int]: The rate (in bits) associated with the weights and biases of each module. Also return the total rate in bits. """ rate_per_module = DescriptorCoolChic( arm=DescriptorNN(weight=0.0, bias=0.0), ifce=DescriptorNN(weight=0.0, bias=0.0), upsampling=DescriptorNN(weight=0.0, bias=0.0), synthesis=DescriptorNN(weight=0.0, bias=0.0), ) total_rate = 0.0 for module_name in self.modules_to_send: cur_module = getattr(self, module_name) module_rate = measure_expgolomb_rate( cur_module, self.nn_q_step.get_value(module_name), self.nn_expgol_cnt.get_value(module_name), ) rate_per_module.set_value(module_rate, module_name) total_rate = rate_per_module.sum() return rate_per_module, total_rate
[docs] def get_network_quantization_step(self) -> DescriptorCoolChic: """Return the quantization step associated to the parameters (weights and biases) of the different modules. Those quantization can be ``None`` if the model has not yet been quantized. Returns: DescriptorCoolChic: The quantization step associated with the weights and biases of each module. """ return self.nn_q_step
[docs] def get_network_expgol_count(self) -> DescriptorCoolChic: """Return the Exp-Golomb count parameter associated to the parameters (weights and biases) of the different modules. Those exp-golomb param can be ``None`` if the model has not yet been quantized. Returns: DescriptorCoolChic: The Exp-Golomb count parameter associated with the weights and biases of each module. """ return self.nn_expgol_cnt
[docs] def str_complexity(self) -> str: """Return a string describing the number of MAC (**not mac per pixel**) and the number of parameters for the different modules of CoolChic Returns: str: A pretty string about CoolChic complexity. """ if not self.flops_str: self.get_flops() msg_total_mac = "----------------------------------\n" msg_total_mac += f"Total MAC / decoded pixel: {self.get_total_mac_per_pixel():.1f}" msg_total_mac += "\n----------------------------------" return self.flops_str + "\n\n" + msg_total_mac
[docs] def get_total_mac_per_pixel(self) -> float: """Count the number of Multiplication-Accumulation (MAC) per decoded pixel for this model. Returns: float: number of floating point operations per decoded pixel. """ if not self.flops_str: self.get_flops() n_pixels = self.param.img_size[-2] * self.param.img_size[-1] return self.total_flops / n_pixels
# ------- Useful functions
[docs] def to_device(self, device: torch.device) -> None: """Push a model to a given device.""" self = self.to(device) if self.param.flag_common_randomness: for i in range(len(self.cr)): self.cr[i] = self.cr[i].to(device)
# # Push integerized weights and biases of the mlp (resp qw and qb) to # # the required device # for idx_layer, layer in enumerate(self.arm.mlp): # if hasattr(layer, "qw"): # if layer.qw is not None: # self.arm.mlp[idx_layer].qw = layer.qw.to(device) # if hasattr(layer, "qb"): # if layer.qb is not None: # self.arm.mlp[idx_layer].qb = layer.qb.to(device)
[docs] def pretty_string(self) -> str: """Get a pretty string detailing the complexity of a ``CoolChicEncoder`` Returns: str: a pretty string ready to be printed out """ s = "" if not self.flops_str: self.get_flops() n_pixels = self.param.img_size[-2] * self.param.img_size[-1] total_mac_per_pix = self.get_total_mac_per_pixel() ups_complexity = self.flops_per_module["upsampling"] / n_pixels ups_share_complexity = 100 * ups_complexity / total_mac_per_pix arm_complexity = self.flops_per_module["arm"] / n_pixels arm_share_complexity = 100 * arm_complexity / total_mac_per_pix syn_complexity = self.flops_per_module["synthesis"] / n_pixels syn_share_complexity = 100 * syn_complexity / total_mac_per_pix s = ( f" - {'ARM':<14} {arm_complexity:5.0f} MAC / pixel; {arm_share_complexity:4.1f} % of the complexity\n" f" - {'Upsampling':<14} {ups_complexity:5.0f} MAC / pixel; {ups_share_complexity:4.1f} % of the complexity\n" f" - {'Synthesis':<14} {syn_complexity:5.0f} MAC / pixel; {syn_share_complexity:4.1f} % of the complexity\n" ) if "ifce" in self.flops_per_module: ifce_complexity = self.flops_per_module["ifce"] / n_pixels ifce_share_complexity = 100 * ifce_complexity / total_mac_per_pix s += f" - {'Inter ft ARM':<14} {ifce_complexity:5.0f} MAC / pixel; {ifce_share_complexity:4.1f} % of the complexity\n" return s
def instantiate_latent_grids_from_cc_param(param: CoolChicEncoderParameter) -> nn.ParameterList: return nn.ParameterList( [nn.Parameter(torch.empty(size_i), requires_grad=True) for size_i in param.size_per_latent] ) def instantiate_common_randomness_from_cc_param(param: CoolChicEncoderParameter) -> List[Tensor]: common_noise_generator = CommonGaussianNoiseGenerator() cr = [common_noise_generator.sample(size) for size in param.size_per_latent_cr] return cr def instantiate_arm_from_cc_param(param: CoolChicEncoderParameter) -> Arm: return Arm( param.total_context_arm, param.n_hidden_layers_arm, flag_linear_stabiliser=param.linear_stabiliser_arm, ) def instantiate_syn_from_cc_param(param: CoolChicEncoderParameter) -> Synthesis: return Synthesis( param.input_feature_synthesis, param.layers_synthesis, param.linear_stabiliser_synth, param.flag_common_randomness, ) def instantiate_ups_from_cc_param(param: CoolChicEncoderParameter) -> Upsampling: # If latent_resolution = (1, 6), there 6 upsampling to go from a downsampling of 2**-6 to 2**0 n_ups = param.latent_resolution[1] return Upsampling( ups_k_size=param.ups_k_size, ups_preconcat_k_size=param.ups_preconcat_k_size, n_ups_kernel=n_ups, n_ups_preconcat_kernel=n_ups, ) def instantiate_ifce_from_cc_param(param: CoolChicEncoderParameter) -> Optional[Ifce]: if param.flag_ifce: return Ifce(param.input_features_ifce, param.output_feature_ifce) else: return None @torch.no_grad() def measure_expgolomb_rate( q_module: nn.Module, q_step: DescriptorNN, expgol_cnt: DescriptorNN ) -> DescriptorNN: """Get the rate associated with the current parameters. Returns: DescriptorNN: The rate of the different modules wrapped inside a dictionary of float. It does **not** return tensor so no back propagation is possible """ # Concatenate the sent parameters here to measure the entropy later sent_param = DescriptorNN(bias=[], weight=[]) rate_param = DescriptorNN(bias=0.0, weight=0.0) param = q_module.get_param() # Retrieve all the sent item for parameter_name, parameter_value in param.items(): if ".weight" in parameter_name: current_q_step = q_step.weight elif ".bias" in parameter_name: current_q_step = q_step.bias # Current quantization step is None because the module is not yet # quantized. Return an all zero rate if current_q_step is None: return rate_param # Quantization is round(parameter_value / q_step) * q_step so we divide by q_step # to obtain the sent latent. current_sent_param = (parameter_value / current_q_step).view(-1) if ".weight" in parameter_name: sent_param.weight.append(current_sent_param) elif ".bias" in parameter_name: sent_param.bias.append(current_sent_param) else: print(f'Parameter name should include ".weight" or ".bias" Found: {parameter_name}') # For each sent parameters (e.g. all biases and all weights) # compute their cost with an exp-golomb coding. for field_wb in fields(DescriptorNN): weight_or_bias = field_wb.name param = getattr(sent_param, weight_or_bias) # If we do not have any parameter, there is no rate associated. # This can happens for the upsampling biases for instance if len(param) == 0: setattr(rate_param, weight_or_bias, 0.0) continue # Current exp-golomb count is None because the module is not yet # quantized. Return an all zero rate current_expgol_cnt = getattr(expgol_cnt, weight_or_bias) if current_expgol_cnt is None: return rate_param # Concatenate the list of parameters as a big one dimensional tensor param = torch.cat(param) # This will be pretty long! Could it be vectorized? # ! Todo: replace that with the actual encode_exp_golomb code? setattr(rate_param, weight_or_bias, exp_golomb_nbins(param, count=current_expgol_cnt)) return rate_param @torch.no_grad() def exp_golomb_nbins(symbol: Tensor, count: int = 0) -> Tensor: """Compute the number of bits required to encode a Tensor of integers using an exponential-golomb code with exponent ``count``. This estimates the rate of an actual exp-golomb code with less than 0.5% mismatch. Args: symbol: Tensor to encode count (int, optional): Exponent of the exp-golomb code. Defaults to 0. Returns: Number of bits required to encode all the symbols. """ # We encode the sign equiprobably at the end thus one more bit if symbol != 0 nbins = 2 * torch.floor(torch.log2(2 * symbol.abs() / (2**count) + 1)) + count + (symbol != 0) res = nbins.sum() return res