biapy.engine.train_engine
Training and evaluation engine for BiaPy.
This module provides functions to train and evaluate deep learning models for one epoch, handling distributed training, logging, learning rate scheduling, and memory bank operations for contrastive/self-supervised learning.
- biapy.engine.train_engine.train_one_epoch(cfg: CfgNode, model: Module | DistributedDataParallel, model_call_func: Callable, loss_function: Callable, metric_function: Callable, prepare_targets: Callable, data_loader: DataLoader, optimizer: Optimizer, device: device, epoch: int, log_writer: TensorboardLogger | None = None, lr_scheduler: ReduceLROnPlateau | WarmUpCosineDecayScheduler | OneCycleLR | None = None, verbose: bool = False, memory_bank: MemoryBank | None = None, total_iters: int = 0, contrast_warmup_iters: int = 0)[source]
Train the model for one epoch.
Handles forward and backward passes, loss computation, metric logging, optimizer steps, learning rate scheduling, and optional memory bank updates.
- Parameters:
cfg (CN) – BiaPy configuration node.
model (nn.Module or nn.parallel.DistributedDataParallel) – Model to train.
model_call_func (Callable) – Function to call the model (handles multi-heads, etc.).
loss_function (Callable) – Loss function.
metric_function (Callable) – Metric computation function.
prepare_targets (Callable) – Function to prepare targets for loss/metrics.
data_loader (DataLoader) – Training data loader.
optimizer (Optimizer) – Optimizer for model parameters.
device (torch.device) – Device to use.
epoch (int) – Current epoch number.
log_writer (TensorboardLogger, optional) – Logger for TensorBoard.
lr_scheduler (Scheduler, optional) – Learning rate scheduler.
verbose (bool, optional) – Verbosity flag.
memory_bank (MemoryBank, optional) – Memory bank for contrastive/self-supervised learning.
total_iters (int, optional) – Total iterations completed (for contrastive warmup).
contrast_warmup_iters (int, optional) – Number of warmup iterations for contrastive learning.
- Returns:
dict – Dictionary of averaged metrics for the epoch.
int – Number of steps (batches) processed.
- biapy.engine.train_engine.evaluate(cfg: CfgNode, model: Module | DistributedDataParallel, model_call_func: Callable, loss_function: Callable, metric_function: Callable, prepare_targets: Callable, epoch: int, data_loader: DataLoader, lr_scheduler: ReduceLROnPlateau | WarmUpCosineDecayScheduler | OneCycleLR | None = None, memory_bank: MemoryBank | None = None)[source]
Evaluate the model on the validation set.
Runs the model in evaluation mode, computes loss and metrics, and updates learning rate scheduler if needed.
- Parameters:
cfg (CN) – BiaPy configuration node.
model (nn.Module or nn.parallel.DistributedDataParallel) – Model to evaluate.
model_call_func (Callable) – Function to call the model.
loss_function (Callable) – Loss function.
metric_function (Callable) – Metric computation function.
prepare_targets (Callable) – Function to prepare targets for loss/metrics.
epoch (int) – Current epoch number.
data_loader (DataLoader) – Validation data loader.
lr_scheduler (Scheduler, optional) – Learning rate scheduler.
memory_bank (MemoryBank, optional) – Memory bank for contrastive/self-supervised learning.
- Returns:
Dictionary of averaged metrics for the validation set.
- Return type:
dict