Source code for enc.nnquant.expgolomb

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md


from enc.component.types import DescriptorNN
from enc.nnquant.quantstep import get_q_step_from_parameter_name
import torch
from torch import Tensor, nn


POSSIBLE_EXP_GOL_COUNT = {
    "arm": {
        "weight": torch.linspace(0, 12, 13, device="cpu"),
        "bias": torch.linspace(0, 12, 13, device="cpu"),
    },
    "upsampling": {
        "weight": torch.linspace(0, 12, 13, device="cpu"),
        "bias": torch.linspace(0, 12, 13, device="cpu"),
    },
    "synthesis": {
        "weight": torch.linspace(0, 12, 13, device="cpu"),
        "bias": torch.linspace(0, 12, 13, device="cpu"),
    },
}


[docs] @torch.no_grad() def measure_expgolomb_rate( q_module: nn.Module, q_step: DescriptorNN, expgol_cnt: DescriptorNN ) -> DescriptorNN: """Get the rate associated with the current parameters. Returns: DescriptorNN: The rate of the different modules wrapped inside a dictionary of float. It does **not** return tensor so no back propagation is possible """ # Concatenate the sent parameters here to measure the entropy later sent_param: DescriptorNN = {"bias": [], "weight": []} rate_param: DescriptorNN = {"bias": 0.0, "weight": 0.0} param = q_module.get_param() # Retrieve all the sent item for parameter_name, parameter_value in param.items(): current_q_step = get_q_step_from_parameter_name(parameter_name, q_step) # Current quantization step is None because the module is not yet # quantized. Return an all zero rate if current_q_step is None: return rate_param # Quantization is round(parameter_value / q_step) * q_step so we divide by q_step # to obtain the sent latent. current_sent_param = (parameter_value / current_q_step).view(-1) if ".weight" in parameter_name: sent_param["weight"].append(current_sent_param) elif ".bias" in parameter_name: sent_param["bias"].append(current_sent_param) else: print( 'Parameter name should include ".weight" or ".bias" ' f"Found: {parameter_name}" ) return rate_param # For each sent parameters (e.g. all biases and all weights) # compute their cost with an exp-golomb coding. for k, v in sent_param.items(): # If we do not have any parameter, there is no rate associated. # This can happens for the upsampling biases for instance if len(v) == 0: rate_param[k] = 0.0 continue # Current exp-golomb count is None because the module is not yet # quantized. Return an all zero rate current_expgol_cnt = expgol_cnt[k] if current_expgol_cnt is None: return rate_param # Concatenate the list of parameters as a big one dimensional tensor v = torch.cat(v) # This will be pretty long! Could it be vectorized? rate_param[k] = exp_golomb_nbins(v, count=current_expgol_cnt) return rate_param
[docs] def exp_golomb_nbins(symbol: Tensor, count: int = 0) -> Tensor: """Compute the number of bits required to encode a Tensor of integers using an exponential-golomb code with exponent ``count``. Args: symbol: Tensor to encode count (int, optional): Exponent of the exp-golomb code. Defaults to 0. Returns: Number of bits required to encode all the symbols. """ # We encode the sign equiprobably at the end thus one more bit if symbol != 0 nbins = ( 2 * torch.floor(torch.log2(symbol.abs() / (2**count) + 1)) + count + 1 + (symbol != 0) ) res = nbins.sum() return res