Source code for biapy.models

import importlib
import os
import torch
import numpy as np
import torch.nn as nn
from torchinfo import summary 

from biapy.utils.misc import is_main_process
from biapy.engine import prepare_optimizer
from biapy.models.blocks import get_activation

[docs]def build_model(cfg, job_identifier, device): """ Build selected model Parameters ---------- cfg : YACS CN object Configuration. job_identifier: str Job name. device : Torch device Using device ("cpu" or "cuda" for GPU). Returns ------- model : Keras model Selected model. """ # Import the model if 'efficientnet' in cfg.MODEL.ARCHITECTURE.lower(): modelname = 'efficientnet' else: modelname = str(cfg.MODEL.ARCHITECTURE).lower() mdl = importlib.import_module('biapy.models.'+modelname) names = [x for x in mdl.__dict__ if not x.startswith("_")] globals().update({k: getattr(mdl, k) for k in names}) ndim = 3 if cfg.PROBLEM.NDIM == "3D" else 2 # Model building if modelname in ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet']: args = dict(image_shape=cfg.DATA.PATCH_SIZE, activation=cfg.MODEL.ACTIVATION.lower(), feature_maps=cfg.MODEL.FEATURE_MAPS, drop_values=cfg.MODEL.DROPOUT_VALUES, batch_norm=cfg.MODEL.BATCH_NORMALIZATION, k_size=cfg.MODEL.KERNEL_SIZE, upsample_layer=cfg.MODEL.UPSAMPLE_LAYER, z_down=cfg.MODEL.Z_DOWN) if modelname == 'unet': f_name = U_Net elif modelname == 'resunet': f_name = ResUNet elif modelname == 'resunet++': f_name = ResUNetPlusPlus elif modelname == 'attention_unet': f_name = Attention_U_Net elif modelname == 'seunet': f_name = SE_U_Net args['output_channels'] = cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS if cfg.PROBLEM.TYPE == 'INSTANCE_SEG' else None if cfg.PROBLEM.TYPE == 'SUPER_RESOLUTION': args['upsampling_factor'] = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING args['upsampling_position'] = cfg.MODEL.UNET_SR_UPSAMPLE_POSITION args['n_classes'] = cfg.DATA.PATCH_SIZE[-1] else: args['n_classes'] = cfg.MODEL.N_CLASSES if cfg.PROBLEM.TYPE != 'DENOISING' else cfg.DATA.PATCH_SIZE[-1] model = f_name(**args) else: if modelname == 'simple_cnn': model = simple_CNN(image_shape=cfg.DATA.PATCH_SIZE, activation=cfg.MODEL.ACTIVATION.lower(), n_classes=cfg.MODEL.N_CLASSES) elif 'efficientnet' in modelname: shape = (224, 224)+(cfg.DATA.PATCH_SIZE[-1],) if cfg.DATA.PATCH_SIZE[:-1] != (224, 224) else cfg.DATA.PATCH_SIZE model = efficientnet(cfg.MODEL.ARCHITECTURE.lower(), shape, n_classes=cfg.MODEL.N_CLASSES) elif modelname == 'vit': args = dict(img_size=cfg.DATA.PATCH_SIZE[0], patch_size=cfg.MODEL.VIT_TOKEN_SIZE, in_chans=cfg.DATA.PATCH_SIZE[-1], ndim=ndim, num_classes=cfg.MODEL.N_CLASSES, norm_layer=partial(nn.LayerNorm, eps=1e-6)) if cfg.MODEL.VIT_MODEL == "custom": args2 = dict(embed_dim=cfg.MODEL.VIT_EMBED_DIM, depth=cfg.MODEL.VIT_NUM_LAYERS, num_heads=cfg.MODEL.VIT_NUM_HEADS, mlp_ratio=cfg.MODEL.VIT_MLP_RATIO, drop_rate=cfg.MODEL.DROPOUT_VALUES[0]) args.update(args2) model = VisionTransformer(**args) else: model = eval(cfg.MODEL.VIT_MODEL)(**args) elif modelname == 'multiresunet': args = dict(input_channels=cfg.DATA.PATCH_SIZE[-1], ndim=ndim, alpha=1.67, z_down=cfg.MODEL.Z_DOWN) args['output_channels'] = cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS if cfg.PROBLEM.TYPE == 'INSTANCE_SEG' else None if cfg.PROBLEM.TYPE == 'SUPER_RESOLUTION': args['upsampling_factor'] = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING args['upsampling_position'] = cfg.MODEL.UNET_SR_UPSAMPLE_POSITION args['n_classes'] = cfg.DATA.PATCH_SIZE[-1] else: args['n_classes'] = cfg.MODEL.N_CLASSES if cfg.PROBLEM.TYPE != 'DENOISING' else cfg.DATA.PATCH_SIZE[-1] model = MultiResUnet(**args) elif modelname == 'unetr': args = dict(input_shape=cfg.DATA.PATCH_SIZE, patch_size=cfg.MODEL.VIT_TOKEN_SIZE, embed_dim=cfg.MODEL.VIT_EMBED_DIM, depth=cfg.MODEL.VIT_NUM_LAYERS, num_heads=cfg.MODEL.VIT_NUM_HEADS, mlp_ratio=cfg.MODEL.VIT_MLP_RATIO, num_filters=cfg.MODEL.UNETR_VIT_NUM_FILTERS, n_classes=cfg.MODEL.N_CLASSES, decoder_activation=cfg.MODEL.UNETR_DEC_ACTIVATION, ViT_hidd_mult=cfg.MODEL.UNETR_VIT_HIDD_MULT, batch_norm=cfg.MODEL.BATCH_NORMALIZATION, dropout=cfg.MODEL.DROPOUT_VALUES[0], k_size=cfg.MODEL.UNETR_DEC_KERNEL_SIZE) args['output_channels'] = cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS if cfg.PROBLEM.TYPE == 'INSTANCE_SEG' else None model = UNETR(**args) elif modelname == 'edsr': model = EDSR(ndim=ndim, num_filters=64, num_of_residual_blocks=16, upsampling_factor=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'rcan': scale = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING if type(scale) is tuple: scale = scale[0] model = rcan(ndim=ndim, filters=16, scale=scale, n_sub_block=int(np.log2(scale)), num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'dfcan': model = DFCAN(ndim=ndim, input_shape=cfg.DATA.PATCH_SIZE, scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, n_ResGroup = 4, n_RCAB = 4) elif modelname == 'wdsr': model = wdsr(scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_filters=32, num_res_blocks=8, res_block_expansion=6, num_channels=cfg.DATA.PATCH_SIZE[-1]) elif modelname == 'mae': model = MaskedAutoencoderViT( img_size=cfg.DATA.PATCH_SIZE[0], patch_size=cfg.MODEL.VIT_TOKEN_SIZE, in_chans=cfg.DATA.PATCH_SIZE[-1], ndim=ndim, norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_dim=cfg.MODEL.VIT_EMBED_DIM, depth=cfg.MODEL.VIT_NUM_LAYERS, num_heads=cfg.MODEL.VIT_NUM_HEADS, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=cfg.MODEL.VIT_MLP_RATIO, masking_type=cfg.MODEL.MAE_MASK_TYPE, mask_ratio=cfg.MODEL.MAE_MASK_RATIO, device=device) # Check the network created model.to(device) if cfg.PROBLEM.NDIM == '2D': sample_size = (1,cfg.DATA.PATCH_SIZE[2], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1]) else: sample_size = (1,cfg.DATA.PATCH_SIZE[3], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], cfg.DATA.PATCH_SIZE[2]) summary(model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device="cpu" if "cuda" not in device.type else "cuda") return model
[docs]def build_torchvision_model(cfg, device): # Find model in TorchVision if 'quantized_' in cfg.MODEL.TORCHVISION_MODEL_NAME: mdl = importlib.import_module('torchvision.models.quantization', cfg.MODEL.TORCHVISION_MODEL_NAME) w_prefix = "_quantizedweights" tc_model_name = cfg.MODEL.TORCHVISION_MODEL_NAME.replace('quantized_','') mdl_weigths = importlib.import_module('torchvision.models', cfg.MODEL.TORCHVISION_MODEL_NAME) else: w_prefix = "_weights" tc_model_name = cfg.MODEL.TORCHVISION_MODEL_NAME if cfg.PROBLEM.TYPE == 'CLASSIFICATION': mdl = importlib.import_module('torchvision.models', cfg.MODEL.TORCHVISION_MODEL_NAME) elif cfg.PROBLEM.TYPE == 'SEMANTIC_SEG': mdl = importlib.import_module('torchvision.models.segmentation', cfg.MODEL.TORCHVISION_MODEL_NAME) elif cfg.PROBLEM.TYPE in ['INSTANCE_SEG', 'DETECTION']: mdl = importlib.import_module('torchvision.models.detection', cfg.MODEL.TORCHVISION_MODEL_NAME) mdl_weigths = mdl # Import model and weights names = [x for x in mdl.__dict__ if not x.startswith("_")] for weight_name in names: if tc_model_name+w_prefix in weight_name.lower(): break weight_name = weight_name.replace('Quantized','') print(f"Pytorch model selected: {tc_model_name} (weights: {weight_name})") globals().update( { tc_model_name: getattr(mdl, tc_model_name), weight_name: getattr(mdl_weigths, weight_name) }) # Load model and weights model_torchvision_weights = eval(weight_name).DEFAULT args = {} model = eval(tc_model_name)(weights=model_torchvision_weights) # Create new head sample_size = None out_classes = cfg.MODEL.N_CLASSES if cfg.MODEL.N_CLASSES > 2 else 1 if cfg.PROBLEM.TYPE == 'CLASSIFICATION': if cfg.MODEL.N_CLASSES != 1000: # 1000 classes are the ones by default in ImageNet, which are the weights loaded by default print(f"WARNING: Model's head changed from 1000 to {out_classes} so a finetunning is required to have good results") if cfg.MODEL.TORCHVISION_MODEL_NAME in ['squeezenet1_0', 'squeezenet1_1']: head = torch.nn.Conv2d(model.classifier[1].in_channels, out_classes, kernel_size=1, stride=1) model.classifier[1] = head else: if hasattr(model, 'fc'): layer = "fc" elif hasattr(model, 'classifier'): layer = 'classifier' else: layer = "head" if isinstance(getattr(model, layer), list) or isinstance(getattr(model, layer), torch.nn.modules.container.Sequential): head = torch.nn.Linear(getattr(model, layer)[-1].in_features, out_classes, bias=True) getattr(model, layer)[-1] = head else: head = torch.nn.Linear(getattr(model, layer).in_features, out_classes, bias=True) setattr(model, layer, head) # Fix sample input shape as required by some models if cfg.MODEL.TORCHVISION_MODEL_NAME in ['maxvit_t']: sample_size = (1, 3, 224, 224) elif cfg.PROBLEM.TYPE == 'SEMANTIC_SEG': head = torch.nn.Conv2d(model.classifier[-1].in_channels, out_classes, kernel_size=1, stride=1) model.classifier[-1] = head head = torch.nn.Conv2d(model.aux_classifier[-1].in_channels, out_classes, kernel_size=1, stride=1) model.aux_classifier[-1] = head elif cfg.PROBLEM.TYPE == 'INSTANCE_SEG': # MaskRCNN if cfg.MODEL.N_CLASSES != 91: # 91 classes are the ones by default in MaskRCNN cls_score = torch.nn.Linear(in_features=1024, out_features=out_classes, bias=True) model.roi_heads.box_predictor.cls_score = cls_score mask_fcn_logits = torch.nn.Conv2d(model.roi_heads.mask_predictor.mask_fcn_logits.in_channels, out_classes, kernel_size=1, stride=1) model.roi_heads.mask_predictor.mask_fcn_logits = mask_fcn_logits print(f"Model's head changed from 91 to {out_classes} so a finetunning is required") # Check the network created model.to(device) if sample_size is None: if cfg.PROBLEM.NDIM == '2D': sample_size = (1,cfg.DATA.PATCH_SIZE[2], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1]) else: sample_size = (1,cfg.DATA.PATCH_SIZE[3], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], cfg.DATA.PATCH_SIZE[2]) summary(model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device="cpu" if "cuda" not in device.type else "cuda") return model, model_torchvision_weights.transforms()