# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
"""Gather the different encoding presets here."""
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Tuple
import typing
from enc.component.core.quantizer import (
POSSIBLE_QUANTIZATION_NOISE_TYPE,
POSSIBLE_QUANTIZER_TYPE,
)
MODULE_TO_OPTIMIZE = Literal["all", "arm", "upsampling", "synthesis", "latent"]
[docs]
@dataclass
class TrainerPhase:
"""Dataclass representing one phase of an encoding preset.
Args:
lr (float): Initial learning rate of the phase. Can vary if
``schedule_lr`` is True. Defaults to 0.01.
max_itr (int): Maximum number of iterations for the phase. The actual
number of iterations can be made smaller through the patience
mechanism. Defaults to 10000.
freq_valid: Check (and print) the performance
each ``frequency_validation`` iterations. This drives the patience
mechanism. Defaults to 100.
patience: After ``patience`` iterations without any
improvement to the results, exit the training. Patience is disabled
by setting ``patience = max_iterations``. If patience is used alongside
cosine_scheduling_lr, then it does not end the training. Instead,
we simply reload the best model so far once we reach the patience,
and the training continue. Defaults to 1000.
quantize_model (bool): If ``True``, quantize the neural networks
parameters at the end of the training phase. Defaults to ``False``.
schedule_lr (bool): If ``True``, the learning rate is no longer
constant. instead, it varies with a cosine scheduling, as suggested
in `C3: High-performance and low-complexity neural compression from
a single image or video, Kim et al.
<https://arxiv.org/abs/2312.02753>`_. Defaults to False.
softround_temperature (Tuple[float, float]). Start, end temperature of
the :doc:`softround function <../component/core/quantizer>`. It is
used in the forward / backward if ``quantizer_type`` is set to
``"softround"`` or ``"softround_alone"``. It is also used in the
backward pass if ``quantizer_type`` is set to ``"ste"``.
The softround temperature is linearly scheduled
during the training. At iteration n° 0 it is equal to
``softround_temperature[0]`` while at iteration n° ``max_itr`` it is
equal to ``softround_temperature[1]``. Note that the patience might
interrupt the training before it reaches this last value.
Defaults to (0.3, 0.3).
noise_parameter (Tuple[float, float]): The random noise temperature is
linearly scheduled during the training. At iteration n° 0 it is equal
to ``noise_parameter[0]`` while at iteration n° ``max_itr`` it is equal
to ``noise_parameter[1]``. Note that the patience might interrupt
the training before it reaches this last value. Defaults to (2.0,
1.0).
quantizer_noise_type (POSSIBLE_QUANTIZATION_NOISE_TYPE): The random noise
used by the quantizer. More information available in
:doc:`encoder/component/core/quantizer.py <../component/core/quantizer>`.
Defaults to ``"kumaraswamy"``.
quantizer_type (POSSIBLE_QUANTIZER_TYPE): What quantizer to
use during training. See
:doc:`encoder/component/core/quantizer.py <../component/core/quantizer>`
for more information. Defaults to ``"softround"``.
optimized_module (List[MODULE_TO_OPTIMIZE]): List of modules to be
optimized. Most often you'd want to use ``optimized_module = ['all']``.
Defaults to ``['all']``.
"""
lr: float = 1e-2
max_itr: int = 5000
freq_valid: int = 100
patience: int = 10000
quantize_model: bool = False
schedule_lr: bool = False
softround_temperature: Tuple[float, float] = (0.3, 0.3)
noise_parameter: Tuple[float, float] = (1.0, 1.0)
quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy"
quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround"
optimized_module: List[MODULE_TO_OPTIMIZE] = field(default_factory=lambda: ["all"])
def __post_init__(self):
# If all is present in the list of modules to be optimized, alongside something else,
# it overrides everything, leaving the list of modules to be optimized to just ['all'].
if "all" in self.optimized_module:
self.optimized_module == ["all"]
# Some checks about quantization options mismatch. They are done here
# to avoid doing it each time we do a forward pass inside the quantize
# function. Additionally, torch.compile messes up the assertion in the
# quantize function anyway.
assert self.quantizer_noise_type in typing.get_args(POSSIBLE_QUANTIZATION_NOISE_TYPE), (
f"quantizer_noise_type must be in {POSSIBLE_QUANTIZATION_NOISE_TYPE}"
f" found {self.quantizer_noise_type}"
)
assert self.quantizer_type in typing.get_args(POSSIBLE_QUANTIZER_TYPE), (
f"quantizer_type must be in {POSSIBLE_QUANTIZER_TYPE}"
f" found {self.quantizer_type}"
)
# If we use only the softround **alone**, or hardround we do not need
# any noise addition. Otherwise, we need a type of noise, i.e. either
# kumaraswamy or gaussian noise.
if self.quantizer_type in ["softround_alone", "hardround", "ste", "none"]:
assert self.quantizer_noise_type == "none", (
f"Using quantizer type {self.quantizer_type} does not require"
"to have any random noise.\n Switching the "
f"quantizer_noise_type from {self.quantizer_noise_type} to none."
)
else:
assert self.quantizer_noise_type != "none", (
"Using quantizer_noise_type = 'none' is only possible with "
"quantizer_type = 'softround_alone', 'ste' or 'hardround'.\n"
f"Trying to use {self.quantizer_type} quantizer which do require "
"some kind of random noise such as 'gaussian' or 'kumaraswamy'."
)
[docs]
def pretty_string(self) -> str:
"""Return a pretty string describing a warm-up phase"""
s = f'{f"{self.lr:1.2e}":^{14}}|'
s += f"{self.max_itr:^{9}}|"
s += f"{self.patience:^{16}}|"
s += f"{self.freq_valid:^{13}}|"
s += f"{self.quantize_model:^{13}}|"
s += f"{self.schedule_lr:^{13}}|"
softround_str= ', '.join([f'{x:1.1e}' for x in self.softround_temperature])
s += f'{f"{softround_str}":^{18}}|'
noise_str = ', '.join([f'{x:1.2f}' for x in self.noise_parameter])
s += f'{f"{noise_str}":^{14}}|'
return s
@classmethod
def _pretty_string_column_name(cls) -> str:
"""Return the name of the column aligned with the pretty_string function"""
s = f'{"Learn rate":^{14}}|'
s += f'{"Max itr":^{9}}|'
s += f'{"Patience [itr]":^{16}}|'
s += f'{"Valid [itr]":^{13}}|'
s += f'{"Quantize NN":^{13}}|'
s += f'{"Schedule lr":^{13}}|'
s += f'{"Softround Temp":^{18}}|'
s += f'{"Noise":^{14}}|'
return s
@classmethod
def _vertical_line_array(cls) -> str:
"""Return a string made of "-" and "+" matching the columns
of the print detailed above"""
s = '-' * 14 + '+'
s += '-' * 9 + '+'
s += '-' * 16 + '+'
s += '-' * 13 + '+'
s += '-' * 13 + '+'
s += '-' * 13 + '+'
s += '-' * 18 + '+'
s += '-' * 14 + '+'
return s
[docs]
@dataclass
class WarmupPhase:
"""Describe one phase of the :doc:`warm-up <../training/warmup>`. At the
beginning of each warm-up phase, we start by keeping the best ``candidates``
systems. We then perform a short training, and we go to the next phase.
Args:
candidates (int): How many candidates are kept at the beginning of the phase.
training_phase (TrainerPhase): Describe how the candidates are trained.
"""
candidates: int # Keep the first <candidates> best systems at the beginning of this warmup phase
training_phase: TrainerPhase
[docs]
def pretty_string(self) -> str:
"""Return a pretty string describing a warm-up phase"""
s = f"|{self.candidates:^{14}}|"
s += f"{self.training_phase.pretty_string()}"
return s
@classmethod
def _pretty_string_column_name(cls) -> str:
"""Return the name of the column aligned with the pretty_string function"""
s = f'|{"Candidates":^{14}}|'
s += f'{TrainerPhase._pretty_string_column_name()}'
return s
[docs]
@dataclass
class Warmup:
"""A :doc:`warm-up <../training/warmup>` is composed of different phases
where the worse candidates are successively eliminated.
Args:
phase (List[WarmupPhase]): The successive phases of the Warmup.
Defaults to ``[]``.
"""
phases: List[WarmupPhase] = field(default_factory=lambda: [])
def _get_total_warmup_iterations(self) -> int:
"""Return the total number of iterations for the whole warm-up."""
return sum(
[phase.candidates * phase.training_phase.max_itr for phase in self.phases]
)
[docs]
@dataclass
class Preset:
"""Dummy parent (abstract) class of all encoder presets. An actual preset
should inherit from this class.
Encoding preset defines how we encode each frame. They are similar to
conventional codecs presets *e.g* x264 ``--slow`` preset offers better
compression performance at the expense of a longer encoding.
Here a preset defines two things: how the :doc:`warm-up <../training/warmup>`
is done, and how the subsequent :doc:`training <../training/train>` is done.
Args:
preset_name (str): Name of the preset.
all_phases (List[TrainerPhase]): The successive (post warm-up) training
phase. Defaults to ``[]``.
warmup (Warmup): The warm-up parameters. Defaults to ``Warmup()``.
"""
preset_name: str
# Dummy empty training phases and warm-up
all_phases: List[TrainerPhase] = field(default_factory=lambda: []) # All the post-warm-up training phases
warmup: Warmup = field(default_factory=lambda: Warmup()) # All the warm-up phases
def __post_init__(self):
# Check that we do quantize the model at least once during the training
flag_quantize_model = False
for training_phase in self.all_phases:
if training_phase.quantize_model:
flag_quantize_model = True
# Ignore this assertion if there is no self.all_phases described
assert flag_quantize_model or len(self.all_phases) == 0, (
f"The selected preset ({self.preset_name}) does not include "
f" a training phase with neural network quantization.\n"
f"{self.pretty_string()}"
)
def _get_total_training_iterations(self) -> int:
"""Return the total number of iterations for the whole warm-up."""
return sum(
[phase.max_itr for phase in self.all_phases]
)
[docs]
def pretty_string(self) -> str:
"""Return a pretty string describing a warm-up phase"""
s = f"Preset: {self.preset_name:<10}\n"
s += "-------\n"
s += "\nWarm-up\n"
s += "-------\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += WarmupPhase._pretty_string_column_name() + "\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
for warmup_phase in self.warmup.phases:
s += warmup_phase.pretty_string() + "\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += "\nMain training\n"
s += "-------------\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += f'|{"Phase index":^14}|{TrainerPhase._pretty_string_column_name()}\n'
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
for idx, training_phase in enumerate(self.all_phases):
s += f"|{idx:^14}|{training_phase.pretty_string()}\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += "\nMaximum number of iterations (warm-up / training / total):"
warmup_max_itr = self.warmup._get_total_warmup_iterations()
training_max_itr = self._get_total_training_iterations()
total_max_itr =warmup_max_itr + training_max_itr
s += f"{warmup_max_itr:^8} / {training_max_itr:^8} / {total_max_itr:^8}\n\n"
return s
class PresetC3x(Preset):
def __init__(self, start_lr: float = 1e-2, n_itr_per_phase: int = 100000):
super().__init__(preset_name="c3x")
# 1st stage: with soft round and quantization noise
self.all_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=n_itr_per_phase,
patience=5000,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
),
# Stage with STE then network quantization
TrainerPhase(
lr=1.0e-4,
max_itr=1500,
patience=1500,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="ste",
quantizer_noise_type="none",
# This is only used to parameterize the backward of the quantization
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
quantize_model=True, # ! This is an important parameter
),
# Re-tune the latent
TrainerPhase(
lr=1.0e-4,
max_itr=1000,
patience=50,
quantizer_type="ste",
quantizer_noise_type="none",
optimized_module=["latent"], # ! Only fine tune the latent
freq_valid=10,
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
),
]
self.warmup = Warmup(
[
WarmupPhase(
candidates=5,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=400,
freq_valid=400,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
),
WarmupPhase(
candidates=2,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=400,
freq_valid=400,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
)
]
)
class PresetDebug(Preset):
"""Very fast training schedule, should only be used to ensure that the code works properly!"""
def __init__(self, start_lr: float = 1e-2, n_itr_per_phase: int = 100000):
super().__init__(preset_name="debug")
self.all_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=50,
patience=100000,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
)
]
self.all_phases.append(
TrainerPhase(
lr=1e-4,
max_itr=10,
patience=10,
optimized_module=["all"],
quantizer_type="ste",
quantizer_noise_type="none",
quantize_model=True,
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
)
)
self.all_phases.append(
TrainerPhase(
lr=1e-4,
max_itr=10,
patience=50,
optimized_module=["latent"],
freq_valid=5,
quantizer_type="ste",
quantizer_noise_type="none",
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
)
)
self.warmup = Warmup(
[
WarmupPhase(candidates=3, training_phase=TrainerPhase(max_itr=10)),
WarmupPhase(candidates=2, training_phase=TrainerPhase(max_itr=10)),
]
)
class PresetMeasureSpeed(Preset):
def __init__(self, start_lr: float = 1e-2, n_itr_per_phase: int = 100000):
super().__init__(preset_name="c3x")
# Single stage model with the shortest warm-up ever!
self.all_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=n_itr_per_phase,
patience=5000,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
quantize_model=True, # ! This is an important parameter
),
]
self.warmup = Warmup(
[
WarmupPhase(
candidates=1,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=1,
freq_valid=1,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
)
]
)
AVAILABLE_PRESETS: Dict[str, Preset] = {
"c3x": PresetC3x,
"debug": PresetDebug,
"measure_speed": PresetMeasureSpeed,
}