Source code for enc.component.video

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md


import copy
import os
import subprocess
import time
from typing import Dict, List, Tuple

import torch
from enc.utils.manager import FrameEncoderManager
from enc.component.coolchic import CoolChicEncoderParameter
from enc.component.frame import FrameEncoder, load_frame_encoder
from enc.training.quantizemodel import quantize_model
from enc.training.test import test
from enc.training.train import train
from enc.training.warmup import warmup
from enc.utils.codingstructure import CodingStructure, Frame, FrameData
from enc.utils.misc import POSSIBLE_DEVICE, TrainingExitCode, is_job_over, mem_info
from enc.io.io import load_frame_data_from_file
import torch._dynamo.exc

[docs] class VideoEncoder():
[docs] def __init__( self, coding_structure: CodingStructure, shared_coolchic_parameter: CoolChicEncoderParameter, shared_frame_encoder_manager: FrameEncoderManager, ): """A VideoEncoder object is our main object. Its purpose is to encode a video i.e. one I-frame followed by 0 to N inter (P or B) frames. Args: coding_structure: The coding structure (organization of the different frames) used to encode the video. shared_coolchic_parameter: Common parameters for all Cool-chic of all frames (*e.g.* synthesis architecture). Can be overridden later in the encode function to better suits the need of each individual frame. shared_frame_encoder_manager: Common training parameters for all frames (*e.g.* max. number of iterations). It can be overridden later in the encode function to better suits the need of each individual frame. """ self.coding_structure = coding_structure self.shared_coolchic_parameter = shared_coolchic_parameter self.shared_frame_encoder_manager = shared_frame_encoder_manager # This starts empty and is filled during the successive training # Dictionary keys are the coding index. # For each key, we have the corresponding FrameEncoder and the # corresponding FrameEncoderManager. self.all_frame_encoders: Dict[ str, Tuple[FrameEncoder, FrameEncoderManager] ] = {}
[docs] def encode( self, path_original_sequence: str, device: POSSIBLE_DEVICE, workdir: str, job_duration_min: int = -1, print_detailed_archi: bool = False, ) -> TrainingExitCode: """Main training function of a ``VideoEncoder``. Encode all required frames (*i.e.* as stated in ``self.coding_structure``) of the video located at ``path_original_sequence``. This will fill the dictionary ``self.all_frame_encoders`` containing the successively overfitted frame encoders. There is a series of nested loops to encode the video, following roughly this process: .. code-block:: python # Code all frames for idx_coding_order in range(n_frames): # Perform n_loops independent encoding for idx_loop in range(n_loops): # Find the best initialization frame_encoder = warmup(...) # Perform the successive training stages for training_phase in all_training_phases: frame_encoder = train(frame_encoder, training_phase) # Training is over, test and save results = test(frame_encoder) frame_encoder.save() Args: path_original_sequence: Absolute path to the original image or video to be compressed. device: On which device should the training run workdir: Where we'll save many thing job_duration_min: Exit and save the job after this duration is passed. Use -1 to only exit at the end of the entire encoding. Default to -1. print_detailed_archi: True to print the detailed decoder architecture Returns: Either ``TrainingExitCode.REQUEUE`` if job has run for longer than ``job_duration_min`` and should be put back into the job queue, or ``TrainingExitCode.END`` if the encoding is actually over. """ start_time = time.time() n_frames = self.coding_structure.get_number_of_frames() for idx_coding_order in range(n_frames): frame = self.coding_structure.get_frame_from_coding_order(idx_coding_order) if frame.already_encoded: continue # Load the original data and its references frame.data = load_frame_data_from_file( path_original_sequence, frame.display_order ) frame.refs_data = self.get_ref_data(frame) # Everything concerning this frame will be written here frame_workdir = self.get_frame_workdir(workdir, frame.display_order) current_coolchic_parameter = copy.deepcopy(self.shared_coolchic_parameter) current_coolchic_parameter.set_image_size(frame.data.img_size) current_coolchic_parameter.encoder_gain = ( 16 if frame.frame_type == "I" else 16 ) match frame.frame_type: case "I": n_output_synthesis = 3 case "P": n_output_synthesis = 6 case "B": n_output_synthesis = 9 case _: print( f"Unknown frame_type {frame.frame_type}" ) # Change the number of channels for the synthesis output current_coolchic_parameter.layers_synthesis = [ lay.replace("X", str(n_output_synthesis)) for lay in current_coolchic_parameter.layers_synthesis ] # We have started to encode this frame so we already have a # frame_encoder_manager associated if str(idx_coding_order) in self.all_frame_encoders: _, frame_encoder_manager = ( self.all_frame_encoders.get(str(idx_coding_order)) ) # We need to create a new frame_encoder_manager else: print( "-" * 80 + "\n" + f'{" " * 12} Coding frame {frame.coding_order + 1} / {n_frames} ' + f"- Display order: {frame.display_order} - " + f"Coding order: {frame.coding_order}\n" + "-" * 80 ) print("\n" + frame.data.to_string() + "\n") # ----- Set the parameters for the frame frame_encoder_manager = copy.deepcopy( self.shared_frame_encoder_manager ) # Change the lambda according to the depth of the frame in the GOP # The deeper the frame, the bigger the lambda, the smaller the rate frame_encoder_manager.lmbda = self.get_lmbda_from_depth( frame.depth, self.shared_frame_encoder_manager.lmbda ) # Plug the current frame type into the current frame encoder manager frame_encoder_manager.frame_type = frame.frame_type subprocess.call(f"mkdir -p {frame_workdir}", shell=True) # Log a few details about the model print(f"\n{frame_encoder_manager.pretty_string()}") print(f"{current_coolchic_parameter.pretty_string()}") print(f"{frame_encoder_manager.preset.pretty_string()}") for index_loop in range( frame_encoder_manager.loop_counter, frame_encoder_manager.n_loops, ): print( "-" * 80 + "\n" + f'{" " * 30} Training loop {frame_encoder_manager.loop_counter + 1} / ' + f"{frame_encoder_manager.n_loops}\n" + "-" * 80 ) frame.to_device(device) # Get the number of candidates from the initial warm-up phase n_initial_warmup_candidate = ( frame_encoder_manager.preset.warmup.phases[ 0 ].candidates ) list_candidates = [] torch.set_float32_matmul_precision('high') for _ in range(n_initial_warmup_candidate): cur_frame_encoder = FrameEncoder( coolchic_encoder_param=current_coolchic_parameter, frame_type=frame.frame_type, frame_data_type=frame.data.frame_data_type, bitdepth=frame.data.bitdepth ).to(device) list_candidates.append(cur_frame_encoder) # Use the first candidate of the list to log the architecture with open(f"{frame_workdir}/archi.txt", "w") as f_out: f_out.write(str(list_candidates[0].coolchic_encoder) + "\n\n") f_out.write(list_candidates[0].coolchic_encoder.str_complexity() + "\n") print( list_candidates[0].coolchic_encoder.pretty_string( print_detailed_archi=print_detailed_archi ) + "\n\n" ) # Use warm-up to find the best initialization among the list # of candidates parameters. frame_encoder = warmup( frame_encoder_manager=frame_encoder_manager, list_candidates=list_candidates, frame=frame, device=device, ) frame_encoder.to_device(device) # Compile only after the warm-up to compile only once. # No compilation for torch version anterior to 2.5.0 major, minor = [int(x) for x in torch.__version__.split(".")[:2]] use_compile = False if major > 2: use_compile = True elif major == 2: use_compile = minor >= 5 if frame_encoder_manager.preset.preset_name == "debug": print("Skip compilation when debugging\n") elif not use_compile: print("No compilation for torch version anterior to 2.5.0\n") else: print("Compiling frame encoder!\n") frame_encoder = torch.compile( frame_encoder, dynamic=False, mode="reduce-overhead", # Some part of the frame_encoder forward (420-related stuff) # are not (yet) compatible with compilation. So we can't # capture the full graph for yuv420 frame fullgraph=frame.data.frame_data_type != "yuv420", ) for idx_phase, training_phase in enumerate(frame_encoder_manager.preset.all_phases): print(f'{"-" * 30} Training phase: {idx_phase:>2} {"-" * 30}\n') mem_info("Training phase " + str(idx_phase)) frame_encoder = train( frame_encoder=frame_encoder, frame=frame, frame_encoder_manager=frame_encoder_manager, start_lr=training_phase.lr, cosine_scheduling_lr=training_phase.schedule_lr, max_iterations=training_phase.max_itr, frequency_validation=training_phase.freq_valid, patience=training_phase.patience, optimized_module=training_phase.optimized_module, quantizer_type=training_phase.quantizer_type, quantizer_noise_type=training_phase.quantizer_noise_type, softround_temperature=training_phase.softround_temperature, noise_parameter=training_phase.noise_parameter, ) if training_phase.quantize_model: # Store full precision parameters inside the # frame_encoder for later use if needed frame_encoder.coolchic_encoder._store_full_precision_param() frame_encoder = quantize_model( frame_encoder, frame, frame_encoder_manager, ) phase_results = test( frame_encoder, frame, frame_encoder_manager, ) print("\nResults at the end of the phase:") print("--------------------------------") print( f'\n{phase_results.pretty_string(show_col_name=True, mode="short")}\n' ) # At the end of each loop, compute the final loss loop_results = test( frame_encoder, frame, frame_encoder_manager, ) # Write results file path_results_log = f"{frame_workdir}results_loop_{frame_encoder_manager.loop_counter + 1}.tsv" with open(path_results_log, "w") as f_out: f_out.write( loop_results.pretty_string(show_col_name=True, mode="all") + "\n" ) # We've beaten our record if frame_encoder_manager.record_beaten(loop_results.loss): print(f'Best loss beaten at loop {frame_encoder_manager.loop_counter + 1}') print(f'Previous best loss: {frame_encoder_manager.best_loss * 1e3 :.6f}') print(f'New best loss : {loop_results.loss.cpu().item() * 1e3 :.6f}') frame_encoder_manager.set_best_loss(loop_results.loss.cpu().item()) # Save best results with open(f'{frame_workdir}results_best.tsv', 'w') as f_out: f_out.write(loop_results.pretty_string(show_col_name=True, mode='all') + '\n') self.concat_results_file(workdir) best_frame_encoder = frame_encoder # We haven't beaten our record, keep the old frame encoder as # the current best frame encoder else: best_frame_encoder = self.all_frame_encoders[str(frame.coding_order)][0] frame_encoder_manager.loop_counter += 1 # Store the current best FrameEncoder and the corresponding # frame_encoder_manager self.all_frame_encoders[str(frame.coding_order)] = ( copy.deepcopy(best_frame_encoder), copy.deepcopy(frame_encoder_manager) ) print('End of training loop\n\n') self.save(f'{workdir}video_encoder.pt') # The save function unload the decoded frames and the original # ones. We need to reload them frame.data = load_frame_data_from_file( path_original_sequence, frame.display_order ) frame.refs_data = self.get_ref_data(frame) if is_job_over(start_time=start_time, max_duration_job_min=job_duration_min): return TrainingExitCode.REQUEUE self.coding_structure.set_encoded_flag( coding_order=frame.coding_order, flag_value=True ) print(self.coding_structure.pretty_string()) self.save(f'{workdir}video_encoder.pt') return TrainingExitCode.END
[docs] def get_frame_workdir(self, workdir: str, frame_display_order: int) -> str: """Compute the absolute path for the workdir of one frame. Args: workdir: Main working directory of the video encoder frame_display_order: Display order of the frame Returns: Working directory of the frame """ return f"{workdir}/frame_{str(frame_display_order).zfill(3)}/"
[docs] def concat_results_file(self, workdir: str) -> None: """Look at all the already encoded frames inside ``workdir`` and concatenate their result files (``workdir/frame_XXX/results_best.tsv``) into a single result file ``workdir/results_best.tsv``. Args: workdir: Working directory of the video encoder """ list_results_file = [] for idx_display_order in range(self.coding_structure.get_number_of_frames()): cur_res_file = ( self.get_frame_workdir(workdir, idx_display_order) + "results_best.tsv" ) if not os.path.isfile(cur_res_file): continue list_results_file.append(cur_res_file) # decoded_frame_name is something like decoded_416x240_1p_yuv420_8b.yuv out_path = workdir + "results_best.tsv" subprocess.call(f"rm -f {out_path}", shell=True) for idx, frame_path in enumerate(list_results_file): if idx == 0: subprocess.call(f"cat {frame_path} >> {out_path}", shell=True) # Print only the second line (no need for the column name) else: subprocess.call( f"cat {frame_path} | head -2 | tail -1 >> {out_path}", shell=True )
[docs] @torch.no_grad() def get_ref_data(self, frame: Frame) -> List[FrameData]: """Return a list of the (decoded) reference frames. The decoded data are obtained by recursively inferring the already learned FrameEncoder. Args: frame: The frame whose reference(s) we want. Returns: The decoded reference frames. """ # We obtain the reference frames by re-inferring the already encoded frames. ref_data = [] # idx_ref is in display order for idx_ref in frame.index_references: ref_frame = self.coding_structure.get_frame_from_display_order(idx_ref) # No need to re-infer the reference, this has already been decoded if ref_frame.decoded_data is not None: pass else: ref_frame.refs_data = self.get_ref_data(ref_frame) print( f"get_ref_data(): Decoding frame {ref_frame.display_order:<3}..." ) # Load the best encoder for the reference frame # No need to load the corresponding frame_encoder_manager # hence the "_" frame_encoder, _ = self.all_frame_encoders.get(str(ref_frame.coding_order)) # Infer it to get the data of the references frame_encoder.set_to_eval() frame_encoder.to_device("cpu") ref_frame.upsample_reference_to_444() frame_encoder_out = frame_encoder.forward( reference_frames=[ref_i.data for ref_i in ref_frame.refs_data], quantizer_noise_type="none", quantizer_type="hardround", AC_MAX_VAL=-1, flag_additional_outputs=False, ) ref_frame.set_decoded_data( FrameData( frame_encoder.bitdepth, frame_encoder.frame_data_type, frame_encoder_out.decoded_image, ) ) ref_data.append(ref_frame.decoded_data) return ref_data
[docs] def get_lmbda_from_depth(self, depth: float, initial_lmbda: float) -> float: """Perform the QP offset as follows: .. math:: \\lambda_i = (\\frac{3}{2})^{d} \\lambda, Args: depth: The depth :math:`d` of the frame in the GOP. See encoder/utils/coding_structure.py for more info initial_lmbda: The lmbda of the I frame :math:`\\lambda`. Returns: The lambda of the current frame :math:`\\lambda_i`. """ return initial_lmbda * (1.5**depth)
[docs] def save(self, save_path: str) -> None: """Save current VideoEncoder at given path. It contains everything, the ``CodingStructure``, the shared parameters between the different frames as well as all the successive ``FrameEncoder`` and their respective ``FrameEncoderManager``. Args: save_path: Where to save the model """ subprocess.call(f"mkdir -p {os.path.dirname(save_path)}", shell=True) # We don't need to save the original frames nor the coded ones. # The original frames can be reloaded from the dataset. The coded ones # can be retrieved by inferring the trained FrameEncoders. self.coding_structure.unload_all_original_frames() self.coding_structure.unload_all_references_data() self.coding_structure.unload_all_decoded_data() data_to_save = { "coding_structure": self.coding_structure, "shared_coolchic_parameter": self.shared_coolchic_parameter, "shared_frame_encoder_manager": self.shared_frame_encoder_manager, "all_frame_encoders": {}, } for k, v in self.all_frame_encoders.items(): frame_encoder, frame_encoder_manager = v data_to_save["all_frame_encoders"][k] = (frame_encoder.save(), frame_encoder_manager) torch.save(data_to_save, save_path)
[docs] def load_video_encoder(load_path: str) -> VideoEncoder: """Load a video encoder. Args: load_path: Absolute path where the VideoEncoder should be loaded. Returns: The loaded VideoEncoder """ print(f"Loading a video encoder from {load_path}") raw_data = torch.load(load_path, map_location="cpu", weights_only=False) # Calling the VideoEncoder constructor automatically reload the # original frames. video_encoder = VideoEncoder( coding_structure=raw_data["coding_structure"], shared_coolchic_parameter=raw_data["shared_coolchic_parameter"], shared_frame_encoder_manager=raw_data["shared_frame_encoder_manager"], ) # Load all the frame encoders to reconstruct the reference frames when needed # TODO: Load only the required frame encoder for k, v in raw_data["all_frame_encoders"].items(): raw_bytes_frame_encoder, frame_encoder_manager = v video_encoder.all_frame_encoders[k] = ( load_frame_encoder(raw_bytes_frame_encoder), frame_encoder_manager, ) return video_encoder