Source code for biapy.engine.base_workflow

"""
Base workflow class for BiaPy.

This module defines the Base_Workflow abstract class, which provides the main
structure and utility methods for building training and inference workflows in BiaPy.
It handles configuration, model preparation, data loading, training, testing,
metrics, logging, and post-processing for both 2D and 3D biomedical image analysis.
"""
import math
import os
import datetime
import time
import json
import torch
import argparse
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
import torch.distributed as dist
from typing import Any, Dict, Optional, List
from numpy.typing import NDArray
from yacs.config import CfgNode as CN
import pandas as pd

import biapy
from bioimageio.spec import load_description

from biapy.models import (
    build_model,
    build_torchvision_model,
    build_bmz_model,
    is_biapy_model,
    get_bmz_model_kwargs,
    check_bmz_args,
    get_last_layer_info,
)
from biapy.models.blocks import get_activation
from biapy.engine import prepare_optimizer, build_callbacks
from biapy.data.generators import (
    create_train_val_augmentors,
    create_test_generator,
    create_chunked_test_generator,
    check_generator_consistence,
    create_chunked_workflow_process_generator,
)
from biapy.utils.misc import (
    get_world_size,
    get_rank,
    is_main_process,
    save_model,
    time_text,
    load_model_checkpoint,
    TensorboardLogger,
    MetricLogger,
    to_pytorch_format,
    to_numpy_format,
    is_dist_avail_and_initialized,
    setup_for_distributed,
    update_dict_with_existing_keys,
)
from biapy.engine.check_configuration import (
    convert_old_model_cfg_to_current_version,
    diff_between_configs,
    check_configuration,
)
from biapy.utils.util import (
    create_plots,
    check_downsample_division,
    get_cfg_key_value,
)
from biapy.engine.train_engine import train_one_epoch, evaluate
from biapy.data.data_2D_manipulation import (
    crop_data_with_overlap,
    merge_data_with_overlap,
)
from biapy.data.data_3D_manipulation import (
    crop_3D_data_with_overlap, 
    merge_3D_data_with_overlap, 
    order_dimensions,
    looks_like_hdf5,
)
from biapy.data.data_manipulation import (
    load_and_prepare_train_data,
    load_and_prepare_test_data,
    read_img_as_ndarray,
    save_tif,
    resize,
)
from biapy.data.pre_processing import resize_images
from biapy.data.post_processing.post_processing import (
    ensemble8_2d_predictions,
    ensemble16_3d_predictions,
    apply_binary_mask,
)
from biapy.data.post_processing import apply_post_processing
from biapy.data.pre_processing import preprocess_data
from biapy.data.norm import normalize_image, normalize_mask
from biapy.data.generators.chunked_test_pair_data_generator import chunked_test_pair_data_generator
from biapy.data.dataset import PatchCoords
from biapy.models.memory_bank import MemoryBank


[docs] class Base_Workflow(metaclass=ABCMeta): """ Base workflow class. A new workflow should extend this class. Parameters ---------- cfg : YACS configuration Running configuration. Job_identifier : str Complete name of the running job. device : Torch device Device used. args : argpase class Arguments used in BiaPy's call. """ def __init__( self, cfg: CN, job_identifier: str, device: torch.device, system_dict: Dict[str, int], args: argparse.Namespace, ): """ Initialize the Base_Workflow object. Sets up configuration, device, job identifier, and initializes all workflow attributes and state variables. Parameters ---------- cfg : CN Running configuration. job_identifier : str Complete name of the running job. device : torch.device Device used. system_dict : dict System dictionary containing: * 'cpu_budget': int, Total CPU budget. * 'cpu_per_rank': int, CPU budget per rank. * 'main_threads': int, Number of main threads. * 'num_workers_hint': int, Hint for the number of workers. args : argparse.Namespace Arguments used in BiaPy's call. """ self.cfg = cfg self.args = args self.job_identifier = job_identifier self.device = device self.system_dict = system_dict if self.cfg.TEST.METRICS_IN_CPU: self.test_device = torch.device("cpu") else: self.test_device = device self.original_test_mask_path = None self.test_mask_filenames = None self.cross_val_samples_ids = None self.post_processing = {} self.post_processing["per_image"] = False self.post_processing["as_3D_stack"] = False self.model = None self.model_build_kwargs = None self.checkpoint_path = None self.optimizer = None self.model_prepared = False self.dtype = np.float32 if not self.cfg.TEST.REDUCE_MEMORY else np.float16 self.dtype_str = "float32" if not self.cfg.TEST.REDUCE_MEMORY else "float16" self.loss_dtype = torch.float32 self.dims = 2 if self.cfg.PROBLEM.NDIM == "2D" else 3 self.apply_activations = True self.use_gt = False if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: self.use_gt = True # Save paths in case we need them in a future self.orig_train_path = self.cfg.DATA.TRAIN.PATH self.orig_train_mask_path = self.cfg.DATA.TRAIN.GT_PATH self.orig_val_path = self.cfg.DATA.VAL.PATH self.orig_val_mask_path = self.cfg.DATA.VAL.GT_PATH self.all_pred = [] self.all_gt = [] self.stats = {} # Per crop self.stats["per_crop"] = {} self.stats["patch_by_batch_counter"] = 0 # Merging the image self.stats["merge_patches"] = {} self.stats["merge_patches_post"] = {} # As 3D stack self.stats["as_3D_stack"] = {} self.stats["as_3D_stack_post"] = {} # Full image self.stats["full_image"] = {} self.stats["full_image_post"] = {} # To store all the metrics for each test file in order to create a final csv file with the results self.metrics_per_test_file = [] self.mask_path = "" self.is_y_mask = False self.model_output_channels = [] self.model_output_channel_info = [] self.head_activations = [] self.separated_class_channel = None self.train_metrics = [] self.train_metric_best = [] self.train_metric_names = [] self.test_metrics = [] self.test_metric_best = [] self.test_metric_names = [] self.loss = None self.memory_bank = None self.gt_channels_expected = -1 self.train_metrics_message = "" self.test_metrics_message = "" self.resolution: List[int | float] = list(self.cfg.DATA.TEST.RESOLUTION) if self.cfg.PROBLEM.NDIM == "2D": self.resolution = [ 1, ] + self.resolution self.world_size = get_world_size() self.global_rank = get_rank() # Test variables if self.cfg.TEST.POST_PROCESSING.MEDIAN_FILTER: if self.cfg.PROBLEM.NDIM == "2D": if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK: self.post_processing["as_3D_stack"] = True else: self.post_processing["per_image"] = True else: self.post_processing["per_image"] = True # Define permute shapes to pass from Numpy axis order (Y,X,C) to Pytorch's (C,Y,X) self.axes_order = (0, 3, 1, 2) if self.cfg.PROBLEM.NDIM == "2D" else (0, 4, 1, 2, 3) self.axes_order_back = (0, 2, 3, 1) if self.cfg.PROBLEM.NDIM == "2D" else (0, 2, 3, 4, 1) # Tochvision variables self.torchvision_preprocessing = None # Load pretrained model configuration if needed and check consistency with current configuration self.bmz_config = {} if self.cfg.MODEL.SOURCE == "biapy": # Obtain model spec from checkpoint if self.cfg.MODEL.LOAD_CHECKPOINT: # Take cfg from the checkpoint saved_cfg, biapy_ckpt_version = load_model_checkpoint( cfg=self.cfg, jobname=self.job_identifier, model_without_ddp=None, device=self.device, just_extract_checkpoint_info=True, skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS, ) assert isinstance(saved_cfg, CN), "There was an error loading the checkpoint configuration. The loaded configuration is not a YACS CfgNode object but of type {}. Check that the checkpoint file is not corrupted.".format(type(saved_cfg)) if saved_cfg: if len(self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT) > 0: print("Checkpoint file loaded. Extracting the following items (if available): {} . Checking consistency with current configuration . . .".format(", ".join(self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT))) # Checks that this config and previous represent the same workflow header_message = "There is an inconsistency between the configuration loaded from checkpoint and the actual one. Error:\n" tmp_cfg = convert_old_model_cfg_to_current_version(saved_cfg.clone()) # Override model specs if self.cfg.PROBLEM.PRINT_OLD_KEY_CHANGES: print("The following changes were made in order to adapt the loaded input configuration from checkpoint into the current configuration version:") diff_between_configs(saved_cfg, tmp_cfg) if "weights" in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT and "norm" not in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT: print("WARNING: Weights will be loaded from checkpoint but not normalization instructions. This can lead to inconsistent results if the normalization instructions in the current configuration are different from the ones used in the checkpoint. Consider adding 'norm' to 'MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT' to avoid this issue.") if "norm" in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT: print("Normalization instructions will be loaded from checkpoint.") update_dict_with_existing_keys(self.cfg["DATA"]["NORMALIZATION"], tmp_cfg["DATA"]["NORMALIZATION"]) if "model_arch" in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT: print("Model architecture will be loaded from checkpoint.") # Save current model config tmp_BMZ_config = self.cfg.MODEL.clone() update_dict_with_existing_keys(self.cfg["MODEL"], tmp_cfg["MODEL"]) # Restore some model config self.cfg["MODEL"]["BMZ"] = tmp_BMZ_config["BMZ"] self.cfg["MODEL"]["OUT_CHECKPOINT_FORMAT"] = tmp_BMZ_config["OUT_CHECKPOINT_FORMAT"] # Check if the merge is coherent self.cfg["MODEL"]["LOAD_CHECKPOINT"] = True self.cfg["MODEL"]["LOAD_MODEL_FROM_CHECKPOINT"] = False check_configuration(self.cfg, self.job_identifier) # Load BioImage Model Zoo pretrained model information elif self.cfg.MODEL.SOURCE == "bmz": self.bmz_config["preprocessing"], opts = check_bmz_args(self.cfg.MODEL.BMZ.SOURCE_MODEL_ID, self.cfg) print("[BMZ] Overriding preprocessing steps to the ones fixed in BMZ model: {}".format(self.bmz_config["preprocessing"])) # Adapt configuration to match the one defined in the RDF option_list = [] for key, val in opts.items(): old_val = get_cfg_key_value(cfg, key) change = False # Not changing patch size if only the channel dimension is different if "DATA.PATCH_SIZE" in key: if old_val[:-1] != val[:-1]: change = True elif old_val != val: change = True if change: print(f"[BMZ] Changed '{key}' from '{old_val}' to '{val}' as defined in the RDF") option_list.append(key) option_list.append(val) print("Loading BioImage Model Zoo pretrained model . . .") self.bmz_config["original_bmz_config"] = load_description(self.cfg.MODEL.BMZ.SOURCE_MODEL_ID) self.cfg.merge_from_list(option_list) # Check consistency of the resulting configuration after merging with the BMZ model configuration check_configuration(self.cfg, self.job_identifier) # Save number of channels to be created by the model self.define_activations_and_channels() # Define metrics self.define_metrics() # Normalization checks print("Creating normalization module . . .") self.norm_module = { "type": cfg.DATA.NORMALIZATION.TYPE, "mask_norm": "as_mask", "out_dtype": "float32", "percentile_clip": cfg.DATA.NORMALIZATION.PERC_CLIP.ENABLE, "per_lower_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_PERC, "per_upper_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_PERC, "lower_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_VALUE, "upper_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_VALUE, "mean": cfg.DATA.NORMALIZATION.ZERO_MEAN_UNIT_VAR.MEAN_VAL, "std": cfg.DATA.NORMALIZATION.ZERO_MEAN_UNIT_VAR.STD_VAL, } print("Normalization module created with the following configuration:") for key, val in self.norm_module.items(): print(f" {key}: {val}") self.test_norm_module = self.norm_module.copy() self.test_norm_module["train_normalization"] = False self.test_norm_module["out_dtype"] = "float32" if not cfg.TEST.REDUCE_MEMORY else "float16" if self.cfg.MODEL.SOURCE == "torchvision": print("Creating normalization module . . .") self.torchvision_norm = { "type": "scale_range", "mask_norm": "as_mask", "out_dtype": "float32" if not cfg.TEST.REDUCE_MEMORY else "float16", "percentile_clip": cfg.DATA.NORMALIZATION.PERC_CLIP.ENABLE, "per_lower_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_PERC, "per_upper_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_PERC, "lower_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_VALUE, "upper_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_VALUE, } print("Torchvision normalization module created with the following configuration:") for key, val in self.torchvision_norm.items(): print(f" {key}: {val}") # Chunked workflow process generator placeholder self.test_chunked_workflow_process_vars = { "out_dir": self.cfg.PATHS.RESULT_DIR.PER_IMAGE, "dtype_str": self.dtype_str, }
[docs] def define_activations_and_channels(self): """ Define the activations to be applied to the model output and the channels that the model will output. This function must define the following variables: self.model_output_channels : List of int Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels, [1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc. self.model_output_channel_info : List of str Information about the output channels. A value per output head of the model must be defined. self.separated_class_channel : bool Whether if we should expect a separated output channel for classification. self.head_activations : List of str Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined. "linear" and "ce_sigmoid" will not be applied. E.g. ["linear"] for a model with one channel, ["linear", "sigmoid"] for a model with two channels, etc. Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would be correct:: self.model_output_channels = [1, 3] self.model_output_channel_info = ["mask", "class"] self.separated_class_channel = True self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"] """ if not self.model_output_channels: raise ValueError( "'model_output_channels' needs to be defined. Correct define_activations_and_channels() function" ) if not isinstance(self.model_output_channels, list): raise ValueError("'self.model_output_channels' must be a list") for x in self.model_output_channels: if not isinstance(x, int): raise ValueError("'self.model_output_channels' must be a list of integers") if self.model_output_channel_info is None: raise ValueError("'model_output_channel_info' needs to be defined. Correct define_activations_and_channels() function") if not isinstance(self.model_output_channel_info, list): raise ValueError("'self.model_output_channel_info' must be a list") for x in self.model_output_channel_info: if not isinstance(x, str): raise ValueError("'self.model_output_channel_info' must be a list of strings") if self.separated_class_channel is None: raise ValueError("'separated_class_channel' needs to be defined. Correct define_activations_and_channels() function") if not self.head_activations: raise ValueError("'self.head_activations' needs to be defined. Correct define_activations_and_channels() function") if not isinstance(self.head_activations, list): raise ValueError("'self.head_activations' must be a list of strings") for x in self.head_activations: if not isinstance(x, str): raise ValueError("'self.head_activations' must be a list of strings") head_number = sum(self.model_output_channels) assert len(self.head_activations) == head_number, "Activations and output channels do not match. " "{} activations vs {} output channels".format(len(self.head_activations), head_number) assert len(self.model_output_channels) == len(self.model_output_channel_info), "Output channel info and output channels do not match. " "{} output channel info vs {} output channels".format(len(self.model_output_channel_info), len(self.model_output_channels)) if self.gt_channels_expected == -1: raise ValueError( "'gt_channels_expected' needs to be defined. Correct define_activations_and_channels() function" )
[docs] def define_metrics(self): """ Define the metrics to be calculated during training and test. This function must define the following variables: self.train_metrics : List of functions Metrics to be calculated during model's training. self.train_metric_names : List of str Names of the metrics calculated during training. self.train_metric_best : List of str To know which value should be considered as the best one. Options must be: "max" or "min". self.test_metrics : List of functions Metrics to be calculated during model's test/inference. self.test_metric_names : List of str Names of the metrics calculated during test/inference. self.loss : Function Loss function used during training and test. """ if not self.train_metrics: raise ValueError("'train_metrics' needs to be defined. Correct define_metrics() function") if not self.train_metric_names: raise ValueError("'train_metric_names' needs to be defined. Correct define_metrics() function") if not self.train_metric_best: raise ValueError("'train_metric_best' needs to be defined. Correct define_metrics() function") else: assert all( [True if x in ["max", "min"] else False for x in self.train_metric_best] ), "'train_metric_best' needs to be one between ['max', 'min']" if not self.test_metrics: raise ValueError("'test_metrics' needs to be defined. Correct define_metrics() function") if not self.test_metric_names: raise ValueError("'test_metric_names' needs to be defined. Correct define_metrics() function") if self.loss is None: raise ValueError("'loss' needs to be defined. Correct define_metrics() function")
[docs] @abstractmethod def metric_calculation( self, output: NDArray | torch.Tensor, targets: NDArray | torch.Tensor, train: bool = True, metric_logger: Optional[MetricLogger] = None, ) -> Dict: """ Execute the calculation of metrics defined in :func:`~define_metrics` function. Parameters ---------- output : Torch Tensor Prediction of the model. targets : Torch Tensor Ground truth to compare the prediction with. train : bool, optional Whether to calculate train or test metrics. metric_logger : MetricLogger, optional Class to be updated with the new metric(s) value(s) calculated. Returns ------- value : float Value of the metric for the given prediction. """ raise NotImplementedError
[docs] def prepare_targets(self, targets, batch): """ Location to perform any necessary data transformations to ``targets`` before calculating the loss. Parameters ---------- targets : Torch Tensor Ground truth to compare the prediction with. batch : Torch Tensor Prediction of the model. Only used in SSL workflow. Returns ------- targets : Torch tensor Resulting targets. """ # We do not use 'batch' input but in SSL workflow return to_pytorch_format(targets, self.axes_order, self.device, dtype=targets.dtype)
[docs] def load_train_data(self): """Load training and validation data.""" print("##########################") print("# LOAD TRAINING DATA #") print("##########################") train_zarr_data_information = { "raw_path": self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA_RAW_PATH, "gt_path": self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA_GT_PATH, "use_gt_path": self.cfg.PROBLEM.TYPE != "INSTANCE_SEG", "multiple_data_within_zarr": self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA, "input_img_axes": self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER, "input_mask_axes": self.cfg.DATA.TRAIN.INPUT_MASK_AXES_ORDER, } val_zarr_data_information = { "raw_path": self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA_RAW_PATH, "gt_path": self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA_GT_PATH, "use_gt_path": self.cfg.PROBLEM.TYPE != "INSTANCE_SEG", "multiple_data_within_zarr": self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA, "input_img_axes": self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER, "input_mask_axes": self.cfg.DATA.VAL.INPUT_MASK_AXES_ORDER, } ( self.X_train, self.Y_train, self.X_val, self.Y_val, ) = load_and_prepare_train_data( train_path=self.cfg.DATA.TRAIN.PATH, train_mask_path=self.mask_path, train_in_memory=self.cfg.DATA.TRAIN.IN_MEMORY, train_ov=self.cfg.DATA.TRAIN.OVERLAP, train_padding=self.cfg.DATA.TRAIN.PADDING, val_path=self.cfg.DATA.VAL.PATH, val_mask_path=self.cfg.DATA.VAL.GT_PATH, val_in_memory=self.cfg.DATA.VAL.IN_MEMORY, val_ov=self.cfg.DATA.VAL.OVERLAP, val_padding=self.cfg.DATA.VAL.PADDING, norm_module=self.norm_module, crop_shape=self.cfg.DATA.PATCH_SIZE, cross_val=self.cfg.DATA.VAL.CROSS_VAL, cross_val_nsplits=self.cfg.DATA.VAL.CROSS_VAL_NFOLD, cross_val_fold=self.cfg.DATA.VAL.CROSS_VAL_FOLD, val_split=self.cfg.DATA.VAL.SPLIT_TRAIN if self.cfg.DATA.VAL.FROM_TRAIN else 0.0, seed=self.cfg.SYSTEM.SEED, shuffle_val=self.cfg.DATA.VAL.RANDOM, train_preprocess_f=preprocess_data if self.cfg.DATA.PREPROCESS.TRAIN else None, train_preprocess_cfg=self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.TRAIN else None, train_filter_props=( self.cfg.DATA.TRAIN.FILTER_SAMPLES.PROPS if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else [] ), train_filter_vals=( self.cfg.DATA.TRAIN.FILTER_SAMPLES.VALUES if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else [] ), train_filter_signs=( self.cfg.DATA.TRAIN.FILTER_SAMPLES.SIGNS if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else [] ), val_preprocess_f=preprocess_data if self.cfg.DATA.PREPROCESS.VAL else None, val_preprocess_cfg=self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.VAL else None, val_filter_props=( self.cfg.DATA.VAL.FILTER_SAMPLES.PROPS if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else [] ), val_filter_vals=( self.cfg.DATA.VAL.FILTER_SAMPLES.VALUES if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else [] ), val_filter_signs=( self.cfg.DATA.VAL.FILTER_SAMPLES.SIGNS if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else [] ), filter_by_entire_image=self.cfg.DATA.FILTER_BY_IMAGE, norm_before_filter=self.cfg.DATA.TRAIN.FILTER_SAMPLES.NORM_BEFORE, y_upscaling=self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, gt_channels_expected=self.gt_channels_expected, reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE, convert_to_rgb=self.cfg.DATA.FORCE_RGB, is_y_mask=self.is_y_mask, is_3d=(self.cfg.PROBLEM.NDIM == "3D"), train_zarr_data_information=train_zarr_data_information, val_zarr_data_information=val_zarr_data_information, multiple_raw_images=( self.cfg.PROBLEM.TYPE == "IMAGE_TO_IMAGE" and self.cfg.PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER ), save_filtered_images=self.cfg.DATA.SAVE_FILTERED_IMAGES, save_filtered_images_dir=self.cfg.PATHS.FIL_SAMPLES_DIR, save_filtered_images_num=self.cfg.DATA.SAVE_FILTERED_IMAGES_NUM, ) # Ensure all the processes have read the data if is_dist_avail_and_initialized(): print("Waiting until all processes have read the data . . .") dist.barrier()
[docs] def destroy_train_data(self): """Delete training variables to release memory.""" print("Releasing memory . . .") if "X_train" in locals() or "X_train" in globals(): del self.X_train if "Y_train" in locals() or "Y_train" in globals(): del self.Y_train if "X_val" in locals() or "X_val" in globals(): del self.X_val if "Y_val" in locals() or "Y_val" in globals(): del self.Y_val if "train_generator" in locals() or "train_generator" in globals(): del self.train_generator if "val_generator" in locals() or "val_generator" in globals(): del self.val_generator
[docs] def prepare_train_generators(self): """Build training and validation generators.""" if self.cfg.TRAIN.ENABLE: print("##############################") print("# PREPARE TRAIN GENERATORS #") print("##############################") ( self.train_generator, self.val_generator, self.num_training_steps_per_epoch, self.bmz_config["test_input"], self.bmz_config["cover_raw"], self.bmz_config["cover_gt"], ) = create_train_val_augmentors( self.cfg, system_dict=self.system_dict, X_train=self.X_train, X_val=self.X_val, Y_train=self.Y_train, Y_val=self.Y_val, norm_module=self.norm_module, ) if self.cfg.DATA.CHECK_GENERATORS and self.cfg.PROBLEM.TYPE != "CLASSIFICATION": check_generator_consistence( self.train_generator, self.cfg.PATHS.GEN_CHECKS + "_train", self.cfg.PATHS.GEN_MASK_CHECKS + "_train", ) check_generator_consistence( self.val_generator, self.cfg.PATHS.GEN_CHECKS + "_val", self.cfg.PATHS.GEN_MASK_CHECKS + "_val", )
[docs] def bmz_model_call(self, in_img, is_train=False): """ Call BioImage Model Zoo model. Parameters ---------- in_img : torch.Tensor Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. Returns ------- prediction : torch.Tensor Image prediction. """ assert self.model prediction = self.model(in_img) return prediction
[docs] @abstractmethod def torchvision_model_call(self, in_img: torch.Tensor, is_train=False) -> torch.Tensor: """ Call a regular Pytorch model. Parameters ---------- in_img : torch.Tensor Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. Returns ------- prediction : torch.Tensor Image prediction. """ raise NotImplementedError
[docs] def model_call_func( self, in_img: NDArray | torch.Tensor, is_train: bool = False, apply_act: bool = True ) -> Any: """ Call a regular Pytorch model. Parameters ---------- in_img : torch.Tensor Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. apply_act : bool, optional Whether to apply activations or not. Returns ------- prediction : torch.Tensor Image prediction. """ in_img = to_pytorch_format(in_img, self.axes_order, self.device) assert isinstance(in_img, torch.Tensor) if self.cfg.MODEL.SOURCE == "biapy": assert self.model pred = self.model(in_img) # Recover the original shape of the input, as not all the model return a prediction # of the same size as the input image if ( not (self.cfg.LOSS.CONTRAST.ENABLE and is_train) and not (self.cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and self.cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK.lower() == "masking") and self.cfg.PROBLEM.TYPE not in ["CLASSIFICATION", "SUPER_RESOLUTION"] ): if isinstance(pred, dict): if "pred" in pred and pred["pred"].shape[2:] != in_img.shape[2:]: mode = "bilinear" if self.cfg.PROBLEM.NDIM == "2D" else "trilinear" sshape = (in_img.shape[0],) + (pred["pred"].shape[1],) + in_img.shape[2:] pred["pred"] = resize(pred["pred"], sshape, mode=mode) if "class" in pred and pred["class"].shape[2:] != in_img.shape[2:]: sshape = (in_img.shape[0],) + (pred["class"].shape[1],) + in_img.shape[2:] pred["class"] = resize(pred["class"], sshape, mode="nearest") elif not isinstance(pred, list): if pred.shape[2:] != in_img.shape[2:]: mode = "bilinear" if self.cfg.PROBLEM.NDIM == "2D" else "trilinear" sshape = (in_img.shape[0],) + (pred.shape[1],) + in_img.shape[2:] pred = resize(pred, sshape, mode=mode) # Allow multiple outputs if isinstance(pred, list): for i in range(len(pred)): if apply_act: pred[i] = self.apply_model_activations(pred[i], training=is_train) else: if apply_act: pred = self.apply_model_activations(pred, training=is_train) # type: ignore elif self.cfg.MODEL.SOURCE == "bmz": pred = self.apply_model_activations(self.bmz_model_call(in_img, is_train), training=is_train) elif self.cfg.MODEL.SOURCE == "torchvision": pred = self.torchvision_model_call(in_img, is_train) if not is_train: if isinstance(pred, dict): for key in pred: if torch.is_tensor(pred[key]): pred[key] = pred[key].to(self.test_device) elif torch.is_tensor(pred): pred = pred.to(self.test_device) return pred
[docs] def prepare_model(self): """Build the model.""" if self.model_prepared: print("Model already prepared!") return print("###############") print("# Build model #") print("###############") if self.cfg.MODEL.SOURCE == "biapy": ( self.model, self.bmz_config["callable_model"], self.bmz_config["collected_sources"], self.bmz_config["all_import_lines"], self.bmz_config["scanned_files"], self.model_build_kwargs, self.network_stride, ) = build_model( self.cfg, self.model_output_channels, self.model_output_channel_info, self.head_activations, self.device ) self.bmz_config["is_biapy_model"] = True elif self.cfg.MODEL.SOURCE == "torchvision": self.model, self.torchvision_preprocessing = build_torchvision_model(self.cfg, self.device) self.bmz_config["is_biapy_model"] = False # BioImage Model Zoo pretrained models elif self.cfg.MODEL.SOURCE == "bmz": self.model = build_bmz_model(self.cfg, self.bmz_config["original_bmz_config"], self.device) # For old BMZ models uploaded in BMZ we need to explicitly insert the sigmoid activation as postprocessing self.bmz_config["is_biapy_model"] = is_biapy_model(self.bmz_config["original_bmz_config"]) bmz_model_kwargs = get_bmz_model_kwargs(self.bmz_config["original_bmz_config"]) if self.bmz_config["is_biapy_model"] and "explicit_activation" not in bmz_model_kwargs: self.bmz_config["postprocessing"] = [] if self.head_activations[0] in ["ce_sigmoid", "Sigmoid"]: self.bmz_config["postprocessing"].append("sigmoid") layer_info = get_last_layer_info(self.model) if layer_info['is_activation']: print("[BMZ] Disabling manual activations after the model as the last layer of the model is already an activation function ({}).".format(layer_info['layer_type'])) self.apply_activations = False self.model_without_ddp = self.model if self.args.distributed: self.model = torch.nn.parallel.DistributedDataParallel( self.model, device_ids=[self.args.gpu], find_unused_parameters=False, ) self.model_without_ddp = self.model.module self.model_prepared = True # Load checkpoint if necessary if self.cfg.MODEL.SOURCE == "biapy" and self.cfg.MODEL.LOAD_CHECKPOINT: self.start_epoch, self.checkpoint_path = load_model_checkpoint( cfg=self.cfg, jobname=self.job_identifier, model_without_ddp=self.model_without_ddp, device=self.device, optimizer=self.optimizer, skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS, ) else: self.start_epoch = 0
[docs] def prepare_logging_tool(self): """Prepare looging tool.""" print("#######################") print("# Prepare logging tool #") print("#######################") # To start the logging now = datetime.datetime.now() now = now.strftime("%Y_%m_%d_%H_%M_%S") self.log_file = os.path.join( self.cfg.LOG.LOG_DIR, self.cfg.LOG.LOG_FILE_PREFIX + "_log_" + str(now) + ".txt", ) if self.global_rank == 0: os.makedirs(self.cfg.LOG.LOG_DIR, exist_ok=True) os.makedirs(self.cfg.PATHS.CHECKPOINT, exist_ok=True) self.log_writer = TensorboardLogger(log_dir=self.cfg.LOG.TENSORBOARD_LOG_DIR) else: self.log_writer = None self.plot_values = {} self.plot_values["loss"] = [] self.plot_values["val_loss"] = [] for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]] = [] self.plot_values["val_" + self.train_metric_names[i]] = []
[docs] def train(self): """Training phase.""" self.load_train_data() if not self.model_prepared: self.prepare_model() self.prepare_train_generators() self.prepare_logging_tool() self.early_stopping = build_callbacks(self.cfg) assert ( self.start_epoch is not None and self.model is not None and self.model_without_ddp is not None and self.loss ) assert isinstance(self.start_epoch, int), "'start_epoch' should be an integer" self.optimizer, self.lr_scheduler = prepare_optimizer( self.cfg, self.model_without_ddp, len(self.train_generator) ) contrast_init_iter = 0 if self.cfg.LOSS.CONTRAST.ENABLE: self.memory_bank = MemoryBank( num_classes=self.gt_channels_expected, memory_size = self.cfg.LOSS.CONTRAST.MEMORY_SIZE, feature_dims = self.cfg.LOSS.CONTRAST.PROJ_DIM, network_stride = self.network_stride, pixel_update_freq=self.cfg.LOSS.CONTRAST.PIXEL_UPD_FREQ, device = self.device, ignore_index = self.cfg.LOSS.IGNORE_INDEX, ) self.memory_bank.to(self.device) # When to activate the contrastive loss contrast_init_iter = self.cfg.LOSS.CONTRAST.MEMORY_SIZE if self.cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": contrast_init_iter += self.cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS print("#####################") print("# TRAIN THE MODEL #") print("#####################") print(f"Start training in epoch {self.start_epoch+1} - Total: {self.cfg.TRAIN.EPOCHS}") start_time = time.time() self.val_best_metric = np.zeros(len(self.train_metric_names), dtype=np.float32) self.val_best_loss = np.inf total_iters = 0 for epoch in range(self.start_epoch, self.cfg.TRAIN.EPOCHS): print("~~~ Epoch {}/{} ~~~\n".format(epoch + 1, self.cfg.TRAIN.EPOCHS)) e_start = time.time() if self.args.distributed: self.train_generator.sampler.set_epoch(epoch) # type: ignore if self.log_writer: self.log_writer.set_step(epoch * self.num_training_steps_per_epoch) # Train train_stats, iterations_done = train_one_epoch( self.cfg, model=self.model, model_call_func=self.model_call_func, loss_function=self.loss, metric_function=self.metric_calculation, prepare_targets=self.prepare_targets, data_loader=self.train_generator, optimizer=self.optimizer, device=self.device, epoch=epoch, log_writer=self.log_writer, lr_scheduler=self.lr_scheduler, verbose=self.cfg.TRAIN.VERBOSE, memory_bank=self.memory_bank, total_iters=total_iters, contrast_warmup_iters=contrast_init_iter, ) total_iters += iterations_done # Save checkpoint if self.cfg.MODEL.SAVE_CKPT_FREQ != -1: if ( (epoch + 1) % self.cfg.MODEL.SAVE_CKPT_FREQ == 0 or epoch + 1 == self.cfg.TRAIN.EPOCHS and is_main_process() ): save_model( cfg=self.cfg, biapy_version=biapy.__version__, jobname=self.job_identifier, model_without_ddp=self.model_without_ddp, optimizer=self.optimizer, epoch=epoch + 1, model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, ) # Validation if self.val_generator: test_stats = evaluate( self.cfg, model=self.model, model_call_func=self.model_call_func, loss_function=self.loss, metric_function=self.metric_calculation, prepare_targets=self.prepare_targets, epoch=epoch, data_loader=self.val_generator, lr_scheduler=self.lr_scheduler, memory_bank=self.memory_bank, ) # Save checkpoint is val loss improved if test_stats["loss"] < self.val_best_loss: f = os.path.join( self.cfg.PATHS.CHECKPOINT, "{}-checkpoint-best.pth".format(self.job_identifier), ) print( "Val loss improved from {} to {}, saving model to {}".format( self.val_best_loss, test_stats["loss"], f ) ) m = " " for i in range(len(self.val_best_metric)): self.val_best_metric[i] = test_stats[self.train_metric_names[i]] m += f"{self.train_metric_names[i]}: {self.val_best_metric[i]:.4f} " self.val_best_loss = test_stats["loss"] if is_main_process(): self.checkpoint_path = save_model( cfg=self.cfg, biapy_version=biapy.__version__, jobname=self.job_identifier, model_without_ddp=self.model_without_ddp, optimizer=self.optimizer, epoch="best", model_build_kwargs=self.model_build_kwargs, extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT, ) print(f"[Val] best loss: {self.val_best_loss:.4f} best " + m) # Store validation stats if self.log_writer: self.log_writer.update(test_loss=test_stats["loss"], head="perf", step=epoch) for i in range(len(self.train_metric_names)): self.log_writer.update( test_iou=test_stats[self.train_metric_names[i]], head="perf", step=epoch, ) log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, **{f"test_{k}": v for k, v in test_stats.items()}, "epoch": epoch, } else: log_stats = { **{f"train_{k}": v for k, v in train_stats.items()}, "epoch": epoch, } # Write statistics in the logging file if is_main_process(): # Log epoch stats if self.log_writer: self.log_writer.flush() with open(self.log_file, mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") # Create training plot self.plot_values["loss"].append(train_stats["loss"]) if self.val_generator: self.plot_values["val_loss"].append(test_stats["loss"]) for i in range(len(self.train_metric_names)): self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]]) if self.val_generator: self.plot_values["val_" + self.train_metric_names[i]].append( test_stats[self.train_metric_names[i]] ) if (epoch + 1) % self.cfg.LOG.CHART_CREATION_FREQ == 0: create_plots( self.plot_values, self.train_metric_names, self.job_identifier, self.cfg.PATHS.CHARTS, ) if self.val_generator and self.early_stopping: self.early_stopping(test_stats["loss"]) if self.early_stopping.early_stop: print("Early stopping") break e_end = time.time() t_epoch = e_end - e_start print( "[Time] {} {}/{}\n".format( time_text(t_epoch), time_text(e_end - start_time), time_text((e_end - start_time) + (t_epoch * (self.cfg.TRAIN.EPOCHS - epoch))), ) ) total_time = time.time() - start_time self.total_training_time_str = str(datetime.timedelta(seconds=int(total_time))) print("Training time: {}".format(self.total_training_time_str)) self.train_metrics_message += ("Train loss: {}\n".format(train_stats["loss"])) for i in range(len(self.train_metric_names)): self.train_metrics_message += ("Train {}: {}\n".format(self.train_metric_names[i], train_stats[self.train_metric_names[i]])) if self.val_generator: self.train_metrics_message += ("Validation loss: {}\n".format(self.val_best_loss)) for i in range(len(self.train_metric_names)): self.train_metrics_message += ("Validation {}: {}\n".format(self.train_metric_names[i], self.val_best_metric[i])) if self.train_metrics_message != "": for line in self.train_metrics_message.split("\n"): print(line) print("Finished Training") if is_dist_avail_and_initialized(): print(f"[Rank {get_rank()} ({os.getpid()})] Process waiting (train finished, step 1) . . . ") dist.barrier() # Save output sample to export the model to BMZ if "test_output" not in self.bmz_config: assert self.model_without_ddp self.model_without_ddp.eval() # Load best checkpoint on validation to ensure it _ = load_model_checkpoint( cfg=self.cfg, jobname=self.job_identifier, model_without_ddp=self.model_without_ddp, device=self.device, skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS, ) # Save BMZ input/output so the user could export the model to BMZ later self.prepare_bmz_data(self.bmz_config["test_input"]) self.destroy_train_data()
[docs] def load_test_data(self): """Load test data.""" print("######################") print("# LOAD TEST DATA #") print("######################") self.X_test, self.Y_test = None, None if self.cfg.DATA.TEST.USE_VAL_AS_TEST: print("Loading train data information to extract the validation to be used as test") self.cfg.merge_from_list(["DATA.TRAIN.IN_MEMORY", False, "DATA.VAL.IN_MEMORY", False]) self.load_train_data() self.X_test = self.X_val.copy() if self.Y_val: self.Y_test = self.Y_val.copy() else: # Paths to the raw and gt within the Zarr file. Only used when 'DATA.TEST.INPUT_ZARR_MULTIPLE_DATA' is True. test_zarr_data_information = None if self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA: use_gt_path = True if self.cfg.PROBLEM.TYPE != "INSTANCE_SEG" and self.cfg.PROBLEM.INSTANCE_SEG.TYPE != "synapses": use_gt_path = False test_zarr_data_information = { "raw_path": self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_RAW_PATH, "gt_path": self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_GT_PATH, "use_gt_path": use_gt_path, } ( self.X_test, self.Y_test, self.test_filenames, ) = load_and_prepare_test_data( test_path=self.cfg.DATA.TEST.PATH, test_mask_path=self.cfg.DATA.TEST.GT_PATH if self.use_gt else None, multiple_raw_images=( self.cfg.PROBLEM.TYPE == "IMAGE_TO_IMAGE" and self.cfg.PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER ), test_zarr_data_information=test_zarr_data_information, )
[docs] def destroy_test_data(self): """Delete test variable to release memory.""" print("Releasing memory . . .") if "X_test" in locals() or "X_test" in globals(): del self.X_test if "Y_test" in locals() or "Y_test" in globals(): del self.Y_test if "test_generator" in locals() or "test_generator" in globals(): del self.test_generator if "current_sample" in locals() or "current_sample" in globals(): del self.current_sample
[docs] def prepare_test_generators(self): """Prepare test data generator.""" if self.cfg.TEST.ENABLE: print("############################") print("# PREPARE TEST GENERATOR #") print("############################") ( self.test_generator, test_input, cover_raw, cover_gt, ) = create_test_generator( cfg=self.cfg, X_test=self.X_test, Y_test=self.Y_test, norm_module=self.test_norm_module, ) # Save BMZ data if not available if "cover_raw" not in self.bmz_config or "cover_gt" not in self.bmz_config: self.bmz_config["cover_raw"] = cover_raw self.bmz_config["cover_gt"] = cover_gt if "test_input" not in self.bmz_config: self.bmz_config["test_input"] = test_input
[docs] def apply_model_activations(self, pred: torch.Tensor | Dict, training=False) -> torch.Tensor | Dict: """ Apply the last activation (if any) to the model's output. Parameters ---------- pred : Torch Tensor Predictions of the model. training : bool, optional To advice the function if this is being applied during training of inference. During training, ``ce_sigmoid`` activations will NOT be applied, as ``torch.nn.BCEWithLogitsLoss`` will apply ``Sigmoid`` automatically in a way that is more stable numerically (`ref <https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html>`_). Returns ------- pred : Torch tensor Resulting predictions after applying last activation(s). """ # Do not apply any activation when using masking as pretext task if ( not self.apply_activations or (self.cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and self.cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK.lower() == "masking") ): return pred # 1. Expand channel info to map 1-to-1 with every channel all_channel_info = [] for i, c_info in enumerate(self.model_output_channel_info): for _ in range(self.model_output_channels[i]): # For semantic segmentation and classification problems, we consider that all the channels are "class" channels if self.cfg.PROBLEM.TYPE in ["SEMANTIC_SEG", "CLASSIFICATION"]: c_info = "class" all_channel_info.append(c_info) def __apply_acts(tensor, acts, c_infos): if len(acts) == 0: return tensor out_slices = [] # Find the exact index where "class" channels begin class_start_idx = len(c_infos) for i, info in enumerate(c_infos): if "class" in info.lower(): class_start_idx = i break # --- PART A: Process standard channels (1-by-1) --- for i in range(min(class_start_idx, tensor.shape[1])): chunk = tensor[:, i:i+1, ...] act_str = acts[i].lower() # Skip if linear, or if training and it's handled by the loss function if act_str == "linear" or (training and act_str in ["ce_sigmoid", "ce_softmax"]): out_slices.append(chunk) else: clean_act = "sigmoid" if act_str == "ce_sigmoid" else act_str act_fn = get_activation(clean_act) out_slices.append(act_fn(chunk)) # --- PART B: Process the entire "class" block (all at once) --- if class_start_idx < tensor.shape[1]: class_chunk = tensor[:, class_start_idx:, ...] act_str = acts[class_start_idx].lower() # Skip if linear, or if training and it's handled by the loss function if act_str == "linear" or (training and act_str in ["ce_sigmoid", "ce_softmax"]): out_slices.append(class_chunk) else: # Apply a single activation to the whole multidimensional block clean_act = "softmax" if "softmax" in act_str else ("sigmoid" if act_str == "ce_sigmoid" else act_str) act_fn = get_activation(clean_act) out_slices.append(act_fn(class_chunk)) return torch.cat(out_slices, dim=1) if out_slices else tensor # 2. Apply to inputs (handling both Dict and Tensor formats) if isinstance(pred, dict): pred_len = pred["pred"].shape[1] pred["pred"] = __apply_acts(pred["pred"], self.head_activations[:pred_len], all_channel_info[:pred_len]) if "class" in pred: class_len = pred["class"].shape[1] pred["class"] = __apply_acts(pred["class"], self.head_activations[-class_len:], all_channel_info[-class_len:]) else: pred = __apply_acts(pred, self.head_activations, all_channel_info) return pred
[docs] @torch.no_grad() def test(self): """Test/Inference step.""" self.load_test_data() if not self.model_prepared: self.prepare_model() self.prepare_test_generators() # Switch to evaluation mode assert self.model_without_ddp self.model_without_ddp.eval() # When not training was done if "test_output" not in self.bmz_config: # Save BMZ input/output so the user could export the model to BMZ later self.prepare_bmz_data(self.bmz_config["test_input"]) # Load best checkpoint on validation if self.cfg.TRAIN.ENABLE: self.start_epoch, self.checkpoint_path = load_model_checkpoint( cfg=self.cfg, jobname=self.job_identifier, model_without_ddp=self.model_without_ddp, device=self.device, skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS, ) # Check possible checkpoint problems if self.start_epoch == -1: raise ValueError("There was a problem loading the checkpoint. Test phase aborted!") image_counter = 0 print("###############") print("# INFERENCE #") print("###############") print("Making predictions on test data . . .") # Reactivate prints to see each rank progress if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == "3D": setup_for_distributed(True) # Process all the images for i, self.current_sample in enumerate(self.test_generator): # type: ignore self.current_sample_metrics = {"file": self.current_sample["X_filename"]} self.f_numbers = [i] if "Y" not in self.current_sample: self.current_sample["Y"] = None # Decide whether to infer by chunks or not discarded = False _, file_extension = os.path.splitext(self.current_sample["X_filename"]) if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == "3D": by_chunks = True if not looks_like_hdf5(self.current_sample["X_filename"]) and file_extension not in [".zarr", ".n5"]: print( "WARNING: You are not using an image format that can extract patches without loading it entirely into memory. " "The image formats that support this are: '.hdf5', '.hdf', '.h5', '.zarr' and '.n5'. " ) else: by_chunks = False if by_chunks: me = "[Rank {} ({})] Processing image (by chunks): {}".format( get_rank(), os.getpid(), self.current_sample["X_filename"] ) if "Y_filename" in self.current_sample: me += " (GT: {})".format(self.current_sample["Y_filename"]) print(me) self.process_test_sample_by_chunks() else: if is_main_process(): if self.cfg.PROBLEM.TYPE != "CLASSIFICATION": me = "Processing image: {}".format(self.current_sample["X_filename"]) if "Y_filename" in self.current_sample: me += " (GT: {})".format(self.current_sample["Y_filename"]) print(me) print("Normalization used: {}".format(self.current_sample["X_norm"])) discarded = self.process_test_sample() # If process_test_sample() returns True means that the sample was skipped due to filter set # up with: DATA.TEST.FILTER_SAMPLES if discarded: print(" Skipping image: {}".format(self.current_sample["X_filename"])) else: image_counter += 1 self.metrics_per_test_file.append(self.current_sample_metrics) # Only enable print for the main rank again if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == "3D": setup_for_distributed(is_main_process()) self.destroy_test_data() if is_main_process(): self.after_all_images() print("#############") print("# RESULTS #") print("#############") print("The values below represent the averages across all test samples") if self.cfg.TRAIN.ENABLE: print("Epoch number: {}".format(len(self.plot_values["val_loss"]))) print("Train time (s): {}".format(self.total_training_time_str)) print("Train loss: {}".format(np.min(self.plot_values["loss"]))) for i in range(len(self.train_metric_names)): metric_name = ( "Foreground IoU" if self.train_metric_names[i] == "IoU" else self.train_metric_names[i] ) print( "Train {}: {}".format( metric_name, ( np.max(self.plot_values[self.train_metric_names[i]]) if self.train_metric_best[i] == "max" else np.min(self.plot_values[self.train_metric_names[i]]) ), ) ) print("Validation loss: {}".format(self.val_best_loss)) for i in range(len(self.train_metric_names)): metric_name = ( "Foreground IoU" if self.train_metric_names[i] == "IoU" else self.train_metric_names[i] ) print( "Validation {}: {}".format( metric_name, self.val_best_metric[i], ) ) self.print_stats(image_counter)
[docs] def predict_batches_in_test( self, x_batch: NDArray, y_batch: Optional[NDArray], stats_name="per_crop", disable_tqdm: bool = False ) -> NDArray: """ Predict data for the test phase. Parameters ---------- x_batch : NDArray X data. Expected axes are: ``(num_patches, z, y, x, channels)`` for 3D and ``(num_patches, y, x, channels)`` for 2D. y_batch: NDArray Y data. Expected axes are: ``(num_patches, z, y, x, channels)`` for 3D and ``(num_patches, y, x, channels)`` for 2D. stats_name : str, optional Name of the statistics to save. disable_tqdm : bool, optional Whether to disable tqdm or not. Returns ------- pred : NDArray Predicted batch. """ if self.cfg.TEST.AUGMENTATION: for k in tqdm(range(x_batch.shape[0]), leave=False, disable=disable_tqdm): if self.cfg.PROBLEM.NDIM == "2D": p = ensemble8_2d_predictions( x_batch[k], axes_order_back=self.axes_order_back, axes_order=self.axes_order, device=self.test_device, pred_func=self.model_call_func, mode=self.cfg.TEST.AUGMENTATION_MODE, ) else: p = ensemble16_3d_predictions( x_batch[k], batch_size_value=self.cfg.TRAIN.BATCH_SIZE, axes_order_back=self.axes_order_back, axes_order=self.axes_order, device=self.test_device, pred_func=self.model_call_func, mode=self.cfg.TEST.AUGMENTATION_MODE, ) # Multi-head concatenation if isinstance(p, dict): if "class" in p: p = torch.cat((p["pred"], p["class"]), dim=1) else: p = p["pred"] # Calculate the metrics if y_batch is not None: metric_values = self.metric_calculation(output=p, targets=np.expand_dims(y_batch[k],0), train=False) for metric in metric_values: if str(metric).lower() not in self.stats[stats_name]: self.stats[stats_name][str(metric).lower()] = 0 self.stats[stats_name][str(metric).lower()] += metric_values[metric] self.stats["patch_by_batch_counter"] += 1 p = to_numpy_format(p, self.axes_order_back) if "pred" not in locals(): pred = np.zeros((x_batch.shape[0],) + p.shape[1:], dtype=self.dtype) pred[k] = p else: l = int(math.ceil(x_batch.shape[0] / self.cfg.TRAIN.BATCH_SIZE)) for k in tqdm(range(l), leave=False, disable=disable_tqdm): top = ( (k + 1) * self.cfg.TRAIN.BATCH_SIZE if (k + 1) * self.cfg.TRAIN.BATCH_SIZE < x_batch.shape[0] else x_batch.shape[0] ) p = self.model_call_func(x_batch[k * self.cfg.TRAIN.BATCH_SIZE : top]) # Multi-head concatenation if isinstance(p, dict): if "class" in p: p = torch.cat((p["pred"], p["class"]), dim=1) else: p = p["pred"] # Calculate the metrics if y_batch is not None: metric_values = self.metric_calculation( output=p, targets=y_batch[k * self.cfg.TRAIN.BATCH_SIZE : top], train=False, ) for metric in metric_values: if str(metric).lower() not in self.stats[stats_name]: self.stats[stats_name][str(metric).lower()] = 0 self.stats[stats_name][str(metric).lower()] += metric_values[metric] self.stats["patch_by_batch_counter"] += 1 p = to_numpy_format(p, self.axes_order_back) if "pred" not in locals(): pred = np.zeros((x_batch.shape[0],) + p.shape[1:], dtype=self.dtype) pred[k * self.cfg.TRAIN.BATCH_SIZE : top] = p return pred
[docs] def prepare_bmz_data(self, img): """ Prepare required data for exporting a model into BMZ. Parameters ---------- img : 4D/5D Numpy array Image to save (unnormalized). The axes must be in Torch format already, i.e. ``(b,c,y,x)`` for 2D or ``(b,c,z,y,x)`` for 3D. """ def _prepare_bmz_sample(sample_key, img, apply_norm=True): """ Prepare a sample from the given ``img`` using the patch size in the configuration. It also saves the sample in ``self.bmz_config`` using the ``sample_key``. Parameters ---------- sample_key : str Key to store the sample into. Must be one between: ``["test_input", "test_output"]`` img : 4D/5D Numpy array Image to extract the sample from. The axes must be in Torch format already, i.e. ``(b,c,y,x)`` for 2D or ``(b,c,z,y,x)`` for 3D. """ img = img.astype(np.float32) if len(img.shape) == 2: # Classification self.bmz_config[sample_key] = img.copy() else: self.bmz_config[sample_key] = img[0].copy() # Ensure dimensions if self.cfg.PROBLEM.NDIM == "2D": if self.bmz_config[sample_key].ndim == 3: self.bmz_config[sample_key] = np.expand_dims(self.bmz_config[sample_key], 0) else: # 3D if self.bmz_config[sample_key].ndim == 4: self.bmz_config[sample_key] = np.expand_dims(self.bmz_config[sample_key], 0) # Apply normalization if apply_norm: if self.cfg.PROBLEM.NDIM == "2D": # Transpose to (b,y,x,c) for normalization and back to (b,c,y,x) after self.bmz_config[sample_key], bmz_norm_used = normalize_image(self.bmz_config[sample_key].transpose(0, 2, 3, 1), self.norm_module) self.bmz_config[sample_key] = self.bmz_config[sample_key].astype(np.float32).transpose(0, 3, 1, 2) else: # Transpose to (b,z,y,x,c) for normalization and back to (b,c,z,y,x) after self.bmz_config[sample_key], bmz_norm_used = normalize_image(self.bmz_config[sample_key].transpose(0, 2, 3, 4, 1), self.norm_module) self.bmz_config[sample_key] = self.bmz_config[sample_key].astype(np.float32).transpose(0, 4, 1, 2, 3) print("Normalization used in when creating BMZ data: {}".format(bmz_norm_used)) # Save test_input without the normalization if not already saved if "test_input" not in self.bmz_config: _prepare_bmz_sample("test_input", img, apply_norm=False) # Save test_input with the normalization if "test_input_norm" not in self.bmz_config: _prepare_bmz_sample("test_input_norm", img) # Model prediction assert self.model and self.model_without_ddp assert isinstance(self.bmz_config["test_input_norm"], np.ndarray) pred = self.model(torch.from_numpy(self.bmz_config["test_input_norm"]).to(self.device)) # MAE if isinstance(pred, dict) and "mask" in pred: pred = self.apply_model_activations(pred) assert isinstance(pred, dict), "The model output should be a dictionary containing 'pred' and 'mask' for the MAE pretext task." assert "pred" in pred and "mask" in pred, "The model output should contain 'pred' and 'mask' for the MAE pretext task." mask = pred["mask"] pred = pred["pred"] pred, p_mask, _ = self.model_without_ddp.save_images( torch.from_numpy(self.bmz_config["test_input_norm"]).to(self.device), pred, mask, self.dtype, ) # type: ignore # We call MAE again with "return_just_preds" that sets a seed in the random masking to ensure that the same mask is applied here # and in the BMZ exported model. If not BMZ check will crash reproducing the output pred = self.model(torch.from_numpy(self.bmz_config["test_input_norm"]).to(self.device), return_just_preds=True) else: pred = self.apply_model_activations(pred) # Multi-head concatenation if isinstance(pred, dict): if "class" in pred: pred = torch.cat((pred["pred"], torch.argmax(pred["class"], dim=1).unsqueeze(1)), dim=1) else: pred = pred["pred"] # Save output _prepare_bmz_sample("test_output", pred.clone().cpu().detach().numpy().astype(np.float32), apply_norm=False) if "cover_gt" not in self.bmz_config or ("cover_gt" in self.bmz_config and self.bmz_config["cover_gt"] is None): if self.bmz_config["test_output"].ndim == 1: # Classification self.bmz_config["cover_gt"] = self.bmz_config["test_output"].copy() elif self.cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and self.cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK.lower() == "masking": self.bmz_config["cover_gt"] = p_mask[0].copy() else: self.bmz_config["cover_gt"] = self.bmz_config["test_output"].copy().transpose(0, *range(2, self.bmz_config["test_output"].ndim), 1)[0] if self.cfg.DATA.N_CLASSES > 2 and self.cfg.PROBLEM.TYPE == "SEMANTIC_SEG": self.bmz_config["cover_gt"] = np.expand_dims(np.argmax(self.bmz_config["cover_gt"], -1), -1)
[docs] def process_test_sample(self): """Process a sample in the inference phase.""" # Skip processing image if "discard" in self.current_sample and self.current_sample["discard"]: return True ################# ### PER PATCH ### ################# if not self.cfg.TEST.FULL_IMG or self.cfg.PROBLEM.NDIM == "3D": if not self.cfg.TEST.REUSE_PREDICTIONS: original_data_shape = self.current_sample["X"].shape # Crop if necessary if self.current_sample["X"].shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]: # Copy X to be used later in full image if self.cfg.PROBLEM.NDIM != "3D": X_original = self.current_sample["X"].copy() if ( self.current_sample["Y"] is not None and self.current_sample["X"].shape[:-1] != self.current_sample["Y"].shape[:-1] ): raise ValueError( "Image {} ({}) and mask {} ({}) differ in shape (without considering the channels, i.e. last dimension). " "Please check the images.".format( self.current_sample["X"].shape, self.current_sample['X_filename'], self.current_sample["Y"].shape, self.current_sample['Y_filename'] ) ) if self.cfg.PROBLEM.NDIM == "2D": obj = crop_data_with_overlap( self.current_sample["X"], self.cfg.DATA.PATCH_SIZE, data_mask=self.current_sample["Y"], overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, ) if self.current_sample["Y"] is not None: self.current_sample["X"], self.current_sample["Y"], _ = obj # type: ignore else: self.current_sample["X"], _ = obj # type: ignore del obj else: if self.cfg.TEST.REDUCE_MEMORY: self.current_sample["X"], _ = crop_3D_data_with_overlap( # type: ignore self.current_sample["X"][0], self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING, ) if self.current_sample["Y"] is not None: self.current_sample["Y"], _ = crop_3D_data_with_overlap( # type: ignore self.current_sample["Y"][0], self.cfg.DATA.PATCH_SIZE[:-1] + (self.current_sample["Y"].shape[-1],), overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING, ) else: if self.current_sample["Y"] is not None: self.current_sample["Y"] = self.current_sample["Y"][0] obj = crop_3D_data_with_overlap( self.current_sample["X"][0], self.cfg.DATA.PATCH_SIZE, data_mask=self.current_sample["Y"], overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING, ) if self.current_sample["Y"] is not None: self.current_sample["X"], self.current_sample["Y"], _ = obj # type: ignore else: self.current_sample["X"], _ = obj # type: ignore del obj pred = self.predict_batches_in_test(self.current_sample["X"], self.current_sample["Y"]) # Reconstruct the predictions if original_data_shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]: if self.cfg.PROBLEM.NDIM == "3D": original_data_shape = original_data_shape[1:] f_name = merge_data_with_overlap if self.cfg.PROBLEM.NDIM == "2D" else merge_3D_data_with_overlap if self.cfg.TEST.REDUCE_MEMORY: pred = f_name( pred, original_data_shape[:-1] + (pred.shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE, ) if self.current_sample["Y"] is not None: self.current_sample["Y"] = f_name( self.current_sample["Y"], original_data_shape[:-1] + (self.current_sample["Y"].shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE, ) else: obj = f_name( pred, original_data_shape[:-1] + (pred.shape[-1],), data_mask=self.current_sample["Y"], padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE, ) if self.current_sample["Y"] is not None: pred, self.current_sample["Y"] = obj else: pred = obj del obj self.current_sample["X"] = f_name( self.current_sample["X"], original_data_shape[:-1] + (self.current_sample["X"].shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE, ) assert isinstance(pred, np.ndarray) if self.cfg.PROBLEM.NDIM != "3D": self.current_sample["X"] = X_original.copy() del X_original else: pred = np.expand_dims(pred, 0) self.current_sample["X"] = np.expand_dims(self.current_sample["X"], 0) if self.current_sample["Y"] is not None: self.current_sample["Y"] = np.expand_dims(self.current_sample["Y"], 0) # Resize to original shape if self.cfg.DATA.PREPROCESS.TEST and "rescaled_shape" in self.current_sample: rescaled_shape = (1,) + self.current_sample["rescaled_shape"][:-1]+(pred.shape[-1],) if self.cfg.TEST.VERBOSE: print( "Resizing prediction from {} to {}".format( pred.shape, rescaled_shape ) ) pred = resize_images( [pred], output_shape=rescaled_shape, order=self.cfg.DATA.PREPROCESS.RESIZE.ORDER, mode=self.cfg.DATA.PREPROCESS.RESIZE.MODE, cval=self.cfg.DATA.PREPROCESS.RESIZE.CVAL, clip=self.cfg.DATA.PREPROCESS.RESIZE.CLIP, preserve_range=self.cfg.DATA.PREPROCESS.RESIZE.PRESERVE_RANGE, anti_aliasing=self.cfg.DATA.PREPROCESS.RESIZE.ANTI_ALIASING, )[0] self.current_sample["X"] = resize_images( [self.current_sample["X"]], output_shape=rescaled_shape, order=self.cfg.DATA.PREPROCESS.RESIZE.ORDER, mode=self.cfg.DATA.PREPROCESS.RESIZE.MODE, cval=self.cfg.DATA.PREPROCESS.RESIZE.CVAL, clip=self.cfg.DATA.PREPROCESS.RESIZE.CLIP, preserve_range=self.cfg.DATA.PREPROCESS.RESIZE.PRESERVE_RANGE, anti_aliasing=self.cfg.DATA.PREPROCESS.RESIZE.ANTI_ALIASING, )[0] if self.current_sample["Y"] is not None: self.current_sample["Y"] = resize_images( [self.current_sample["Y"]], output_shape=self.current_sample["rescaled_shape"][:-1]+(self.current_sample["Y"].shape[-1],), order=0, mode=self.cfg.DATA.PREPROCESS.RESIZE.MODE, cval=self.cfg.DATA.PREPROCESS.RESIZE.CVAL, clip=self.cfg.DATA.PREPROCESS.RESIZE.CLIP, preserve_range=self.cfg.DATA.PREPROCESS.RESIZE.PRESERVE_RANGE, anti_aliasing=self.cfg.DATA.PREPROCESS.RESIZE.ANTI_ALIASING, )[0] if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE: reflected_orig_shape = (1,) + self.current_sample["reflected_orig_shape"] if reflected_orig_shape != pred.shape: if self.cfg.TEST.VERBOSE: print( "Cropping prediction to original shape {}".format( self.current_sample["reflected_orig_shape"] ) ) if self.cfg.PROBLEM.NDIM == "2D": pred = pred[:, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :] self.current_sample["X"] = self.current_sample["X"][ :, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :, ] if self.current_sample["Y"] is not None: self.current_sample["Y"] = self.current_sample["Y"][ :, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :, ] else: pred = pred[ :, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :, -reflected_orig_shape[3] :, ] self.current_sample["X"] = self.current_sample["X"][ :, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :, -reflected_orig_shape[3] :, ] if self.current_sample["Y"] is not None: self.current_sample["Y"] = self.current_sample["Y"][ :, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :, -reflected_orig_shape[3] :, ] # Apply mask if self.cfg.TEST.POST_PROCESSING.APPLY_MASK: pred = np.expand_dims(apply_binary_mask(pred[0], self.cfg.DATA.TEST.BINARY_MASKS), 0) if self.separated_class_channel: class_idx = self.model_output_channel_info.index("class") if "class" in self.model_output_channel_info else -1 pred = np.concatenate( ( pred[...,:-self.model_output_channels[class_idx]], np.expand_dims(np.argmax(pred[..., -self.model_output_channels[class_idx]:], axis=-1), axis=-1) ), axis=-1) # Save image if self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "" and self.cfg.TEST.SAVE_MODEL_RAW_OUTPUT: save_tif( pred, self.cfg.PATHS.RESULT_DIR.PER_IMAGE, [self.current_sample["X_filename"]], verbose=self.cfg.TEST.VERBOSE, ) # Calculate the metrics if self.current_sample["Y"] is not None: metric_values = self.metric_calculation(output=pred, targets=self.current_sample["Y"], train=False) for metric in metric_values: if str(metric).lower() not in self.stats["merge_patches"]: self.stats["merge_patches"][str(metric).lower()] = 0 self.stats["merge_patches"][str(metric).lower()] += metric_values[metric] self.current_sample_metrics[str(metric).lower()] = metric_values[metric] ############################ ### POST-PROCESSING (3D) ### ############################ if self.post_processing["per_image"]: pred = apply_post_processing(self.cfg, pred) # Calculate the metrics if self.current_sample["Y"] is not None: metric_values = self.metric_calculation( output=pred, targets=self.current_sample["Y"], train=False ) for metric in metric_values: if str(metric).lower() not in self.stats["merge_patches_post"]: self.stats["merge_patches_post"][str(metric).lower()] = 0 self.stats["merge_patches_post"][str(metric).lower()] += metric_values[metric] self.current_sample_metrics[str(metric).lower() + " (post-processing)"] = metric_values[metric] save_tif( pred, self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING, [self.current_sample["X_filename"]], verbose=self.cfg.TEST.VERBOSE, ) else: # Load prediction from file folder = ( self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING if self.post_processing["per_image"] else self.cfg.PATHS.RESULT_DIR.PER_IMAGE ) # read file created by 'save_tif' (it always has .tif extension) test_file = os.path.join(folder, os.path.splitext(self.current_sample["X_filename"])[0]+'.tif') pred = read_img_as_ndarray(test_file, is_3d=self.cfg.PROBLEM.NDIM == "3D") pred = np.expand_dims(pred, 0) # expand dimensions to include "batch" # Calculate the metrics if self.current_sample["Y"] is not None: metric_values = self.metric_calculation(output=pred, targets=self.current_sample["Y"], train=False) for metric in metric_values: if str(metric).lower() not in self.stats["merge_patches"]: self.stats["merge_patches"][str(metric).lower()] = 0 self.stats["merge_patches"][str(metric).lower()] += metric_values[metric] self.current_sample_metrics[str(metric).lower()] = metric_values[metric] self.after_merge_patches(pred) if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK: assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list) self.all_pred.append(pred) if self.current_sample["Y"] is not None and self.all_gt is not None: self.all_gt.append(self.current_sample["Y"]) ################## ### FULL IMAGE ### ################## if self.cfg.TEST.FULL_IMG and self.cfg.PROBLEM.NDIM == "2D": self.current_sample["X"], o_test_shape = check_downsample_division( self.current_sample["X"], len(self.cfg.MODEL.FEATURE_MAPS) - 1 ) if not self.cfg.TEST.REUSE_PREDICTIONS: if self.current_sample["Y"] is not None: self.current_sample["Y"], _ = check_downsample_division( self.current_sample["Y"], len(self.cfg.MODEL.FEATURE_MAPS) - 1 ) # Make the prediction if self.cfg.TEST.AUGMENTATION: pred = ensemble8_2d_predictions( self.current_sample["X"][0], axes_order_back=self.axes_order_back, pred_func=self.model_call_func, axes_order=self.axes_order, device=self.test_device, mode=self.cfg.TEST.AUGMENTATION_MODE, ) else: pred = self.model_call_func(self.current_sample["X"]) # Multi-head concatenation if isinstance(pred, dict): if "class" in pred: pred = torch.cat((pred["pred"], torch.argmax(pred["class"], dim=1).unsqueeze(1)), dim=1) else: pred = pred["pred"] pred = to_numpy_format(pred, self.axes_order_back) del self.current_sample["X"] # Recover original shape if padded with check_downsample_division pred = pred[:, : o_test_shape[1], : o_test_shape[2]] if self.current_sample["Y"] is not None: self.current_sample["Y"] = self.current_sample["Y"][:, : o_test_shape[1], : o_test_shape[2]] # Save image save_tif( pred, self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, [self.current_sample["X_filename"]], verbose=self.cfg.TEST.VERBOSE, ) if self.cfg.TEST.POST_PROCESSING.APPLY_MASK: pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS) else: # load prediction from file # read file created by 'save_tif' (it always has .tif extension) test_file = os.path.join(self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, os.path.splitext(self.current_sample["X_filename"])[0]+'.tif') pred = read_img_as_ndarray(test_file, is_3d=self.cfg.PROBLEM.NDIM == "3D") pred = np.expand_dims(pred, 0) # expand dimensions to include "batch" # Calculate the metrics if self.current_sample["Y"] is not None: metric_values = self.metric_calculation(output=pred, targets=self.current_sample["Y"], train=False) for metric in metric_values: if str(metric).lower() not in self.stats["full_image"]: self.stats["full_image"][str(metric).lower()] = 0 self.stats["full_image"][str(metric).lower()] += metric_values[metric] self.current_sample_metrics[str(metric).lower()] = metric_values[metric] if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK: assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list) self.all_pred.append(pred) if self.current_sample["Y"] is not None and self.all_gt is not None: self.all_gt.append(self.current_sample["Y"]) self.after_full_image(pred)
[docs] def normalize_stats(self, image_counter): """ Normalize statistics. Parameters ---------- image_counter : int Number of images to average the metrics. """ # Per crop for metric in self.stats["per_crop"]: self.stats["per_crop"][metric] = ( self.stats["per_crop"][metric] / self.stats["patch_by_batch_counter"] if self.stats["patch_by_batch_counter"] != 0 else 0 ) # Merge patches for metric in self.stats["merge_patches"]: self.stats["merge_patches"][metric] = ( self.stats["merge_patches"][metric] / image_counter if image_counter != 0 else 0 ) # Full image for metric in self.stats["full_image"]: self.stats["full_image"][metric] = ( self.stats["full_image"][metric] / image_counter if image_counter != 0 else 0 ) if self.post_processing["per_image"]: for metric in self.stats["merge_patches_post"]: self.stats["merge_patches_post"][metric] = ( self.stats["merge_patches_post"][metric] / image_counter if image_counter != 0 else 0 )
[docs] def print_stats(self, image_counter): """ Print statistics. Parameters ---------- image_counter : int Number of images to call ``normalize_stats``. """ self.normalize_stats(image_counter) if self.cfg.DATA.TEST.LOAD_GT: if not self.cfg.TEST.FULL_IMG or (len(self.stats["per_crop"]) > 0 or len(self.stats["merge_patches"]) > 0): if len(self.stats["per_crop"]) > 0: for metric in self.test_metric_names: if metric.lower() in self.stats["per_crop"]: metric_name = "Foreground IoU" if metric == "IoU" else metric self.test_metrics_message += ( "Test {} (per patch): {}\n".format( metric_name, self.stats["per_crop"][metric.lower()], ) ) if len(self.stats["merge_patches"]) > 0: for metric in self.test_metric_names: if metric.lower() in self.stats["merge_patches"]: metric_name = "Foreground IoU" if metric == "IoU" else metric self.test_metrics_message += ( "Test {} (merge patches): {}\n".format( metric_name, self.stats["merge_patches"][metric.lower()], ) ) else: if len(self.stats["full_image"]) > 0: for metric in self.test_metric_names: if metric.lower() in self.stats["full_image"]: metric_name = "Foreground IoU" if metric == "IoU" else metric self.test_metrics_message += ( "Test {} (per image): {}\n".format( metric_name, self.stats["full_image"][metric.lower()], ) ) if self.post_processing["per_image"] and len(self.stats["merge_patches_post"]) > 0: for metric in self.test_metric_names: if metric.lower() in self.stats["merge_patches_post"]: metric_name = "Foreground IoU" if metric == "IoU" else metric self.test_metrics_message += ( "Test {} (merge patches - post-processing): {}\n".format( metric_name, self.stats["merge_patches_post"][metric.lower()], ) ) if self.post_processing["as_3D_stack"] and len(self.stats["as_3D_stack_post"]) > 0: for metric in self.test_metric_names: if metric.lower() in self.stats["as_3D_stack_post"]: metric_name = "Foreground IoU" if metric == "IoU" else metric self.test_metrics_message += ( "Test {} (as 3D stack - post-processing): {}\n".format( metric_name, self.stats["as_3D_stack_post"][metric.lower()], ) ) if self.test_metrics_message != "": for line in self.test_metrics_message.split("\n"): print(line) df_metrics = pd.DataFrame(self.metrics_per_test_file) os.makedirs(self.cfg.PATHS.RESULT_DIR.PATH, exist_ok=True) df_metrics.to_csv( os.path.join( self.cfg.PATHS.RESULT_DIR.PATH, "test_results_metrics.csv", ), index=False, )
[docs] @abstractmethod def after_merge_patches(self, pred): """ Place any code that needs to be done after merging all predicted patches into the original image. Parameters ---------- pred : Torch Tensor Model prediction. """ raise NotImplementedError
[docs] @abstractmethod def after_full_image(self, pred: NDArray): """ Place here any code that must be executed after generating the prediction by supplying the entire image to the model. To enable this, the model should be convolutional, and the image(s) should be in a 2D format. Using 3D images as direct inputs to the model is not feasible due to their large size. Parameters ---------- pred : NDArray Model prediction. """ raise NotImplementedError
[docs] def after_all_images(self): """Place here any code that must be done after predicting all images.""" ############################ ### POST-PROCESSING (2D) ### ############################ if self.post_processing["as_3D_stack"]: self.all_pred = np.expand_dims(np.concatenate(self.all_pred), 0) if self.cfg.DATA.TEST.LOAD_GT and self.all_gt is not None: self.all_gt = np.expand_dims(np.concatenate(self.all_gt), 0) save_tif( self.all_pred, self.cfg.PATHS.RESULT_DIR.AS_3D_STACK, verbose=self.cfg.TEST.VERBOSE, ) save_tif( (self.all_pred > 0.5).astype(np.uint8), self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_BIN, verbose=self.cfg.TEST.VERBOSE, ) self.all_pred = apply_post_processing(self.cfg, self.all_pred) # Calculate the metrics if self.cfg.DATA.TEST.LOAD_GT: metric_values = self.metric_calculation(output=self.all_pred[0], targets=self.all_gt[0], train=False) for metric in metric_values: self.stats["as_3D_stack_post"][str(metric).lower()] = metric_values[metric] self.current_sample_metrics[str(metric).lower() + " as 3D stack (post-processing)"] = metric_values[metric] save_tif( self.all_pred, self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_POST_PROCESSING, verbose=self.cfg.TEST.VERBOSE, )
######################### ### BY CHUNKS METHODS ### ######################### # The order of the execution of the "by chunks" methods is the following: # * 'process_test_sample_by_chunks': process a sample in the inference phase in "by chunks" setting, this is the main method # that calls the other three methods below # 1. 'after_one_chunk_raw_prediction': after predicting one chunk # Once the predictions for all the chunks are generated each workflow will process the generated zarr in an specific way. The # process is the following: # 2. 'after_all_chunk_prediction_workflow_process': process to be done after predicting all the chunks in all ranks # 2.1 'after_one_chunk_workflow_process': process a list of chunks # 3. 'after_all_chunk_prediction_workflow_process_master_rank': process to be done after predicting all the chunks # but only on the master rank. #
[docs] def process_test_sample_by_chunks(self): """ Process a sample in the inference phase. A final H5/Zarr file is created in "TZCYX" or "TZYXC" order depending on ``DATA.TEST.INPUT_IMG_AXES_ORDER`` ('T' is always included). """ if not self.cfg.TEST.REUSE_PREDICTIONS and "prediction" in self.cfg.TEST.BY_CHUNKS.PHASES: # Create the generator self.test_generator = create_chunked_test_generator( self.cfg, system_dict=self.system_dict, current_sample=self.current_sample, norm_module=self.norm_module, out_dir=self.cfg.PATHS.RESULT_DIR.PER_IMAGE, dtype_str=self.dtype_str, ) tgen: chunked_test_pair_data_generator = self.test_generator.dataset # type: ignore # Get parallel data shape is ZYX _, z_dim, _, y_dim, x_dim = order_dimensions( tgen.X_parallel_data.shape, self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER ) self.parallel_data_shape = [z_dim, y_dim, x_dim] samples_visited = {} for obj_list in self.test_generator: sampler_ids, img, mask, patch_in_data, added_pad, norm_extra_info = obj_list if self.cfg.TEST.VERBOSE: print( "[Rank {} ({})] Patch number {} processing patches {} from {}".format( get_rank(), os.getpid(), sampler_ids, patch_in_data, self.parallel_data_shape ) ) # Pass the batch through the model pred = self.predict_batches_in_test(img, mask, disable_tqdm=True) if self.separated_class_channel: class_idx = self.model_output_channel_info.index("class") if "class" in self.model_output_channel_info else -1 pred = np.concatenate( ( pred[...,:-self.model_output_channels[class_idx]], np.expand_dims(np.argmax(pred[..., -self.model_output_channels[class_idx]:], axis=-1), axis=-1) ), axis=-1) lbreaked = False for i in range(pred.shape[0]): # Break the loop as those samples were created just to complete the last batch if sampler_ids[i] < sampler_ids[0] or sampler_ids[i] in samples_visited: print( "[Rank {} ({})] Patch {} discarded".format( get_rank(), os.getpid(), sampler_ids[i], ) ) lbreaked = True break single_pred = pred[i] single_pred_pad = added_pad[i] single_patch_in_data = patch_in_data[i] self.after_one_chunk_raw_prediction( sampler_ids[i], single_pred, single_patch_in_data, single_pred_pad ) # Remove padding if added single_pred = single_pred[ single_pred_pad[0][0] : single_pred.shape[0] - single_pred_pad[0][1], single_pred_pad[1][0] : single_pred.shape[1] - single_pred_pad[1][1], single_pred_pad[2][0] : single_pred.shape[2] - single_pred_pad[2][1], ] # Insert into the data tgen.insert_patch_in_file(single_pred, single_patch_in_data) samples_visited[sampler_ids[i]] = True if lbreaked and sampler_ids[i] in samples_visited: print( "[Rank {} ({})] Finishing the loop. Seems that the patches are starting to repeat".format( get_rank(), os.getpid(), ) ) break # Wait until all threads are done so the main thread can create the full size image if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): if self.cfg.TEST.VERBOSE: print( f"[Rank {get_rank()} ({os.getpid()})] Finished predicting patches. Waiting for all ranks . . ." ) dist.barrier() tgen.close_open_files() # Only after everyone finished writing, optionally convert to TIF on rank0 if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF: if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): dist.barrier() if is_main_process(): tgen.save_parallel_data_as_tif() if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): dist.barrier() self.after_all_chunk_prediction_workflow_process() if is_main_process(): self.after_all_chunk_prediction_workflow_process_master_rank() # Wait until all threads are done so the main thread can create the full size image if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): if self.cfg.TEST.VERBOSE: print(f"[Rank {get_rank()} ({os.getpid()})] Finished predicting sample. Waiting for all ranks . . .") dist.barrier()
[docs] def after_one_chunk_raw_prediction( self, chunk_id: int, chunk: NDArray, chunk_in_data: PatchCoords, added_pad: List[List[int]] ): """ Place any code that needs to be done after predicting one chunk of data in "by chunks" setting. Parameters ---------- chunk_id: int Chunk identifier. chunk : NDArray Predicted chunk patch_in_data : PatchCoords Global coordinates of the chunk. added_pad: List of list of ints Padding added to the chunk in each dimension. The order of dimensions is the same as the input image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], ...]. """ raise NotImplementedError
[docs] def after_one_chunk_workflow_process(self, chunks: List[NDArray], patch_in_data: List) -> Optional[List[NDArray]]: """ Process a list of chunks during inference in "by chunks" setting. Each workflow should have its own implementation of this method. Parameters ---------- chunks : List[NDArray] List of chunks. Expected axes are: ``(z, y, x, channels)`` for 3D and ``(y, x, channels)`` for 2D. patch_in_data : List[PatchCoords] Spatial coordinates of each chunk in the full volume. Returns ------- chunks : Optional[List[NDArray]] Processed chunks. """ raise NotImplementedError
[docs] def after_all_chunk_prediction_workflow_process(self): """ Place any code that needs to be done after predicting all patches in "by chunks" setting. This function is called on all ranks. """ print("Processing generated predictions . . .") # Create the generator fpath = os.path.join( self.cfg.PATHS.RESULT_DIR.PER_IMAGE, os.path.splitext(self.current_sample["X_filename"])[0] + ".zarr" ) self.test_generator = create_chunked_workflow_process_generator( self.cfg, system_dict=self.system_dict, model_predictions=fpath, out_dir=self.test_chunked_workflow_process_vars["out_dir"], dtype_str=self.test_chunked_workflow_process_vars["dtype_str"], ) tgen: chunked_workflow_process_generator = self.test_generator.dataset # type: ignore # Get parallel data shape is ZYX _, z_dim, _, y_dim, x_dim = order_dimensions( tgen.X_parallel_data.shape, self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER ) self.parallel_data_shape = [z_dim, y_dim, x_dim] samples_visited = {} for obj_list in self.test_generator: sampler_ids, chunks, patch_in_data = obj_list if self.cfg.TEST.VERBOSE: print( "[Rank {} ({})] Patch number {} processing patches {} from {}".format( get_rank(), os.getpid(), sampler_ids, patch_in_data, self.parallel_data_shape ) ) processed_chunks = self.after_one_chunk_workflow_process(chunks, patch_in_data) assert processed_chunks is not None, "The after_one_chunk_workflow_process() method must return a value." lbreaked = False for i in range(len(processed_chunks)): # Break the loop as those samples were created just to complete the last batch if sampler_ids[i] < sampler_ids[0] or sampler_ids[i] in samples_visited: print( "[Rank {} ({})] Patch {} discarded".format( get_rank(), os.getpid(), sampler_ids[i], ) ) lbreaked = True break tgen.insert_patch_in_file(processed_chunks[i], patch_in_data[i]) samples_visited[sampler_ids[i]] = True if lbreaked and sampler_ids[i] in samples_visited: print( "[Rank {} ({})] Finishing the loop. Seems that the patches are starting to repeat".format( get_rank(), os.getpid(), ) ) break # Wait until all threads are done so the main thread can create the full size image if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): if self.cfg.TEST.VERBOSE: print( f"[Rank {get_rank()} ({os.getpid()})] Finished predicting patches. Waiting for all ranks . . ." ) dist.barrier() tgen.close_open_files() # Only after everyone finished writing, optionally convert to TIF on rank0 if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF: if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): dist.barrier() if is_main_process(): tgen.save_parallel_data_as_tif() if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized(): dist.barrier()
[docs] def after_all_chunk_prediction_workflow_process_master_rank(self): """ Place any code that needs to be done after predicting all the patches in the "by chunks" setting. This function is only called on the master rank. """ raise NotImplementedError