# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
"""Gather the different encoding presets here."""
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Tuple
import typing
from enc.component.core.quantizer import (
POSSIBLE_QUANTIZATION_NOISE_TYPE,
POSSIBLE_QUANTIZER_TYPE,
)
from enc.component.types import NAME_COOLCHIC_ENC
MODULE_TO_OPTIMIZE = Literal[
# All combinations of <coolchic_enc_name>.<module_name>
# with all designing either all coolchic encoders or all
# modules e.g.
# - "residue.all": train the entire Cool-chic residue
# - "motion.arm": train only the ARM of the Cool-chic motion
# - "all.latent": train the latent of all Cool-chic
tuple(
[
f"{cc_name}.{mod_name}"
for cc_name in list(typing.get_args(NAME_COOLCHIC_ENC)) + ["all"]
for mod_name in ["all", "arm", "upsampling", "synthesis", "latent"]
]
+
# Train everything
["all"]
)
]
[docs]
@dataclass
class TrainerPhase:
"""Dataclass representing one phase of an encoding preset.
Args:
lr (float): Initial learning rate of the phase. Can vary if
``schedule_lr`` is True. Defaults to 0.01.
max_itr (int): Maximum number of iterations for the phase. The actual
number of iterations can be made smaller through the patience
mechanism. Defaults to 10000.
freq_valid: Check (and print) the performance
each ``frequency_validation`` iterations. This drives the patience
mechanism. Defaults to 100.
patience: After ``patience`` iterations without any
improvement to the results, exit the training. Patience is disabled
by setting ``patience = max_iterations``. If patience is used alongside
cosine_scheduling_lr, then it does not end the training. Instead,
we simply reload the best model so far once we reach the patience,
and the training continue. Defaults to 1000.
quantize_model (bool): If ``True``, quantize the neural networks
parameters at the end of the training phase. Defaults to ``False``.
schedule_lr (bool): If ``True``, the learning rate is no longer
constant. instead, it varies with a cosine scheduling, as suggested
in `C3: High-performance and low-complexity neural compression from
a single image or video, Kim et al.
<https://arxiv.org/abs/2312.02753>`_. Defaults to False.
softround_temperature (Tuple[float, float]). Start, end temperature of
the :doc:`softround function <../component/core/quantizer>`. It is
used in the forward / backward if ``quantizer_type`` is set to
``"softround"`` or ``"softround_alone"``. It is also used in the
backward pass if ``quantizer_type`` is set to ``"ste"``.
The softround temperature is linearly scheduled
during the training. At iteration n° 0 it is equal to
``softround_temperature[0]`` while at iteration n° ``max_itr`` it is
equal to ``softround_temperature[1]``. Note that the patience might
interrupt the training before it reaches this last value.
Defaults to (0.3, 0.3).
noise_parameter (Tuple[float, float]): The random noise temperature is
linearly scheduled during the training. At iteration n° 0 it is equal
to ``noise_parameter[0]`` while at iteration n° ``max_itr`` it is equal
to ``noise_parameter[1]``. Note that the patience might interrupt
the training before it reaches this last value. Defaults to (2.0,
1.0).
quantizer_noise_type (POSSIBLE_QUANTIZATION_NOISE_TYPE): The random noise
used by the quantizer. More information available in
:doc:`encoder/component/core/quantizer.py <../component/core/quantizer>`.
Defaults to ``"kumaraswamy"``.
quantizer_type (POSSIBLE_QUANTIZER_TYPE): What quantizer to
use during training. See
:doc:`encoder/component/core/quantizer.py <../component/core/quantizer>`
for more information. Defaults to ``"softround"``.
optimized_module (List[MODULE_TO_OPTIMIZE]): List of modules to be
optimized. Most often you'd want to use ``optimized_module = ['all']``.
Defaults to ``['all']``.
"""
lr: float = 1e-2
max_itr: int = 5000
freq_valid: int = 100
patience: int = 10000
quantize_model: bool = False
schedule_lr: bool = False
softround_temperature: Tuple[float, float] = (0.3, 0.3)
noise_parameter: Tuple[float, float] = (1.0, 1.0)
quantizer_noise_type: POSSIBLE_QUANTIZATION_NOISE_TYPE = "kumaraswamy"
quantizer_type: POSSIBLE_QUANTIZER_TYPE = "softround"
optimized_module: List[MODULE_TO_OPTIMIZE] = field(default_factory=lambda: ["all"])
def __post_init__(self):
# If all is present in the list of modules to be optimized, alongside something else,
# it overrides everything, leaving the list of modules to be optimized to just ['all'].
if "all" in self.optimized_module:
self.optimized_module == ["all"]
for cur_module in self.optimized_module:
assert cur_module in list(typing.get_args(MODULE_TO_OPTIMIZE)), (
f"Module to optimize unknown: found {cur_module}. Available: "
f"{list(typing.get_args(MODULE_TO_OPTIMIZE))}."
)
[docs]
def pretty_string(self) -> str:
"""Return a pretty string describing a warm-up phase"""
s = f'{f"{self.lr:1.2e}":^{14}}|'
s += f"{' '.join(self.optimized_module):^{20}}|"
s += f"{self.max_itr:^{9}}|"
s += f"{self.patience:^{16}}|"
s += f"{self.freq_valid:^{13}}|"
s += f"{self.quantize_model:^{13}}|"
s += f"{self.schedule_lr:^{13}}|"
softround_str= ', '.join([f'{x:1.1e}' for x in self.softround_temperature])
s += f'{f"{softround_str}":^{18}}|'
noise_str = ', '.join([f'{x:1.2f}' for x in self.noise_parameter])
s += f'{f"{noise_str}":^{14}}|'
return s
@classmethod
def _pretty_string_column_name(cls) -> str:
"""Return the name of the column aligned with the pretty_string function"""
s = f'{"Learn rate":^{14}}|'
s += f'{"Module optimized":^{20}}|'
s += f'{"Max itr":^{9}}|'
s += f'{"Patience [itr]":^{16}}|'
s += f'{"Valid [itr]":^{13}}|'
s += f'{"Quantize NN":^{13}}|'
s += f'{"Schedule lr":^{13}}|'
s += f'{"Softround Temp":^{18}}|'
s += f'{"Noise":^{14}}|'
return s
@classmethod
def _vertical_line_array(cls) -> str:
"""Return a string made of "-" and "+" matching the columns
of the print detailed above"""
s = '-' * 14 + '+'
s += '-' * 20 + '+'
s += '-' * 9 + '+'
s += '-' * 16 + '+'
s += '-' * 13 + '+'
s += '-' * 13 + '+'
s += '-' * 13 + '+'
s += '-' * 18 + '+'
s += '-' * 14 + '+'
return s
[docs]
@dataclass
class WarmupPhase:
"""Describe one phase of the :doc:`warm-up <../training/warmup>`. At the
beginning of each warm-up phase, we start by keeping the best ``candidates``
systems. We then perform a short training, and we go to the next phase.
Args:
candidates (int): How many candidates are kept at the beginning of the phase.
training_phase (TrainerPhase): Describe how the candidates are trained.
"""
candidates: int # Keep the first <candidates> best systems at the beginning of this warmup phase
training_phase: TrainerPhase
[docs]
def pretty_string(self) -> str:
"""Return a pretty string describing a warm-up phase"""
s = f"|{self.candidates:^{14}}|"
s += f"{self.training_phase.pretty_string()}"
return s
@classmethod
def _pretty_string_column_name(cls) -> str:
"""Return the name of the column aligned with the pretty_string function"""
s = f'|{"Candidates":^{14}}|'
s += f'{TrainerPhase._pretty_string_column_name()}'
return s
[docs]
@dataclass
class Warmup:
"""A :doc:`warm-up <../training/warmup>` is composed of different phases
where the worse candidates are successively eliminated.
Args:
phase (List[WarmupPhase]): The successive phases of the Warmup.
Defaults to ``[]``.
"""
phases: List[WarmupPhase] = field(default_factory=lambda: [])
def _get_total_warmup_iterations(self) -> int:
"""Return the total number of iterations for the whole warm-up."""
return sum(
[phase.candidates * phase.training_phase.max_itr for phase in self.phases]
)
[docs]
@dataclass
class Preset:
"""Dummy parent (abstract) class of all encoder presets. An actual preset
should inherit from this class.
Encoding preset defines how we encode each frame. They are similar to
conventional codecs presets *e.g* x264 ``--slow`` preset offers better
compression performance at the expense of a longer encoding.
Here a preset defines two things: how the :doc:`warm-up <../training/warmup>`
is done, and how the subsequent :doc:`training <../training/train>` is done.
Args:
preset_name (str): Name of the preset.
training_phases (List[TrainerPhase]): The successive (post warm-up) training
phase. Defaults to ``[]``.
warmup (Warmup): The warm-up parameters. Defaults to ``Warmup()``.
"""
preset_name: str
# Dummy empty training phases and warm-up
motion_pretrain_phase: List[TrainerPhase] = field(default_factory=lambda: []) # All the phases to pre-train the motion
warmup: Warmup = field(default_factory=lambda: Warmup()) # All the warm-up phases
training_phases: List[TrainerPhase] = field(default_factory=lambda: []) # All the post-warm-up training phases
def __post_init__(self):
# Check that we do quantize the model at least once during the training
flag_quantize_model = False
for training_phase in self.training_phases:
if training_phase.quantize_model:
flag_quantize_model = True
# Ignore this assertion if there is no self.training_phases described
assert flag_quantize_model or len(self.training_phases) == 0, (
f"The selected preset ({self.preset_name}) does not include "
f" a training phase with neural network quantization.\n"
f"{self.pretty_string()}"
)
def _get_total_training_iterations(self, train_phases: List[TrainerPhase]) -> int:
"""Return the total number of iterations for the whole warm-up."""
return sum([phase.max_itr for phase in train_phases])
[docs]
def pretty_string(self) -> str:
"""Return a pretty string describing a warm-up phase"""
s = f"Preset: {self.preset_name:<10}\n"
s += "-------\n"
if len(self.motion_pretrain_phase) > 0:
s += "\nMotion pre-training\n"
s += "-------------------\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += f'|{"Phase index":^14}|{TrainerPhase._pretty_string_column_name()}\n'
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
for idx, training_phase in enumerate(self.motion_pretrain_phase):
s += f"|{idx:^14}|{training_phase.pretty_string()}\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += "\nWarm-up\n"
s += "-------\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += WarmupPhase._pretty_string_column_name() + "\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
for warmup_phase in self.warmup.phases:
s += warmup_phase.pretty_string() + "\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += "\nMain training\n"
s += "-------------\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += f'|{"Phase index":^14}|{TrainerPhase._pretty_string_column_name()}\n'
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
for idx, training_phase in enumerate(self.training_phases):
s += f"|{idx:^14}|{training_phase.pretty_string()}\n"
s += "+" + "-" * 14 + "+" + TrainerPhase._vertical_line_array() + "\n"
s += "\nMaximum number of iterations (motion / warm-up / training / total):"
motion_max_itr = self._get_total_training_iterations(self.motion_pretrain_phase)
warmup_max_itr = self.warmup._get_total_warmup_iterations()
training_max_itr = self._get_total_training_iterations(self.training_phases)
total_max_itr = motion_max_itr + warmup_max_itr + training_max_itr
s += (
f"{motion_max_itr:^8} / "
f"{warmup_max_itr:^8} / "
f"{training_max_itr:^8} / "
f"{total_max_itr:^8}\n"
)
return s
class PresetC3xIntra(Preset):
def __init__(
self,
start_lr: float = 1e-2,
itr_main_training: int = 100000,
itr_motion_pretrain: int = 1000,
):
super().__init__(preset_name="c3x_intra")
# 1st stage: with soft round and quantization noise
self.training_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=itr_main_training,
patience=5000,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
# quantize_model=True, # ! This is an important parameter
),
# Stage with STE then network quantization
TrainerPhase(
lr=1.0e-4,
max_itr=1500,
patience=1500,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="ste",
quantizer_noise_type="none",
# This is only used to parameterize the backward of the quantization
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
quantize_model=True, # ! This is an important parameter
),
# # Re-tune the latent
# TrainerPhase(
# lr=1.0e-4,
# max_itr=1000,
# patience=50,
# quantizer_type="ste",
# quantizer_noise_type="none",
# optimized_module=["latent"], # ! Only fine tune the latent
# freq_valid=10,
# softround_temperature=(1e-4, 1e-4),
# noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
# ),
]
self.warmup = Warmup(
[
WarmupPhase(
candidates=5,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=400,
freq_valid=400,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
),
WarmupPhase(
candidates=2,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=400,
freq_valid=400,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
)
]
)
# self.motion_pretrain_phase: List[TrainerPhase] = []
class PresetC3xInter(Preset):
def __init__(
self,
start_lr: float = 1e-2,
itr_main_training: int = 100000,
itr_motion_pretrain: int = 1000,
):
super().__init__(preset_name="c3x_inter")
# 1st stage: with soft round and quantization noise
self.training_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=itr_main_training,
patience=5000,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
quantize_model=True, # ! This is an important parameter
),
]
self.warmup = Warmup(
[
WarmupPhase(
candidates=2,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=600,
freq_valid=600,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
)
]
)
self.motion_pretrain_phase: List[TrainerPhase] = [
TrainerPhase(
lr=1e-2,
max_itr=itr_motion_pretrain,
patience=itr_motion_pretrain,
optimized_module=["all"],
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
),
]
class PresetDebug(Preset):
"""Very fast training schedule, should only be used to ensure that the code works properly!"""
def __init__(
self,
start_lr: float = 1e-2,
itr_main_training: int = 100,
itr_motion_pretrain: int = 10,
):
super().__init__(preset_name="debug")
self.training_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=50,
patience=100000,
optimized_module=["residue.all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
)
]
self.training_phases.append(
TrainerPhase(
lr=1e-4,
max_itr=10,
patience=10,
optimized_module=["all"],
quantizer_type="ste",
quantizer_noise_type="none",
quantize_model=True,
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
)
)
self.training_phases.append(
# Re-tune the latent
TrainerPhase(
lr=1.0e-4,
max_itr=10,
patience=5,
quantizer_type="ste",
quantizer_noise_type="none",
optimized_module=["all.latent"], # ! Only fine tune the latent
freq_valid=10,
softround_temperature=(1e-4, 1e-4),
noise_parameter=(1.0, 1.0), # not used since quantizer type is "ste"
),
)
self.warmup = Warmup(
[
WarmupPhase(candidates=3, training_phase=TrainerPhase(max_itr=10)),
WarmupPhase(candidates=2, training_phase=TrainerPhase(max_itr=10)),
]
)
self.motion_pretrain_phase: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=50,
patience=50,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
),
]
class PresetMeasureSpeed(Preset):
def __init__(self, start_lr: float = 1e-2, itr_main_training: int = 100000, itr_motion_pretrain: int = 10):
super().__init__(preset_name="measure_speed")
# Single stage model with the shortest warm-up ever!
self.training_phases: List[TrainerPhase] = [
TrainerPhase(
lr=start_lr,
max_itr=itr_main_training,
patience=5000,
optimized_module=["all"],
schedule_lr=True,
quantizer_type="softround",
quantizer_noise_type="gaussian",
softround_temperature=(0.3, 0.1),
noise_parameter=(0.25, 0.1),
quantize_model=True, # ! This is an important parameter
),
]
self.warmup = Warmup(
[
WarmupPhase(
candidates=1,
training_phase=TrainerPhase(
lr=start_lr,
max_itr=1,
freq_valid=1,
patience=100000,
quantize_model=False,
schedule_lr=False,
softround_temperature=(0.3, 0.3),
noise_parameter=(2.0, 2.0),
quantizer_noise_type="kumaraswamy",
quantizer_type="softround",
optimized_module=["all"],
)
)
]
)
# self.motion_pretrain_phase: List[TrainerPhase] = [
# TrainerPhase(
# lr=start_lr,
# max_itr=50,
# patience=50,
# optimized_module=["all"],
# schedule_lr=True,
# quantizer_type="softround",
# quantizer_noise_type="gaussian",
# softround_temperature=(0.3, 0.1),
# noise_parameter=(0.25, 0.1),
# ),
# ]
AVAILABLE_PRESETS: Dict[str, Preset] = {
"c3x_intra": PresetC3xIntra,
"c3x_inter": PresetC3xInter,
"debug": PresetDebug,
"measure_speed": PresetMeasureSpeed,
}