Source code for enc.training.quantizemodel

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md

import itertools
import time
from typing import Optional, OrderedDict

import torch
from enc.utils.misc import exp_golomb_nbins
from enc.training.loss import loss_function
from enc.utils.manager import FrameEncoderManager
from enc.component.frame import FrameEncoder
from enc.utils.codingstructure import Frame
from enc.utils.misc import (
    MAX_AC_MAX_VAL,
    POSSIBLE_EXP_GOL_COUNT,
    POSSIBLE_Q_STEP,
    DescriptorNN,
    get_q_step_from_parameter_name,
)
from torch import Tensor


def _quantize_parameters(
    fp_param: OrderedDict[str, Tensor],
    q_step: DescriptorNN,
) -> Optional[OrderedDict[str, Tensor]]:
    """Quantize a dictionary of parameters fp_param with a given quantization
    step (e.g. one for bias one for the weight).
    Return None if quantization fails i.e. if round(param / q_step) is greater
    than MAC_AX_MAX_VAL.

    Args:
        fp_param (OrderedDict[str, Tensor]): Full precision parameter, usually
            the output of self.get_param() or self.named_parameters()
        q_step (DescriptorNN): A dictionary with one quantization step for the
            weight and one for the bias.

    Returns:
        Optional[OrderedDict[str, Tensor]]: The quantized parameters or None
            if quantization failed.
    """
    q_param = OrderedDict()
    for k, v in fp_param.items():
        current_q_step = get_q_step_from_parameter_name(k, q_step)
        sent_param = torch.round(v / current_q_step)

        if sent_param.abs().max() > MAX_AC_MAX_VAL:
            print(
                f"Sent param {k} exceed MAX_AC_MAX_VAL! Q step {current_q_step} too small."
            )
            return None

        q_param[k] = sent_param * current_q_step

    return q_param

[docs] @torch.no_grad() def quantize_model( frame_encoder: FrameEncoder, frame: Frame, frame_encoder_manager: FrameEncoderManager, ) -> FrameEncoder: """Quantize a ``FrameEncoder`` compressing a ``Frame`` under a rate constraint ``lmbda`` and return it. This function iterates on all the neural networks sent from the encoder to the decoder, listed in `frame_encoder.coolchic_encoder.modules_to_send`. For each module :math:`m`, we want to find the most suited pair of quantization steps for the weight and the biases :math:`(\\Delta_w^m, \\Delta_b^m)`. To do so, a greedy search is used where we quantize the weights and biases using all the possible pairs of quantization steps, and we compute the :doc`usual loss function <./loss>`. The loss measures the impact of the NN quantization steps :math:`(\\Delta_w^m, \\Delta_b^m)` on the MSE / rate of the decoded image and the rate of the NN.- In the end, we select the pair of quantization step minimizing the loss: .. math:: (\\Delta_w^m, \\Delta_b^m) = \\arg\\min ||\\mathbf{x} - \hat{\\mathbf{x}}||^2 + \\lambda (\\mathrm{R}(\hat{\\mathbf{x}}) + \\mathrm{R}_{NN}), \\text{ with } \\begin{cases} \\mathbf{x} & \\text{the original image}\\\\ \\hat{\\mathbf{x}} & \\text{the coded image}\\\\ \\mathrm{R}(\\hat{\\mathbf{x}}) & \\text{A measure of the rate of } \\hat{\\mathbf{x}} \\\\ \\mathrm{R}_{NN} & \\text{The rate of the neural networks} \\end{cases} Then we quantize the next module to be sent. .. warning:: The parameter ``frame_encoder_manager`` tracking the encoding time of the frame (``total_training_time_sec``) and the number of encoding iterations (``iterations_counter``) is modified ** in place** by this function. Args: frame_encoder: Model to be compressed. frame: Original frame to code, including its references. frame_encoder_manager: Contains (among other things) the rate constraint :math:`\\lambda` and description of the warm-up preset. It is also used to track the total encoding time and encoding iterations. Modified in place. Returns: Model with quantized parameters. """ start_time = time.time() frame_encoder.set_to_eval() # We have to quantize all the modules that we want to send module_to_quantize = { module_name: getattr(frame_encoder.coolchic_encoder, module_name) for module_name in frame_encoder.coolchic_encoder.modules_to_send } for module_name, cur_module in sorted(module_to_quantize.items()): # Start the RD optimization for the quantization step of each module with an # arbitrary high value for the RD cost. best_loss = 1e6 # All possible quantization steps for this module all_q_step = POSSIBLE_Q_STEP.get(module_name) all_expgol_cnt = POSSIBLE_EXP_GOL_COUNT.get(module_name) # Save full precision parameter. fp_param = cur_module.get_param() best_q_step = {} # Overall best expgol count for this module weights and biases final_best_expgol_cnt = {} for q_step_w, q_step_b in itertools.product(all_q_step.get("weight"), all_q_step.get("bias")): # Reset full precision parameters, set the quantization step # and quantize the model. current_q_step: DescriptorNN = {"weight": q_step_w, "bias": q_step_b} # Reset full precision parameter before quantizing q_param = _quantize_parameters(fp_param, current_q_step) # Quantization has failed if q_param is None: continue cur_module.set_param(q_param) # Plug the quantized module back into Cool-chic setattr(frame_encoder.coolchic_encoder, module_name, cur_module) frame_encoder.coolchic_encoder.nn_q_step[module_name] = current_q_step # Test Cool-chic performance with this quantization steps pair frame_encoder_out = frame_encoder.forward( reference_frames=[ref_i.data for ref_i in frame.refs_data], quantizer_noise_type="none", quantizer_type="hardround", AC_MAX_VAL=-1, flag_additional_outputs=False, ) param = cur_module.get_param() # Best exp-golomb count for this quantization step best_expgol_cnt = {} for weight_or_bias in ["weight", "bias"]: # Find the best exp-golomb count for this quantization step: cur_best_expgol_cnt = None # Arbitrarily high number cur_best_rate = 1e9 sent_param = [] for parameter_name, parameter_value in param.items(): # Quantization is round(parameter_value / q_step) * q_step so we divide by q_step # to obtain the sent latent. current_sent_param = (parameter_value / current_q_step.get(weight_or_bias)).view(-1) if parameter_name.endswith(weight_or_bias): sent_param.append(current_sent_param) # Integer, sent parameters v = torch.cat(sent_param) # Find the best expgol count for this weight for expgol_cnt in all_expgol_cnt.get(weight_or_bias): cur_rate = exp_golomb_nbins(v, count=expgol_cnt) if cur_rate < cur_best_rate: cur_best_rate = cur_rate cur_best_expgol_cnt = expgol_cnt best_expgol_cnt[weight_or_bias] = int(cur_best_expgol_cnt) frame_encoder.coolchic_encoder.nn_expgol_cnt[module_name] = best_expgol_cnt rate_mlp = 0.0 rate_per_module = frame_encoder.coolchic_encoder.get_network_rate() for _, module_rate in rate_per_module.items(): for _, param_rate in module_rate.items(): # weight, bias rate_mlp += param_rate loss_fn_output = loss_function( frame_encoder_out.decoded_image, frame_encoder_out.rate, frame.data.data, lmbda=frame_encoder_manager.lmbda, rate_mlp_bit=rate_mlp, compute_logs=True, ) # Store best quantization steps if loss_fn_output.loss < best_loss: best_loss = loss_fn_output.loss best_q_step = current_q_step final_best_expgol_cnt = best_expgol_cnt # Once we've tested all the possible quantization step and expgol_cnt, # quantize one last time with the best one we've found to actually use it. frame_encoder.coolchic_encoder.nn_q_step[module_name] = best_q_step frame_encoder.coolchic_encoder.nn_expgol_cnt[module_name] = final_best_expgol_cnt q_param = _quantize_parameters(fp_param, frame_encoder.coolchic_encoder.nn_q_step[module_name]) assert q_param is not None, ( "_quantize_parameters() failed with q_step " f"{frame_encoder.coolchic_encoder.nn_q_step[module_name]}" ) cur_module.set_param(q_param) # Plug the quantized module back into Cool-chic setattr(frame_encoder.coolchic_encoder, module_name, cur_module) time_nn_quantization = time.time() - start_time print(f"\nTime quantize_model(): {time_nn_quantization:4.1f} seconds\n") frame_encoder_manager.total_training_time_sec += time_nn_quantization return frame_encoder