Source code for enc.utils.codingstructure

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md

"""Utilities to define the coding structures."""

import math
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from enc.utils.misc import POSSIBLE_DEVICE

# The different frame types:
#       - I frames have no reference (intra)
#       - P frames have 1 single (past) reference
#       - B frames have 2 (past & future) references.
FRAME_TYPE = Literal["I", "P", "B"]

#   A GOP is defined as something starting with an intra frames and followed
# by an arbitrary number of inter (P or B) frames. As such the number of frames
# in the GOP is the number of inter frames + 1, i.e.
#
#   number_of_frames_in_gop = intra_period + 1 = number_of_inter_frames_in_gop + 1
#
#
#   E.g.:
#       I0 ---> P1 ---> P2 ---> P3 ---> P4 ---> P5 ---> P6 ---> P7 ---> P8
#
# Or a hierarchical random access GOP with nested B-frames (RA).
#   E.g.:
#          I0 -----------------------> P4 ------------------------> P8
#           \----------> B2 <---------/ \----------> B6 <----------/
#            \--> B1 <--/ \--> B3 <--/   \--> B5 <--/  \--> B7 <--/
#
# Here, both GOPs have an intra period of 8 (i.e. 8 inter-frames) in
# between two I-frames. First GOP P-period is 1, while second P period is 4
# which is the distance of the P-frame prediction.
#
FRAME_DATA_TYPE = Literal["rgb", "yuv420", "yuv444"]
POSSIBLE_BITDEPTH = Literal[8, 10]


[docs] class DictTensorYUV(TypedDict): """``TypedDict`` representing a YUV420 frame.. .. hint:: ``torch.jit`` requires I/O of modules to be either ``Tensor``, ``List`` or ``Dict``. So we don't use a python dataclass here and rely on ``TypedDict`` instead. Args: y (Tensor): :math:`([B, 1, H, W])`. u (Tensor): :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`. v (Tensor): :math:`([B, 1, \\frac{H}{2}, \\frac{W}{2}])`. """ y: Tensor u: Tensor v: Tensor
[docs] def yuv_dict_to_device(yuv: DictTensorYUV, device: POSSIBLE_DEVICE) -> DictTensorYUV: """Send a ``DictTensor`` to a device. Args: yuv: Data to be sent to a device. device: The requested device Returns: Data on the appropriate device. """ return DictTensorYUV( y=yuv.get("y").to(device), u=yuv.get("u").to(device), v=yuv.get("v").to(device) )
# ============================== YUV upsampling ============================= #
[docs] def convert_444_to_420(yuv444: Tensor) -> DictTensorYUV: """From a 4D YUV 444 tensor :math:`(B, 3, H, W)`, return a ``DictTensorYUV``. The U and V tensors are down sampled using a nearest neighbor downsampling. Args: yuv444: YUV444 data :math:`(B, 3, H, W)` Returns: YUV420 dictionary of 4D tensors """ assert yuv444.dim() == 4, f"Number of dimension should be 5, found {yuv444.dim()}" b, c, h, w = yuv444.size() assert c == 3, f"Number of channel should be 3, found {c}" # No need to downsample y channel but it should remain a 5D tensor y = yuv444[:, 0, :, :].view(b, 1, h, w) # Downsample U and V channels together uv = F.interpolate(yuv444[:, 1:3, :, :], scale_factor=(0.5, 0.5), mode="nearest") u, v = uv.split(1, dim=1) yuv420 = DictTensorYUV(y=y, u=u, v=v) return yuv420
[docs] def convert_420_to_444(yuv420: DictTensorYUV) -> Tensor: """Convert a DictTensorYUV to a 4D tensor:math:`(B, 3, H, W)`. The U and V tensors are up sampled using a nearest neighbor upsampling Args: yuv420: YUV420 dictionary of 4D tensor Returns: YUV444 Tensor :math:`(B, 3, H, W)` """ u = F.interpolate(yuv420.get("u"), scale_factor=(2, 2)) v = F.interpolate(yuv420.get("v"), scale_factor=(2, 2)) yuv444 = torch.cat((yuv420.get("y"), u, v), dim=1) return yuv444
# ============================== YUV upsampling ============================= #
[docs] @dataclass class FrameData: """FrameData is a dataclass storing the actual pixel values of a frame and a few additional information about its size, bitdepth of color space. Args: bitdepth (POSSIBLE_BITDEPTH): Bitdepth, either ``"8"`` or ``"10"``. frame_data_type (FRAME_DATA_TYPE): Data type, either ``"rgb"``, ``"yuv420"``, ``"yuv444"``. data (Union[Tensor, DictTensorYUV]): The actual RGB or YUV data """ bitdepth: POSSIBLE_BITDEPTH frame_data_type: FRAME_DATA_TYPE data: Union[Tensor, DictTensorYUV] # Filled up by the __post_init__() function # ==================== Not set by the init function ===================== # #: Height & width of the video :math:`(H, W)` img_size: Tuple[int, int] = field(init=False) #: Number of pixels :math:`H \times W` n_pixels: int = field(init=False) # Height x Width # ==================== Not set by the init function ===================== # def __post_init__(self): if self.frame_data_type == "rgb" or self.frame_data_type == "yuv444": self.img_size = self.data.size()[-2:] elif self.frame_data_type == "yuv420": self.img_size = self.data.get("y").size()[-2:] self.n_pixels = self.img_size[0] * self.img_size[1]
[docs] def to_device(self, device: POSSIBLE_DEVICE) -> None: """Push the data attribute to the relevant device **in place**. Args: device: The device on which the model should run. """ if self.frame_data_type == "rgb" or self.frame_data_type == "yuv444": self.data = self.data.to(device) elif self.frame_data_type == "yuv420": self.data = yuv_dict_to_device(self.data, device)
[docs] @dataclass class Frame: """Dataclass representing a frame to be encoded. It contains useful info like the display & coding indices, the indices of its references as well as the data of the decoded references and the original (*i.e.* uncompressed) frame. Args: coding_order (int): Frame with ``coding_order=0`` is coded first. display_order (int): Frame with ``display_order=0`` is displayed first. depth (int): Depth of the frame in the GOP. 0 for Intra, 1 for P-frame, 2 or more for B-frames. Roughly corresponds to the notion of temporal layers in conventional codecs. Defaults to 0. seq_name (str): Name of the video. Mainly used for logging purposes. Defaults to ``""``. data (Optional[FrameData]): Data of the uncompressed image to be coded. Defaults to ``None``. already_encoded (bool): ``True`` if the frame has already been coded by the VideoEncoder. Defaults to False index_references (List[int]): Index of the frame(s) used as references, in **display_order**. Leave empty when no reference are available *i.e.* for I-frame. Defaults to ``[]``. ref_data (List[FrameData]): The actual data describing the decoded references. Leave empty when no reference are available *i.e.* for I-frame. Defaults to ``[]``. """ coding_order: int display_order: int depth: int = 0 seq_name: str = "" data: Optional[FrameData] = None decoded_data: Optional[FrameData] = None already_encoded: bool = False index_references: List[int] = field(default_factory=lambda: []) # Filled up by the set_refs_data() function. refs_data: List[FrameData] = field(default_factory=lambda: []) # ==================== Not set by the init function ===================== # #: Automatically set from the number of entry in ``self.index_references``. frame_type: FRAME_TYPE = field(init=False) # ==================== Not set by the init function ===================== # def __post_init__(self): assert len(self.index_references) <= 2, ( "A frame can not have more than 2 references.\n" f"Found {len(self.index_references)} references for frame {self.display_order} " f"(display order).\n Exiting!" ) if len(self.index_references) == 2: self.frame_type = "B" elif len(self.index_references) == 1: self.frame_type = "P" else: self.frame_type = "I"
[docs] def set_frame_data( self, data: Union[Tensor, DictTensorYUV], frame_data_type: FRAME_DATA_TYPE, bitdepth: POSSIBLE_BITDEPTH, ) -> None: """Set the data representing the frame i.e. create the ``FrameData`` object describing the actual frame. Args: data: RGB or YUV value of the frame. frame_data_type: Data type. bitdepth: Bitdepth. """ self.data = FrameData( bitdepth=bitdepth, frame_data_type=frame_data_type, data=data )
[docs] def set_decoded_data(self, decoded_data: FrameData) -> None: """Set the data representing the decoded frame. Args: refs_data: Data of the reference(s) """ # ! There might be a memory management issue here (deep copy vs. shallow copy) self.decoded_data = decoded_data
[docs] def set_refs_data(self, refs_data: List[FrameData]) -> None: """Set the data representing the reference(s). Args: refs_data: Data of the reference(s) """ assert len(refs_data) == len(self.index_references), ( f"Trying to load data for " f"{len(refs_data)} references but current frame only has {len(self.index_references)} " f"references. Frame type is {self.frame_type}." ) # ! There might be a memory management issue here (deep copy vs. shallow copy) self.refs_data = refs_data
[docs] def upsample_reference_to_444(self) -> None: """Upsample the references from 420 to 444 **in place**. Do nothing if this is already the case. """ upsampled_refs = [] for ref in self.refs_data: if ref.frame_data_type == "yuv420": ref.data = convert_420_to_444(ref.data) ref.frame_data_type = "yuv444" upsampled_refs.append(ref) self.refs_data = upsampled_refs
[docs] def to_device(self, device: POSSIBLE_DEVICE) -> None: """Push the data attribute to the relevant device **in place**. Args: device: The device on which the model should run. """ if self.data is not None: self.data.to_device(device) for index_ref in range(len(self.refs_data)): if self.refs_data[index_ref] is not None: self.refs_data[index_ref].to_device(device)
[docs] @dataclass class CodingStructure: """Dataclass representing the organization of the video *i.e.* which frames are coded using which references. A few examples: .. code-block:: # A low-delay P configuration # I0 ---> P1 ---> P2 ---> P3 ---> P4 ---> P5 ---> P6 ---> P7 ---> P8 intra_period=8 p_period=1 # A hierarchical Random Access configuration # I0 -----------------------------------------------------> P8 # \-------------------------> B4 <-------------------------/ # \----------> B2 <---------/ \----------> B6 <----------/ # \--> B1 <--/ \--> B3 <--/ \--> B5 <--/ \--> B7 <--/ intra_period=8 p_period=8 # There is no more prediction from I0 to P8. Instead the GOP in split in # half so that there is no inter frame with reference further than --p_period # I0 -----------------------> P4 ------------------------> P8 # \----------> B2 <---------/ \----------> B6 <----------/ # \--> B1 <--/ \--> B3 <--/ \--> B5 <--/ \--> B7 <--/ intra_period=8 p_period=4 A coding is composed of a few hyper-parameters and most importantly a list of ``Frame`` describing the different frames to code. Args: intra_period (int): Number of inter frames in the GOP. As such, the first (intra) frame of two successive GOPs would be spaced by `intra_period` inter frames. Set this to 0 for all intra coding. p_period (int): Distance to the furthest P prediction in the GOP. Set this to 1 for low-delay P or to ``intra_period`` for the usual random access configuration. seq_name (str): Name of the video. Mainly used for logging purposes. Defaults to ``""``. """ intra_period: int p_period: int = 0 seq_name: str = "" # ==================== Not set by the init function ===================== # #: All the frames to code, deduced from the GOP type, intra period and P period. #: Frames are index in display order (i.e. temporal order). frames[0] is the 1st #: frame, while frames[-1] is the last one. frames: List[Frame] = field(init=False) # ==================== Not set by the init function ===================== # def __post_init__(self): self.frames = self.compute_gop(self.intra_period, self.p_period)
[docs] def compute_gop(self, intra_period: int, p_period: int) -> List[Frame]: """Return a list of frames with one intra followed by ``intra_period`` inter frames. The relation between the inter frames is implied by p_period. See examples in the class description. Args: intra_period: Number of inter frames in the GOP. p_period: Distance between I0 and the first P frame or between subsequent P-frames. Returns: List describing the frames to code. """ # I-frame frames = [ Frame( coding_order=0, display_order=0, index_references=[], seq_name=self.seq_name, ) ] if intra_period == 0 and p_period == 0: print("Intra period is 0 and P period is 0: all intra coding!") return frames assert intra_period % p_period == 0, ( f"Intra period must be divisible by P period." f" Found intra_period = {intra_period} ; p_period = {p_period}." ) # In the example of RA GOP given above, the number of chained GOP is 2. n_chained_gop = intra_period // p_period for index_chained_gop in range(n_chained_gop): for index_frame_in_gop in range(1, p_period + 1): display_order = index_frame_in_gop + index_chained_gop * p_period depth_frame_in_gop = self.get_frame_depth_in_gop(index_frame_in_gop) # References display order are located at +/- delta_time_ref # from the current frame display order delta_time_ref = p_period // 2 ** (depth_frame_in_gop - 1) # First frame is an intra # Last frame of each chained GOP is a P-frame if index_frame_in_gop == p_period: refs = [display_order - delta_time_ref] # Otherwise we have a B-frame else: refs = [ display_order - delta_time_ref, display_order + delta_time_ref, ] if depth_frame_in_gop != 0: # Coding order of the first frame with this depth in # the current chained gop. # Until depth = 3 (included), the depth **is** the coding order since # all temporal layer whose depth is 0, 1, 2, 3 have a single frame. # For depth >= 4, we must take into account that each previous layer # of depth d_i < 4 has had 2 ** d_i - 1 frames in it. coding_order_in_gop = depth_frame_in_gop + sum( [2 ** (x - 2) - 1 for x in range(3, depth_frame_in_gop)] ) # When depth >= 4 we have multiple frames per layer, this takes it into account # to obtain the proper coding order coding_order_in_gop += (index_frame_in_gop - delta_time_ref) // ( 2 * delta_time_ref ) else: coding_order_in_gop = 0 coding_order = index_chained_gop * p_period + coding_order_in_gop frames.append( Frame( coding_order=coding_order, display_order=display_order, index_references=refs, depth=depth_frame_in_gop, seq_name=self.seq_name, ) ) return frames
[docs] def pretty_string(self) -> str: """Return a pretty string formatting the data within the class""" COL_WIDTH = 14 s = "Coding configuration:\n" s += "---------------------\n" s += f'{"Frame type":<{COL_WIDTH}}\t{"Coding order":<{COL_WIDTH}}\t{"Display order":<{COL_WIDTH}}\t' s += f'{"Ref 1":<{COL_WIDTH}}\t{"Ref 2":<{COL_WIDTH}}\t{"Depth":<{COL_WIDTH}}\t{"Encoded"}\n' for idx_coding_order in range(len(self.frames)): cur_frame = self.get_frame_from_coding_order(idx_coding_order) s += f"{cur_frame.frame_type:<{COL_WIDTH}}\t" s += f"{cur_frame.coding_order:<{COL_WIDTH}}\t" s += f"{cur_frame.display_order:<{COL_WIDTH}}\t" if len(cur_frame.index_references) > 0: s += f"{cur_frame.index_references[0]:<{COL_WIDTH}}\t" else: s += f'{"/":<{COL_WIDTH}}\t' if len(cur_frame.index_references) > 1: s += f"{cur_frame.index_references[1]:<{COL_WIDTH}}\t" else: s += f'{"/":<{COL_WIDTH}}\t' s += f"{cur_frame.depth:<{COL_WIDTH}}\t" s += f"{cur_frame.already_encoded:<{COL_WIDTH}}\t" s += "\n" return s
[docs] def get_number_of_frames(self) -> int: """Return the number of frames in the coding structure. Returns: Number of frames in the coding structure. """ return len(self.frames)
[docs] def get_max_depth(self) -> int: """Return the maximum depth of a coding configuration Returns: Maximum depth of the coding configuration """ return max([frame.depth for frame in self.frames])
[docs] def get_all_frames_of_depth(self, depth: int) -> List[Frame]: """Return a list with all the frames for a given depth Args: depth: Depth for which we want the frames. Returns: List of frames with the given depth """ return [frame for frame in self.frames if frame.depth == depth]
[docs] def get_max_coding_order(self) -> int: """Return the maximum coding order of a coding configuration Returns: Maximum coding order of the coding configuration """ return max([frame.coding_order for frame in self.frames])
[docs] def get_frame_from_coding_order(self, coding_order: int) -> Optional[Frame]: """Return the frame whose coding order is equal to ``coding_order``. Return ``None`` if no frame has been found. Args: coding_order: Coding order for which we want the frame. Returns: Frame whose coding order is equal to ``coding_order``. """ for frame in self.frames: if frame.coding_order == coding_order: return frame return None
[docs] def get_max_display_order(self) -> int: """Return the maximum display order of a coding configuration Returns: Maximum display order of the coding configuration """ return max([frame.display_order for frame in self.frames])
[docs] def get_frame_from_display_order(self, display_order: int) -> Optional[Frame]: """Return the frame whose display order is equal to ``display_order``. Return None if no frame has been found. Args: display_order: Coding order for which we want the frame. Returns: Frame whose coding order is equal to ``display_order``. """ for frame in self.frames: if frame.display_order == display_order: return frame return None
[docs] def set_encoded_flag(self, coding_order: int, flag_value: bool) -> None: """Set the flag ``self.already_encode`` of the frame whose coding order is ``coding_order`` to the value ``flag_value``. Args: coding_order: Coding order of the frame for which we'll change the flag flag_value: Value to be set """ for frame in self.frames: if frame.coding_order == coding_order: frame.already_encoded = flag_value
[docs] def unload_all_decoded_data(self) -> None: """Remove the data describing the decoded data from the memory. This is used before saving the coding structure. The decoded data can be retrieved by re-inferring the trained model.""" for idx_display_order in range(self.get_number_of_frames()): # if hasattr(self.frames[idx_display_order], "decoded_data"): # del self.frames[idx_display_order].decoded_data # TODO: Set to None and rely on the garbage collector to # TODO: delete this? self.frames[idx_display_order].decoded_data = None
[docs] def unload_all_original_frames(self) -> None: """Remove the data describing the original frame from the memory. This is used before saving the coding structure. The original frames can be retrieved by reloading the sequence""" for idx_display_order in range(self.get_number_of_frames()): # if hasattr(self.frames[idx_display_order], "data"): # del self.frames[idx_display_order].data # TODO: Set to None and rely on the garbage collector to # TODO: delete this? self.frames[idx_display_order].data = None
[docs] def unload_all_references_data(self) -> None: """Remove the data describing all the references from the memory. This is used before saving the coding structure. The reference data can be retrieved by re-inferring the trained model.""" for idx_display_order in range(self.get_number_of_frames()): # if hasattr(self.frames[idx_display_order], "refs_data"): # del self.frames[idx_display_order].refs_data # TODO: Set to None and rely on the garbage collector to # TODO: delete this? self.frames[idx_display_order].refs_data = None
[docs] def get_frame_depth_in_gop(self, idx_frame: int) -> int: """Return the depth of a frame with index <idx_frame> within a hierarchical GOP. Some notes: - ``idx_frame == 0`` **always** corresponds to an intra frame i.e. depth = 0 - ``idx_frame == p_period`` is the P-frame i.e. depth = 1 - This should be used separately for the successive chained GOPs. Args: idx_frame: Display order of the frame in the GOP. p_period: P-period. Should be a power of two. Returns: Depth of the frame in the GOP. """ assert idx_frame <= self.p_period, ( f"idx_frame should be <= to p_period." f" P-period is {self.p_period}, Index frame is {idx_frame}." ) assert math.log2(self.p_period) % 1 == 0, ( f"p_period should be a power of 2." f" P-period is {self.p_period}." ) if idx_frame == 0: return 0 # Compute the depth depth = int(math.log2(self.p_period) + 1) for i in range(int(math.log2(self.p_period)), 0, -1): if idx_frame % 2**i == 0: depth -= 1 return int(depth)