Auto-Regressive Module (ARM)¶

class Arm[source]¶

Instantiate an autoregressive probability module, modelling the conditional distribution \(p_{\psi}(\hat{y}_i \mid \mathbf{c}_i)\) of a (quantized) latent pixel \(\hat{y}_i\), conditioned on neighboring already decoded context pixels \(\mathbf{c}_i \in \mathbb{Z}^C\), where \(C\) denotes the number of context pixels.

The distribution \(p_{\psi}\) is assumed to follow a Laplace distribution, parameterized by an expectation \(\mu\) and a scale \(b\), where the scale and the variance \(\sigma^2\) are related as follows \(\sigma^2 = 2 b ^2\).

The parameters of the Laplace distribution for a given latent pixel \(\hat{y}_i\) are obtained by passing its context pixels \(\mathbf{c}_i\) through an MLP \(f_{\psi}\):

\[p_{\psi}(\hat{y}_i \mid \mathbf{c}_i) \sim \mathcal{L}(\mu_i, b_i), \text{ where } \mu_i, b_i = f_{\psi}(\mathbf{c}_i).\]


The MLP \(f_{\psi}\) has a few constraint on its architecture:

  • The width of all hidden layers (i.e. the output of all layers except the final one) are identical to the number of pixel contexts \(C\);

  • All layers except the last one are residual layers, followed by a ReLU non-linearity;

  • \(C\) must be at a multiple of 8.

The MLP \(f_{\psi}\) is made of custom Linear layers instantiated from the ArmLinear class.

__init__(dim_arm: int, n_hidden_layers_arm: int)[source]¶
  • dim_arm (int) – Number of context pixels AND dimension of all hidden layers \(C\).

  • n_hidden_layers_arm (int) – Number of hidden layers. Set it to 0 for a linear ARM.

forward(x: Tensor) Tuple[Tensor, Tensor, Tensor][source]¶

Perform the auto-regressive module (ARM) forward pass. The ARM takes as input a tensor of shape \([B, C]\) i.e. \(B\) contexts with \(C\) context pixels. ARM outputs \([B, 2]\) values correspond to \(\mu, b\) for each of the \(B\) input pixels.


Note that the ARM expects input to be flattened i.e. spatial dimensions \(H, W\) are collapsed into a single batch-like dimension \(B = HW\), leading to an input of shape \([B, C]\), gathering the \(C\) contexts for each of the \(B\) pixels to model.


The ARM MLP does not output directly the scale \(b\). Denoting \(s\) the raw output of the MLP, the scale is obtained as follows:

\[b = e^{x - 4}\]

x (Tensor) – Concatenation of all input contexts \(\mathbf{c}_i\). Tensor of shape \([B, C]\).


Concatenation of all Laplace distributions param \(\mu, b\). Tensor of shape :math:([B]). Also return the log scale \(s\) as described above. Tensor of shape \((B)\)

Return type:

Tuple[Tensor, Tensor, Tensor]

get_param() OrderedDict[str, Tensor][source]¶

Return a copy of the weights and biases inside the module.


A copy of all weights & biases in the layers.

Return type:

OrderedDict[str, Tensor]

set_param(param: OrderedDict[str, Tensor]) None[source]¶

Replace the current parameters of the module with param.


param (OrderedDict[str, Tensor]) – Parameters to be set.

Return type:


reinitialize_parameters() None[source]¶

Re-initialize in place the parameters of all the ArmLinear layer.

Return type:


class ArmLinear[source]¶

Create a Linear layer of the Auto-Regressive Module (ARM). This is a wrapper around the usual nn.Linear layer of PyTorch, with a custom initialization. It performs the following operations:

  • \(\mathbf{x}_{out} = \mathbf{W}\mathbf{x}_{in} + \mathbf{b}\) if residual is False

  • \(\mathbf{x}_{out} = \mathbf{W}\mathbf{x}_{in} + \mathbf{b} + \mathbf{x}_{in}\) if residual is True.

The input \(\mathbf{x}_{in}\) is a \([B, C_{in}]\) tensor, the output \(\mathbf{x}_{out}\) is a \([B, C_{out}]\) tensor.

The layer weight and bias shapes are \(\mathbf{W} \in \mathbb{R}^{C_{out} \times C_{in}}\) and \(\mathbf{b} \in \mathbb{R}^{C_{out}}\).

__init__(in_channels: int, out_channels: int, residual: bool = False)[source]¶
  • in_channels (int) – Number of input features \(C_{in}\).

  • out_channels (int) – Number of output features \(C_{out}\).

  • residual (bool) – True to add a residual connexion to the layer. Defaults to False.

initialize_parameters() None[source]¶

Initialize in place the weight and the bias of the linear layer.

  • Biases are always set to zero.

  • Weights are set to zero if residual == True. Otherwise, sample from the Normal distribution: \(\mathbf{W} \sim \mathcal{N}(0, \tfrac{1}{(C_{out})^4})\).

Return type:


forward(x: Tensor) Tensor[source]¶

Perform the forward pass of this layer.


x (Tensor) – Input tensor of shape \([B, C_{in}]\).


Tensor with shape \([B, C_{out}]\).

Return type:
