src.tasks.train_once module#
- src.tasks.train_once.train_once(args, train_loaders: List[torch.utils.data.DataLoader], model: Model, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, diffusion: DDPM, epoch: int) Dict[source]#
Run one training epoch over the provided loaders.
- Parameters:
args – Global arguments.
train_loaders – Training dataloaders.
model – Model to train.
optimizer – Optimizer.
criterion – Loss function.
diffusion – Diffusion module.
epoch – Current epoch index.
- Returns:
Training metrics dictionary.
- Return type:
all_records