src.tasks.train_once module

Contents

src.tasks.train_once module#

src.tasks.train_once.train_once(args, train_loaders: List[torch.utils.data.DataLoader], model: Model, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, diffusion: DDPM, epoch: int) Dict[source]#

Run one training epoch over the provided loaders.

Parameters:
  • args – Global arguments.

  • train_loaders – Training dataloaders.

  • model – Model to train.

  • optimizer – Optimizer.

  • criterion – Loss function.

  • diffusion – Diffusion module.

  • epoch – Current epoch index.

Returns:

Training metrics dictionary.

Return type:

all_records