Source code for enc.component.intercoding.warp

# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md

# Some code here is inspired from
# https://github.com/microsoft/DCVC/blob/main/DCVC-FM/src/models/block_mc.py
# License: MIT

import torch
from torch import Tensor, nn


# backward_grid = None

[docs] def warp_fn(x: Tensor, flow: Tensor) -> Tensor: """Motion compensation (warping) of a tensor [B, C, H, W] with a 2-d displacement [B, 2, H, W]. Args: x (Tensor): Tensor to be motion compensated [B, C, H, W]. flow (Tensor): Displacement [B, C, H, W]. flow[:, 0, :, :] corresponds to the horizontal displacement. flow[:, 1, :, :] is the vertical displacement. Returns: Tensor: Motion compensated tensor [B, C, H, W]. """ B, _, H, W = x.size() cur_device = x.device # TODO: Could be better managed by avoiding the reallocation each time # TODO: we call the warp function # global backward_grid # if backward_grid is None: tensor_hor = ( torch.linspace(-1.0, 1.0, W, device=cur_device, dtype=torch.float32) .view(1, 1, 1, W) .expand(B, -1, H, -1) ) tensor_ver = ( torch.linspace(-1.0, 1.0, H, device=cur_device, dtype=torch.float32) .view(1, 1, H, 1) .expand(B, -1, -1, W) ) backward_grid = torch.cat([tensor_hor, tensor_ver], 1) flow = torch.cat( [ flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0), ], dim=1, ) grid = backward_grid + flow output = nn.functional.grid_sample( x, grid.permute(0, 2, 3, 1), mode="bilinear", padding_mode="border", align_corners=True, ) return output