# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
import math
import typing
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, TypedDict
from torch import nn, Tensor
import torch
from fvcore.nn import FlopCountAnalysis, flop_count_table
from enc.component.core.arm import (
Arm,
_get_neighbor,
_get_non_zero_pixel_ctx_index,
_laplace_cdf,
)
from enc.component.core.quantizer import (
POSSIBLE_QUANTIZATION_NOISE_TYPE,
POSSIBLE_QUANTIZER_TYPE,
quantize,
)
from enc.component.core.synthesis import Synthesis
from enc.component.core.upsampling import Upsampling
from enc.utils.misc import (
MAX_ARM_MASK_SIZE,
POSSIBLE_DEVICE,
DescriptorCoolChic,
DescriptorNN,
measure_expgolomb_rate,
)
"""A cool-chic encoder is composed of:
- A set of 2d hierarchical latent grids
- An auto-regressive probability module
- An upsampling module
- A synthesis.
At its core, it is a tool to compress any spatially organized signal, with
one or more features by representing it as a set of 2d entropy coding-friendly
latent grids. After upsampling, these latent grids allows to synthesize the
desired signal.
"""
"""Dataclass to store the parameters of CoolChicEncoder for one frame."""
[docs]
@dataclass
class CoolChicEncoderParameter:
"""Dataclass storing the parameters of a ``CoolChicEncoder``.
Args:
img_size (Tuple[int, int]): Height and width :math:`(H, W)` of the frame
to be coded
layers_synthesis (List[str]): Describes the architecture of the
synthesis transform. See the :doc:`synthesis documentation
<core/synthesis>` for more information.
n_ft_per_res (List[int]): Number of latent features for each latent
resolution *i.e.* ``n_ft_per_res[i]`` gives the number of channel
:math:`C_i` of the latent with resolution :math:`\\frac{H}{2^i},
\\frac{W}{2^i}`.
dim_arm (int, Optional): Number of context pixels for the ARM. Also
corresponds to the ARM hidden layer width. See the :doc:`ARM
documentation <core/arm>` for more information. Defaults to 24
n_hidden_layers_arm (int, Optional): Number of hidden layers in the
ARM. Set ``n_hidden_layers_arm = 0`` for a linear ARM. Defaults
to 2.
upsampling_kernel_size (int, Optional): Kernel size for the upsampler.
See the :doc:`upsampling documentation <core/upsampling>` for more
information. Defaults to 8.
static_upsampling_kernel (bool, Optional): Set this flag to ``True`` to
prevent learning the upsampling kernel. Defaults to ``False``.
encoder_gain (int, Optional): Multiply the latent by this value before
quantization. See the documentation of Cool-chic forward pass.
Defaults to 16.
"""
layers_synthesis: List[str]
n_ft_per_res: List[int]
dim_arm: int = 24
n_hidden_layers_arm: int = 2
upsampling_kernel_size: int = 8
static_upsampling_kernel: bool = False
encoder_gain: int = 16
# ==================== Not set by the init function ===================== #
#: Automatically computed, number of different latent resolutions
latent_n_grids: int = field(init=False)
#: Height and width :math:`(H, W)` of the frame to be coded. Must be
#: set using the ``set_image_size()`` function.
img_size: Optional[Tuple[int, int]] = field(init=False, default=None)
# ==================== Not set by the init function ===================== #
def __post_init__(self):
self.latent_n_grids = len(self.n_ft_per_res)
[docs]
def set_image_size(self, img_size: Tuple[int, int]) -> None:
"""Register the field self.img_size.
Args:
img_size: Height and width :math:`(H, W)` of the frame to be coded
"""
self.img_size = img_size
[docs]
def pretty_string(self) -> str:
"""Return a pretty string formatting the data within the class"""
ATTRIBUTE_WIDTH = 25
VALUE_WIDTH = 80
s = "CoolChicEncoderParameter value:\n"
s += "-------------------------------\n"
for k in fields(self):
s += f"{k.name:<{ATTRIBUTE_WIDTH}}: {str(getattr(self, k.name)):<{VALUE_WIDTH}}\n"
s += "\n"
return s
[docs]
class CoolChicEncoderOutput(TypedDict):
"""``TypedDict`` representing the output of CoolChicEncoder forward.
Args:
raw_out (Tensor): Output of the synthesis :math:`([B, C, H, W])`.
rate (Tensor): rate associated to each latent (in bits). Shape is
:math:`(N)`, with :math:`N` the total number of latent variables.
additional_data (Dict[str, Any]): Any other data required to compute
some logs, stored inside a dictionary
"""
raw_out: Tensor
rate: Tensor
additional_data: Dict[str, Any]
[docs]
class CoolChicEncoder(nn.Module):
"""CoolChicEncoder for a single frame."""
[docs]
def __init__(self, param: CoolChicEncoderParameter):
"""Instantiate a cool-chic encoder for one frame.
Args:
param (CoolChicEncoderParameter): Architecture of the
`CoolChicEncoder`. See the documentation of
`CoolChicEncoderParameter` for more information
"""
super().__init__()
# Everything is stored inside param
self.param = param
assert self.param.img_size is not None, (
"You are trying to instantiate a CoolChicEncoder from a "
"CoolChicEncoderParameter with a field img_size set to None. Use "
"the function coolchic_encoder_param.set_img_size((H, W)) before "
"instantiating the CoolChicEncoder."
)
# ================== Synthesis related stuff ================= #
# Encoder-side latent gain applied prior to quantization, one per feature
self.encoder_gains = param.encoder_gain
# Populate the successive grids
self.size_per_latent = []
self.latent_grids = nn.ParameterList()
for i in range(self.param.latent_n_grids):
h_grid, w_grid = [int(math.ceil(x / (2**i))) for x in self.param.img_size]
c_grid = self.param.n_ft_per_res[i]
cur_size = (1, c_grid, h_grid, w_grid)
self.size_per_latent.append(cur_size)
# Instantiate empty tensor, we fill them later on with the function
# self.initialize_latent_grids()
self.latent_grids.append(
nn.Parameter(torch.empty(cur_size), requires_grad=True)
)
self.initialize_latent_grids()
# Instantiate the synthesis MLP with as many inputs as the number
# of latent channels
self.synthesis = Synthesis(
sum([latent_size[1] for latent_size in self.size_per_latent]),
self.param.layers_synthesis,
)
# ================== Synthesis related stuff ================= #
# ===================== Upsampling stuff ===================== #
self.upsampling = Upsampling(
self.param.upsampling_kernel_size, self.param.static_upsampling_kernel
)
# ===================== Upsampling stuff ===================== #
# ===================== ARM related stuff ==================== #
# Create the probability model for the main INR. It uses a spatial context
# parameterized by the spatial context
# For a given mask size N (odd number e.g. 3, 5, 7), we have at most
# (N * N - 1) / 2 context pixels in it.
# Example, a 9x9 mask as below has 40 context pixel (indicated with 1s)
# available to predict the pixel '*'
# 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1
# 1 1 1 1 * 0 0 0 0
# 0 0 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0 0 0
# No more than 40 context pixels i.e. a 9x9 mask size (see example above)
max_mask_size = MAX_ARM_MASK_SIZE
max_context_pixel = int((max_mask_size**2 - 1) / 2)
assert self.param.dim_arm <= max_context_pixel, (
f"You can not have more context pixels "
f" than {max_context_pixel}. Found {self.param.dim_arm}"
)
# Mask of size 2N + 1 when we have N rows & columns of context.
self.mask_size = max_mask_size
# 1D tensor containing the indices of the selected context pixels.
# register_buffer for automatic device management. We set persistent to false
# to simply use the "automatically move to device" function, without
# considering non_zero_pixel_ctx_index as a parameters (i.e. returned
# by self.parameters())
self.register_buffer(
"non_zero_pixel_ctx_index",
_get_non_zero_pixel_ctx_index(self.param.dim_arm),
persistent=False,
)
self.arm = Arm(self.param.dim_arm, self.param.n_hidden_layers_arm)
# ===================== ARM related stuff ==================== #
# ======================== Monitoring ======================== #
# Pretty string representing the decoder complexity
self.flops_str = ""
# Total number of multiplications to decode the image
self.total_flops = 0.0
# Fill the two attributes aboves
self.get_flops()
# ======================== Monitoring ======================== #
# Something like ['arm', 'synthesis', 'upsampling']
self.modules_to_send = [tmp.name for tmp in fields(DescriptorCoolChic)]
# Track the quantization step of each neural network, None if the
# module is not yet quantized
self.nn_q_step: Dict[str, DescriptorNN] = {
k: {"weight": None, "bias": None} for k in self.modules_to_send
}
# Track the exponent of the exp-golomb code used for the NN parameters.
# None if module is not yet quantized
self.nn_expgol_cnt: Dict[str, DescriptorNN] = {
k: {"weight": None, "bias": None} for k in self.modules_to_send
}
# Copy of the full precision parameters, set just before calling the
# quantize_model() function. This is done through the
# self._store_full_precision_param() function
self.full_precision_param = None
# ------- Actual forward
[docs]
def forward(
self,
quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy",
quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround",
soft_round_temperature: Optional[float] = 0.3,
noise_parameter: Optional[float] = 1.0,
AC_MAX_VAL: int = -1,
flag_additional_outputs: bool = False,
) -> CoolChicEncoderOutput:
"""Perform CoolChicEncoder forward pass, to be used during the training.
The main step are as follows:
1. **Scale & quantize the encoder-side latent** :math:`\\mathbf{y}` to
get the decoder-side latent
.. math::
\\hat{\\mathbf{y}} = \\mathrm{Q}(\\Gamma_{enc}\\ \\mathbf{y}),
with :math:`\\Gamma_{enc} \\in \\mathbb{R}` a scalar encoder gain
defined in ``self.param.encoder_gains`` and :math:`\\mathrm{Q}`
the :doc:`quantization operation <core/quantizer>`.
2. **Measure the rate** of the decoder-side latent with the
:doc:`ARM <core/arm>`:
.. math::
\\mathrm{R}(\\hat{\\mathbf{y}}) = -\\log_2 p_{\\psi}(\\hat{\\mathbf{y}}),
where :math:`p_{\\psi}`
is given by the :doc:`Auto-Regressive Module (ARM) <core/arm>`.
3. **Upsample and synthesize** the latent to get the output
.. math::
\\hat{\\mathbf{x}} = f_{\\theta}(f_{\\upsilon}(\\hat{\\mathbf{y}})),
with :math:`f_{\\psi}` the :doc:`Upsampling <core/upsampling>`
and :math:`f_{\\theta}` the :doc:`Synthesis <core/synthesis>`.
Args:
quantizer_noise_type: Defaults to ``"kumaraswamy"``.
quantizer_type: Defaults to ``"softround"``.
soft_round_temperature: Soft round temperature.
This is used for softround modes as well as the
ste mode to simulate the derivative in the backward.
Defaults to 0.3.
noise_parameter: noise distribution parameter. Defaults to 1.0.
AC_MAX_VAL: If different from -1, clamp the value to be in
:math:`[-AC\\_MAX\\_VAL; AC\\_MAX\\_VAL + 1]` to write the actual bitstream.
Defaults to -1.
flag_additional_outputs: True to fill
``CoolChicEncoderOutput['additional_data']`` with many different
quantities which can be used to analyze Cool-chic behavior.
Defaults to False.
Returns:
Output of Cool-chic training forward pass.
"""
# ! Order of the operations are important as these are asynchronous
# ! CUDA operations. Some ordering are faster than other...
# ------ Encoder-side: quantize the latent
# Convert the N [1, C, H_i, W_i] 4d latents with different resolutions
# to a single flat vector. This allows to call the quantization
# only once, which is faster
encoder_side_flat_latent = torch.cat(
[latent_i.view(-1) for latent_i in self.latent_grids]
)
flat_decoder_side_latent = quantize(
encoder_side_flat_latent * self.encoder_gains,
quantizer_noise_type if self.training else "none",
quantizer_type if self.training else "hardround",
soft_round_temperature,
noise_parameter,
)
# Clamp latent if we need to write a bitstream
if AC_MAX_VAL != -1:
flat_decoder_side_latent = torch.clamp(
flat_decoder_side_latent, -AC_MAX_VAL, AC_MAX_VAL + 1
)
# Convert back the 1d tensor to a list of N [1, C, H_i, W_i] 4d latents.
# This require a few additional information about each individual
# latent dimension, stored in self.size_per_latent
decoder_side_latent = []
cnt = 0
for latent_size in self.size_per_latent:
b, c, h, w = latent_size # b should be one
latent_numel = b * c * h * w
decoder_side_latent.append(
flat_decoder_side_latent[cnt : cnt + latent_numel].view(latent_size)
)
cnt += latent_numel
# ----- ARM to estimate the distribution and the rate of each latent
# As for the quantization, we flatten all the latent and their context
# so that the ARM network is only called once.
# flat_latent: [N, 1] tensor describing N latents
# flat_context: [N, context_size] tensor describing each latent context
# Get all the context as a single 2D vector of size [B, context size]
flat_context = torch.cat(
[
_get_neighbor(spatial_latent_i, self.mask_size, self.non_zero_pixel_ctx_index)
for spatial_latent_i in decoder_side_latent
],
dim=0,
)
# Get all the B latent variables as a single one dimensional vector
flat_latent = torch.cat(
[spatial_latent_i.view(-1) for spatial_latent_i in decoder_side_latent],
dim=0
)
# Feed the spatial context to the arm MLP and get mu and scale
flat_mu, flat_scale, flat_log_scale = self.arm(flat_context)
# Compute the rate (i.e. the entropy of flat latent knowing mu and scale)
proba = torch.clamp_min(
_laplace_cdf(flat_latent + 0.5, flat_mu, flat_scale)
- _laplace_cdf(flat_latent - 0.5, flat_mu, flat_scale),
min=2**-16, # No value can cost more than 16 bits.
)
flat_rate = -torch.log2(proba)
# Upsampling and synthesis to get the output
synthesis_output = self.synthesis(self.upsampling(decoder_side_latent))
additional_data = {}
if flag_additional_outputs:
# Prepare list to accommodate the visualisations
additional_data["detailed_sent_latent"] = []
additional_data["detailed_mu"] = []
additional_data["detailed_scale"] = []
additional_data["detailed_log_scale"] = []
additional_data["detailed_rate_bit"] = []
additional_data["detailed_centered_latent"] = []
# "Pointer" for the reading of the 1D scale, mu and rate
cnt = 0
# for i, _ in enumerate(filtered_latent):
for index_latent_res, _ in enumerate(self.latent_grids):
c_i, h_i, w_i = decoder_side_latent[index_latent_res].size()[-3:]
additional_data["detailed_sent_latent"].append(
decoder_side_latent[index_latent_res].view((1, c_i, h_i, w_i))
)
# Scale, mu and rate are 1D tensors where the N latent grids
# are flattened together. As such we have to read the appropriate
# number of values in this 1D vector to reconstruct the i-th grid in 2D
mu_i, scale_i, log_scale_i, rate_i = [
# Read h_i * w_i values starting from cnt
tmp[cnt : cnt + (c_i * h_i * w_i)].view((1, c_i, h_i, w_i))
for tmp in [flat_mu, flat_scale, flat_log_scale, flat_rate]
]
cnt += c_i * h_i * w_i
additional_data["detailed_mu"].append(mu_i)
additional_data["detailed_scale"].append(scale_i)
additional_data["detailed_log_scale"].append(log_scale_i)
additional_data["detailed_rate_bit"].append(rate_i)
additional_data["detailed_centered_latent"].append(
additional_data["detailed_sent_latent"][-1] - mu_i
)
res: CoolChicEncoderOutput = {
"raw_out": synthesis_output,
"rate": flat_rate,
"additional_data": additional_data,
}
return res
# ------- Getter / Setter and Initializer
[docs]
def get_param(self) -> OrderedDict[str, Tensor]:
"""Return **a copy** of the weights and biases inside the module.
Returns:
OrderedDict[str, Tensor]: A copy of all weights & biases in the module.
"""
param = OrderedDict({})
param.update(
{
# Detach & clone to create a copy
f"latent_grids.{k}": v.detach().clone()
for k, v in self.latent_grids.named_parameters()
}
)
param.update({f"arm.{k}": v for k, v in self.arm.get_param().items()})
param.update(
{f"upsampling.{k}": v for k, v in self.upsampling.get_param().items()}
)
param.update(
{f"synthesis.{k}": v for k, v in self.synthesis.get_param().items()}
)
return param
[docs]
def set_param(self, param: OrderedDict[str, Tensor]):
"""Replace the current parameters of the module with param.
Args:
param (OrderedDict[str, Tensor]): Parameters to be set.
"""
self.load_state_dict(param)
[docs]
def initialize_latent_grids(self) -> None:
"""Initialize the latent grids. The different tensors composing
the latent grids must have already been created e.g. through
``torch.empty()``.
"""
for latent_index, latent_value in enumerate(self.latent_grids):
self.latent_grids[latent_index] = nn.Parameter(
torch.zeros_like(latent_value), requires_grad=True
)
[docs]
def reinitialize_parameters(self):
"""Reinitialize in place the different parameters of a CoolChicEncoder
namely the latent grids, the arm, the upsampling and the weights.
"""
self.arm.reinitialize_parameters()
self.upsampling.reinitialize_parameters()
self.synthesis.reinitialize_parameters()
self.initialize_latent_grids()
# Reset the quantization steps and exp-golomb count of the neural
# network to None since we are resetting the parameters.
self.nn_q_step: Dict[str, DescriptorNN] = {
k: {"weight": None, "bias": None} for k in self.modules_to_send
}
self.nn_expgol_cnt: Dict[str, DescriptorNN] = {
k: {"weight": None, "bias": None} for k in self.modules_to_send
}
def _store_full_precision_param(self) -> None:
"""Store the current parameters inside self.full_precision_param
This function checks that there is no self.nn_q_step and
self.nn_expgol_cnt already saved. This would mean that we no longer
have full precision parameters but quantized ones.
"""
if self.full_precision_param is not None:
print(
"Warning: overwriting already saved full-precision parameters"
" in CoolChicEncoder _store_full_precision_param()."
)
# Check that we haven't already quantized the network by looking at
# the nn_expgol_cnt and nn_q_step dictionaries
no_q_step = True
for _, q_step_dict in self.nn_q_step.items():
for _, q_step in q_step_dict.items():
if q_step is not None:
no_q_step = False
assert no_q_step, (
"Trying to store full precision parameters, while CoolChicEncoder "
"nn_q_step attributes is not full of None. This means that the "
"parameters have already been quantized... aborting!"
)
no_expgol_cnt = True
for _, expgol_cnt_dict in self.nn_expgol_cnt.items():
for _, expgol_cnt in expgol_cnt_dict.items():
if expgol_cnt is not None:
no_expgol_cnt = False
assert no_expgol_cnt, (
"Trying to store full precision parameters, while CoolChicEncoder "
"nn_expgol_cnt attributes is not full of None. This means that the "
"parameters have already been quantized... aborting!"
)
# All good, simply save the parameters
self.full_precision_param = self.get_param()
def _load_full_precision_param(self) -> None:
assert self.full_precision_param is not None, (
"Trying to load full precision parameters but "
"self.full_precision_param is None"
)
self.set_param(self.full_precision_param)
# Reset the side information about the quantization step and expgol cnt
# so that the rate is no longer computed by the test() function.
self.nn_q_step: Dict[str, DescriptorNN] = {
k: {"weight": None, "bias": None} for k in self.modules_to_send
}
self.nn_expgol_cnt: Dict[str, DescriptorNN] = {
k: {"weight": None, "bias": None} for k in self.modules_to_send
}
# ------- Get flops, neural network rates and quantization step
[docs]
def get_flops(self) -> None:
"""Compute the number of MAC & parameters for the model.
Update ``self.total_flops`` (integer describing the number of total MAC)
and ``self.flops_str``, a pretty string allowing to print the model
complexity somewhere.
.. attention::
``fvcore`` measures MAC (multiplication & accumulation) but calls it
FLOP (floating point operation)... We do the same here and call
everything FLOP even though it would be more accurate to use MAC.
"""
# print("Ignoring get_flops")
# Count the number of floating point operations here. It must be done before
# torch scripting the different modules.
flops = FlopCountAnalysis(
self,
(
"none", # Quantization noise
"hardround", # Quantizer type
0.3, # Soft round temperature
0.1, # Noise parameter
-1, # AC_MAX_VAL
False, # Flag additional outputs
),
)
flops.unsupported_ops_warnings(False)
flops.uncalled_modules_warnings(False)
self.total_flops = flops.total()
self.flops_str = flop_count_table(flops)
del flops
[docs]
def get_network_rate(self) -> DescriptorCoolChic:
"""Return the rate (in bits) associated to the parameters
(weights and biases) of the different modules
Returns:
DescriptorCoolChic: The rate (in bits) associated with the weights
and biases of each module
"""
rate_per_module: DescriptorCoolChic = {
module_name: {"weight": 0.0, "bias": 0.0}
for module_name in self.modules_to_send
}
for module_name in self.modules_to_send:
cur_module = getattr(self, module_name)
rate_per_module[module_name] = measure_expgolomb_rate(
cur_module,
self.nn_q_step.get(module_name),
self.nn_expgol_cnt.get(module_name),
)
return rate_per_module
[docs]
def get_network_quantization_step(self) -> DescriptorCoolChic:
"""Return the quantization step associated to the parameters (weights
and biases) of the different modules. Those quantization can be
``None`` if the model has not yet been quantized.
Returns:
DescriptorCoolChic: The quantization step associated with the
weights and biases of each module.
"""
return self.nn_q_step
[docs]
def get_network_expgol_count(self) -> DescriptorCoolChic:
"""Return the Exp-Golomb count parameter associated to the parameters
(weights and biases) of the different modules. Those quantization can be
``None`` if the model has not yet been quantized.
Returns:
DescriptorCoolChic: The Exp-Golomb count parameter associated
with the weights and biases of each module.
"""
return self.nn_expgol_cnt
[docs]
def str_complexity(self) -> str:
"""Return a string describing the number of MAC (**not mac per pixel**) and the
number of parameters for the different modules of CoolChic
Returns:
str: A pretty string about CoolChic complexity.
"""
if not self.flops_str:
self.get_flops()
msg_total_mac = "----------------------------------\n"
msg_total_mac += (
f"Total MAC / decoded pixel: {self.get_total_mac_per_pixel():.1f}"
)
msg_total_mac += "\n----------------------------------"
return self.flops_str + "\n\n" + msg_total_mac
[docs]
def get_total_mac_per_pixel(self) -> float:
"""Count the number of Multiplication-Accumulation (MAC) per decoded pixel
for this model.
Returns:
float: number of floating point operations per decoded pixel.
"""
if not self.flops_str:
self.get_flops()
n_pixels = self.param.img_size[-2] * self.param.img_size[-1]
return self.total_flops / n_pixels
# ------- Useful functions
[docs]
def to_device(self, device: POSSIBLE_DEVICE) -> None:
"""Push a model to a given device.
Args:
device (POSSIBLE_DEVICE): The device on which the model should run.
"""
assert device in typing.get_args(
POSSIBLE_DEVICE
), f"Unknown device {device}, should be in {typing.get_args(POSSIBLE_DEVICE)}"
self = self.to(device)
# Push integerized weights and biases of the mlp (resp qw and qb) to
# the required device
for idx_layer, layer in enumerate(self.arm.mlp):
if hasattr(layer, "qw"):
if layer.qw is not None:
self.arm.mlp[idx_layer].qw = layer.qw.to(device)
if hasattr(layer, "qb"):
if layer.qb is not None:
self.arm.mlp[idx_layer].qb = layer.qb.to(device)