Source code for src.model.permuted

import torch
import torch.nn as nn


[docs] class Permuted(nn.Module): """ Dimension permutation layer. Wraps `Tensor.permute` so it can be used inside `nn.Sequential`. """
[docs] def __init__(self, *dims): """ Args: *dims (int): Target dimension ordering. """ super().__init__() self.dims = dims
[docs] def forward(self, x): return x.permute(*self.dims)