# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
"""Utilities to define the coding structures."""
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from enc.io.framedata import FrameData
# The different frame types:
# - I frames have no reference (intra)
# - P frames have 1 single (past) reference
# - B frames have 2 (past & future) references.
FRAME_TYPE = Literal["I", "P", "B"]
[docs]
@dataclass
class Frame:
"""Dataclass representing a frame to be encoded. It contains useful info
like the display & coding indices, the indices of its references as well
as the data of the decoded references and the original (*i.e.* uncompressed)
frame.
Args:
coding_order (int): Frame with ``coding_order=0`` is coded first.
display_order (int): Frame with ``display_order=0`` is displayed first.
frame_offset (int): Shift the position of the 0-th frame of the video.
If frame_offset=15 skip the first 15 frames of the video. That is
the display index 0 corresponds to the 16th frame.
This is only used to load the data + for logging purposes
Defaults to 0.
depth (int): Depth of the frame in the GOP. 0 for Intra, 1 for P-frame,
2 or more for B-frames. Roughly corresponds to the notion of
temporal layers in conventional codecs.
Defaults to 0.
seq_name (str): Name of the video. Mainly used for logging purposes.
Defaults to ``""``.
data (Optional[FrameData]): Data of the uncompressed image to be coded.
Defaults to ``None``.
index_references (List[int]): Index of the frame(s) used as references,
in **display_order**. Leave empty when no reference are available
*i.e.* for I-frame. Defaults to ``[]``.
ref_data (List[FrameData]): The actual data describing the decoded
references. Leave empty when no reference are available
*i.e.* for I-frame. Defaults to ``[]``.
"""
coding_order: int
display_order: int
frame_offset: int = 0
depth: int = 0
seq_name: str = ""
data: Optional[FrameData] = None
index_references: List[int] = field(default_factory=lambda: [])
# Filled up by the set_refs_data() function.
refs_data: List[FrameData] = field(default_factory=lambda: [])
# ==================== Not set by the init function ===================== #
#: Automatically set from the number of entry in ``self.index_references``.
frame_type: FRAME_TYPE = field(init=False)
# ==================== Not set by the init function ===================== #
def __post_init__(self):
assert len(self.index_references) <= 2, (
"A frame can not have more than 2 references.\n"
f"Found {len(self.index_references)} references for frame {self.display_order} "
f"(display order).\n Exiting!"
)
# The reference further in the past is always first.
self.index_references.sort()
if len(self.index_references) == 2:
self.frame_type = "B"
elif len(self.index_references) == 1:
self.frame_type = "P"
else:
self.frame_type = "I"
[docs]
def set_frame_data(self, data: FrameData) -> None:
"""Set the data representing the frame i.e. create the ``FrameData``
object describing the actual frame.
Args:
data: FrameData object representing the frame.
"""
self.data = data
[docs]
def set_refs_data(self, refs_data: List[FrameData]) -> None:
"""Set the data representing the reference(s).
Args:
refs_data: Data of the reference(s)
"""
assert len(refs_data) == len(self.index_references), (
f"Trying to load data for "
f"{len(refs_data)} references but current frame only has {len(self.index_references)} "
f"references. Frame type is {self.frame_type}."
)
# ! There might be a memory management issue here (deep copy vs. shallow copy)
self.refs_data = refs_data
[docs]
def pretty_string(self, show_header: bool = False, show_bottom_line: bool = False) -> str:
"""Return a string describing the frame.
Args:
show_header: Also print column nam. Defaults to False.
show_bottom_line: Print a line below the frame description to close
the array. Defaults to False.
Returns:
str: Pretty string describing the frame
"""
COL_WIDTH = 18
s = ""
single_col = f"+{'-' * (COL_WIDTH - 2)}"
vertical_line = single_col * 6 + "+"
if show_header:
s += vertical_line + "\n"
# Column name
s += f'|{"Frame type":^{COL_WIDTH-2}}|'
s += f'{"Coding order":^{COL_WIDTH-2}}|'
s += f'{"Display order":^{COL_WIDTH-2}}|'
s += f'{"Ref 1 (disp)":^{COL_WIDTH-2}}|'
s += f'{"Ref 2 (disp)":^{COL_WIDTH-2}}|'
s += f'{"Depth":^{COL_WIDTH-2}}|'
s += "\n"
s += vertical_line + "\n"
ref_1 = str(self.index_references[0]) if len(self.index_references) > 0 else "/"
ref_2 = str(self.index_references[1]) if len(self.index_references) > 1 else "/"
s += f"|{self.frame_type:^{COL_WIDTH-2}}|"
s += f"{self.coding_order:^{COL_WIDTH-2}}|"
s += f"{self.display_order:^{COL_WIDTH-2}}|"
s += f"{ref_1:^{COL_WIDTH-2}}|"
s += f"{ref_2:^{COL_WIDTH-2}}|"
s += f"{self.depth:^{COL_WIDTH-2}}|"
s += "\n"
if show_bottom_line:
s += vertical_line + "\n\n"
return s
[docs]
@dataclass
class CodingStructure:
"""Dataclass representing the organization of the video *i.e.* which
frames are coded using which references.
A few examples:
.. code-block::
# A low-delay P configuration
# I0
# \------> P1
# \-------> P2
# \------> P3
# \-------> P4
--n_frames=5 --intra_pos=0 --p_pos=1-4
# A hierarchical Random Access configuration, with a closed GOP
# I0
# \-------------------------------------------------------------------------------------> P8
# \----------------------------------------> B4 <----------------------------------------/
# \-----------------> B2 <------------------/ \------------------> B6 <-----------------/
# \------> B1 <------/ \-------> B3 <------/ \------> B5 <-------/ \------> B7 <------/
--n_frames=8 --intra_pos=0 --p_pos=-1
# A hierarchical Random Access configuration, with an open GOP
# I0 I8
# \----------------------------------------> B4 <----------------------------------------/
# \-----------------> B2 <------------------/ \------------------> B6 <-----------------/
# \------> B1 <------/ \-------> B3 <------/ \------> B5 <-------/ \------> B7 <------/
--n_frames=8 --intra_pos=0,-1
# Or some very peculiar structures...
# I0
# \---------------------------------------------------------------> P6
# \-----------------------------> B3 <-----------------------------/ \-----------------> P8
# \------> B1 <------------------/ \------> B4 <------------------/ \------> B7 <------/
# \------> B2 <-------/ \------> B5 <-------/
--n_frames=8 --intra_pos=0 --p_pos=6,8
A coding is composed of a few hyper-parameters and most importantly a
list of ``Frame`` describing the different frames to code.
Args:
n_frames (int): Number of frames in the coding structure
frame_offset (int): Shift the position of the 0-th frame of the video.
If frame_offset=15 skip the first 15 frames of the video. That is
the display index 0 corresponds to the 16th frame.
intra_pos (List[int]): Position of all the intra frames in display
order
p_pos (List[int]): Position of all the P frames in display order
seq_name (str): Name of the video. Mainly used for logging purposes.
Defaults to ``""``.
"""
seq_name: str = ""
n_frames: int = 1
frame_offset: int = 0
# Intra and P positions are given in **display** order
# Always start with an intra
intra_pos: List[int] = field(default_factory=lambda: [0])
p_pos: List[int] = field(default_factory=lambda: [])
# ==================== Not set by the init function ===================== #
#: All the frames to code, deduced from the GOP type, intra period and P period.
#: Frames are index in display order (i.e. temporal order). frames[0] is the 1st
#: frame, while frames[-1] is the last one.
frames: List[Frame] = field(init=False)
# ==================== Not set by the init function ===================== #
def __post_init__(self):
self.intra_pos.sort()
self.p_pos.sort()
first_frame_is_intra = self.intra_pos[0] == 0
assert first_frame_is_intra, (
"First frame of the video should an intra frame. Change --intra_pos "
"to include the frame 0."
)
last_frame_is_intra = self.intra_pos[-1] == self.n_frames - 1
last_frame_is_p = self.p_pos[-1] == self.n_frames - 1 if self.p_pos else False
assert last_frame_is_intra or last_frame_is_p, (
"Last frame of the video should be either an intra frame or a P "
"frame. Add -1 to --intra_pos or --p_pos to include the last frame."
)
if len(self.intra_pos) != len(set(self.intra_pos)):
print(
f"Found duplicate elements in --intra_pos: {self.intra_pos}.\n"
"They are automatically removed."
)
if len(self.p_pos) != len(set(self.p_pos)):
print(
f"Found duplicate elements in --p_pos: {self.p_pos}.\n"
"They are automatically removed."
)
common_elements = list(set(self.intra_pos).intersection(self.p_pos))
assert not common_elements, (
"Frames can not be an I-frame and a P-frame at the same time!\n"
f"Found --intra_pos={self.intra_pos} --p_pos={self.p_pos}.\n"
f"Frame(s) {common_elements} are in both arguments, they should "
"be present only in one of them."
)
self.frames = self.compute_coding_struct(
self.n_frames, self.intra_pos, self.p_pos
)
[docs]
def compute_coding_struct(
self, n_frames: int, intra_pos: List[int], p_pos: List[int]
) -> List[Frame]:
"""Construct a coding structure of n_frames. The algorithm works as
follows.
Step 1:
-------
Position all the intra frames following ``intra_pos``.
Step 2:
-------
Position all the P frames following ``p_pos``. A P-frame use the
closest frame in the past as a reference.
Step 3:
-------
Automatically fill the remaining frames with hierarchical B-frames.
This is achieved by iterating on the list of frames and inserting
B-frames in between already added frames each time there is a gap.
For instance:
frames = [I0, P4]
==> [I0, B2, P4] # Fill the middle frame
==> [I0, B1, B2, P4] # Fill the middle frame
==> [I0, B1, B2, B3 P4] # Fill the middle frame
Args:
n_frames (int): Number of frames in the coding structure
intra_pos (List[int]): Position of all the intra frames in display
order
p_pos (List[int]): Position of all the P frames in display order
Returns:
List[Frame]: List of all the frames within the coding structure.
"""
frames = []
# ----- Step 1: fill all the intra frames
for idx_display_order in intra_pos:
frames.append(
Frame(
coding_order=len(frames),
display_order=idx_display_order,
index_references=[],
depth=0, # All intra depth is 0
seq_name=self.seq_name, # Not very elegant... but useful!
frame_offset=self.frame_offset,
)
)
frames.sort(key=lambda x: x.display_order)
def get_closest_past_ref(idx_display_order: int, frames: List[Frame]) -> Frame:
"""Return the biggest display_order present in frames that is still
smaller than idx_display_order. It corresponds to the index of the
closest past reference.
**Everything is in display order**.
Args:
idx_display_order (int): Display index of the frame for which
we want to find the closest past reference.
frames (List[Frames]): List of the already coded (available) frames.
Returns:
int: Display order of the closest past reference.
"""
frames.sort(key=lambda x: x.display_order)
# The * P-frame will used the P3 frame as reference regardless
# of the actual display order of the * P-frame (from 4 to 7)
# I0 P3 * I8
closest_frame = frames[0]
for frame in frames:
if frame.display_order >= idx_display_order:
break
closest_frame = frame
return closest_frame
def get_closest_future_ref(idx_display_order: int, frames: List[Frame]) -> int:
"""Return the smallest display_order present in frames that is still
bigger than idx_display_order. It corresponds to the index of the
closest future reference.
**Everything is in display order**.
Args:
idx_display_order (int): Display index of the frame for which
we want to find the closest future reference.
frames (List[Frames]): List of the already coded (available) frames.
Returns:
int: Display order of the closest future reference.
"""
frames.sort(key=lambda x: x.display_order, reverse=True)
# The * P-frame will used the I8 frame as reference regardless
# of the actual display order of the * P-frame (from 4 to 7)
# I0 P3 * I8
closest_frame = frames[0]
for frame in frames:
if frame.display_order <= idx_display_order:
break
closest_frame = frame
return closest_frame
# ----- Step 2: fill all the P frames
for idx_display_order in p_pos:
past_ref = get_closest_past_ref(idx_display_order, frames)
frames.append(
Frame(
coding_order=len(frames),
display_order=idx_display_order,
index_references=[past_ref.display_order],
depth=past_ref.depth + 1,
seq_name=self.seq_name,
frame_offset=self.frame_offset,
)
)
frames.sort(key=lambda x: x.display_order)
# ----- Step 3: Fill out the blanks with B-frames in a hierarchical manner
# Stop when we've filled the coding structure with n_frames
while len(frames) < n_frames:
# Iterate on the frames list and stop each time we find a "gap".
# Create a B frame right in the middle of this gap.
for i in range(n_frames):
# Case 1: we've already constructed this frame
already_coded_frames = [x.display_order for x in frames]
if i in already_coded_frames:
continue
# Case 2: we need to construct a new frame
past_ref = get_closest_past_ref(i, frames)
future_ref = get_closest_future_ref(i, frames)
# The display order of the frame being creating is equal to the
# past reference + half the distance between its 2 references
ref_distance = future_ref.display_order - past_ref.display_order
idx_display_order = past_ref.display_order + ref_distance // 2
frames.append(
Frame(
coding_order=len(frames),
display_order=idx_display_order,
index_references=[
past_ref.display_order,
future_ref.display_order,
],
depth=max([past_ref.depth, future_ref.depth]) + 1,
seq_name=self.seq_name,
frame_offset=self.frame_offset,
)
)
frames.sort(key=lambda x: x.display_order)
# Loop once more!
break
return frames
[docs]
def pretty_structure_diagram(self) -> str:
"""Return a nice diagram presenting the coding structure. Like:
.. code::
I0 -----------------------------------------------------> P8
\-------------------------> B4 <-------------------------/
\----------> B2 <---------/ \----------> B6 <----------/
\--> B1 <--/ \--> B3 <--/ \--> B5 <--/ \--> B7 <--/
Returns:
str: A string describing the coding structure. Ready to be printed.
"""
# Handle edge case where there is a single frame to be coded
if self.n_frames == 1:
return "I0"
_LENGTH_PRINT = 10 * self.n_frames
all_x_pos = [
round(x / (self.n_frames - 1) * _LENGTH_PRINT) for x in range(self.n_frames)
]
# print(all_x_pos)
lines = []
# print(f"{'frame':<8}{'spacing':<8}{'l_len':<8}{'r_len':<8}{'x_pos':<8}")
for depth in range(self.get_max_depth() + 1):
current_x_pos = 0
s = ""
for frame in self.get_all_frames_of_depth(depth):
frame_str = f"{frame.frame_type}{frame.display_order}"
# No ref, whitespace left and right
if frame.frame_type == "I":
spacing = all_x_pos[frame.display_order] - current_x_pos
s += f"{' ' * spacing}{frame_str}"
current_x_pos += spacing + len(frame_str)
# Only past ref, \----> on the left, only whitespace on the right
elif frame.frame_type == "P":
# All frames requires at least to character e.g. B4 +
# one more additional character for each additional digit
len_left_ref_str = (
2
+ int(frame.index_references[0] >= 10)
+ int(frame.index_references[0] >= 100)
)
spacing = (
all_x_pos[frame.index_references[0]]
+ len_left_ref_str
- current_x_pos
)
left_arrow_len = (
all_x_pos[frame.display_order]
- all_x_pos[frame.index_references[0]]
- len_left_ref_str
)
s += f"{' ' * spacing}"
s += f"\{'-' * (left_arrow_len - 3)}> "
s += f"{frame_str}"
# print(f"{frame_str:<8}{spacing:<8}{left_arrow_len:<8}{' ':<8}{current_x_pos:<8}")
current_x_pos += spacing + left_arrow_len + len(frame_str)
# Past and future \---> on the left, <---/ on the right
elif frame.frame_type == "B":
# All frames requires at least to character e.g. B4 +
# one more additional character for each additional digit
len_left_ref_str = (
2
+ int(frame.index_references[0] >= 10)
+ int(frame.index_references[0] >= 100)
)
spacing = (
all_x_pos[frame.index_references[0]]
+ len_left_ref_str
- current_x_pos
)
left_arrow_len = (
all_x_pos[frame.display_order]
- all_x_pos[frame.index_references[0]]
- len_left_ref_str
)
right_arrow_len = (
all_x_pos[frame.index_references[1]]
- all_x_pos[frame.display_order]
- len(frame_str)
)
s += f"{' ' * spacing}"
s += f"\{'-' * (left_arrow_len - 3)}> "
s += f"{frame_str}"
s += f" <{'-' * (right_arrow_len - 3)}/"
# print(f"{frame_str:<8}{spacing:<8}{left_arrow_len:<8}{right_arrow_len:<8}{current_x_pos:<8}")
current_x_pos += (
spacing + left_arrow_len + right_arrow_len + len(frame_str)
)
lines.append(s)
results = "\n".join(lines)
return results
[docs]
def pretty_string(self, print_detailed_struct: bool = False) -> str:
"""Return a pretty string formatting the data within the class
Args:
print_detailed_struct: True to print the detailed coding structure
Returns:
str: a pretty string ready to be printed out
"""
COL_WIDTH = 18
s = "Coding configuration:\n"
s += "---------------------\n"
s += f"{'n_frames':<26}: {self.n_frames}\n"
s += f"{'frame_offset':<26}: {self.frame_offset}\n"
s += f"{'seq_name':<26}: {self.seq_name}\n"
s += f"{'intra_pos':<26}: {', '.join([str(x) for x in self.intra_pos])}\n"
s += f"{'p_pos':<26}: {', '.join([str(x) for x in self.p_pos])}\n\n"
if not print_detailed_struct:
return s
# Print row after tow
for idx_coding_order in range(len(self.frames)):
cur_frame = self.get_frame_from_coding_order(idx_coding_order)
s += cur_frame.pretty_string(
show_header=idx_coding_order == 0,
show_bottom_line=idx_coding_order == len(self.frames) - 1
)
s += self.pretty_structure_diagram()
return s
[docs]
def get_number_of_frames(self) -> int:
"""Return the number of frames in the coding structure.
Returns:
Number of frames in the coding structure.
"""
return len(self.frames)
[docs]
def get_max_depth(self) -> int:
"""Return the maximum depth of a coding configuration
Returns:
Maximum depth of the coding configuration
"""
return max([frame.depth for frame in self.frames])
[docs]
def get_all_frames_of_depth(self, depth: int) -> List[Frame]:
"""Return a list with all the frames for a given depth
Args:
depth: Depth for which we want the frames.
Returns:
List of frames with the given depth
"""
return [frame for frame in self.frames if frame.depth == depth]
[docs]
def get_max_coding_order(self) -> int:
"""Return the maximum coding order of a coding configuration
Returns:
Maximum coding order of the coding configuration
"""
return max([frame.coding_order for frame in self.frames])
[docs]
def get_frame_from_coding_order(self, coding_order: int) -> Optional[Frame]:
"""Return the frame whose coding order is equal to ``coding_order``.
Return ``None`` if no frame has been found.
Args:
coding_order: Coding order for which we want the frame.
Returns:
Frame whose coding order is equal to ``coding_order``.
"""
for frame in self.frames:
if frame.coding_order == coding_order:
return frame
return None
[docs]
def get_max_display_order(self) -> int:
"""Return the maximum display order of a coding configuration
Returns:
Maximum display order of the coding configuration
"""
return max([frame.display_order for frame in self.frames])
[docs]
def get_frame_from_display_order(self, display_order: int) -> Optional[Frame]:
"""Return the frame whose display order is equal to ``display_order``.
Return None if no frame has been found.
Args:
display_order: Coding order for which we want the frame.
Returns:
Frame whose coding order is equal to ``display_order``.
"""
for frame in self.frames:
if frame.display_order == display_order:
return frame
return None
[docs]
def get_all_frames_using_one_ref(self, display_order_ref: int) -> List[Frame]:
"""Return a list of frames using the frame <display_order_ref> as
a reference.
Args:
display_order_ref: Display order of the frame that is used as reference
Returns:
List[Frame]: List of frames using one given frame as a reference.
"""
res = []
for frame in self.frames:
if display_order_ref in frame.index_references:
res.append(frame)
return res