Source code for src.model.multiscale_cnn
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class MultiScaleCNN(nn.Module):
"""
Multi-scale convolutional neural network module.
Uses convolution layers with different kernel sizes (`3x3`, `5x5`) and
dilation rates in parallel to capture map information at multiple scales,
from local detail to larger structure.
"""
[docs]
def __init__(self, args):
"""
Args:
args (Namespace): Configuration parameters containing:
- map_feature_dim (int): Number of input feature channels.
- model_dim (int): Number of output feature channels.
"""
super().__init__()
dim = args.map_feature_dim
# Branch 1: 3x3 receptive field for local detail.
self.branch1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
# Branch 2: 5x5 receptive field for medium-scale structure.
self.branch2 = nn.Conv2d(dim, dim, kernel_size=5, padding=2)
# Branch 3: dilated convolution for larger context.
self.branch3 = nn.Conv2d(dim, dim, kernel_size=3, padding=2, dilation=2)
self.fusion = nn.Conv2d(dim * 3, args.model_dim, kernel_size=1)
[docs]
def forward(self, x):
"""
Args:
x (torch.Tensor): Input feature map.
Shape: (batch_size, in_channels, height, width)
Returns:
torch.Tensor: Output after fusing multi-scale features.
Shape: (batch_size, out_channels, height, width)
"""
x1 = F.relu(self.branch1(x))
x2 = F.relu(self.branch2(x))
x3 = F.relu(self.branch3(x))
# Concatenate features from all branches.
out = torch.cat([x1, x2, x3], dim=1)
return self.fusion(out)