LossΒΆ

loss_function(
decoded_image: Tensor | DictTensorYUV,
rate_latent_bit: Tensor,
target_image: Tensor | DictTensorYUV,
lmbda: float = 0.001,
rate_mlp_bit: float = 0.0,
compute_logs: bool = False,
) LossFunctionOutput[source]ΒΆ

Compute the loss and a few other quantities. The loss equation is:

\[\begin{split}\mathcal{L} = ||\mathbf{x} - \hat{\mathbf{x}}||^2 + \lambda (\mathrm{R}(\hat{\mathbf{x}}) + \mathrm{R}_{NN}), \text{ with } \begin{cases} \mathbf{x} & \text{the original image}\\ \hat{\mathbf{x}} & \text{the coded image}\\ \mathrm{R}(\hat{\mathbf{x}}) & \text{A measure of the rate of } \hat{\mathbf{x}} \\ \mathrm{R}_{NN} & \text{The rate of the neural networks} \end{cases}\end{split}\]

Warning

There is no back-propagation through the term \(\mathrm{R}_{NN}\). It is just here to be taken into account by the rate-distortion cost so that it better reflects the compression performance.

Parameters:
  • decoded_image (Tensor | DictTensorYUV) – The decoded image, either as a Tensor for RGB or YUV444 data, or as a dictionary of Tensors for YUV420 data.

  • rate_latent_bit (Tensor) – Tensor with the rate of each latent value. The rate is in bit.

  • target_image (Tensor | DictTensorYUV) – The target image, either as a Tensor for RGB or YUV444 data, or as a dictionary of Tensors for YUV420 data.

  • lmbda (float) – Rate constraint. Defaults to 1e-3.

  • rate_mlp_bit (float) – Sum of the rate allocated for the different neural networks. Rate is in bit. Defaults to 0.0.

  • compute_logs (bool) – True to output a few more quantities beside the loss. Defaults to False.

Returns:

Object gathering the different quantities computed by this loss function. Chief among them: the loss itself.

Return type:

LossFunctionOutput

class LossFunctionOutput[source]ΒΆ

Output for FrameEncoder.loss_function

__init__(
*,
loss: float | None = None,
mse: float | None = None,
rate_nn_bpp: float | None = None,
rate_latent_bpp: float | None = None,
) NoneΒΆ
Parameters:
  • loss (float | None)

  • mse (float | None)

  • rate_nn_bpp (float | None)

  • rate_latent_bpp (float | None)

Return type:

None