Source code for biapy.engine

import os
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR
import timm
import timm.optim.optim_factory as optim_factory

from biapy.engine.schedulers.warmup_cosine_decay import WarmUpCosineDecayScheduler
from biapy.utils.misc import NativeScalerWithGradNormCount as NativeScaler
from biapy.utils.callbacks import EarlyStopping

[docs]def prepare_optimizer(cfg, model_without_ddp, steps_per_epoch): """Select the optimizer, loss and metrics for the given model. Parameters ---------- cfg : YACS CN object Configuration. """ lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR opt_args = {} if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]: opt_args["betas"] = cfg.TRAIN.OPT_BETAS optimizer = optim_factory.create_optimizer_v2(model_without_ddp, opt=cfg.TRAIN.OPTIMIZER, lr=lr, weight_decay=cfg.TRAIN.W_DECAY, **opt_args) print(optimizer) # Learning rate schedulers lr_scheduler = None if cfg.TRAIN.LR_SCHEDULER.NAME != '': if cfg.TRAIN.LR_SCHEDULER.NAME == 'reduceonplateau': lr_scheduler = ReduceLROnPlateau(optimizer, patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR,) elif cfg.TRAIN.LR_SCHEDULER.NAME == 'warmupcosine': lr_scheduler = WarmUpCosineDecayScheduler(lr=cfg.TRAIN.LR, min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, epochs=cfg.TRAIN.EPOCHS) elif cfg.TRAIN.LR_SCHEDULER.NAME == 'onecycle': lr_scheduler = OneCycleLR(optimizer, cfg.TRAIN.LR, epochs=cfg.TRAIN.EPOCHS, steps_per_epoch=steps_per_epoch) loss_scaler = NativeScaler() return optimizer, lr_scheduler, loss_scaler
[docs]def build_callbacks(cfg): """Create training and validation generators. Parameters ---------- cfg : YACS CN object Configuration. Returns ------- callbacks : List of callbacks All callbacks to be applied to a model. """ # Stop early and restore the best model weights when finished the training earlystopper = None if cfg.TRAIN.PATIENCE != -1: earlystopper = EarlyStopping(patience=cfg.TRAIN.PATIENCE) # if cfg.TRAIN.PROFILER: # tb_callback = tf.keras.callbacks.TensorBoard(log_dir=cfg.PATHS.PROFILER, profile_batch=cfg.TRAIN.PROFILER_BATCH_RANGE) return earlystopper