Source code for biapy.models

"""
This package (`biapy.models`) is responsible for building and managing deep learning models within the BiaPy framework.

It provides functionalities to:

1.  **Dynamically build models**: Select and instantiate various neural network architectures
    (e.g., U-Net, ResUNet, ViT, ConvNeXt variants, etc.) based on configuration settings.
2.  **Integrate with BioImage Model Zoo (BMZ)**: Facilitate the loading and compatibility
    checking of pre-trained models from the BioImage Model Zoo, enabling easy reuse
    of community-contributed models.
3.  **Extract model source code**: Collect the necessary source code for a given model
    and its dependencies, which is crucial for reproducibility and export functionalities.

The module handles different problem types (e.g., semantic segmentation, super-resolution,
classification) and adapts model configurations (e.g., 2D/3D, input/output channels,
normalization, dropout) accordingly.
"""

from importlib import import_module
import os
import math
import re
import torch
import torch.nn as nn
from torchinfo import summary
from typing import Iterable, Optional, Dict, Tuple, List, Callable, Any
from packaging.version import Version
from yacs.config import CfgNode as CN
import ast
from collections import deque, defaultdict
import requests
import sys
import importlib
import yaml

from bioimageio.spec.utils import download
from bioimageio.core.backends.pytorch_backend import load_torch_model
from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4
from bioimageio.spec.model.v0_5 import ModelDescr as ModelDescr_v0_5
from bioimageio.spec import load_description
from bioimageio.spec.model import v0_4, v0_5

[docs] def build_model( cfg: CN, output_channels: List[int], output_channel_info: List[str], head_activations: List[str], device: torch.device ) -> Tuple[nn.Module, str, Dict, set, List[str], Dict, Tuple[int, ...]]: # model, model_file, model_name, args """ Build selected model. Parameters ---------- cfg : YACS CN object Configuration. output_channels : List[int] Number of output channels for each head. output_channel_info : List[str] Information about each output channel. head_activations : List[str] Activation functions for each output head. device : Torch device Using device. Most commonly "cpu" or "cuda" for GPU, but also potentially "mps", "xpu", "xla" or "meta". Returns ------- model : Pytorch model Selected model. """ # Import the model if "efficientnet" in cfg.MODEL.ARCHITECTURE.lower(): modelname = "efficientnet" elif "hrnet" in cfg.MODEL.ARCHITECTURE.lower(): modelname = "hrnet" else: modelname = str(cfg.MODEL.ARCHITECTURE).lower() mdl = import_module("biapy.models." + modelname) model_file = os.path.abspath(mdl.__file__) # type: ignore names = [x for x in mdl.__dict__ if not x.startswith("_")] globals().update({k: getattr(mdl, k) for k in names}) ndim = 3 if cfg.PROBLEM.NDIM == "3D" else 2 network_stride = None # Put again the specific model name if "hrnet" in cfg.MODEL.ARCHITECTURE.lower(): modelname = cfg.MODEL.ARCHITECTURE.lower() # Model building if modelname in [ "unet", "resunet", "resunet++", "seunet", "resunet_se", "attention_unet", "unext_v1", "unext_v2", ]: separated_decoders = False if cfg.PROBLEM.TYPE == "IMAGE_TO_IMAGE" and cfg.PROBLEM.IMAGE_TO_IMAGE.SEPARATED_DECODERS_PER_HEAD: separated_decoders = True elif cfg.PROBLEM.TYPE == "INSTANCE_SEG" and cfg.PROBLEM.INSTANCE_SEG.SEPARATED_DECODERS_PER_HEAD: separated_decoders = True args = dict( image_shape=cfg.DATA.PATCH_SIZE, activation=cfg.MODEL.ACTIVATION.lower(), feature_maps=cfg.MODEL.FEATURE_MAPS, drop_values=cfg.MODEL.DROPOUT_VALUES, normalization=cfg.MODEL.NORMALIZATION, k_size=cfg.MODEL.KERNEL_SIZE, upsample_layer=cfg.MODEL.UPSAMPLE_LAYER, yx_down=cfg.MODEL.YX_DOWN, z_down=cfg.MODEL.Z_DOWN, output_channels=output_channels, output_channel_info=output_channel_info, head_activations=head_activations, explicit_activations=False, contrast=cfg.LOSS.CONTRAST.ENABLE, contrast_proj_dim=cfg.LOSS.CONTRAST.PROJ_DIM, separated_decoders=separated_decoders, isotropy=cfg.MODEL.ISOTROPY, larger_io=cfg.MODEL.LARGER_IO, return_one_tensor=False, ) if modelname == "unet": callable_model = U_Net # type: ignore elif modelname == "resunet": callable_model = ResUNet # type: ignore elif modelname == "resunet++": callable_model = ResUNetPlusPlus # type: ignore elif modelname == "attention_unet": callable_model = Attention_U_Net # type: ignore elif modelname == "seunet": callable_model = SE_U_Net # type: ignore elif modelname == "resunet_se": callable_model = ResUNet_SE # type: ignore elif modelname in ["unext_v1", "unext_v2"]: args["cn_layers"] = cfg.MODEL.CONVNEXT_LAYERS args["stochastic_depth_prob"] = cfg.MODEL.CONVNEXT_SD_PROB args["stem_k_size"] = cfg.MODEL.CONVNEXT_STEM_K_SIZE del args["activation"] # ConvNeXt uses GELU activation by default del args["drop_values"] # ConvNeXt uses DropPath for regularization, not standard dropout del args["normalization"] # ConvNeXt uses LayerNorm, not BatchNorm or GroupNorm del args["k_size"] # ConvNeXt uses 7x7 kernels in the early layers, but this is fixed in the model definition del args["larger_io"] # ConvNeXt does not use larger input/output layers, but this is fixed in the model definition if modelname == "unext_v1": callable_model = U_NeXt_V1 # type: ignore args["layer_scale"] = cfg.MODEL.CONVNEXT_LAYER_SCALE else: callable_model = U_NeXt_V2 # type: ignore if cfg.PROBLEM.TYPE == "SUPER_RESOLUTION": args["upsampling_factor"] = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING args["upsampling_position"] = cfg.MODEL.UNET_SR_UPSAMPLE_POSITION network_stride = [1, 1] if ndim == 3: network_stride = [1] + network_stride model = callable_model(**args) elif "hrnet" in modelname: args = dict( image_shape=cfg.DATA.PATCH_SIZE, normalization=cfg.MODEL.NORMALIZATION, output_channels=output_channels, contrast=cfg.LOSS.CONTRAST.ENABLE, contrast_proj_dim=cfg.LOSS.CONTRAST.PROJ_DIM, head_type=cfg.MODEL.HRNET.HEAD_TYPE, head_activations=head_activations, output_channel_info=output_channel_info, explicit_activations=False, activation=cfg.MODEL.ACTIVATION.lower(), return_one_tensor=False, ) variant = str(cfg.MODEL.HRNET.VARIANT).lower() if variant == "custom": # Pass the full custom configuration exactly as defined in the yaml/config args["cfg"] = cfg.MODEL.HRNET else: # Extract base channels directly from the variant string try: base_channels = int(variant.replace("w", "")) # Remove 'w' prefix if present and convert to int except ValueError: raise ValueError( f"Invalid MODEL.HRNET.VARIANT: '{variant}'. " "Expected 'W18', 'W32', 'W48', 'W64', or 'custom'." ) # Auto-generate standard HRNet topology num_stages = 3 num_modules = [1, 4, 3] num_branches = [2, 3, 4] # Procedurally generate blocks and channels based on the number of branches num_blocks = [[4] * b for b in num_branches] num_channels = [[base_channels * (2**i) for i in range(b)] for b in num_branches] args["cfg"] = { 'Z_DOWN': cfg.MODEL.HRNET.Z_DOWN, 'YX_DOWN': cfg.MODEL.HRNET.YX_DOWN, 'BLOCK_TYPE': cfg.MODEL.HRNET.BLOCK_TYPE, 'NUM_STAGES': num_stages, 'NUM_MODULES': num_modules, 'NUM_BRANCHES': num_branches, 'NUM_BLOCKS': num_blocks, 'NUM_CHANNELS': num_channels, } callable_model = HighResolutionNet # type: ignore model = callable_model(**args) # Calculate YX total stride (e.g., [2, 2, 2] -> 8) yx_schedule = args["cfg"].get("YX_DOWN", [2, 2, 2]) yx_total_stride = math.prod(yx_schedule) network_stride = [yx_total_stride, yx_total_stride] if ndim == 3: z_schedule = args["cfg"].get("Z_DOWN") z_total_stride = math.prod(z_schedule) network_stride = [z_total_stride] + network_stride elif "stunet" in modelname: callable_model = build_stunet # type: ignore args = dict( image_shape=cfg.DATA.PATCH_SIZE, output_channels=output_channels, variant=cfg.MODEL.STUNET.VARIANT, deep_supervision=False, explicit_activations=False, head_activations=head_activations, output_channel_info=output_channel_info, return_one_tensor=False, pretrained=cfg.MODEL.STUNET.PRETRAINED, ) model = build_stunet(**args) # type: ignore else: if modelname == "simple_cnn": args = dict( image_shape=cfg.DATA.PATCH_SIZE, activation=cfg.MODEL.ACTIVATION.lower(), n_classes=cfg.DATA.N_CLASSES, ) model = simple_CNN(**args) # type: ignore callable_model = simple_CNN # type: ignore elif "efficientnet" in modelname: args = dict(efficientnet_name=cfg.MODEL.ARCHITECTURE.lower(), n_classes=cfg.DATA.N_CLASSES) model = efficientnet(**args) # type: ignore callable_model = efficientnet # type: ignore elif modelname == "vit": args = dict( img_size=cfg.DATA.PATCH_SIZE[0], patch_size=cfg.MODEL.VIT_TOKEN_SIZE, in_chans=cfg.DATA.PATCH_SIZE[-1], ndim=ndim, num_classes=cfg.DATA.N_CLASSES, explicit_activations=False, head_activations=head_activations, output_channel_info=output_channel_info, ) if cfg.MODEL.VIT_MODEL == "custom": args2 = dict( embed_dim=cfg.MODEL.VIT_EMBED_DIM, depth=cfg.MODEL.VIT_NUM_LAYERS, num_heads=cfg.MODEL.VIT_NUM_HEADS, mlp_ratio=cfg.MODEL.VIT_MLP_RATIO, drop_rate=cfg.MODEL.DROPOUT_VALUES[0], ) args.update(args2) model = VisionTransformer(**args) # type: ignore callable_model = VisionTransformer # type: ignore else: model = eval(cfg.MODEL.VIT_MODEL)(**args) callable_model = eval(cfg.MODEL.VIT_MODEL) elif modelname == "multiresunet": args = dict( input_channels=cfg.DATA.PATCH_SIZE[-1], ndim=ndim, alpha=1.67, z_down=cfg.MODEL.Z_DOWN, output_channels=output_channels, explicit_activations=False, ) if cfg.PROBLEM.TYPE == "SUPER_RESOLUTION": args["upsampling_factor"] = cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING args["upsampling_position"] = cfg.MODEL.UNET_SR_UPSAMPLE_POSITION model = MultiResUnet(**args) # type: ignore callable_model = MultiResUnet # type: ignore elif modelname == "unetr": args = dict( input_shape=cfg.DATA.PATCH_SIZE, patch_size=cfg.MODEL.VIT_TOKEN_SIZE, embed_dim=cfg.MODEL.VIT_EMBED_DIM, depth=cfg.MODEL.VIT_NUM_LAYERS, num_heads=cfg.MODEL.VIT_NUM_HEADS, mlp_ratio=cfg.MODEL.VIT_MLP_RATIO, num_filters=cfg.MODEL.UNETR_VIT_NUM_FILTERS, output_channels=output_channels, output_channel_info=output_channel_info, head_activations=head_activations, explicit_activations=False, decoder_activation=cfg.MODEL.ACTIVATION.lower(), ViT_hidd_mult=cfg.MODEL.UNETR_VIT_HIDD_MULT, normalization=cfg.MODEL.NORMALIZATION, dropout=cfg.MODEL.DROPOUT_VALUES[0], k_size=cfg.MODEL.KERNEL_SIZE, contrast=cfg.LOSS.CONTRAST.ENABLE, contrast_proj_dim=cfg.LOSS.CONTRAST.PROJ_DIM, return_one_tensor=False, ) model = UNETR(**args) # type: ignore callable_model = UNETR # type: ignore elif modelname == "edsr": args = dict( ndim=ndim, num_filters=64, num_of_residual_blocks=16, upsampling_factor=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_channels=cfg.DATA.PATCH_SIZE[-1], ) model = EDSR(args) # type: ignore callable_model = EDSR # type: ignore elif modelname == "rcan": args = dict( ndim=ndim, filters=cfg.MODEL.RCAN_CONV_FILTERS, scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_rg=cfg.MODEL.RCAN_RG_BLOCK_NUM, num_rcab=cfg.MODEL.RCAN_RCAB_BLOCK_NUM, reduction=cfg.MODEL.RCAN_REDUCTION_RATIO, num_channels=cfg.DATA.PATCH_SIZE[-1], upscaling_layer=cfg.MODEL.RCAN_UPSCALING_LAYER, ) model = rcan(**args) # type: ignore callable_model = rcan # type: ignore elif modelname == "dfcan": args = dict( ndim=ndim, input_shape=cfg.DATA.PATCH_SIZE, scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, n_ResGroup=4, n_RCAB=4, ) model = DFCAN(**args) # type: ignore callable_model = DFCAN # type: ignore elif modelname == "wdsr": args = dict( scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, num_filters=32, num_res_blocks=8, res_block_expansion=6, num_channels=cfg.DATA.PATCH_SIZE[-1], ) model = wdsr(**args) # type: ignore callable_model = wdsr # type: ignore elif modelname == "mae": args = dict( img_size=cfg.DATA.PATCH_SIZE[0], patch_size=cfg.MODEL.VIT_TOKEN_SIZE, in_chans=cfg.DATA.PATCH_SIZE[-1], ndim=ndim, embed_dim=cfg.MODEL.VIT_EMBED_DIM, depth=cfg.MODEL.VIT_NUM_LAYERS, num_heads=cfg.MODEL.VIT_NUM_HEADS, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=cfg.MODEL.VIT_MLP_RATIO, masking_type=cfg.MODEL.MAE_MASK_TYPE, mask_ratio=cfg.MODEL.MAE_MASK_RATIO, return_just_preds=False, device=device.type, ) model = MaskedAutoencoderViT(**args) # type: ignore callable_model = MaskedAutoencoderViT # type: ignore # Check the network created model.to(device) if cfg.PROBLEM.NDIM == "2D": sample_size = ( 1, cfg.DATA.PATCH_SIZE[2], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], ) else: sample_size = ( 1, cfg.DATA.PATCH_SIZE[3], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], cfg.DATA.PATCH_SIZE[2], ) summary( model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device=device.type, ) # Queue for recursive dependency tracing dependency_queue = deque() dependency_queue.append(callable_model) collected_sources, all_import_lines, scanned_files = extract_model(dependency_queue, model_file) all_import_lines = merge_import_lines(all_import_lines) # Special handling for instance segmentation models with sigma outputs if cfg.PROBLEM.TYPE == "INSTANCE_SEG" and "E_sigma" in cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS: init_embedding_output(model, n_sigma=2 if cfg.PROBLEM.NDIM == "2D" else 3) return model, str(callable_model.__name__), collected_sources, all_import_lines, scanned_files, args, network_stride # type: ignore
[docs] def init_embedding_output(model: nn.Module, n_sigma: int = 2): """ Initialize the output layer of the model for embedding. Parameters ---------- model : nn.Module The model whose output layer needs to be initialized. n_sigma : int Number of sigma channels to initialize. """ try: with torch.no_grad(): print("Initialize last layer with size: ", model.last_block.weight.size()) print("*************************") model.last_block.weight[0:n_sigma, :, :, :].fill_(0) model.last_block.bias[0:n_sigma].fill_(0) model.last_block.weight[n_sigma : n_sigma + n_sigma, :, :, :].fill_(0) model.last_block.bias[n_sigma : n_sigma + n_sigma].fill_(1) except: raise ValueError("Could not initialize embedding output layer. Check the model structure.")
[docs] def extract_model(dependency_queue: deque, model_file: str) -> Tuple[Dict[str, str], set, List[str]]: """ Extract the source code of the model and its dependencies, ensuring dependencies are ordered before the definition that uses them. Parameters ---------- dependency_queue : deque Queue of model dependencies to be processed. model_file : str Path to the main model file. Returns ------- collected_sources : dict Dictionary containing the source code of the collected model dependencies, ordered so that dependencies appear before the main model. all_import_lines : set Set of all external import lines found in the model and its dependencies. scanned_files : list List of all files that were scanned for dependencies. """ visited_files = set() visited_names = set() collected_sources: Dict[str, str] = {} all_import_lines = set() scanned_files = [] queue = [model_file] # {name: source_code} for all class/function definitions name_to_source: Dict[str, str] = {} # === Step 1: Scan all relevant files and build name → source map === while queue: filepath = os.path.abspath(queue.pop()) if filepath in visited_files: continue visited_files.add(filepath) scanned_files.append(filepath) with open(filepath, "r") as f: source_lines = f.readlines() source_text = "".join(source_lines) tree = ast.parse(source_text, filename=filepath) biapy_module_names = [] for node in ast.walk(tree): # Import parsing if isinstance(node, ast.Import): for alias in node.names: mod = alias.name full = f"import {mod}" + (f" as {alias.asname}" if alias.asname else "") if mod.startswith("biapy"): biapy_module_names.append(mod) else: all_import_lines.add(full) elif isinstance(node, ast.ImportFrom): mod = node.module if not mod: continue names = ", ".join( f"{alias.name}" + (f" as {alias.asname}" if alias.asname else "") for alias in node.names ) full = f"from {mod} import {names}" if mod.startswith("biapy"): biapy_module_names.append(mod) else: all_import_lines.add(full) # Extract all top-level classes and functions and map name → source for _node in tree.body: if isinstance(_node, (ast.FunctionDef, ast.ClassDef)): name = _node.name start_line = _node.lineno - 1 # Try to find the end of the block end_line = start_line + 1 indent = len(source_lines[start_line]) - len(source_lines[start_line].lstrip()) while end_line < len(source_lines): line_indent = len(source_lines[end_line]) - len(source_lines[end_line].lstrip()) if source_lines[end_line].strip() and line_indent <= indent: break end_line += 1 name_to_source[name] = "".join(source_lines[start_line:end_line]) # Follow BiaPy module imports (if file-based) visited = set() for name in biapy_module_names: if name in visited: continue visited.add(name) try: from importlib.util import find_spec spec = find_spec(name) if spec and spec.origin and os.path.isfile(spec.origin): queue.append(spec.origin) except Exception as e: print(f"Warning: Failed to resolve {name}: {e}") # === Step 2: Traverse dependency tree and store definitions in discovery order === # We use a list to store the (name, source) tuples in the order they are found (BFS). # This order is: [Dependent_A, Dependency_B, Dependency_C, ...] definition_list: List[Tuple[str, str]] = [] class NameVisitor(ast.NodeVisitor): def __init__(self): self.names = [] def visit_Name(self, node): self.names.append(node.id) self.generic_visit(node) def visit_Attribute(self, node): if isinstance(node.value, ast.Name): self.names.append(node.value.id) self.generic_visit(node) while dependency_queue: obj = dependency_queue.pop() name = obj.__name__ if name in visited_names: continue visited_names.add(name) source = name_to_source.get(name) if not source: print(f"Warning: Source not found for {name}") continue # Add the definition to the temporary list definition_list.append((name, source)) # Find dependencies visitor = NameVisitor() visitor.visit(ast.parse(source)) for dep_name in visitor.names: if dep_name not in visited_names and dep_name in name_to_source: class FakeObject: def __init__(self, __name__): self.__name__ = __name__ dependency_queue.append(FakeObject(dep_name)) # === Step 3: Populate collected_sources in reverse order (Dependencies First) === # By reversing the list, the deepest dependencies (discovered last) are placed # at the start of the dictionary, ensuring they are defined before being used. for name, source in reversed(definition_list): collected_sources[name] = source return collected_sources, sorted(all_import_lines), scanned_files
[docs] def merge_import_lines(import_lines: List[str]) -> List[str]: """ Merge import lines by grouping them by module and sorting names within each module. Parameters ---------- import_lines : list of str List of import lines to be merged. Returns ------- merged : list of str Merged import lines, sorted and grouped by module. """ grouped = defaultdict(set) standalone_imports = set() for line in import_lines: line = line.strip() if line.startswith("import "): # Regular import, keep it as-is standalone_imports.add(line) elif line.startswith("from "): try: parts = line.split(" import ") mod = parts[0][5:].strip() # remove "from " names = parts[1].split(",") for name in names: grouped[mod].add(name.strip()) except Exception as e: print(f"Warning: could not parse import line '{line}': {e}") else: standalone_imports.add(line) merged = [] for mod, names in grouped.items(): sorted_names = sorted(names) merged.append(f"from {mod} import {', '.join(sorted_names)}") merged.extend(sorted(standalone_imports)) return sorted(merged)
[docs] def adapt_bmz_model_kwargs(model_kwargs: Dict, model_to_consume: bool) -> Dict: """ Adapt BMZ model arguments to be compatible with BiaPy's model building functions. Parameters ---------- model_kwargs : dict Dictionary of model arguments to be adapted. model_to_consume : bool Whether the model is being adapted for consumption (True) or for exporting (False). Returns ------- adapted_args : dict Dictionary of adapted arguments ready to be passed to the model building function. """ adapted_args = model_kwargs.copy() if "explicit_activations" in model_kwargs: adapted_args["explicit_activations"] = not model_to_consume if "return_just_preds" in model_kwargs: adapted_args["return_just_preds"] = not model_to_consume if "return_one_tensor" in model_kwargs: adapted_args["return_one_tensor"] = not model_to_consume return adapted_args
[docs] def get_bmz_model_kwargs(model: ModelDescr_v0_4 | ModelDescr_v0_5) -> Dict: """ Get the PyTorch state dict weight specification from a BMZ model description. Parameters ---------- model : ModelDescr_v0_4 | ModelDescr_v0_5 BMZ model description. Returns ------- model_kwargs : dict Dictionary of model arguments extracted from the BMZ model description, ready to be adapted for BiaPy's model building functions. """ assert model.weights.pytorch_state_dict weight_spec = model.weights.pytorch_state_dict if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr): return weight_spec.kwargs elif isinstance(weight_spec, v0_5.PytorchStateDictWeightsDescr): return weight_spec.architecture.kwargs else: raise ValueError("Unsupported weight specification type in BMZ model description.")
[docs] def update_bmz_model_kwargs_to_biapy(new_biapy_model_kwargs: Dict, model: ModelDescr_v0_4 | ModelDescr_v0_5): """ Build a model from Bioimage Model Zoo (BMZ). Parameters ---------- new_biapy_model_kwargs : dict Dictionary of model arguments as expected by BiaPy's model building functions. model : ModelDescr BMZ model RDF that contains all the information of the model. Returns ------- bmz_model_kwargs : dict Updated dictionary of model arguments to be used in the BMZ model RDF, ensuring compatibility with BiaPy. """ bmz_model_kwargs = get_bmz_model_kwargs(model) for k, v in new_biapy_model_kwargs.items(): if k in bmz_model_kwargs: if bmz_model_kwargs[k] != new_biapy_model_kwargs[k]: print(f" Updating BMZ model argument '{k}' from '{bmz_model_kwargs[k]}' to '{new_biapy_model_kwargs[k]}'") bmz_model_kwargs[k] = new_biapy_model_kwargs[k] else: print(f" Adding new argument '{k}' with value '{v}'") bmz_model_kwargs[k] = v return bmz_model_kwargs
[docs] def build_bmz_model(cfg: CN, model: ModelDescr_v0_4 | ModelDescr_v0_5, device: torch.device) -> nn.Module: """ Build a model from Bioimage Model Zoo (BMZ). Parameters ---------- cfg : YACS configuration Running configuration. model : ModelDescr BMZ model RDF that contains all the information of the model. device : Torch device Device used. Returns ------- model_instance : Torch model Torch model. """ assert model.weights.pytorch_state_dict weight_spec = model.weights.pytorch_state_dict if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr): weight_spec.kwargs = adapt_bmz_model_kwargs(weight_spec.kwargs, model_to_consume=True) elif isinstance(weight_spec, v0_5.PytorchStateDictWeightsDescr): weight_spec.architecture.kwargs = adapt_bmz_model_kwargs(weight_spec.architecture.kwargs, model_to_consume=True) model_instance = load_torch_model(weight_spec, load_state=True, devices=[device]) # Check the network created if cfg.PROBLEM.NDIM == "2D": sample_size = ( 1, cfg.DATA.PATCH_SIZE[2], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], ) else: sample_size = ( 1, cfg.DATA.PATCH_SIZE[3], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], cfg.DATA.PATCH_SIZE[2], ) summary( model_instance, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device=device.type, ) return model_instance
[docs] def is_biapy_model(model: ModelDescr_v0_4 | ModelDescr_v0_5) -> bool: """ Check if a model is a BiaPy model by looking for: 1) the presence of "danifranco" in the GitHub username or "Daniel Franco" in the author name. 2) the presence of "biapy" in the model tags. 3) the presence of a citation with the text "BiaPy: accessible deep learning on bioimages" in the citations of the model. Parameters ---------- model : ModelDescr_v0_4 | ModelDescr_v0_5 The model to check. Returns ------- bool True if the model is a BiaPy model, False otherwise. """ try: # Check authors for "danifranco" GitHub username or "Daniel Franco" name for author in model.authors: github_username = author.github_user if hasattr(author, "github_user") else "" author_name = author.name if hasattr(author, "name") else "" if github_username == "danifranco" or author_name in ["Daniel Franco-Barranco", "Daniel Franco Barranco", "Daniel Franco"]: return True tags = model.tags if hasattr(model, "tags") else [""] if "biapy" in tags: return True for cite in model.cite: cite_text = cite.text if hasattr(cite, "text") else "" if "BiaPy: accessible deep learning on bioimages" == cite_text: return True except Exception as e: print(f"Warning: Could not determine if model is a BiaPy model due to error: {e}") return False
[docs] def find_bmz_models( model_ID: Optional[str] = None, url: str = "https://hypha.aicell.io/bioimage-io/artifacts/bioimage.io/children?limit=1000000", timeout: int = 30, ): """ Query the BioImage.IO Hypha API for *models* and return those whose nickname/id/rdf_source contains `model_ID` (case-insensitive). Returns list of dicts with: id, alias, nickname, rdf_source, version, format_version, artifact_path (id used as path), and a few handy urls. Parameters ---------- model_ID : str Model identifier. It can be either its ``DOI`` or ``nickname``. Leave it as None to get all available models. url : str URL to the BioImage.IO Hypha API endpoint to query for models. timeout : int Timeout for the HTTP request in seconds. Returns ------- out : list of dict List of dictionaries containing model information. Each dictionary has the following keys: `id`, `alias`, `nickname`, `rdf_source`, `version`, `format_version`, `artifact_path`, `urls` (which contains `covers` and `documentation` URLs), and `raw` (the original item from the API response). """ q = str(model_ID).lower() if model_ID else None r = requests.get(url, timeout=timeout) r.raise_for_status() items = r.json() if isinstance(items, dict) and "children" in items: items = items["children"] out = [] for it in items or []: if (it or {}).get("type") != "model": continue # Pull common fields defensively from the manifest config cfg = ((it.get("manifest") or {}).get("config")) or {} b = cfg.get("bioimageio") or {} nickname = b.get("nickname") or it.get("alias") rdf_source = b.get("rdf_source") or b.get("source") # some deployments use 'source' version = b.get("version") or cfg.get("version") or it.get("version") format_version = b.get("format_version") or cfg.get("format_version") # Build haystack for matching (old behavior) hay = [ nickname or "", it.get("id") or "", rdf_source or "", ] if q and not any(q in h.lower() for h in hay): continue out.append( { "artifact_path": it.get("id"), # usable as path for other calls "id": it.get("id"), "alias": it.get("alias"), "nickname": nickname, "rdf_source": rdf_source, "version": version, "format_version": format_version, "urls": { "covers": (b.get("thumbnails") or {}), "documentation": b.get("documentation"), }, "raw": it, # keep the original item in case you need more fields } ) return out
[docs] def check_bmz_args( model_ID: str, cfg: CN, ) -> Tuple[List, Dict]: """ Check user's provided BMZ arguments. Parameters ---------- model_ID : str Model identifier. It can be either its ``DOI`` or ``nickname``. cfg : YACS configuration Running configuration. Returns ------- preproc_info: dict Preprocessing names that the model is using. """ # Checking BMZ model compatibility using the available model list provided by BMZ matches = find_bmz_models(model_ID) if len(matches) == 0: raise ValueError(f"No model found with the provided DOI/name: {model_ID}") if len(matches) > 1: raise ValueError(f"More than one model found with the provided DOI/name ({model_ID}). Contact BiaPy team.") model_dict = matches[0] workflow_specs = {} workflow_specs["workflow_type"] = cfg.PROBLEM.TYPE workflow_specs["ndim"] = cfg.PROBLEM.NDIM workflow_specs["nclasses"] = cfg.DATA.N_CLASSES ( preproc_info, error, error_message, opts ) = check_bmz_model_compatibility( model_dict, workflow_specs=workflow_specs, ) if error: raise ValueError(f"Model {model_ID} can not be used in BiaPy. Message:\n{error_message}\n") return preproc_info, opts
[docs] def check_bmz_model_compatibility( model_rdf: Dict, workflow_specs: Optional[Dict] = None, ) -> Tuple[List, bool, str, Dict]: """ Check one model compatibility with BiaPy by looking at its RDF file provided by BMZ. This function is the one used in BMZ's continuous integration with BiaPy. Parameters ---------- model_rdf : dict BMZ model RDF that contains all the information of the model. workflow_specs : dict Specifications of the workflow. If not provided all possible models will be considered. Returns ------- preproc_info: dict Preprocessing names that the model is using. error : bool Whether it there is a problem to consume the model in BiaPy or not. reason_message: str Reason why the model can not be consumed if there is any. """ # --------- helpers --------- def g(d, *ks, default=None): cur = d for k in ks: if isinstance(cur, dict) and k in cur: cur = cur[k] else: return default return cur m = g(model_rdf, "raw", "manifest", default=model_rdf) or model_rdf specific_workflow = "all" if workflow_specs is None else workflow_specs["workflow_type"] specific_dims = "all" if workflow_specs is None else workflow_specs["ndim"] ref_classes = "all" if workflow_specs is None else workflow_specs["nclasses"] preproc_info: List = [] opts = {} # --------- Accept only PyTorch state dict models with a single input --------- weights = g(m, "weights", "pytorch_state_dict") inputs = g(m, "inputs") or [] if not (isinstance(weights, dict) and weights): reason_message = f"[{specific_workflow}] pytorch_state_dict not found in model RDF\n" return preproc_info, True, reason_message, opts if not (isinstance(inputs, list) and len(inputs) == 1): reason_message = f"[{specific_workflow}] Model needs to have a single input.\n" return preproc_info, True, reason_message, opts # Model format version (defaults to 0.5 for your legacy logic) model_version = Version("0.5") fmt = g(m, "format_version") if isinstance(fmt, str): try: model_version = Version(fmt) except Exception: pass # --------- Extract model kwargs --------- model_kwargs = None if "kwargs" in weights: model_kwargs = weights["kwargs"] elif "architecture" in weights and isinstance(weights["architecture"], dict): model_kwargs = weights["architecture"].get("kwargs", None) if model_kwargs is None: return preproc_info, True, f"[{specific_workflow}] Couldn't extract kwargs from model description.\n", opts # --------- Problem type via tags --------- tags = g(m, "tags", default=[]) or [] if (specific_workflow in ["all", "SEMANTIC_SEG"]) and ( "semantic-segmentation" in tags or ("segmentation" in tags and "instance-segmentation" not in tags) ): # classes classes = -1 for k in ("n_classes", "out_channels", "output_channels", "classes"): if k in model_kwargs: classes = model_kwargs[k] break if isinstance(classes, list): classes = classes[-1] if not isinstance(classes, int): reason_message = ( f"[{specific_workflow}] 'DATA.N_CLASSES' not extracted. Obtained {classes}. Please check it!\n" ) return preproc_info, True, reason_message, opts if ( classes == -1 and "architecture" in weights and isinstance(weights["architecture"], dict) and ("callable" in weights["architecture"] or "source" in weights["architecture"]) ): # Check if the model is one of the known architectures and assume it returns 1 class (as is the default in BiaPy) for arch in [weights["architecture"].get("callable", None), weights["architecture"].get("source", None)]: if arch is not None: arch = str(arch).lower().replace(".py", "") if arch in [ "unet", "resunet", "resunet++", "seunet", "attention_unet", "resunet_se", "unetr", "multiresunet", "unext_v1", "unext_v2", "hrnet", ]: classes = 1 if classes != -1: print(f"[BMZ] Detected BiaPy model ({arch}) so assuming 1 as the class output, which is the default in BiaPy") break if isinstance(classes, int) and classes != -1: if ref_classes != "all": if classes > 2 and ref_classes != classes: reason_message = f"[{specific_workflow}] 'DATA.N_CLASSES' does not match network's output classes. Please check it!\n" return preproc_info, True, reason_message, opts else: reason_message = f"[{specific_workflow}] Couldn't find the classes this model is returning so please be aware to match it\n" return preproc_info, True, reason_message, opts opts["DATA.N_CLASSES"] = max(2, classes) elif specific_workflow in ["all", "INSTANCE_SEG"] and "instance-segmentation" in tags: # Assumed it's F + C. This needs a more elaborated process. Still deciding this: # https://github.com/bioimage-io/spec-bioimage-io/issues/621 # Defaults channels = 2 channel_code = ["F", "C"] classes = 2 if "out_channels" in model_kwargs: channels = model_kwargs["out_channels"] elif "output_channels" in model_kwargs: channels = model_kwargs["output_channels"] if "biapy" in tags: if "description" in m and "representation:" in m["description"]: try: representation = m["description"].split("representation:")[1].split("\n")[0].strip().split("+") channel_code = [x.strip() for x in representation] except Exception: print(f"[{specific_workflow}] couldn't extract channel representation from model RDF description: {m['description']}. Setting the default F+C\n") # CartoCell models if ( "cyst" in tags and "3d" in tags and "fluorescence" in tags ): channel_code = ["F", "C", "M"] # Handle separated_class_channel assert isinstance(channels, list) if len(channels) == 2: classes = channels[-1] channels = channels[0] else: # for other models set some defaults if isinstance(channels, list): channels = channels[-1] if channels == 1: channel_code = ["C"] elif channels == 2: channel_code = ["F", "C"] elif channels == 8: channel_code = ["A"] # wild-whale opts["PROBLEM.INSTANCE_SEG.DATA_CHANNELS"] = channel_code opts["PROBLEM.INSTANCE_SEG.DATA_CHANNEL_WEIGHTS"] = (1, 1) opts["PROBLEM.INSTANCE_SEG.DATA_CHANNELS_LOSSES"] = [] if any([x for x in ["F_pre", "F_post", "F_cleft"] if x in channel_code]): opts["PROBLEM.INSTANCE_SEG.TYPE"] = "synapses" else: opts["PROBLEM.INSTANCE_SEG.TYPE"] = "regular" opts["PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS"] = [] opts["PROBLEM.INSTANCE_SEG.WATERSHED.TOPOGRAPHIC_SURFACE_CHANNEL"] = "" opts["PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS"] = [] opts["PROBLEM.INSTANCE_SEG.INSTANCE_CREATION_PROCESS"] = "" opts["PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS"] = [{}] if classes != 2: opts["DATA.N_CLASSES"] = max(2, classes) elif specific_workflow in ["all", "DETECTION"] and "detection" in tags: pass elif specific_workflow in ["all", "DENOISING"] and "denoising" in tags: pass elif specific_workflow in ["all", "SUPER_RESOLUTION"] and ("super-resolution" in tags or "superresolution" in tags): pass elif specific_workflow in ["all", "SELF_SUPERVISED"] and "self-supervision" in tags: pass elif specific_workflow in ["all", "CLASSIFICATION"] and "classification" in tags: pass elif specific_workflow in ["all", "IMAGE_TO_IMAGE"] and any( t in tags for t in ("pix2pix", "image-reconstruction", "image-to-image", "image-restoration") ): pass else: reason_message = f"[{specific_workflow}] no workflow tag recognized in {tags}.\n" return preproc_info, True, reason_message, opts # --------- Axes checks --------- axes_order = g(inputs[0], "axes") input_image_shape = [] # Model version > 5 if isinstance(axes_order, str): input_image_shape = inputs[0].get("shape", {}).get("min", []) elif isinstance(axes_order, list): _axes_order = "" for axis in axes_order: if "type" in axis: if axis["type"] == "batch": _axes_order += "b" input_image_shape += [1] elif axis["type"] == "channel": _axes_order += "c" input_image_shape += [1] elif "id" in axis: if isinstance(axis.get("size"), int): input_image_shape += [axis["size"]] elif isinstance(axis.get("size"), dict) and "min" in axis["size"]: input_image_shape += [axis["size"]["min"]] _axes_order += axis["id"] elif "id" in axis: if axis["id"] == "channel": _axes_order += "c" input_image_shape += [1] else: if isinstance(axis.get("size"), int): input_image_shape += [axis["size"]] elif isinstance(axis.get("size"), dict) and "min" in axis["size"]: input_image_shape += [axis["size"]["min"]] _axes_order += axis["id"] axes_order = _axes_order for x in input_image_shape: if not isinstance(x, int): reason_message = f"[{specific_workflow}] couldn't extract input image shape from model RDF: {input_image_shape}\n" return preproc_info, True, reason_message, opts try: opts["DATA.PATCH_SIZE"] = tuple(input_image_shape[2:] + [input_image_shape[1]]) # (z) y x c except Exception: reason_message = f"[{specific_workflow}] couldn't extract input image shape from model RDF: {input_image_shape}\n" return preproc_info, True, reason_message, opts if specific_dims == "2D": if axes_order != "bcyx": reason_message = f"[{specific_workflow}] In a 2D problem the axes need to be 'bcyx', found {axes_order}\n" return preproc_info, True, reason_message, opts elif "2d" not in tags and "3d" in tags: reason_message = f"[{specific_workflow}] Selected model seems to not be 2D\n" return preproc_info, True, reason_message, opts elif specific_dims == "3D": if axes_order != "bczyx": reason_message = f"[{specific_workflow}] In a 3D problem the axes need to be 'bczyx', found {axes_order}\n" return preproc_info, True, reason_message, opts elif "3d" not in tags and "2d" in tags: reason_message = f"[{specific_workflow}] Selected model seems to not be 3D\n" return preproc_info, True, reason_message, opts else: # "all" if axes_order not in ["bcyx", "bczyx"]: reason_message = ( f"[{specific_workflow}] Accepting models only with ['bcyx', 'bczyx'] axis order, found {axes_order}\n" ) return preproc_info, True, reason_message, opts # --------- Preprocessing --------- if "preprocessing" in (inputs[0] or {}): preproc_info = inputs[0]["preprocessing"] key_to_find = "id" if model_version > Version("0.5.0") else "name" if isinstance(preproc_info, list): # remove ensure_dtype->float casts (BiaPy does it anyway) filtered_preproc_info = [] for preproc in preproc_info: if key_to_find in preproc and not ( preproc[key_to_find] == "ensure_dtype" and "kwargs" in preproc and "dtype" in preproc["kwargs"] and "float" in str(preproc["kwargs"]["dtype"]) ): filtered_preproc_info.append(preproc) for preproc_info in filtered_preproc_info: if key_to_find not in preproc_info: reason_message = ( f"[{specific_workflow}] Not recognized preprocessing structure found: {preproc_info}\n" ) return preproc_info, True, reason_message, opts proc_id = preproc_info[key_to_find] if proc_id not in [ "zero_mean_unit_variance", "fixed_zero_mean_unit_variance", "scale_range", "scale_linear", "clip" ]: reason_message = ( f"[{specific_workflow}] Not recognized preprocessing found: {proc_id}\n" ) return preproc_info, True, reason_message, opts # zero_mean_unit_variance / fixed_zero_mean_unit_variance -> zero_mean_unit_variance(mean,std) if proc_id in ["fixed_zero_mean_unit_variance", "zero_mean_unit_variance"]: if "kwargs" in preproc_info and "mean" in preproc_info["kwargs"]: mean = preproc_info["kwargs"]["mean"] std = preproc_info["kwargs"]["std"] elif "mean" in preproc_info: mean = preproc_info["mean"] std = preproc_info["std"] else: mean, std = -1.0, -1.0 if not isinstance(mean, list): mean = [float(mean)] if not isinstance(std, list): std = [float(std)] opts["DATA.NORMALIZATION.TYPE"] = "zero_mean_unit_variance" opts["DATA.NORMALIZATION.ZERO_MEAN_UNIT_VAR.MEAN_VAL"] = mean opts["DATA.NORMALIZATION.ZERO_MEAN_UNIT_VAR.STD_VAL"] = std # scale_linear ~ div (gain not handled, same as original) elif proc_id == "scale_linear": opts["DATA.NORMALIZATION.TYPE"] = "div" # scale_range -> scale_range (+ optional PERC_CLIP) elif proc_id == "scale_range": opts["DATA.NORMALIZATION.TYPE"] = "scale_range" min_percentile = float(preproc_info["kwargs"].get("min_percentile", 0)) max_percentile = float(preproc_info["kwargs"].get("max_percentile", 100)) # Check if there is percentile clipping if min_percentile != 0 or max_percentile != 100: opts["DATA.NORMALIZATION.PERC_CLIP.ENABLE"] = True opts["DATA.NORMALIZATION.PERC_CLIP.LOWER_PERC"] = min_percentile opts["DATA.NORMALIZATION.PERC_CLIP.UPPER_PERC"] = max_percentile elif proc_id == "clip": opts["DATA.NORMALIZATION.PERC_CLIP.ENABLE"] = True min_percentile = float(preproc_info["kwargs"].get("min_percentile", 0)) max_percentile = float(preproc_info["kwargs"].get("max_percentile", 100)) max_value = float(preproc_info["kwargs"].get("max_value", -1)) min_value = float(preproc_info["kwargs"].get("min_value", -1)) if min_percentile != 0 or max_percentile != 100: opts["DATA.NORMALIZATION.PERC_CLIP.LOWER_PERC"] = min_percentile opts["DATA.NORMALIZATION.PERC_CLIP.UPPER_PERC"] = max_percentile elif min_value != -1 or max_value != -1: opts["DATA.NORMALIZATION.PERC_CLIP.LOWER_VALUE"] = min_value opts["DATA.NORMALIZATION.PERC_CLIP.UPPER_VALUE"] = max_value # --------- Post-processing in kwargs (unsupported) --------- if "postprocessing" in model_kwargs and model_kwargs["postprocessing"] is not None: reason_message = ( f"[{specific_workflow}] Currently no postprocessing is supported. Found: {model_kwargs['postprocessing']}\n" ) return preproc_info, True, reason_message, opts # --------- Dependency checks --------- if "dependencies" in weights and weights["dependencies"] is not None: try: nickname = model_rdf.get("nickname") or model_rdf.get("alias") except Exception: return preproc_info, True, f"[{specific_workflow}] Couldn't extract model nickname from model description for dependency check.\n", opts try: current_model = load_description(nickname) except Exception: return preproc_info, True, f"[{specific_workflow}] Couldn't load model for dependency check.\n", opts ok, msg = True, "" try: deps = current_model.weights.pytorch_state_dict.dependencies if deps is None: # nothing to check ok, msg = True, "" elif hasattr(deps, "get_reader"): # newer spec: dependencies is (or behaves like) a FileDescr yaml_reader = deps.get_reader() ok, msg = can_import_env_deps(yaml_reader) else: # v0.4 spec: deps is a Dependencies object with a .file FileSource_ # (see DependenciesNode.file in v0_4.py) yaml_reader = download(deps.file) # returns a file-like BytesReader ok, msg = can_import_env_deps(yaml_reader) except Exception: return preproc_info, True, f"[{specific_workflow}] Couldn't read dependencies file for dependency check.\n", opts if not ok: return preproc_info, True, f"[{specific_workflow}] Model has incompatible dependencies: {msg}\n", opts # All checks passed return preproc_info, False, "", opts
[docs] def build_torchvision_model(cfg: CN, device: torch.device) -> Tuple[nn.Module, Callable]: """ Build and adapt a model from the `torchvision.models` library based on the configuration. This function dynamically loads a pre-trained model from `torchvision.models` (e.g., ResNet, DeepLabV3, MaskRCNN, etc.) as specified in the configuration. It then adapts the final output layer(s) of the model to match the number of classes or output channels required by the specific problem type (e.g., classification, semantic segmentation, instance segmentation). Parameters ---------- cfg : YACS CN object The configuration object. Key parameters used are: - `cfg.MODEL.TORCHVISION_MODEL_NAME`: Name of the torchvision model to load (e.g., "resnet50", "deeplabv3_resnet101", "maskrcnn_resnet50_fpn", "quantized_resnet50"). - `cfg.PROBLEM.TYPE`: Type of problem (e.g., "CLASSIFICATION", "SEMANTIC_SEG", "INSTANCE_SEG", "DETECTION") to determine model adaptation logic. - `cfg.DATA.N_CLASSES`: Number of output classes required for the problem. - `cfg.DATA.PATCH_SIZE`: Input patch size, used for generating the model summary. - `cfg.PROBLEM.NDIM`: Number of input dimensions ("2D" or "3D"). device : torch.device The PyTorch device (e.g., "cpu", "cuda", "mps") on which the model will be loaded and run. Returns ------- model : nn.Module The instantiated and adapted PyTorch model from torchvision. transforms : Callable A callable representing the default preprocessing transforms associated with the loaded torchvision model's weights. This should be applied to input images before feeding them to the model. Notes ----- - Models are loaded with their `DEFAULT` pre-trained weights from torchvision. - The final layer adaptation logic is specific to common torchvision model structures for classification, semantic segmentation, and instance segmentation. - For classification, the final linear layer is replaced. A warning is printed if the number of classes differs from ImageNet's default (1000). - For semantic segmentation, the final convolutional layer(s) of the classifier and auxiliary classifier (if present) are replaced. A warning is printed if the number of classes differs from Pascal VOC's default (21). - For instance segmentation (MaskRCNN), the box predictor's classification score head and the mask predictor's final convolutional layer are replaced. A warning is printed if the number of classes differs from COCO's default (91). - Special handling is included for `squeezenet` and `lraspp_mobilenet_v3_large` due to their unique head structures. - For `maxvit_t` in classification, a fixed sample input size of (1, 3, 224, 224) is used for the model summary. - This function assumes the necessary `torchvision` models and their default weights are installed and accessible. """ # Find model in TorchVision if "quantized_" in cfg.MODEL.TORCHVISION_MODEL_NAME: mdl = import_module("torchvision.models.quantization", cfg.MODEL.TORCHVISION_MODEL_NAME) w_prefix = "_quantizedweights" tc_model_name = cfg.MODEL.TORCHVISION_MODEL_NAME.replace("quantized_", "") mdl_weigths = import_module("torchvision.models", cfg.MODEL.TORCHVISION_MODEL_NAME) else: w_prefix = "_weights" tc_model_name = cfg.MODEL.TORCHVISION_MODEL_NAME if cfg.PROBLEM.TYPE == "CLASSIFICATION": mdl = import_module("torchvision.models", cfg.MODEL.TORCHVISION_MODEL_NAME) elif cfg.PROBLEM.TYPE == "SEMANTIC_SEG": mdl = import_module("torchvision.models.segmentation", cfg.MODEL.TORCHVISION_MODEL_NAME) elif cfg.PROBLEM.TYPE in ["INSTANCE_SEG", "DETECTION"]: mdl = import_module("torchvision.models.detection", cfg.MODEL.TORCHVISION_MODEL_NAME) mdl_weigths = mdl # Import model and weights names = [x for x in mdl.__dict__ if not x.startswith("_")] for weight_name in names: if tc_model_name + w_prefix in weight_name.lower(): break weight_name = weight_name.replace("Quantized", "") print(f"Pytorch model selected: {tc_model_name} (weights: {weight_name})") globals().update( { tc_model_name: getattr(mdl, tc_model_name), weight_name: getattr(mdl_weigths, weight_name), } ) # Load model and weights model_torchvision_weights = eval(weight_name).DEFAULT model = eval(tc_model_name)(weights=model_torchvision_weights) # Create new head sample_size = None out_classes = cfg.DATA.N_CLASSES if cfg.DATA.N_CLASSES > 2 else 1 if cfg.PROBLEM.TYPE == "CLASSIFICATION": if ( cfg.DATA.N_CLASSES != 1000 ): # 1000 classes are the ones by default in ImageNet, which are the weights loaded by default print( f"WARNING: Model's head changed from 1000 to {out_classes} so a finetunning is required to have good results" ) if cfg.MODEL.TORCHVISION_MODEL_NAME in ["squeezenet1_0", "squeezenet1_1"]: head = torch.nn.Conv2d( model.classifier[1].in_channels, out_classes, kernel_size=1, stride=1, ) model.classifier[1] = head else: if hasattr(model, "fc"): layer = "fc" elif hasattr(model, "classifier"): layer = "classifier" else: layer = "head" if isinstance(getattr(model, layer), list) or isinstance( getattr(model, layer), torch.nn.modules.container.Sequential ): head = torch.nn.Linear(getattr(model, layer)[-1].in_features, out_classes, bias=True) getattr(model, layer)[-1] = head else: head = torch.nn.Linear(getattr(model, layer).in_features, out_classes, bias=True) setattr(model, layer, head) # Fix sample input shape as required by some models if cfg.MODEL.TORCHVISION_MODEL_NAME in ["maxvit_t"]: sample_size = (1, 3, 224, 224) elif cfg.PROBLEM.TYPE == "SEMANTIC_SEG": if cfg.DATA.N_CLASSES != 21: print( f"WARNING: Model's head changed from 21 to {out_classes} so a finetunning is required to have good results" ) if tc_model_name == "lraspp_mobilenet_v3_large": head = torch.nn.Conv2d(model.classifier.low_classifier.in_channels, out_classes, kernel_size=1, stride=1) model.classifier.low_classifier = head head = torch.nn.Conv2d(model.classifier.high_classifier.in_channels, out_classes, kernel_size=1, stride=1) model.classifier.high_classifier = head else: head = torch.nn.Conv2d(model.classifier[-1].in_channels, out_classes, kernel_size=1, stride=1) model.classifier[-1] = head head = torch.nn.Conv2d(model.aux_classifier[-1].in_channels, out_classes, kernel_size=1, stride=1) model.aux_classifier[-1] = head elif cfg.PROBLEM.TYPE == "INSTANCE_SEG": # MaskRCNN if cfg.DATA.N_CLASSES != 91: # 91 classes are the ones by default in MaskRCNN cls_score = torch.nn.Linear(in_features=1024, out_features=out_classes, bias=True) model.roi_heads.box_predictor.cls_score = cls_score mask_fcn_logits = torch.nn.Conv2d( model.roi_heads.mask_predictor.mask_fcn_logits.in_channels, out_classes, kernel_size=1, stride=1, ) model.roi_heads.mask_predictor.mask_fcn_logits = mask_fcn_logits print(f"Model's head changed from 91 to {out_classes} so a finetunning is required") # Check the network created model.to(device) if sample_size is None: if cfg.PROBLEM.NDIM == "2D": sample_size = ( 1, cfg.DATA.PATCH_SIZE[2], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], ) else: sample_size = ( 1, cfg.DATA.PATCH_SIZE[3], cfg.DATA.PATCH_SIZE[0], cfg.DATA.PATCH_SIZE[1], cfg.DATA.PATCH_SIZE[2], ) summary( model, input_size=sample_size, col_names=("input_size", "output_size", "num_params"), depth=10, device=device.type, ) return model, model_torchvision_weights.transforms()
[docs] def can_import_env_deps( yaml_reader, import_overrides={'pyyaml': 'yaml', 'scikit-learn': 'sklearn'}, allowlist: Optional[Iterable[str]] = {"pytorch", "torch", "pytorch-cuda", "pytorch-mutex", "torchvision"}, ) -> Tuple[bool, str]: """ Check if all dependencies listed in a conda-style environment yaml file can be imported. Dependencies whose *distribution name* is in `allowlist` are ignored. Parameters ---------- yaml_reader : file-like object Provides the content of a conda-style environment yaml file. import_overrides : dict, optional Map dist name -> import name (e.g. {'pyyaml': 'yaml'}). allowlist : Iterable[str], optional Dist names to ignore if they fail import/version checks (case-insensitive). Example: {"pytorch", "torch", "pytorch-cuda"} Returns ------- ok : bool msg : str """ import_overrides = {k.lower(): v for k, v in (import_overrides or {}).items()} allow = {a.lower() for a in (allowlist or [])} raw = yaml_reader.read() if isinstance(raw, bytes): text = raw.decode("utf-8", errors="replace") else: text = str(raw) doc = yaml.safe_load(text) or {} deps = doc.get("dependencies", []) if isinstance(doc, dict) else [] failures = [] def normalize_dist(dist: str) -> str: # normalize to compare in allowlist return dist.strip().lower() def is_allowed(dist: str) -> bool: d = normalize_dist(dist) return d in allow def dist_to_import_name(dist: str) -> str: # Most common mapping: "foo-bar" -> "foo_bar" d = dist.lower() return import_overrides.get(d, dist.replace("-", "_")) def try_import(dist: str): if is_allowed(dist): return mod = dist_to_import_name(dist) try: importlib.import_module(mod) except Exception: failures.append(dist) # Check conda-style deps and pip deps for item in deps: if isinstance(item, str): s = item.strip() low = s.lower() if low.startswith("python="): m = re.match(r"python\s*=\s*(\d+)\.(\d+)", low) if m: req_major, req_minor = int(m.group(1)), int(m.group(2)) if (sys.version_info.major, sys.version_info.minor) != (req_major, req_minor): failures.append(f"python={req_major}.{req_minor}") elif low == "pip": continue else: # take dist name before any version/marker extras dist = re.split(r"[<>=!~\[]", s, maxsplit=1)[0].strip() if dist: try_import(dist) elif isinstance(item, dict) and "pip" in item and isinstance(item["pip"], list): for req in item["pip"]: req = str(req).strip() dist = re.split(r"[<>=!~\[]", req, maxsplit=1)[0].strip() if dist: try_import(dist) ok = len(failures) == 0 return ok, ("" if ok else ", ".join(failures))
[docs] def get_last_layer_info(model: nn.Module) -> Dict[str, Any]: """ Recursively finds the last layer of a model and checks if it's an activation. Parameters ---------- model : nn.Module The PyTorch model to analyze. Returns ------- dict A dictionary containing: - "layer_object": The last layer object found in the model. - "layer_type": The type name of the last layer. - "is_activation": A boolean indicating whether the last layer is a common activation function. """ # 1. Recursively find the last child module last_layer = model while list(last_layer.children()): last_layer = list(last_layer.children())[-1] # 2. Define a tuple of common activation types activation_types = ( nn.ReLU, nn.Sigmoid, nn.Softmax, nn.LogSoftmax, nn.Tanh, nn.LeakyReLU, nn.ELU, nn.PReLU, nn.GELU, nn.Softplus, nn.Softsign, nn.Hardtanh ) # 3. Check if the last layer is an instance of these types is_activation = isinstance(last_layer, activation_types) return { "layer_object": last_layer, "layer_type": type(last_layer).__name__, "is_activation": is_activation }