Source code for src.model.permuted
import torch
import torch.nn as nn
[docs]
class Permuted(nn.Module):
"""
Dimension permutation layer.
Wraps `Tensor.permute` so it can be used inside `nn.Sequential`.
"""
[docs]
def __init__(self, *dims):
"""
Args:
*dims (int): Target dimension ordering.
"""
super().__init__()
self.dims = dims
[docs]
def forward(self, x):
return x.permute(*self.dims)