# Software Name: Cool-Chic
# SPDX-FileCopyrightText: Copyright (c) 2023-2024 Orange
# SPDX-License-Identifier: BSD 3-Clause "New"
#
# This software is distributed under the BSD-3-Clause license.
#
# Authors: see CONTRIBUTORS.md
import math
from typing import List, OrderedDict
import torch
import torch.nn.functional as F
from torch import Tensor, nn
[docs]
class SynthesisConv2d(nn.Module):
"""Instantiate a synthesis layer applying the following operation to an
input tensor :math:`\\mathbf{x}` with shape :math:`[B, C_{in}, H, W]`, producing
an output tensor :math:`\\mathbf{y}` with shape :math:`[B, C_{out}, H, W]`.
.. math::
\\mathbf{y} =
\\begin{cases}
\mathrm{conv}(\\mathbf{x}) + \\mathbf{x} & \\text{if residual,} \\\\
\mathrm{conv}(\\mathbf{x}) & \\text{otherwise.} \\\\
\\end{cases}
"""
[docs]
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
residual: bool = False,
):
"""
Args:
in_channels: Number of input channels :math:`C_{in}`.
out_channels: Number of output channels :math:`C_{out}`.
kernel_size: Kernel size (height and width are identical)
residual: True to add a residual connexion to the layer.
Default to False.
"""
super().__init__()
self.pad = int((kernel_size - 1) / 2)
self.residual = residual
# -------- Instantiate empty parameters, set by the initialize function
self.groups = 1 # Hardcoded for now
self.weight = nn.Parameter(
torch.empty(
out_channels, in_channels // self.groups, kernel_size, kernel_size
),
requires_grad=True,
)
self.bias = nn.Parameter(torch.empty((out_channels)), requires_grad=True)
self.initialize_parameters()
# -------- Instantiate empty parameters, set by the initialize function
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Perform the forward pass of this layer.
Args:
x: Input tensor of shape :math:`[B, C_{in}, H, W]`.
Returns:
Output tensor of shape :math:`[B, C_{out}, H, W]`.
"""
padded_x = F.pad(x, (self.pad, self.pad, self.pad, self.pad), mode="replicate")
y = F.conv2d(padded_x, self.weight, self.bias, groups=self.groups)
if self.residual:
return y + x
else:
return y
[docs]
def initialize_parameters(self) -> None:
"""Initialize **in place** the weights and biases of the
``SynthesisConv2d`` layer.
* Biases are always set to zero.
* Weights are set to zero if ``residual`` is ``True``. Otherwise, they
follow a Uniform distribution: :math:`\\mathbf{W} \sim
\\mathcal{U}(-a, a)`, where :math:`a =
\\frac{1}{C_{out}^2\\sqrt{C_{in}k^2}}` with :math:`k` the kernel size.
"""
self.bias = nn.Parameter(torch.zeros_like(self.bias), requires_grad=True)
if self.residual:
self.weight = nn.Parameter(
torch.zeros_like(self.weight), requires_grad=True
)
else:
# Default PyTorch initialization for convolution 2d: weight ~ Uniform(-sqrt(k), sqrt(k))
# Empirically, it works better if we further divide the resulting weights by output_ft ** 2
out_channel, in_channel_divided_by_group, kernel_height, kernel_weight = (
self.weight.size()
)
in_channel = in_channel_divided_by_group * self.groups
k = self.groups / (in_channel * kernel_height * kernel_weight)
sqrt_k = math.sqrt(k)
self.weight = nn.Parameter(
(torch.rand_like(self.weight) - 0.5) * 2 * sqrt_k / (out_channel**2),
requires_grad=True,
)
[docs]
class Synthesis(nn.Module):
"""Instantiate Cool-chic convolution-based synthesis transform. It
performs the following operation.
.. math::
\hat{\mathbf{x}} = f_{\\theta}(\hat{\mathbf{z}}).
Where :math:`\hat{\mathbf{x}}` is the :math:`[B, C_{out}, H, W]`
synthesis output, :math:`\hat{\mathbf{z}}` is the :math:`[B, C_{in}, H,
W]` synthesis input (i.e. the upsampled latent variable) and
:math:`\\theta` the synthesis parameters.
The synthesis is composed of one or more convolution layers,
instantiated using the class ``SynthesisConv2d``. The parameter
``layers_dim`` set the synthesis architecture. Each layer is described
as follows: ``<output_dim>-<kernel_size>-<type>-<non_linearity>``
* ``output_dim``: number of output features :math:`C_{out}`.
* ``kernel_size``: spatial dimension of the kernel. Use 1 to mimic an MLP.
* ``type``: either ``linear`` or ``residual`` *i.e.*
.. math::
\\mathbf{y} =
\\begin{cases}
\mathrm{conv}(\\mathbf{x}) + \\mathbf{x} & \\text{if residual,} \\\\
\mathrm{conv}(\\mathbf{x}) & \\text{otherwise.} \\\\
\\end{cases}
* ``non_linearity``: either ``none`` (no non-linearity) or ``relu``.
The non-linearity is applied after the residual connexion if any.
Example of a convolution layer with 40 input features, 3 output features, a
residual connexion followed with a relu: ``40-3-residual-relu``
"""
possible_non_linearity = {
"none": nn.Identity,
"relu": nn.ReLU,
# "leakyrelu": nn.LeakyReLU, # Unsupported by the decoder
# "gelu": nn.GELU, # Unsupported by the decoder
}
possible_mode = ["linear", "residual"]
[docs]
def __init__(self, input_ft: int, layers_dim: List[str]):
"""
Args:
input_ft: Number of input features :math:`C_{in}`. This corresponds
to the number of latent features.
layers_dim: Description of each synthesis layer as a list of strings
following the notation detailed above.
"""
super().__init__()
layers_list = nn.ModuleList()
# Construct the hidden layer(s)
for layers in layers_dim:
out_ft, k_size, mode, non_linearity = layers.split("-")
out_ft = int(out_ft)
k_size = int(k_size)
# Check that mode and non linearity is correct
assert (
mode in Synthesis.possible_mode
), f"Unknown mode. Found {mode}. Should be in {Synthesis.possible_mode}"
assert non_linearity in Synthesis.possible_non_linearity, (
f"Unknown non linearity. Found {non_linearity}. "
f"Should be in {Synthesis.possible_non_linearity.keys()}"
)
# Instantiate them
layers_list.append(
SynthesisConv2d(input_ft, out_ft, k_size, residual=mode == "residual")
)
layers_list.append(Synthesis.possible_non_linearity[non_linearity]())
input_ft = out_ft
self.layers = nn.Sequential(*layers_list)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Perform the synthesis forward pass :math:`\hat{\mathbf{x}} =
f_{\\theta}(\hat{\mathbf{z}})`, where :math:`\hat{\mathbf{x}}` is the
:math:`[B, C_{out}, H, W]` synthesis output, :math:`\hat{\mathbf{z}}` is
the :math:`[B, C_{in}, H, W]` synthesis input (i.e. the upsampled latent
variable) and :math:`\\theta` the synthesis parameters.
Args:
x: Dense latent representation :math:`[B, C_{in}, H, W]`.
Returns:
Raw output features :math:`[B, C_{out}, H, W]`.
"""
return self.layers(x)
[docs]
def get_param(self) -> OrderedDict[str, Tensor]:
"""Return **a copy** of the weights and biases inside the module.
Returns:
A copy of all weights & biases in the layers.
"""
# Detach & clone to create a copy
return OrderedDict({k: v.detach().clone() for k, v in self.named_parameters()})
[docs]
def set_param(self, param: OrderedDict[str, Tensor]):
"""Replace the current parameters of the module with param.
Args:
param: Parameters to be set.
"""
self.load_state_dict(param)
[docs]
def reinitialize_parameters(self) -> None:
"""Re-initialize in place the parameters of all the SynthesisConv2d layer."""
for layer in self.layers.children():
if isinstance(layer, SynthesisConv2d):
layer.initialize_parameters()