Source code for enc.component.coolchic

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md

import math
import typing
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, TypedDict

from enc.visu.console import pretty_string_nn, pretty_string_ups
from torch import nn, Tensor

import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table

from enc.component.core.arm import (
    Arm,
    _get_neighbor,
    _get_non_zero_pixel_ctx_index,
    _laplace_cdf,
)
from enc.component.core.quantizer import (
    POSSIBLE_QUANTIZATION_NOISE_TYPE,
    POSSIBLE_QUANTIZER_TYPE,
    quantize,
)
from enc.component.core.synthesis import Synthesis
from enc.component.core.upsampling import Upsampling
from enc.utils.misc import (
    MAX_ARM_MASK_SIZE,
    POSSIBLE_DEVICE,
    DescriptorCoolChic,
    DescriptorNN,
    measure_expgolomb_rate,
)


"""A cool-chic encoder is composed of:
    - A set of 2d hierarchical latent grids
    - An auto-regressive probability module
    - An upsampling module
    - A synthesis.

At its core, it is a tool to compress any spatially organized signal, with
one or more features by representing it as a set of 2d entropy coding-friendly
latent grids. After upsampling, these latent grids allows to synthesize the
desired signal.
"""

"""Dataclass to store the parameters of CoolChicEncoder for one frame."""


[docs] @dataclass class CoolChicEncoderParameter: """Dataclass storing the parameters of a ``CoolChicEncoder``. Args: img_size (Tuple[int, int]): Height and width :math:`(H, W)` of the frame to be coded layers_synthesis (List[str]): Describes the architecture of the synthesis transform. See the :doc:`synthesis documentation <core/synthesis>` for more information. n_ft_per_res (List[int]): Number of latent features for each latent resolution *i.e.* ``n_ft_per_res[i]`` gives the number of channel :math:`C_i` of the latent with resolution :math:`\\frac{H}{2^i}, \\frac{W}{2^i}`. dim_arm (int, Optional): Number of context pixels for the ARM. Also corresponds to the ARM hidden layer width. See the :doc:`ARM documentation <core/arm>` for more information. Defaults to 24 n_hidden_layers_arm (int, Optional): Number of hidden layers in the ARM. Set ``n_hidden_layers_arm = 0`` for a linear ARM. Defaults to 2. ups_k_size (int, Optional): Upsampling kernel size for the transposed convolutions. See the :doc:`upsampling documentation <core/upsampling>` for more information. Defaults to 8. ups_preconcat_k_size (int, Optional): Upsampling kernel size for the pre-concatenation convolutions. See the :doc:`upsampling documentation <core/upsampling>` for more information. Defaults to 7. encoder_gain (int, Optional): Multiply the latent by this value before quantization. See the documentation of Cool-chic forward pass. Defaults to 16. """ layers_synthesis: List[str] n_ft_per_res: List[int] dim_arm: int = 24 n_hidden_layers_arm: int = 2 encoder_gain: int = 16 ups_k_size: int = 8 ups_preconcat_k_size: int = 7 # ==================== Not set by the init function ===================== # #: Automatically computed, number of different latent resolutions latent_n_grids: int = field(init=False) #: Height and width :math:`(H, W)` of the frame to be coded. Must be #: set using the ``set_image_size()`` function. img_size: Optional[Tuple[int, int]] = field(init=False, default=None) # ==================== Not set by the init function ===================== # def __post_init__(self): self.latent_n_grids = len(self.n_ft_per_res)
[docs] def set_image_size(self, img_size: Tuple[int, int]) -> None: """Register the field self.img_size. Args: img_size: Height and width :math:`(H, W)` of the frame to be coded """ self.img_size = img_size
[docs] def pretty_string(self) -> str: """Return a pretty string formatting the data within the class""" ATTRIBUTE_WIDTH = 25 VALUE_WIDTH = 80 s = "CoolChicEncoderParameter value:\n" s += "-------------------------------\n" for k in fields(self): s += f"{k.name:<{ATTRIBUTE_WIDTH}}: {str(getattr(self, k.name)):<{VALUE_WIDTH}}\n" s += "\n" return s
[docs] class CoolChicEncoderOutput(TypedDict): """``TypedDict`` representing the output of CoolChicEncoder forward. Args: raw_out (Tensor): Output of the synthesis :math:`([B, C, H, W])`. rate (Tensor): rate associated to each latent (in bits). Shape is :math:`(N)`, with :math:`N` the total number of latent variables. additional_data (Dict[str, Any]): Any other data required to compute some logs, stored inside a dictionary """ raw_out: Tensor rate: Tensor additional_data: Dict[str, Any]
[docs] class CoolChicEncoder(nn.Module): """CoolChicEncoder for a single frame."""
[docs] def __init__(self, param: CoolChicEncoderParameter): """Instantiate a cool-chic encoder for one frame. Args: param (CoolChicEncoderParameter): Architecture of the `CoolChicEncoder`. See the documentation of `CoolChicEncoderParameter` for more information """ super().__init__() # Everything is stored inside param self.param = param assert self.param.img_size is not None, ( "You are trying to instantiate a CoolChicEncoder from a " "CoolChicEncoderParameter with a field img_size set to None. Use " "the function coolchic_encoder_param.set_img_size((H, W)) before " "instantiating the CoolChicEncoder." ) # ================== Synthesis related stuff ================= # # Encoder-side latent gain applied prior to quantization, one per feature self.encoder_gains = param.encoder_gain # Populate the successive grids self.size_per_latent = [] self.latent_grids = nn.ParameterList() for i in range(self.param.latent_n_grids): h_grid, w_grid = [int(math.ceil(x / (2**i))) for x in self.param.img_size] c_grid = self.param.n_ft_per_res[i] cur_size = (1, c_grid, h_grid, w_grid) self.size_per_latent.append(cur_size) # Instantiate empty tensor, we fill them later on with the function # self.initialize_latent_grids() self.latent_grids.append( nn.Parameter(torch.empty(cur_size), requires_grad=True) ) self.initialize_latent_grids() # Instantiate the synthesis MLP with as many inputs as the number # of latent channels self.synthesis = Synthesis( sum([latent_size[1] for latent_size in self.size_per_latent]), self.param.layers_synthesis, ) # ================== Synthesis related stuff ================= # # ===================== Upsampling stuff ===================== # self.upsampling = Upsampling( ups_k_size=self.param.ups_k_size, ups_preconcat_k_size=self.param.ups_preconcat_k_size, # Instantiate one different upsampling and pre-concatenation # filters for each of the upsampling step. Could also be set to one # to share the same filter across all latents. n_ups_kernel=self.param.latent_n_grids - 1, n_ups_preconcat_kernel=self.param.latent_n_grids - 1, ) # ===================== Upsampling stuff ===================== # # ===================== ARM related stuff ==================== # # Create the probability model for the main INR. It uses a spatial context # parameterized by the spatial context # For a given mask size N (odd number e.g. 3, 5, 7), we have at most # (N * N - 1) / 2 context pixels in it. # Example, a 9x9 mask as below has 40 context pixel (indicated with 1s) # available to predict the pixel '*' # 1 1 1 1 1 1 1 1 1 # 1 1 1 1 1 1 1 1 1 # 1 1 1 1 1 1 1 1 1 # 1 1 1 1 1 1 1 1 1 # 1 1 1 1 * 0 0 0 0 # 0 0 0 0 0 0 0 0 0 # 0 0 0 0 0 0 0 0 0 # 0 0 0 0 0 0 0 0 0 # 0 0 0 0 0 0 0 0 0 # No more than 40 context pixels i.e. a 9x9 mask size (see example above) max_mask_size = MAX_ARM_MASK_SIZE max_context_pixel = int((max_mask_size**2 - 1) / 2) assert self.param.dim_arm <= max_context_pixel, ( f"You can not have more context pixels " f" than {max_context_pixel}. Found {self.param.dim_arm}" ) # Mask of size 2N + 1 when we have N rows & columns of context. self.mask_size = max_mask_size # 1D tensor containing the indices of the selected context pixels. # register_buffer for automatic device management. We set persistent to false # to simply use the "automatically move to device" function, without # considering non_zero_pixel_ctx_index as a parameters (i.e. returned # by self.parameters()) self.register_buffer( "non_zero_pixel_ctx_index", _get_non_zero_pixel_ctx_index(self.param.dim_arm), persistent=False, ) self.arm = Arm(self.param.dim_arm, self.param.n_hidden_layers_arm) # ===================== ARM related stuff ==================== # # Something like ['arm', 'synthesis', 'upsampling'] self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)] # ======================== Monitoring ======================== # # Pretty string representing the decoder complexity self.flops_str = "" # Total number of multiplications to decode the image self.total_flops = 0.0 self.flops_per_module = {k: 0 for k in self.modules_to_send} # Fill the two attributes aboves self.get_flops() # ======================== Monitoring ======================== # # Track the quantization step of each neural network, None if the # module is not yet quantized self.nn_q_step: Dict[str, DescriptorNN] = { k: {"weight": None, "bias": None} for k in self.modules_to_send } # Track the exponent of the exp-golomb code used for the NN parameters. # None if module is not yet quantized self.nn_expgol_cnt: Dict[str, DescriptorNN] = { k: {"weight": None, "bias": None} for k in self.modules_to_send } # Copy of the full precision parameters, set just before calling the # quantize_model() function. This is done through the # self._store_full_precision_param() function self.full_precision_param = None
# ------- Actual forward
[docs] def forward( self, quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy", quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround", soft_round_temperature: Optional[float] = 0.3, noise_parameter: Optional[float] = 1.0, AC_MAX_VAL: int = -1, flag_additional_outputs: bool = False, ) -> CoolChicEncoderOutput: """Perform CoolChicEncoder forward pass, to be used during the training. The main step are as follows: 1. **Scale & quantize the encoder-side latent** :math:`\\mathbf{y}` to get the decoder-side latent .. math:: \\hat{\\mathbf{y}} = \\mathrm{Q}(\\Gamma_{enc}\\ \\mathbf{y}), with :math:`\\Gamma_{enc} \\in \\mathbb{R}` a scalar encoder gain defined in ``self.param.encoder_gains`` and :math:`\\mathrm{Q}` the :doc:`quantization operation <core/quantizer>`. 2. **Measure the rate** of the decoder-side latent with the :doc:`ARM <core/arm>`: .. math:: \\mathrm{R}(\\hat{\\mathbf{y}}) = -\\log_2 p_{\\psi}(\\hat{\\mathbf{y}}), where :math:`p_{\\psi}` is given by the :doc:`Auto-Regressive Module (ARM) <core/arm>`. 3. **Upsample and synthesize** the latent to get the output .. math:: \\hat{\\mathbf{x}} = f_{\\theta}(f_{\\upsilon}(\\hat{\\mathbf{y}})), with :math:`f_{\\psi}` the :doc:`Upsampling <core/upsampling>` and :math:`f_{\\theta}` the :doc:`Synthesis <core/synthesis>`. Args: quantizer_noise_type: Defaults to ``"kumaraswamy"``. quantizer_type: Defaults to ``"softround"``. soft_round_temperature: Soft round temperature. This is used for softround modes as well as the ste mode to simulate the derivative in the backward. Defaults to 0.3. noise_parameter: noise distribution parameter. Defaults to 1.0. AC_MAX_VAL: If different from -1, clamp the value to be in :math:`[-AC\\_MAX\\_VAL; AC\\_MAX\\_VAL + 1]` to write the actual bitstream. Defaults to -1. flag_additional_outputs: True to fill ``CoolChicEncoderOutput['additional_data']`` with many different quantities which can be used to analyze Cool-chic behavior. Defaults to False. Returns: Output of Cool-chic training forward pass. """ # ! Order of the operations are important as these are asynchronous # ! CUDA operations. Some ordering are faster than other... # ------ Encoder-side: quantize the latent # Convert the N [1, C, H_i, W_i] 4d latents with different resolutions # to a single flat vector. This allows to call the quantization # only once, which is faster encoder_side_flat_latent = torch.cat( [latent_i.view(-1) for latent_i in self.latent_grids] ) flat_decoder_side_latent = quantize( encoder_side_flat_latent * self.encoder_gains, quantizer_noise_type if self.training else "none", quantizer_type if self.training else "hardround", soft_round_temperature, noise_parameter, ) # Clamp latent if we need to write a bitstream if AC_MAX_VAL != -1: flat_decoder_side_latent = torch.clamp( flat_decoder_side_latent, -AC_MAX_VAL, AC_MAX_VAL + 1 ) # Convert back the 1d tensor to a list of N [1, C, H_i, W_i] 4d latents. # This require a few additional information about each individual # latent dimension, stored in self.size_per_latent decoder_side_latent = [] cnt = 0 for latent_size in self.size_per_latent: b, c, h, w = latent_size # b should be one latent_numel = b * c * h * w decoder_side_latent.append( flat_decoder_side_latent[cnt : cnt + latent_numel].view(latent_size) ) cnt += latent_numel # ----- ARM to estimate the distribution and the rate of each latent # As for the quantization, we flatten all the latent and their context # so that the ARM network is only called once. # flat_latent: [N, 1] tensor describing N latents # flat_context: [N, context_size] tensor describing each latent context # Get all the context as a single 2D vector of size [B, context size] flat_context = torch.cat( [ _get_neighbor(spatial_latent_i, self.mask_size, self.non_zero_pixel_ctx_index) for spatial_latent_i in decoder_side_latent ], dim=0, ) # Get all the B latent variables as a single one dimensional vector flat_latent = torch.cat( [spatial_latent_i.view(-1) for spatial_latent_i in decoder_side_latent], dim=0 ) # Feed the spatial context to the arm MLP and get mu and scale flat_mu, flat_scale, flat_log_scale = self.arm(flat_context) # Compute the rate (i.e. the entropy of flat latent knowing mu and scale) proba = torch.clamp_min( _laplace_cdf(flat_latent + 0.5, flat_mu, flat_scale) - _laplace_cdf(flat_latent - 0.5, flat_mu, flat_scale), min=2**-16, # No value can cost more than 16 bits. ) flat_rate = -torch.log2(proba) # Upsampling and synthesis to get the output synthesis_output = self.synthesis(self.upsampling(decoder_side_latent)) additional_data = {} if flag_additional_outputs: # Prepare list to accommodate the visualisations additional_data["detailed_sent_latent"] = [] additional_data["detailed_mu"] = [] additional_data["detailed_scale"] = [] additional_data["detailed_log_scale"] = [] additional_data["detailed_rate_bit"] = [] additional_data["detailed_centered_latent"] = [] additional_data["hpfilters"] = [] # "Pointer" for the reading of the 1D scale, mu and rate cnt = 0 # for i, _ in enumerate(filtered_latent): for index_latent_res, _ in enumerate(self.latent_grids): c_i, h_i, w_i = decoder_side_latent[index_latent_res].size()[-3:] additional_data["detailed_sent_latent"].append( decoder_side_latent[index_latent_res].view((1, c_i, h_i, w_i)) ) # Scale, mu and rate are 1D tensors where the N latent grids # are flattened together. As such we have to read the appropriate # number of values in this 1D vector to reconstruct the i-th grid in 2D mu_i, scale_i, log_scale_i, rate_i = [ # Read h_i * w_i values starting from cnt tmp[cnt : cnt + (c_i * h_i * w_i)].view((1, c_i, h_i, w_i)) for tmp in [flat_mu, flat_scale, flat_log_scale, flat_rate] ] cnt += c_i * h_i * w_i additional_data["detailed_mu"].append(mu_i) additional_data["detailed_scale"].append(scale_i) additional_data["detailed_log_scale"].append(log_scale_i) additional_data["detailed_rate_bit"].append(rate_i) additional_data["detailed_centered_latent"].append( additional_data["detailed_sent_latent"][-1] - mu_i ) res: CoolChicEncoderOutput = { "raw_out": synthesis_output, "rate": flat_rate, "additional_data": additional_data, } return res
# ------- Getter / Setter and Initializer
[docs] def get_param(self) -> OrderedDict[str, Tensor]: """Return **a copy** of the weights and biases inside the module. Returns: OrderedDict[str, Tensor]: A copy of all weights & biases in the module. """ param = OrderedDict({}) param.update( { # Detach & clone to create a copy f"latent_grids.{k}": v.detach().clone() for k, v in self.latent_grids.named_parameters() } ) param.update({f"arm.{k}": v for k, v in self.arm.get_param().items()}) param.update( {f"upsampling.{k}": v for k, v in self.upsampling.get_param().items()} ) param.update( {f"synthesis.{k}": v for k, v in self.synthesis.get_param().items()} ) return param
[docs] def set_param(self, param: OrderedDict[str, Tensor]): """Replace the current parameters of the module with param. Args: param (OrderedDict[str, Tensor]): Parameters to be set. """ self.load_state_dict(param)
[docs] def initialize_latent_grids(self) -> None: """Initialize the latent grids. The different tensors composing the latent grids must have already been created e.g. through ``torch.empty()``. """ for latent_index, latent_value in enumerate(self.latent_grids): self.latent_grids[latent_index] = nn.Parameter( torch.zeros_like(latent_value), requires_grad=True )
[docs] def reinitialize_parameters(self): """Reinitialize in place the different parameters of a CoolChicEncoder namely the latent grids, the arm, the upsampling and the weights. """ self.arm.reinitialize_parameters() self.upsampling.reinitialize_parameters() self.synthesis.reinitialize_parameters() self.initialize_latent_grids() # Reset the quantization steps and exp-golomb count of the neural # network to None since we are resetting the parameters. self.nn_q_step: Dict[str, DescriptorNN] = { k: {"weight": None, "bias": None} for k in self.modules_to_send } self.nn_expgol_cnt: Dict[str, DescriptorNN] = { k: {"weight": None, "bias": None} for k in self.modules_to_send }
def _store_full_precision_param(self) -> None: """Store the current parameters inside self.full_precision_param This function checks that there is no self.nn_q_step and self.nn_expgol_cnt already saved. This would mean that we no longer have full precision parameters but quantized ones. """ if self.full_precision_param is not None: print( "Warning: overwriting already saved full-precision parameters" " in CoolChicEncoder _store_full_precision_param()." ) # Check that we haven't already quantized the network by looking at # the nn_expgol_cnt and nn_q_step dictionaries no_q_step = True for _, q_step_dict in self.nn_q_step.items(): for _, q_step in q_step_dict.items(): if q_step is not None: no_q_step = False assert no_q_step, ( "Trying to store full precision parameters, while CoolChicEncoder " "nn_q_step attributes is not full of None. This means that the " "parameters have already been quantized... aborting!" ) no_expgol_cnt = True for _, expgol_cnt_dict in self.nn_expgol_cnt.items(): for _, expgol_cnt in expgol_cnt_dict.items(): if expgol_cnt is not None: no_expgol_cnt = False assert no_expgol_cnt, ( "Trying to store full precision parameters, while CoolChicEncoder " "nn_expgol_cnt attributes is not full of None. This means that the " "parameters have already been quantized... aborting!" ) # All good, simply save the parameters self.full_precision_param = self.get_param() def _load_full_precision_param(self) -> None: assert self.full_precision_param is not None, ( "Trying to load full precision parameters but " "self.full_precision_param is None" ) self.set_param(self.full_precision_param) # Reset the side information about the quantization step and expgol cnt # so that the rate is no longer computed by the test() function. self.nn_q_step: Dict[str, DescriptorNN] = { k: {"weight": None, "bias": None} for k in self.modules_to_send } self.nn_expgol_cnt: Dict[str, DescriptorNN] = { k: {"weight": None, "bias": None} for k in self.modules_to_send } # ------- Get flops, neural network rates and quantization step
[docs] def get_flops(self) -> None: """Compute the number of MAC & parameters for the model. Update ``self.total_flops`` (integer describing the number of total MAC) and ``self.flops_str``, a pretty string allowing to print the model complexity somewhere. .. attention:: ``fvcore`` measures MAC (multiplication & accumulation) but calls it FLOP (floating point operation)... We do the same here and call everything FLOP even though it would be more accurate to use MAC. """ # print("Ignoring get_flops") # Count the number of floating point operations here. It must be done before # torch scripting the different modules. self = self.train(mode=False) flops = FlopCountAnalysis( self, ( "none", # Quantization noise "hardround", # Quantizer type 0.3, # Soft round temperature 0.1, # Noise parameter -1, # AC_MAX_VAL False, # Flag additional outputs ), ) flops.unsupported_ops_warnings(False) flops.uncalled_modules_warnings(False) self.total_flops = flops.total() for k in self.flops_per_module: self.flops_per_module[k] = flops.by_module()[k] self.flops_str = flop_count_table(flops) del flops self = self.train(mode=True)
[docs] def get_network_rate(self) -> DescriptorCoolChic: """Return the rate (in bits) associated to the parameters (weights and biases) of the different modules Returns: DescriptorCoolChic: The rate (in bits) associated with the weights and biases of each module """ rate_per_module: DescriptorCoolChic = { module_name: {"weight": 0.0, "bias": 0.0} for module_name in self.modules_to_send } for module_name in self.modules_to_send: cur_module = getattr(self, module_name) rate_per_module[module_name] = measure_expgolomb_rate( cur_module, self.nn_q_step.get(module_name), self.nn_expgol_cnt.get(module_name), ) return rate_per_module
[docs] def get_network_quantization_step(self) -> DescriptorCoolChic: """Return the quantization step associated to the parameters (weights and biases) of the different modules. Those quantization can be ``None`` if the model has not yet been quantized. Returns: DescriptorCoolChic: The quantization step associated with the weights and biases of each module. """ return self.nn_q_step
[docs] def get_network_expgol_count(self) -> DescriptorCoolChic: """Return the Exp-Golomb count parameter associated to the parameters (weights and biases) of the different modules. Those quantization can be ``None`` if the model has not yet been quantized. Returns: DescriptorCoolChic: The Exp-Golomb count parameter associated with the weights and biases of each module. """ return self.nn_expgol_cnt
[docs] def str_complexity(self) -> str: """Return a string describing the number of MAC (**not mac per pixel**) and the number of parameters for the different modules of CoolChic Returns: str: A pretty string about CoolChic complexity. """ if not self.flops_str: self.get_flops() msg_total_mac = "----------------------------------\n" msg_total_mac += ( f"Total MAC / decoded pixel: {self.get_total_mac_per_pixel():.1f}" ) msg_total_mac += "\n----------------------------------" return self.flops_str + "\n\n" + msg_total_mac
[docs] def get_total_mac_per_pixel(self) -> float: """Count the number of Multiplication-Accumulation (MAC) per decoded pixel for this model. Returns: float: number of floating point operations per decoded pixel. """ if not self.flops_str: self.get_flops() n_pixels = self.param.img_size[-2] * self.param.img_size[-1] return self.total_flops / n_pixels
# ------- Useful functions
[docs] def to_device(self, device: POSSIBLE_DEVICE) -> None: """Push a model to a given device. Args: device (POSSIBLE_DEVICE): The device on which the model should run. """ assert device in typing.get_args( POSSIBLE_DEVICE ), f"Unknown device {device}, should be in {typing.get_args(POSSIBLE_DEVICE)}" self = self.to(device) # Push integerized weights and biases of the mlp (resp qw and qb) to # the required device for idx_layer, layer in enumerate(self.arm.mlp): if hasattr(layer, "qw"): if layer.qw is not None: self.arm.mlp[idx_layer].qw = layer.qw.to(device) if hasattr(layer, "qb"): if layer.qb is not None: self.arm.mlp[idx_layer].qb = layer.qb.to(device)
[docs] def pretty_string(self) -> str: """Get a pretty string representing the layer of a ``CoolChicEncoder``""" s = "" if not self.flops_str: self.get_flops() n_pixels = self.param.img_size[-2] * self.param.img_size[-1] total_mac_per_pix = self.get_total_mac_per_pixel() title = f"Cool-chic architecture {total_mac_per_pix:.0f} MAC / pixel" s += ( f"\n{title}\n" f"{'-' * len(title)}\n\n" ) complexity = self.flops_per_module['upsampling'] / n_pixels share_complexity = 100 * complexity / total_mac_per_pix title = f"Upsampling {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity" s += ( f"{title}\n" f"{'=' * len(title)}\n" "Note: all upsampling layers are separable and symmetric " "(transposed) convolutions.\n\n" ) s += pretty_string_ups(self.upsampling, "") complexity = self.flops_per_module['arm'] / n_pixels share_complexity = 100 * complexity / total_mac_per_pix title = f"ARM {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity" s += ( f"\n\n\n{title}\n" f"{'=' * len(title)}\n\n\n" ) input_arm = f"{self.arm.dim_arm}-pixel context" output_arm = "mu, log scale" s += pretty_string_nn( self.arm.mlp, "", input_arm, output_arm ) complexity = self.flops_per_module['synthesis'] / n_pixels share_complexity = 100 * complexity / total_mac_per_pix title = f"Synthesis {complexity:.0f} MAC/pixel ; {share_complexity:.1f} % of the complexity" s += ( f"\n\n\n{title}\n" f"{'=' * len(title)}\n\n\n" ) input_syn = f"{self.synthesis.input_ft} features" output_syn = "Decoded image" s += pretty_string_nn( self.synthesis.layers, "", input_syn, output_syn ) return s