Source code for src.model.residual
import torch
import torch.nn as nn
[docs]
class Residual(nn.Module):
"""
Residual block with an optional projection shortcut.
Implements `y = f(x) + x`. When the input and output dimensions differ, a
linear projection is applied to map `x` to the output dimension.
"""
[docs]
def __init__(self, *layers, input_dim=None, output_dim=None):
"""
Args:
*layers (nn.Module): Sequence of layers on the main path.
input_dim (int, optional): Input dimension. Needed only when the input and output dimensions differ.
output_dim (int, optional): Output dimension.
"""
super().__init__()
self.net = nn.Sequential(*layers)
self.need_proj = (
input_dim is not None and output_dim is not None and input_dim != output_dim
)
if self.need_proj:
self.proj = nn.Linear(input_dim, output_dim)
[docs]
def forward(self, x):
if self.need_proj:
return self.proj(x) + self.net(x)
else:
return x + self.net(x)