Source code for enc.training.warmup

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md


import copy
import time
from typing import List

from enc.utils.manager import FrameEncoderManager
from enc.component.frame import FrameEncoder
from enc.training.test import test
from enc.training.train import train
from enc.utils.codingstructure import Frame
from enc.utils.misc import POSSIBLE_DEVICE, mem_info


[docs] def warmup( frame_encoder_manager: FrameEncoderManager, list_candidates: List[FrameEncoder], frame: Frame, device: POSSIBLE_DEVICE, ) -> FrameEncoder: """Perform the warm-up for a frame encoder. It consists in multiple stages with several candidates, filtering out the best N candidates at each stage. For instance, we can start with 8 different FrameEncoder. We train each of them for 400 iterations. Then we keep the best 4 of them for 400 additional iterations, while finally keeping the final best one. .. warning:: The parameter ``frame_encoder_manager`` tracking the encoding time of the frame (``total_training_time_sec``) and the number of encoding iterations (``iterations_counter``) is modified** in place** by this function. Args: frame_encoder_manager: Contains (among other things) the rate constraint :math:`\\lambda` and description of the warm-up preset. It is also used to track the total encoding time and encoding iterations. Modified in place. list_candidates: The different candidates among which the warm-up will find the best starting point. frame: The original image to be compressed and its references. device: On which device should the training run. Returns: Warmuped frame encoder, with a great initialization. """ start_time = time.time() warmup = frame_encoder_manager.preset.warmup _col_width = 14 # Construct the list of candidates. Each of them has its own parameters, # unique ID and metrics (not yet evaluated so it is set to None). all_candidates = [ {"metrics": None, "id": id_candidate, "encoder": param_candidate} for id_candidate, param_candidate in enumerate(list_candidates) ] for idx_warmup_phase, warmup_phase in enumerate(warmup.phases): print(f'{"-" * 30} Warm-up phase: {idx_warmup_phase:>2} {"-" * 30}') mem_info(f"Warmup-{idx_warmup_phase:02d}") # At the beginning of the all warm-up phases except the first one, # keep the desired number of best candidates. if idx_warmup_phase != 0: n_elements_to_remove = len(all_candidates) - warmup_phase.candidates for _ in range(n_elements_to_remove): all_candidates.pop() # all_candidates = all_candidates[: warmup_phase.candidates] # i is just the index of A candidate in the all_candidates # list. It is **not** a unique identifier for this candidate. This is # given by: # all_candidates[i].get('id') # The all_candidates list gives the ordered list of the best performing # models so its order may change. # # Check that we do have different candidates with different parameters # print('------\nbefore') # for x in all_candidates: # print(f"{x.get('id')} {sum([v.abs().sum() for k, v in x.get('param').items() if 'synthesis' in k])}") # Train all (remaining) candidates for a little bit for i in range(warmup_phase.candidates): cur_candidate = all_candidates[i] cur_id = cur_candidate.get("id") frame_encoder = cur_candidate.get("encoder") frame_encoder.to_device(device) print( f"\nCandidate nĀ° {i:<2}, ID = {cur_id:<2}:" + "\n-------------------------\n" ) mem_info(f"Warmup-cand-in {idx_warmup_phase:02d}-{i:02d}") frame_encoder = train( frame_encoder=frame_encoder, frame=frame, frame_encoder_manager=frame_encoder_manager, start_lr=warmup_phase.training_phase.lr, cosine_scheduling_lr=warmup_phase.training_phase.schedule_lr, max_iterations=warmup_phase.training_phase.max_itr, patience=warmup_phase.training_phase.patience, frequency_validation=warmup_phase.training_phase.freq_valid, optimized_module=warmup_phase.training_phase.optimized_module, quantizer_type=warmup_phase.training_phase.quantizer_type, quantizer_noise_type=warmup_phase.training_phase.quantizer_noise_type, softround_temperature=warmup_phase.training_phase.softround_temperature, noise_parameter=warmup_phase.training_phase.noise_parameter ) metrics = test(frame_encoder, frame, frame_encoder_manager) frame_encoder.to_device("cpu") # Put the updated candidate back into the list. cur_candidate["encoder"] = frame_encoder cur_candidate["metrics"] = metrics all_candidates[i] = cur_candidate all_candidates = sorted(all_candidates, key=lambda x: x.get("metrics").loss) # # Check that we do have different candidates with different parameters # for x in all_candidates: # print(f"{x.get('id')} {sum([v.abs().sum() for k, v in x.get('encoder').get_param().items() if 'synthesis' in k])}") # print('after\n------') # Print the results of this warm-up phase s = "\n\nPerformance at the end of the warm-up phase:\n\n" s += f'{"ID":^{6}}|{"loss":^{_col_width}}|{"rate_bpp":^{_col_width}}|{"psnr_db":^{_col_width}}|\n' s += f'------|{"-" * _col_width}|{"-" * _col_width}|{"-" * _col_width}|\n' for candidate in all_candidates: s += f'{candidate.get("id"):^{6}}|' s += f'{candidate.get("metrics").loss.item() * 1e3:^{_col_width}.4f}|' s += f'{candidate.get("metrics").rate_latent_bpp:^{_col_width}.4f}|' s += f'{candidate.get("metrics").psnr_db:^{_col_width}.4f}|' s += "\n" print(s) # Keep only the best model frame_encoder = copy.deepcopy(all_candidates[0].get("encoder")) # We've already worked for that many second during warm up warmup_duration = time.time() - start_time print("Intra Warm-up is done!") print(f"Intra Warm-up time [s]: {warmup_duration:.2f}") print(f'Intra Winner ID : {all_candidates[0].get("id")}\n') return frame_encoder