# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
from dataclasses import dataclass, field, fields
from typing import Literal, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, nn
POSSIBLE_WARP_MODE = Literal[
"torch_nearest", "torch_bilinear", "torch_bicubic", "bilinear", "bicubic", "sinc"
]
[docs]
@dataclass
class WarpParameter:
"""Dataclass storing the parameters of the motion compensation warping"""
filter_size: int
# While we provide an explicit implementation for bilinear and
# bicubic filtering, the actual PyTorch grid_sample implementation
# is faster. Set this flag to False to use our explicit implementation
# nevertheless.
use_torch_if_available: bool = True
# The actual mode is derived from the desired filter_size
mode: Optional[POSSIBLE_WARP_MODE] = field(init=False, default=None)
# At inference time, the flow is constrained to have only
# <fractional_accuracy> possible values.
fractional_accuracy: int = field(init=False, default=64)
def __post_init__(self):
assert self.filter_size % 2 == 0, (
f"Warp filter size should be even. Found filter_size={self.filter_size}."
)
assert self.filter_size >= 2, (
f"Warp filter size should be >= 2. Found filter_size={self.filter_size}."
)
match self.filter_size:
case 2:
self.mode = (
"torch_bilinear" if self.use_torch_if_available else "bilinear"
)
case 4:
self.mode = (
"torch_bicubic" if self.use_torch_if_available else "bicubic"
)
case _:
self.mode = "sinc"
[docs]
def pretty_string(self) -> str:
"""Return a pretty string presenting the WarpParameter.
Returns:
str: Pretty string ready to be printed.
"""
ATTRIBUTE_WIDTH = 25
VALUE_WIDTH = 80
s = ""
for k in fields(self):
s += f"{k.name:<{ATTRIBUTE_WIDTH}}: {str(getattr(self, k.name)):<{VALUE_WIDTH}}\n"
s += "\n"
return s
[docs]
class Warper(nn.Module):
[docs]
def __init__(self, param: WarpParameter, img_size: Tuple[int, int]):
"""Instantiate a warper module, parameterized by `param`.
Args:
param: Warping parameters (filter length, fractional accuracy).
img_size: [Height, Width].
"""
super().__init__()
self.param = param
self.filter_size = param.filter_size
# Leverage grid sample for these modes
self.native_pytorch_warp = self.param.mode in [
"torch_nearest",
"torch_bilinear",
"torch_bicubic",
]
if self.native_pytorch_warp:
B = 1
H, W = img_size
tensor_hor = (
torch.linspace(-1.0, 1.0, W, dtype=torch.float32)
.view(1, 1, 1, W)
.expand(B, -1, H, -1)
)
tensor_ver = (
torch.linspace(-1.0, 1.0, H, dtype=torch.float32)
.view(1, 1, H, 1)
.expand(B, -1, -1, W)
)
self.register_buffer(
"backward_grid",
torch.cat([tensor_hor, tensor_ver], 1),
persistent=False,
)
# Custom implementation of different interpolation filters, including
# some already offered by PyTorch e.g. bilinear and bicubic
else:
# We always interpolate a point x in [0., 1.[ based on filter_size
# neighbors. With filter_size = 4, we have
# left_neighbor = -1 and right_neighbor = 2 e.g.
#
# -1 0 x 1 2 ==> x is a weighted sum of these 4 neighbors
left_top_neighbor = -int(self.filter_size // 2) + 1
right_bot_neighbor = int(self.filter_size // 2)
grids = []
# + 1 because right half is included
for i in range(left_top_neighbor, right_bot_neighbor + 1):
for j in range(left_top_neighbor, right_bot_neighbor + 1):
grid = self.coords_grid(*img_size) # H, W
grid[:, 0] += j
grid[:, 1] += i
grids.append(grid)
# register_buffer for automatic device management. We set persistent to false
# to simply use the "automatically move to device" function, without
# considering grids as a parameters (i.e. returned by self.parameters())
#
# self.grids dimension is [filter_size ** 2, 2, H, W]
# grids is [filter_size ** 1, 2, H, W]. For each pixel in the HxW frame,
# it stores the x and y indices of each of the filter_size ** 2 neighboring
# values.
# self.grids is something like, with filter_size = 4
# self.grids[:, :, i, j] = [
# [j - 1, i - 1],
# [j , i - 1],
# [j + 1, i - 1],
# [j + 2, i - 1],
# [j - 1, i ],
# [j , i ],
# [j + 1, i ],
# [j + 2, i ],
# ...
# [j + 2, i + 2],
# ]
self.register_buffer(
"grids",
torch.cat(grids, dim=0),
persistent=False,
)
if self.param.mode == "sinc":
# self.half_filter_size = int(self.filter_size // 2)
# Corresponds to \kappa_i in eq. 6 in "Efficient Sub-pixel Motion
# Compensation in Learned Video Codecs".
self.register_buffer(
"relative_neighbor_idx",
# + 1 so that it is included
torch.arange(left_top_neighbor, right_bot_neighbor + 1).view(
1, -1, 1, 1
),
persistent=False,
)
elif self.param.mode == "bicubic":
# ! Exactly like pytorch bicubic grid sample
a = -0.75
bicubic_init = torch.tensor(
[
[0, a, -2 * a, a],
[1, 0, -(a + 3), a + 2],
[0, -a, (2 * a + 3), -(a + 2)],
[0, 0, a, -a],
]
)
self.register_buffer("B", bicubic_init, persistent=False)
elif self.param.mode == "bilinear":
# ! Exactly like pytorch bilinear grid sample
bilinear_init = torch.tensor(
[
[1.0, -1.0],
[0.0, 1.0],
]
)
self.register_buffer("B", bilinear_init, persistent=False)
[docs]
def coords_grid(self, h: int, w: int) -> Tensor:
"""Return a [1, 2, H, W] tensor, where the 1st feature gives the
column index and the 2nd the row index. For instance:
.. code:
coords_grid(3, 5) =
tensor([[[[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.]]]])
Args:
h: height of the grid
w: width of the grids
Returns:
Tensor: Tensor giving the column and row indices.
"""
# [H, W]
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
# [1, 2, H, W]
return torch.stack([x, y], dim=0).unsqueeze(0).float()
[docs]
def get_coeffs(self, s: Tensor) -> Tensor:
"""Generate interpolation coefficients from the fractional displacement
s in [0, 1[.
Args:
s: Fractional displacement for each pixel, shape is [1, 1, H, W].
Returns:
Tensor: Corresponding interpolation coefficients for each pixel,
shape is [1, filter_size, H, W].
"""
if self.param.mode in ["sinc"]:
s = torch.repeat_interleave(s, self.filter_size, dim=1)
# Correspond to eq. 6 in "Efficient Sub-pixel Motion
# Compensation in Learned Video Codecs".
window = torch.cos(
torch.pi * (s - self.relative_neighbor_idx) / self.filter_size
)
coeffs = window * torch.sinc(s - self.relative_neighbor_idx)
elif self.param.mode in ["bilinear", "bicubic"]:
# All these modes behave similarly: we derive the filter coeffs as:
# coeffs = B @ t_exponents
#
# coeffs --> the N taps of the filter
# t_exponents --> a [t^0, .., t^deg] vector
# B --> a [N, deg + 1] matrix allowing to generate coeff
# for any value of t in [0., 1.]
#
h, w = s.size()[-2:]
s = rearrange(s, "1 1 h w -> (h w) 1")
# From here s shape is [HW, max_deg] i.e. each row is s^0, .. , s^max_deg
max_deg = self.B.size()[1]
s_exponents = torch.cat(
[torch.pow(s, exponent=expo) for expo in range(max_deg)], dim=1
)
coeffs = F.linear(s_exponents, self.B, bias=None)
coeffs = rearrange(
coeffs,
"(h w) n_coef -> 1 n_coef h w",
n_coef=self.filter_size,
h=h,
w=w,
)
return coeffs
[docs]
def interpolate_1d(self, neighbors: Tensor, fractional_flow: Tensor) -> Tensor:
"""Performs the interpolations of neighboring integer values to get the
value located at fractional_flow.
Args:
neighbors: [B, filter_size, C, H, W]. We compute B x C x H x W
interpolations in parallel. Each of them has filter_size neighbors.
fractional_flow: [1, 1, H, W]. Fractional flow for each warping.
All channels share the same fractional_flow, hence C=1 here.
All batches share the same fractional_flow, hence B=1 here.
This is used to shift 2d blocks of pixels with B=filter_size.
Returns:
Tensor: [B, C, H, W] the interpolated neighbors
"""
coeffs = self.get_coeffs(fractional_flow)
# Add a one-dimensional channel index in between filter_size and h w
# this will be broadcasted to all the channels in neighbors
coeffs = rearrange(coeffs, "b filter_size h w -> b filter_size 1 h w")
res = torch.sum(neighbors * coeffs, dim=1)
return res
[docs]
def forward(self, x: Tensor, flow: Tensor) -> Tensor:
"""Warp a [1, C, H, W] tensor using a [1, 2, H, W] optical flow. The
optical flow is expressed in absolute pixel i.e. an horizontal motion of
-3 means that the pixel output pixel at [i, j] is equal to x[i - 3, j].
y = warp(x, flow) --> y[i, j] = x[i + vertical flow, j + horizontal flow]
where vertical flow is flow[:, 0, :, :] (row-wise) and horizontal flow
is flow[:, 1, :, :] (column wise).
As such, the value in flow describes the motion from y (the warping
result) to x (the reference.)
Args:
x: Tensor to be warped, shape is [1, C, H, W]
flow: Motion to warp x: shape is [1, 2, H, W].
Returns:
Tensor: Warped tensor [1, C, H, W]
"""
_, C, H, W = x.size()
if self.training:
q_flow = flow
else:
q_flow = (
torch.round(flow * self.param.fractional_accuracy)
/ self.param.fractional_accuracy
)
if self.native_pytorch_warp:
q_flow = torch.cat(
[
q_flow[:, 0:1, :, :] / ((W - 1.0) / 2.0),
q_flow[:, 1:2, :, :] / ((H - 1.0) / 2.0),
],
dim=1,
)
grid = self.backward_grid + q_flow
output = nn.functional.grid_sample(
x,
grid.permute(0, 2, 3, 1),
mode=self.param.mode.replace("torch_", ""),
padding_mode="border",
align_corners=True,
)
return output
else:
# We first apply the integer part of the flow using simple re-indexing
# i.e. grid_sample with mode="nearest".
# Then we interpolate to get the fractional flow value.
# No need to backward through that!
rounded_flow = torch.floor(q_flow)
fractional_flow = q_flow - rounded_flow
# neighbors = self.grids + rounded_flow.expand(self.filter_size**2, 2, H, W)
# grids is [filter_size ** 1, 2, H, W]. Rounded flow is [1, 2, H, W]
# shift the position of all neighbors by the integer displacement
neighbors = self.grids + rounded_flow
# Transform absolute pixel values to relative values in [-1, 1]
normalized_neighbors = torch.cat(
[
# fmt: off
2 * torch.clamp(neighbors[:, 0:1, :, :], min=0, max=W - 1) / (W - 1)
- 1,
2 * torch.clamp(neighbors[:, 1:2, :, :], min=0, max=H - 1) / (H - 1)
- 1,
# fmt:on
],
dim=1,
)
# warped_x_rounded_flow shape is [filter_size ** 2, C, H, W]
# Each CxHxW has filter_size ** 2 neighboring values
warped_x_rounded_flow = torch.nn.functional.grid_sample(
x.expand(self.filter_size**2, C, H, W),
normalized_neighbors.permute(0, 2, 3, 1),
align_corners=True,
mode="nearest",
padding_mode="border",
)
# Split the filter_size ** 2 neighboring values into filter_size rows
# and filter_size columns.
stacked_lines = torch.stack(
torch.split(warped_x_rounded_flow, self.filter_size, dim=0), dim=0
)
# First interpolate the filter_size rows alongside columns
# stacked_lines is [filter_size, filter_size, C, H, W]
# interpolated_lines is [filter_size, C, H, W]
interpolated_lines = self.interpolate_1d(
stacked_lines, fractional_flow[:, 0:1]
)
# interpolated_lines.unsqueeze(0) is [1, filter_size, C, H, W]
# rows are shifted to the column dimension through unsqueeze.
interpolated_column = self.interpolate_1d(
interpolated_lines.unsqueeze(0), fractional_flow[:, 1:2]
)
# interpolated_column shape is [1, C, H, W]
return interpolated_column
[docs]
def vanilla_warp_fn(x: Tensor, flow: Tensor, mode: str = "bicubic") -> Tensor:
"""Motion compensation (warping) of a tensor [B, C, H, W] with a 2-d displacement
[B, 2, H, W]. This function does not allows for longer filters than 4 taps
(bicubic) and does not quantize the flows to a given subpixel accuracy in
eval mode.
Some code in this function is inspired from
https://github.com/microsoft/DCVC/blob/main/DCVC-FM/src/models/block_mc.py
License: MIT
Args:
x: Tensor to be motion compensated [B, C, H, W].
flow: Displacement [B, C, H, W]. flow[:, 0, :, :] corresponds to
the horizontal displacement. flow[:, 1, :, :] is the vertical displacement.
Returns:
Tensor: Motion compensated tensor [B, C, H, W].
"""
B, _, H, W = x.size()
cur_device = x.device
tensor_hor = (
torch.linspace(-1.0, 1.0, W, device=cur_device, dtype=torch.float32)
.view(1, 1, 1, W)
.expand(B, -1, H, -1)
)
tensor_ver = (
torch.linspace(-1.0, 1.0, H, device=cur_device, dtype=torch.float32)
.view(1, 1, H, 1)
.expand(B, -1, -1, W)
)
backward_grid = torch.cat([tensor_hor, tensor_ver], 1)
flow = torch.cat(
[
flow[:, 0:1, :, :] / ((W - 1.0) / 2.0),
flow[:, 1:2, :, :] / ((H - 1.0) / 2.0),
],
dim=1,
)
grid = backward_grid + flow
output = nn.functional.grid_sample(
x,
grid.permute(0, 2, 3, 1),
mode=mode,
padding_mode="border",
align_corners=True,
)
return output
if __name__ == "__main__":
# Check that our custom warping works as PyTorch
h, w = 480, 732
dummy_img = torch.rand((1, 3, h, w))
dummy_flow = torch.randn((1, 2, h, w)) * 30
print("Checking that Cool-chic warping behave similarly to PyTorch grid_sample.")
print("PSNR should be above 60 dB\n")
s = f"{'Warping mode':<20}{'PSNR PyTorch / Cool-chic [dB]':<30}\n"
for filter_size in [2, 4]:
warper = Warper(
WarpParameter(
filter_size=filter_size,
# Do not use torch, that's what we want to compare
use_torch_if_available=False,
),
[h, w],
)
warp_coolchic = warper.forward(dummy_img, dummy_flow)
mode = warper.param.mode
warp_torch = vanilla_warp_fn(
dummy_img, dummy_flow, mode=mode.replace("torch_", "")
)
mse = (warp_torch - warp_coolchic).square().mean()
psnr = -10 * torch.log10(mse)
str_psnr = f"{psnr:7.4f}"
s += f"{mode:<20}{str_psnr:<30}\n"
print(s)
print("timing...")
import time
dummy_target = torch.rand_like(dummy_img)
device = "cuda:0"
dummy_img = dummy_img.to(device)
dummy_flow = dummy_flow.to(device)
print(
f"{'Warping mode':<20}"
f"{'Time torch [s]':<30}"
f"{'Time Cool-chic warper [s]':<30}"
f"{'Ratio Cool-chic / torch':<30}"
)
for filter_size in [2, 4, 8, 12]:
N = 200
time_torch = 0
time_coolchic = 0
cool_chic_warper = Warper(
WarpParameter(
filter_size=filter_size,
# Do not use torch, that's what we want to compare
use_torch_if_available=False,
),
[h, w],
)
cool_chic_warper.to(device)
cool_chic_warper.eval()
cool_chic_warper = torch.compile(
cool_chic_warper,
dynamic=False,
mode="reduce-overhead",
fullgraph=True,
)
for idx in range(N):
start_time = time.time()
mode = cool_chic_warper.param.mode
if filter_size in [2, 4]:
warp_torch = vanilla_warp_fn(
dummy_img, dummy_flow, mode=mode.replace("torch_", "")
)
if device == "cuda:0":
torch.cuda.synchronize()
# First N // 2 iterations are warm-up for more accurate time measurements
if idx > N // 2:
time_torch += time.time() - start_time
else:
time_torch = "N/A"
start_time = time.time()
warp_coolchic = cool_chic_warper(dummy_img, dummy_flow)
if device == "cuda:0":
torch.cuda.synchronize()
# First N // 2 iterations are warm-up for more accurate time measurements
if idx > N // 2:
time_coolchic += time.time() - start_time
if time_torch != "N/A":
ratio = f"{time_coolchic / time_torch:.3f}"
time_torch = f"{time_torch:.4f}"
else:
ratio = "N/A"
time_coolchic = f"{time_coolchic:.4f}"
print(
f"{mode + str(cool_chic_warper.param.filter_size):<20}"
f"{time_torch:<30}"
f"{time_coolchic:<30}"
f"{ratio:<30}"
)