"""
Base workflow class for BiaPy.
This module defines the Base_Workflow abstract class, which provides the main
structure and utility methods for building training and inference workflows in BiaPy.
It handles configuration, model preparation, data loading, training, testing,
metrics, logging, and post-processing for both 2D and 3D biomedical image analysis.
"""
import math
import os
import datetime
import time
import json
import torch
import argparse
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
import torch.distributed as dist
from typing import Any, Dict, Optional, List
from numpy.typing import NDArray
from yacs.config import CfgNode as CN
import pandas as pd
import biapy
from bioimageio.spec import load_description
from biapy.models import (
build_model,
build_torchvision_model,
build_bmz_model,
is_biapy_model,
get_bmz_model_kwargs,
check_bmz_args,
get_last_layer_info,
)
from biapy.models.blocks import get_activation
from biapy.engine import prepare_optimizer, build_callbacks
from biapy.data.generators import (
create_train_val_augmentors,
create_test_generator,
create_chunked_test_generator,
check_generator_consistence,
create_chunked_workflow_process_generator,
)
from biapy.utils.misc import (
get_world_size,
get_rank,
is_main_process,
save_model,
time_text,
load_model_checkpoint,
TensorboardLogger,
MetricLogger,
to_pytorch_format,
to_numpy_format,
is_dist_avail_and_initialized,
setup_for_distributed,
update_dict_with_existing_keys,
)
from biapy.engine.check_configuration import (
convert_old_model_cfg_to_current_version,
diff_between_configs,
check_configuration,
)
from biapy.utils.util import (
create_plots,
check_downsample_division,
get_cfg_key_value,
)
from biapy.engine.train_engine import train_one_epoch, evaluate
from biapy.data.data_2D_manipulation import (
crop_data_with_overlap,
merge_data_with_overlap,
)
from biapy.data.data_3D_manipulation import (
crop_3D_data_with_overlap,
merge_3D_data_with_overlap,
order_dimensions,
looks_like_hdf5,
)
from biapy.data.data_manipulation import (
load_and_prepare_train_data,
load_and_prepare_test_data,
read_img_as_ndarray,
save_tif,
resize,
)
from biapy.data.pre_processing import resize_images
from biapy.data.post_processing.post_processing import (
ensemble8_2d_predictions,
ensemble16_3d_predictions,
apply_binary_mask,
)
from biapy.data.post_processing import apply_post_processing
from biapy.data.pre_processing import preprocess_data
from biapy.data.norm import normalize_image, normalize_mask
from biapy.data.generators.chunked_test_pair_data_generator import chunked_test_pair_data_generator
from biapy.data.dataset import PatchCoords
from biapy.models.memory_bank import MemoryBank
[docs]
class Base_Workflow(metaclass=ABCMeta):
"""
Base workflow class. A new workflow should extend this class.
Parameters
----------
cfg : YACS configuration
Running configuration.
Job_identifier : str
Complete name of the running job.
device : Torch device
Device used.
args : argpase class
Arguments used in BiaPy's call.
"""
def __init__(
self,
cfg: CN,
job_identifier: str,
device: torch.device,
system_dict: Dict[str, int],
args: argparse.Namespace,
):
"""
Initialize the Base_Workflow object.
Sets up configuration, device, job identifier, and initializes
all workflow attributes and state variables.
Parameters
----------
cfg : CN
Running configuration.
job_identifier : str
Complete name of the running job.
device : torch.device
Device used.
system_dict : dict
System dictionary containing:
* 'cpu_budget': int, Total CPU budget.
* 'cpu_per_rank': int, CPU budget per rank.
* 'main_threads': int, Number of main threads.
* 'num_workers_hint': int, Hint for the number of workers.
args : argparse.Namespace
Arguments used in BiaPy's call.
"""
self.cfg = cfg
self.args = args
self.job_identifier = job_identifier
self.device = device
self.system_dict = system_dict
if self.cfg.TEST.METRICS_IN_CPU:
self.test_device = torch.device("cpu")
else:
self.test_device = device
self.original_test_mask_path = None
self.test_mask_filenames = None
self.cross_val_samples_ids = None
self.post_processing = {}
self.post_processing["per_image"] = False
self.post_processing["as_3D_stack"] = False
self.model = None
self.model_build_kwargs = None
self.checkpoint_path = None
self.optimizer = None
self.model_prepared = False
self.dtype = np.float32 if not self.cfg.TEST.REDUCE_MEMORY else np.float16
self.dtype_str = "float32" if not self.cfg.TEST.REDUCE_MEMORY else "float16"
self.loss_dtype = torch.float32
self.dims = 2 if self.cfg.PROBLEM.NDIM == "2D" else 3
self.apply_activations = True
self.use_gt = False
if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
self.use_gt = True
# Save paths in case we need them in a future
self.orig_train_path = self.cfg.DATA.TRAIN.PATH
self.orig_train_mask_path = self.cfg.DATA.TRAIN.GT_PATH
self.orig_val_path = self.cfg.DATA.VAL.PATH
self.orig_val_mask_path = self.cfg.DATA.VAL.GT_PATH
self.all_pred = []
self.all_gt = []
self.stats = {}
# Per crop
self.stats["per_crop"] = {}
self.stats["patch_by_batch_counter"] = 0
# Merging the image
self.stats["merge_patches"] = {}
self.stats["merge_patches_post"] = {}
# As 3D stack
self.stats["as_3D_stack"] = {}
self.stats["as_3D_stack_post"] = {}
# Full image
self.stats["full_image"] = {}
self.stats["full_image_post"] = {}
# To store all the metrics for each test file in order to create a final csv file with the results
self.metrics_per_test_file = []
self.mask_path = ""
self.is_y_mask = False
self.model_output_channels = []
self.model_output_channel_info = []
self.head_activations = []
self.separated_class_channel = None
self.train_metrics = []
self.train_metric_best = []
self.train_metric_names = []
self.test_metrics = []
self.test_metric_best = []
self.test_metric_names = []
self.loss = None
self.memory_bank = None
self.gt_channels_expected = -1
self.train_metrics_message = ""
self.test_metrics_message = ""
self.resolution: List[int | float] = list(self.cfg.DATA.TEST.RESOLUTION)
if self.cfg.PROBLEM.NDIM == "2D":
self.resolution = [
1,
] + self.resolution
self.world_size = get_world_size()
self.global_rank = get_rank()
# Test variables
if self.cfg.TEST.POST_PROCESSING.MEDIAN_FILTER:
if self.cfg.PROBLEM.NDIM == "2D":
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
self.post_processing["as_3D_stack"] = True
else:
self.post_processing["per_image"] = True
else:
self.post_processing["per_image"] = True
# Define permute shapes to pass from Numpy axis order (Y,X,C) to Pytorch's (C,Y,X)
self.axes_order = (0, 3, 1, 2) if self.cfg.PROBLEM.NDIM == "2D" else (0, 4, 1, 2, 3)
self.axes_order_back = (0, 2, 3, 1) if self.cfg.PROBLEM.NDIM == "2D" else (0, 2, 3, 4, 1)
# Tochvision variables
self.torchvision_preprocessing = None
# Load pretrained model configuration if needed and check consistency with current configuration
self.bmz_config = {}
if self.cfg.MODEL.SOURCE == "biapy":
# Obtain model spec from checkpoint
if self.cfg.MODEL.LOAD_CHECKPOINT:
# Take cfg from the checkpoint
saved_cfg, biapy_ckpt_version = load_model_checkpoint(
cfg=self.cfg,
jobname=self.job_identifier,
model_without_ddp=None,
device=self.device,
just_extract_checkpoint_info=True,
skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS,
)
assert isinstance(saved_cfg, CN), "There was an error loading the checkpoint configuration. The loaded configuration is not a YACS CfgNode object but of type {}. Check that the checkpoint file is not corrupted.".format(type(saved_cfg))
if saved_cfg:
if len(self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT) > 0:
print("Checkpoint file loaded. Extracting the following items (if available): {} . Checking consistency with current configuration . . .".format(", ".join(self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT)))
# Checks that this config and previous represent the same workflow
header_message = "There is an inconsistency between the configuration loaded from checkpoint and the actual one. Error:\n"
tmp_cfg = convert_old_model_cfg_to_current_version(saved_cfg.clone())
# Override model specs
if self.cfg.PROBLEM.PRINT_OLD_KEY_CHANGES:
print("The following changes were made in order to adapt the loaded input configuration from checkpoint into the current configuration version:")
diff_between_configs(saved_cfg, tmp_cfg)
if "weights" in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT and "norm" not in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT:
print("WARNING: Weights will be loaded from checkpoint but not normalization instructions. This can lead to inconsistent results if the normalization instructions in the current configuration are different from the ones used in the checkpoint. Consider adding 'norm' to 'MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT' to avoid this issue.")
if "norm" in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT:
print("Normalization instructions will be loaded from checkpoint.")
update_dict_with_existing_keys(self.cfg["DATA"]["NORMALIZATION"], tmp_cfg["DATA"]["NORMALIZATION"])
if "model_arch" in self.cfg.MODEL.ITEMS_TO_LOAD_FROM_CHECKPOINT:
print("Model architecture will be loaded from checkpoint.")
# Save current model config
tmp_BMZ_config = self.cfg.MODEL.clone()
update_dict_with_existing_keys(self.cfg["MODEL"], tmp_cfg["MODEL"])
# Restore some model config
self.cfg["MODEL"]["BMZ"] = tmp_BMZ_config["BMZ"]
self.cfg["MODEL"]["OUT_CHECKPOINT_FORMAT"] = tmp_BMZ_config["OUT_CHECKPOINT_FORMAT"]
# Check if the merge is coherent
self.cfg["MODEL"]["LOAD_CHECKPOINT"] = True
self.cfg["MODEL"]["LOAD_MODEL_FROM_CHECKPOINT"] = False
check_configuration(self.cfg, self.job_identifier)
# Load BioImage Model Zoo pretrained model information
elif self.cfg.MODEL.SOURCE == "bmz":
self.bmz_config["preprocessing"], opts = check_bmz_args(self.cfg.MODEL.BMZ.SOURCE_MODEL_ID, self.cfg)
print("[BMZ] Overriding preprocessing steps to the ones fixed in BMZ model: {}".format(self.bmz_config["preprocessing"]))
# Adapt configuration to match the one defined in the RDF
option_list = []
for key, val in opts.items():
old_val = get_cfg_key_value(cfg, key)
change = False
# Not changing patch size if only the channel dimension is different
if "DATA.PATCH_SIZE" in key:
if old_val[:-1] != val[:-1]:
change = True
elif old_val != val:
change = True
if change:
print(f"[BMZ] Changed '{key}' from '{old_val}' to '{val}' as defined in the RDF")
option_list.append(key)
option_list.append(val)
print("Loading BioImage Model Zoo pretrained model . . .")
self.bmz_config["original_bmz_config"] = load_description(self.cfg.MODEL.BMZ.SOURCE_MODEL_ID)
self.cfg.merge_from_list(option_list)
# Check consistency of the resulting configuration after merging with the BMZ model configuration
check_configuration(self.cfg, self.job_identifier)
# Save number of channels to be created by the model
self.define_activations_and_channels()
# Define metrics
self.define_metrics()
# Normalization checks
print("Creating normalization module . . .")
self.norm_module = {
"type": cfg.DATA.NORMALIZATION.TYPE,
"mask_norm": "as_mask",
"out_dtype": "float32",
"percentile_clip": cfg.DATA.NORMALIZATION.PERC_CLIP.ENABLE,
"per_lower_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_PERC,
"per_upper_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_PERC,
"lower_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_VALUE,
"upper_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_VALUE,
"mean": cfg.DATA.NORMALIZATION.ZERO_MEAN_UNIT_VAR.MEAN_VAL,
"std": cfg.DATA.NORMALIZATION.ZERO_MEAN_UNIT_VAR.STD_VAL,
}
print("Normalization module created with the following configuration:")
for key, val in self.norm_module.items():
print(f" {key}: {val}")
self.test_norm_module = self.norm_module.copy()
self.test_norm_module["train_normalization"] = False
self.test_norm_module["out_dtype"] = "float32" if not cfg.TEST.REDUCE_MEMORY else "float16"
if self.cfg.MODEL.SOURCE == "torchvision":
print("Creating normalization module . . .")
self.torchvision_norm = {
"type": "scale_range",
"mask_norm": "as_mask",
"out_dtype": "float32" if not cfg.TEST.REDUCE_MEMORY else "float16",
"percentile_clip": cfg.DATA.NORMALIZATION.PERC_CLIP.ENABLE,
"per_lower_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_PERC,
"per_upper_bound": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_PERC,
"lower_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.LOWER_VALUE,
"upper_bound_val": cfg.DATA.NORMALIZATION.PERC_CLIP.UPPER_VALUE,
}
print("Torchvision normalization module created with the following configuration:")
for key, val in self.torchvision_norm.items():
print(f" {key}: {val}")
# Chunked workflow process generator placeholder
self.test_chunked_workflow_process_vars = {
"out_dir": self.cfg.PATHS.RESULT_DIR.PER_IMAGE,
"dtype_str": self.dtype_str,
}
[docs]
def define_activations_and_channels(self):
"""
Define the activations to be applied to the model output and the channels that the model will output.
This function must define the following variables:
self.model_output_channels : List of int
Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels,
[1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc.
self.model_output_channel_info : List of str
Information about the output channels. A value per output head of the model must be defined.
self.separated_class_channel : bool
Whether if we should expect a separated output channel for classification.
self.head_activations : List of str
Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined.
"linear" and "ce_sigmoid" will not be applied. E.g. ["linear"] for a model with one channel, ["linear", "sigmoid"] for a
model with two channels, etc.
Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground
and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would
be correct::
self.model_output_channels = [1, 3]
self.model_output_channel_info = ["mask", "class"]
self.separated_class_channel = True
self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"]
"""
if not self.model_output_channels:
raise ValueError(
"'model_output_channels' needs to be defined. Correct define_activations_and_channels() function"
)
if not isinstance(self.model_output_channels, list):
raise ValueError("'self.model_output_channels' must be a list")
for x in self.model_output_channels:
if not isinstance(x, int):
raise ValueError("'self.model_output_channels' must be a list of integers")
if self.model_output_channel_info is None:
raise ValueError("'model_output_channel_info' needs to be defined. Correct define_activations_and_channels() function")
if not isinstance(self.model_output_channel_info, list):
raise ValueError("'self.model_output_channel_info' must be a list")
for x in self.model_output_channel_info:
if not isinstance(x, str):
raise ValueError("'self.model_output_channel_info' must be a list of strings")
if self.separated_class_channel is None:
raise ValueError("'separated_class_channel' needs to be defined. Correct define_activations_and_channels() function")
if not self.head_activations:
raise ValueError("'self.head_activations' needs to be defined. Correct define_activations_and_channels() function")
if not isinstance(self.head_activations, list):
raise ValueError("'self.head_activations' must be a list of strings")
for x in self.head_activations:
if not isinstance(x, str):
raise ValueError("'self.head_activations' must be a list of strings")
head_number = sum(self.model_output_channels)
assert len(self.head_activations) == head_number, "Activations and output channels do not match. "
"{} activations vs {} output channels".format(len(self.head_activations), head_number)
assert len(self.model_output_channels) == len(self.model_output_channel_info), "Output channel info and output channels do not match. "
"{} output channel info vs {} output channels".format(len(self.model_output_channel_info), len(self.model_output_channels))
if self.gt_channels_expected == -1:
raise ValueError(
"'gt_channels_expected' needs to be defined. Correct define_activations_and_channels() function"
)
[docs]
def define_metrics(self):
"""
Define the metrics to be calculated during training and test.
This function must define the following variables:
self.train_metrics : List of functions
Metrics to be calculated during model's training.
self.train_metric_names : List of str
Names of the metrics calculated during training.
self.train_metric_best : List of str
To know which value should be considered as the best one. Options must be: "max" or "min".
self.test_metrics : List of functions
Metrics to be calculated during model's test/inference.
self.test_metric_names : List of str
Names of the metrics calculated during test/inference.
self.loss : Function
Loss function used during training and test.
"""
if not self.train_metrics:
raise ValueError("'train_metrics' needs to be defined. Correct define_metrics() function")
if not self.train_metric_names:
raise ValueError("'train_metric_names' needs to be defined. Correct define_metrics() function")
if not self.train_metric_best:
raise ValueError("'train_metric_best' needs to be defined. Correct define_metrics() function")
else:
assert all(
[True if x in ["max", "min"] else False for x in self.train_metric_best]
), "'train_metric_best' needs to be one between ['max', 'min']"
if not self.test_metrics:
raise ValueError("'test_metrics' needs to be defined. Correct define_metrics() function")
if not self.test_metric_names:
raise ValueError("'test_metric_names' needs to be defined. Correct define_metrics() function")
if self.loss is None:
raise ValueError("'loss' needs to be defined. Correct define_metrics() function")
[docs]
@abstractmethod
def metric_calculation(
self,
output: NDArray | torch.Tensor,
targets: NDArray | torch.Tensor,
train: bool = True,
metric_logger: Optional[MetricLogger] = None,
) -> Dict:
"""
Execute the calculation of metrics defined in :func:`~define_metrics` function.
Parameters
----------
output : Torch Tensor
Prediction of the model.
targets : Torch Tensor
Ground truth to compare the prediction with.
train : bool, optional
Whether to calculate train or test metrics.
metric_logger : MetricLogger, optional
Class to be updated with the new metric(s) value(s) calculated.
Returns
-------
value : float
Value of the metric for the given prediction.
"""
raise NotImplementedError
[docs]
def prepare_targets(self, targets, batch):
"""
Location to perform any necessary data transformations to ``targets`` before calculating the loss.
Parameters
----------
targets : Torch Tensor
Ground truth to compare the prediction with.
batch : Torch Tensor
Prediction of the model. Only used in SSL workflow.
Returns
-------
targets : Torch tensor
Resulting targets.
"""
# We do not use 'batch' input but in SSL workflow
return to_pytorch_format(targets, self.axes_order, self.device, dtype=targets.dtype)
[docs]
def load_train_data(self):
"""Load training and validation data."""
print("##########################")
print("# LOAD TRAINING DATA #")
print("##########################")
train_zarr_data_information = {
"raw_path": self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA_RAW_PATH,
"gt_path": self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA_GT_PATH,
"use_gt_path": self.cfg.PROBLEM.TYPE != "INSTANCE_SEG",
"multiple_data_within_zarr": self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA,
"input_img_axes": self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER,
"input_mask_axes": self.cfg.DATA.TRAIN.INPUT_MASK_AXES_ORDER,
}
val_zarr_data_information = {
"raw_path": self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA_RAW_PATH,
"gt_path": self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA_GT_PATH,
"use_gt_path": self.cfg.PROBLEM.TYPE != "INSTANCE_SEG",
"multiple_data_within_zarr": self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA,
"input_img_axes": self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER,
"input_mask_axes": self.cfg.DATA.VAL.INPUT_MASK_AXES_ORDER,
}
(
self.X_train,
self.Y_train,
self.X_val,
self.Y_val,
) = load_and_prepare_train_data(
train_path=self.cfg.DATA.TRAIN.PATH,
train_mask_path=self.mask_path,
train_in_memory=self.cfg.DATA.TRAIN.IN_MEMORY,
train_ov=self.cfg.DATA.TRAIN.OVERLAP,
train_padding=self.cfg.DATA.TRAIN.PADDING,
val_path=self.cfg.DATA.VAL.PATH,
val_mask_path=self.cfg.DATA.VAL.GT_PATH,
val_in_memory=self.cfg.DATA.VAL.IN_MEMORY,
val_ov=self.cfg.DATA.VAL.OVERLAP,
val_padding=self.cfg.DATA.VAL.PADDING,
norm_module=self.norm_module,
crop_shape=self.cfg.DATA.PATCH_SIZE,
cross_val=self.cfg.DATA.VAL.CROSS_VAL,
cross_val_nsplits=self.cfg.DATA.VAL.CROSS_VAL_NFOLD,
cross_val_fold=self.cfg.DATA.VAL.CROSS_VAL_FOLD,
val_split=self.cfg.DATA.VAL.SPLIT_TRAIN if self.cfg.DATA.VAL.FROM_TRAIN else 0.0,
seed=self.cfg.SYSTEM.SEED,
shuffle_val=self.cfg.DATA.VAL.RANDOM,
train_preprocess_f=preprocess_data if self.cfg.DATA.PREPROCESS.TRAIN else None,
train_preprocess_cfg=self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.TRAIN else None,
train_filter_props=(
self.cfg.DATA.TRAIN.FILTER_SAMPLES.PROPS if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else []
),
train_filter_vals=(
self.cfg.DATA.TRAIN.FILTER_SAMPLES.VALUES if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else []
),
train_filter_signs=(
self.cfg.DATA.TRAIN.FILTER_SAMPLES.SIGNS if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else []
),
val_preprocess_f=preprocess_data if self.cfg.DATA.PREPROCESS.VAL else None,
val_preprocess_cfg=self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.VAL else None,
val_filter_props=(
self.cfg.DATA.VAL.FILTER_SAMPLES.PROPS if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else []
),
val_filter_vals=(
self.cfg.DATA.VAL.FILTER_SAMPLES.VALUES if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else []
),
val_filter_signs=(
self.cfg.DATA.VAL.FILTER_SAMPLES.SIGNS if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else []
),
filter_by_entire_image=self.cfg.DATA.FILTER_BY_IMAGE,
norm_before_filter=self.cfg.DATA.TRAIN.FILTER_SAMPLES.NORM_BEFORE,
y_upscaling=self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING,
gt_channels_expected=self.gt_channels_expected,
reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE,
convert_to_rgb=self.cfg.DATA.FORCE_RGB,
is_y_mask=self.is_y_mask,
is_3d=(self.cfg.PROBLEM.NDIM == "3D"),
train_zarr_data_information=train_zarr_data_information,
val_zarr_data_information=val_zarr_data_information,
multiple_raw_images=(
self.cfg.PROBLEM.TYPE == "IMAGE_TO_IMAGE"
and self.cfg.PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER
),
save_filtered_images=self.cfg.DATA.SAVE_FILTERED_IMAGES,
save_filtered_images_dir=self.cfg.PATHS.FIL_SAMPLES_DIR,
save_filtered_images_num=self.cfg.DATA.SAVE_FILTERED_IMAGES_NUM,
)
# Ensure all the processes have read the data
if is_dist_avail_and_initialized():
print("Waiting until all processes have read the data . . .")
dist.barrier()
[docs]
def destroy_train_data(self):
"""Delete training variables to release memory."""
print("Releasing memory . . .")
if "X_train" in locals() or "X_train" in globals():
del self.X_train
if "Y_train" in locals() or "Y_train" in globals():
del self.Y_train
if "X_val" in locals() or "X_val" in globals():
del self.X_val
if "Y_val" in locals() or "Y_val" in globals():
del self.Y_val
if "train_generator" in locals() or "train_generator" in globals():
del self.train_generator
if "val_generator" in locals() or "val_generator" in globals():
del self.val_generator
[docs]
def prepare_train_generators(self):
"""Build training and validation generators."""
if self.cfg.TRAIN.ENABLE:
print("##############################")
print("# PREPARE TRAIN GENERATORS #")
print("##############################")
(
self.train_generator,
self.val_generator,
self.num_training_steps_per_epoch,
self.bmz_config["test_input"],
self.bmz_config["cover_raw"],
self.bmz_config["cover_gt"],
) = create_train_val_augmentors(
self.cfg,
system_dict=self.system_dict,
X_train=self.X_train,
X_val=self.X_val,
Y_train=self.Y_train,
Y_val=self.Y_val,
norm_module=self.norm_module,
)
if self.cfg.DATA.CHECK_GENERATORS and self.cfg.PROBLEM.TYPE != "CLASSIFICATION":
check_generator_consistence(
self.train_generator,
self.cfg.PATHS.GEN_CHECKS + "_train",
self.cfg.PATHS.GEN_MASK_CHECKS + "_train",
)
check_generator_consistence(
self.val_generator,
self.cfg.PATHS.GEN_CHECKS + "_val",
self.cfg.PATHS.GEN_MASK_CHECKS + "_val",
)
[docs]
def bmz_model_call(self, in_img, is_train=False):
"""
Call BioImage Model Zoo model.
Parameters
----------
in_img : torch.Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : torch.Tensor
Image prediction.
"""
assert self.model
prediction = self.model(in_img)
return prediction
[docs]
@abstractmethod
def torchvision_model_call(self, in_img: torch.Tensor, is_train=False) -> torch.Tensor:
"""
Call a regular Pytorch model.
Parameters
----------
in_img : torch.Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : torch.Tensor
Image prediction.
"""
raise NotImplementedError
[docs]
def model_call_func(
self, in_img: NDArray | torch.Tensor, is_train: bool = False, apply_act: bool = True
) -> Any:
"""
Call a regular Pytorch model.
Parameters
----------
in_img : torch.Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
apply_act : bool, optional
Whether to apply activations or not.
Returns
-------
prediction : torch.Tensor
Image prediction.
"""
in_img = to_pytorch_format(in_img, self.axes_order, self.device)
assert isinstance(in_img, torch.Tensor)
if self.cfg.MODEL.SOURCE == "biapy":
assert self.model
pred = self.model(in_img)
# Recover the original shape of the input, as not all the model return a prediction
# of the same size as the input image
if (
not (self.cfg.LOSS.CONTRAST.ENABLE and is_train)
and not (self.cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and self.cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK.lower() == "masking")
and self.cfg.PROBLEM.TYPE not in ["CLASSIFICATION", "SUPER_RESOLUTION"]
):
if isinstance(pred, dict):
if "pred" in pred and pred["pred"].shape[2:] != in_img.shape[2:]:
mode = "bilinear" if self.cfg.PROBLEM.NDIM == "2D" else "trilinear"
sshape = (in_img.shape[0],) + (pred["pred"].shape[1],) + in_img.shape[2:]
pred["pred"] = resize(pred["pred"], sshape, mode=mode)
if "class" in pred and pred["class"].shape[2:] != in_img.shape[2:]:
sshape = (in_img.shape[0],) + (pred["class"].shape[1],) + in_img.shape[2:]
pred["class"] = resize(pred["class"], sshape, mode="nearest")
elif not isinstance(pred, list):
if pred.shape[2:] != in_img.shape[2:]:
mode = "bilinear" if self.cfg.PROBLEM.NDIM == "2D" else "trilinear"
sshape = (in_img.shape[0],) + (pred.shape[1],) + in_img.shape[2:]
pred = resize(pred, sshape, mode=mode)
# Allow multiple outputs
if isinstance(pred, list):
for i in range(len(pred)):
if apply_act:
pred[i] = self.apply_model_activations(pred[i], training=is_train)
else:
if apply_act:
pred = self.apply_model_activations(pred, training=is_train) # type: ignore
elif self.cfg.MODEL.SOURCE == "bmz":
pred = self.apply_model_activations(self.bmz_model_call(in_img, is_train), training=is_train)
elif self.cfg.MODEL.SOURCE == "torchvision":
pred = self.torchvision_model_call(in_img, is_train)
if not is_train:
if isinstance(pred, dict):
for key in pred:
if torch.is_tensor(pred[key]):
pred[key] = pred[key].to(self.test_device)
elif torch.is_tensor(pred):
pred = pred.to(self.test_device)
return pred
[docs]
def prepare_model(self):
"""Build the model."""
if self.model_prepared:
print("Model already prepared!")
return
print("###############")
print("# Build model #")
print("###############")
if self.cfg.MODEL.SOURCE == "biapy":
(
self.model,
self.bmz_config["callable_model"],
self.bmz_config["collected_sources"],
self.bmz_config["all_import_lines"],
self.bmz_config["scanned_files"],
self.model_build_kwargs,
self.network_stride,
) = build_model(
self.cfg,
self.model_output_channels,
self.model_output_channel_info,
self.head_activations,
self.device
)
self.bmz_config["is_biapy_model"] = True
elif self.cfg.MODEL.SOURCE == "torchvision":
self.model, self.torchvision_preprocessing = build_torchvision_model(self.cfg, self.device)
self.bmz_config["is_biapy_model"] = False
# BioImage Model Zoo pretrained models
elif self.cfg.MODEL.SOURCE == "bmz":
self.model = build_bmz_model(self.cfg, self.bmz_config["original_bmz_config"], self.device)
# For old BMZ models uploaded in BMZ we need to explicitly insert the sigmoid activation as postprocessing
self.bmz_config["is_biapy_model"] = is_biapy_model(self.bmz_config["original_bmz_config"])
bmz_model_kwargs = get_bmz_model_kwargs(self.bmz_config["original_bmz_config"])
if self.bmz_config["is_biapy_model"] and "explicit_activation" not in bmz_model_kwargs:
self.bmz_config["postprocessing"] = []
if self.head_activations[0] in ["ce_sigmoid", "Sigmoid"]:
self.bmz_config["postprocessing"].append("sigmoid")
layer_info = get_last_layer_info(self.model)
if layer_info['is_activation']:
print("[BMZ] Disabling manual activations after the model as the last layer of the model is already an activation function ({}).".format(layer_info['layer_type']))
self.apply_activations = False
self.model_without_ddp = self.model
if self.args.distributed:
self.model = torch.nn.parallel.DistributedDataParallel(
self.model,
device_ids=[self.args.gpu],
find_unused_parameters=False,
)
self.model_without_ddp = self.model.module
self.model_prepared = True
# Load checkpoint if necessary
if self.cfg.MODEL.SOURCE == "biapy" and self.cfg.MODEL.LOAD_CHECKPOINT:
self.start_epoch, self.checkpoint_path = load_model_checkpoint(
cfg=self.cfg,
jobname=self.job_identifier,
model_without_ddp=self.model_without_ddp,
device=self.device,
optimizer=self.optimizer,
skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS,
)
else:
self.start_epoch = 0
[docs]
def train(self):
"""Training phase."""
self.load_train_data()
if not self.model_prepared:
self.prepare_model()
self.prepare_train_generators()
self.prepare_logging_tool()
self.early_stopping = build_callbacks(self.cfg)
assert (
self.start_epoch is not None and self.model is not None and self.model_without_ddp is not None and self.loss
)
assert isinstance(self.start_epoch, int), "'start_epoch' should be an integer"
self.optimizer, self.lr_scheduler = prepare_optimizer(
self.cfg, self.model_without_ddp, len(self.train_generator)
)
contrast_init_iter = 0
if self.cfg.LOSS.CONTRAST.ENABLE:
self.memory_bank = MemoryBank(
num_classes=self.gt_channels_expected,
memory_size = self.cfg.LOSS.CONTRAST.MEMORY_SIZE,
feature_dims = self.cfg.LOSS.CONTRAST.PROJ_DIM,
network_stride = self.network_stride,
pixel_update_freq=self.cfg.LOSS.CONTRAST.PIXEL_UPD_FREQ,
device = self.device,
ignore_index = self.cfg.LOSS.IGNORE_INDEX,
)
self.memory_bank.to(self.device)
# When to activate the contrastive loss
contrast_init_iter = self.cfg.LOSS.CONTRAST.MEMORY_SIZE
if self.cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine":
contrast_init_iter += self.cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS
print("#####################")
print("# TRAIN THE MODEL #")
print("#####################")
print(f"Start training in epoch {self.start_epoch+1} - Total: {self.cfg.TRAIN.EPOCHS}")
start_time = time.time()
self.val_best_metric = np.zeros(len(self.train_metric_names), dtype=np.float32)
self.val_best_loss = np.inf
total_iters = 0
for epoch in range(self.start_epoch, self.cfg.TRAIN.EPOCHS):
print("~~~ Epoch {}/{} ~~~\n".format(epoch + 1, self.cfg.TRAIN.EPOCHS))
e_start = time.time()
if self.args.distributed:
self.train_generator.sampler.set_epoch(epoch) # type: ignore
if self.log_writer:
self.log_writer.set_step(epoch * self.num_training_steps_per_epoch)
# Train
train_stats, iterations_done = train_one_epoch(
self.cfg,
model=self.model,
model_call_func=self.model_call_func,
loss_function=self.loss,
metric_function=self.metric_calculation,
prepare_targets=self.prepare_targets,
data_loader=self.train_generator,
optimizer=self.optimizer,
device=self.device,
epoch=epoch,
log_writer=self.log_writer,
lr_scheduler=self.lr_scheduler,
verbose=self.cfg.TRAIN.VERBOSE,
memory_bank=self.memory_bank,
total_iters=total_iters,
contrast_warmup_iters=contrast_init_iter,
)
total_iters += iterations_done
# Save checkpoint
if self.cfg.MODEL.SAVE_CKPT_FREQ != -1:
if (
(epoch + 1) % self.cfg.MODEL.SAVE_CKPT_FREQ == 0
or epoch + 1 == self.cfg.TRAIN.EPOCHS
and is_main_process()
):
save_model(
cfg=self.cfg,
biapy_version=biapy.__version__,
jobname=self.job_identifier,
model_without_ddp=self.model_without_ddp,
optimizer=self.optimizer,
epoch=epoch + 1,
model_build_kwargs=self.model_build_kwargs,
extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT,
)
# Validation
if self.val_generator:
test_stats = evaluate(
self.cfg,
model=self.model,
model_call_func=self.model_call_func,
loss_function=self.loss,
metric_function=self.metric_calculation,
prepare_targets=self.prepare_targets,
epoch=epoch,
data_loader=self.val_generator,
lr_scheduler=self.lr_scheduler,
memory_bank=self.memory_bank,
)
# Save checkpoint is val loss improved
if test_stats["loss"] < self.val_best_loss:
f = os.path.join(
self.cfg.PATHS.CHECKPOINT,
"{}-checkpoint-best.pth".format(self.job_identifier),
)
print(
"Val loss improved from {} to {}, saving model to {}".format(
self.val_best_loss, test_stats["loss"], f
)
)
m = " "
for i in range(len(self.val_best_metric)):
self.val_best_metric[i] = test_stats[self.train_metric_names[i]]
m += f"{self.train_metric_names[i]}: {self.val_best_metric[i]:.4f} "
self.val_best_loss = test_stats["loss"]
if is_main_process():
self.checkpoint_path = save_model(
cfg=self.cfg,
biapy_version=biapy.__version__,
jobname=self.job_identifier,
model_without_ddp=self.model_without_ddp,
optimizer=self.optimizer,
epoch="best",
model_build_kwargs=self.model_build_kwargs,
extension=self.cfg.MODEL.OUT_CHECKPOINT_FORMAT,
)
print(f"[Val] best loss: {self.val_best_loss:.4f} best " + m)
# Store validation stats
if self.log_writer:
self.log_writer.update(test_loss=test_stats["loss"], head="perf", step=epoch)
for i in range(len(self.train_metric_names)):
self.log_writer.update(
test_iou=test_stats[self.train_metric_names[i]],
head="perf",
step=epoch,
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
}
else:
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
"epoch": epoch,
}
# Write statistics in the logging file
if is_main_process():
# Log epoch stats
if self.log_writer:
self.log_writer.flush()
with open(self.log_file, mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
# Create training plot
self.plot_values["loss"].append(train_stats["loss"])
if self.val_generator:
self.plot_values["val_loss"].append(test_stats["loss"])
for i in range(len(self.train_metric_names)):
self.plot_values[self.train_metric_names[i]].append(train_stats[self.train_metric_names[i]])
if self.val_generator:
self.plot_values["val_" + self.train_metric_names[i]].append(
test_stats[self.train_metric_names[i]]
)
if (epoch + 1) % self.cfg.LOG.CHART_CREATION_FREQ == 0:
create_plots(
self.plot_values,
self.train_metric_names,
self.job_identifier,
self.cfg.PATHS.CHARTS,
)
if self.val_generator and self.early_stopping:
self.early_stopping(test_stats["loss"])
if self.early_stopping.early_stop:
print("Early stopping")
break
e_end = time.time()
t_epoch = e_end - e_start
print(
"[Time] {} {}/{}\n".format(
time_text(t_epoch),
time_text(e_end - start_time),
time_text((e_end - start_time) + (t_epoch * (self.cfg.TRAIN.EPOCHS - epoch))),
)
)
total_time = time.time() - start_time
self.total_training_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time: {}".format(self.total_training_time_str))
self.train_metrics_message += ("Train loss: {}\n".format(train_stats["loss"]))
for i in range(len(self.train_metric_names)):
self.train_metrics_message += ("Train {}: {}\n".format(self.train_metric_names[i], train_stats[self.train_metric_names[i]]))
if self.val_generator:
self.train_metrics_message += ("Validation loss: {}\n".format(self.val_best_loss))
for i in range(len(self.train_metric_names)):
self.train_metrics_message += ("Validation {}: {}\n".format(self.train_metric_names[i], self.val_best_metric[i]))
if self.train_metrics_message != "":
for line in self.train_metrics_message.split("\n"):
print(line)
print("Finished Training")
if is_dist_avail_and_initialized():
print(f"[Rank {get_rank()} ({os.getpid()})] Process waiting (train finished, step 1) . . . ")
dist.barrier()
# Save output sample to export the model to BMZ
if "test_output" not in self.bmz_config:
assert self.model_without_ddp
self.model_without_ddp.eval()
# Load best checkpoint on validation to ensure it
_ = load_model_checkpoint(
cfg=self.cfg,
jobname=self.job_identifier,
model_without_ddp=self.model_without_ddp,
device=self.device,
skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS,
)
# Save BMZ input/output so the user could export the model to BMZ later
self.prepare_bmz_data(self.bmz_config["test_input"])
self.destroy_train_data()
[docs]
def load_test_data(self):
"""Load test data."""
print("######################")
print("# LOAD TEST DATA #")
print("######################")
self.X_test, self.Y_test = None, None
if self.cfg.DATA.TEST.USE_VAL_AS_TEST:
print("Loading train data information to extract the validation to be used as test")
self.cfg.merge_from_list(["DATA.TRAIN.IN_MEMORY", False, "DATA.VAL.IN_MEMORY", False])
self.load_train_data()
self.X_test = self.X_val.copy()
if self.Y_val:
self.Y_test = self.Y_val.copy()
else:
# Paths to the raw and gt within the Zarr file. Only used when 'DATA.TEST.INPUT_ZARR_MULTIPLE_DATA' is True.
test_zarr_data_information = None
if self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA:
use_gt_path = True
if self.cfg.PROBLEM.TYPE != "INSTANCE_SEG" and self.cfg.PROBLEM.INSTANCE_SEG.TYPE != "synapses":
use_gt_path = False
test_zarr_data_information = {
"raw_path": self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_RAW_PATH,
"gt_path": self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_GT_PATH,
"use_gt_path": use_gt_path,
}
(
self.X_test,
self.Y_test,
self.test_filenames,
) = load_and_prepare_test_data(
test_path=self.cfg.DATA.TEST.PATH,
test_mask_path=self.cfg.DATA.TEST.GT_PATH if self.use_gt else None,
multiple_raw_images=(
self.cfg.PROBLEM.TYPE == "IMAGE_TO_IMAGE"
and self.cfg.PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER
),
test_zarr_data_information=test_zarr_data_information,
)
[docs]
def destroy_test_data(self):
"""Delete test variable to release memory."""
print("Releasing memory . . .")
if "X_test" in locals() or "X_test" in globals():
del self.X_test
if "Y_test" in locals() or "Y_test" in globals():
del self.Y_test
if "test_generator" in locals() or "test_generator" in globals():
del self.test_generator
if "current_sample" in locals() or "current_sample" in globals():
del self.current_sample
[docs]
def prepare_test_generators(self):
"""Prepare test data generator."""
if self.cfg.TEST.ENABLE:
print("############################")
print("# PREPARE TEST GENERATOR #")
print("############################")
(
self.test_generator,
test_input,
cover_raw,
cover_gt,
) = create_test_generator(
cfg=self.cfg,
X_test=self.X_test,
Y_test=self.Y_test,
norm_module=self.test_norm_module,
)
# Save BMZ data if not available
if "cover_raw" not in self.bmz_config or "cover_gt" not in self.bmz_config:
self.bmz_config["cover_raw"] = cover_raw
self.bmz_config["cover_gt"] = cover_gt
if "test_input" not in self.bmz_config:
self.bmz_config["test_input"] = test_input
[docs]
def apply_model_activations(self, pred: torch.Tensor | Dict, training=False) -> torch.Tensor | Dict:
"""
Apply the last activation (if any) to the model's output.
Parameters
----------
pred : Torch Tensor
Predictions of the model.
training : bool, optional
To advice the function if this is being applied during training of inference. During training,
``ce_sigmoid`` activations will NOT be applied, as ``torch.nn.BCEWithLogitsLoss`` will apply
``Sigmoid`` automatically in a way that is more stable numerically
(`ref <https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html>`_).
Returns
-------
pred : Torch tensor
Resulting predictions after applying last activation(s).
"""
# Do not apply any activation when using masking as pretext task
if (
not self.apply_activations
or (self.cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and self.cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK.lower() == "masking")
):
return pred
# 1. Expand channel info to map 1-to-1 with every channel
all_channel_info = []
for i, c_info in enumerate(self.model_output_channel_info):
for _ in range(self.model_output_channels[i]):
# For semantic segmentation and classification problems, we consider that all the channels are "class" channels
if self.cfg.PROBLEM.TYPE in ["SEMANTIC_SEG", "CLASSIFICATION"]:
c_info = "class"
all_channel_info.append(c_info)
def __apply_acts(tensor, acts, c_infos):
if len(acts) == 0:
return tensor
out_slices = []
# Find the exact index where "class" channels begin
class_start_idx = len(c_infos)
for i, info in enumerate(c_infos):
if "class" in info.lower():
class_start_idx = i
break
# --- PART A: Process standard channels (1-by-1) ---
for i in range(min(class_start_idx, tensor.shape[1])):
chunk = tensor[:, i:i+1, ...]
act_str = acts[i].lower()
# Skip if linear, or if training and it's handled by the loss function
if act_str == "linear" or (training and act_str in ["ce_sigmoid", "ce_softmax"]):
out_slices.append(chunk)
else:
clean_act = "sigmoid" if act_str == "ce_sigmoid" else act_str
act_fn = get_activation(clean_act)
out_slices.append(act_fn(chunk))
# --- PART B: Process the entire "class" block (all at once) ---
if class_start_idx < tensor.shape[1]:
class_chunk = tensor[:, class_start_idx:, ...]
act_str = acts[class_start_idx].lower()
# Skip if linear, or if training and it's handled by the loss function
if act_str == "linear" or (training and act_str in ["ce_sigmoid", "ce_softmax"]):
out_slices.append(class_chunk)
else:
# Apply a single activation to the whole multidimensional block
clean_act = "softmax" if "softmax" in act_str else ("sigmoid" if act_str == "ce_sigmoid" else act_str)
act_fn = get_activation(clean_act)
out_slices.append(act_fn(class_chunk))
return torch.cat(out_slices, dim=1) if out_slices else tensor
# 2. Apply to inputs (handling both Dict and Tensor formats)
if isinstance(pred, dict):
pred_len = pred["pred"].shape[1]
pred["pred"] = __apply_acts(pred["pred"], self.head_activations[:pred_len], all_channel_info[:pred_len])
if "class" in pred:
class_len = pred["class"].shape[1]
pred["class"] = __apply_acts(pred["class"], self.head_activations[-class_len:], all_channel_info[-class_len:])
else:
pred = __apply_acts(pred, self.head_activations, all_channel_info)
return pred
[docs]
@torch.no_grad()
def test(self):
"""Test/Inference step."""
self.load_test_data()
if not self.model_prepared:
self.prepare_model()
self.prepare_test_generators()
# Switch to evaluation mode
assert self.model_without_ddp
self.model_without_ddp.eval()
# When not training was done
if "test_output" not in self.bmz_config:
# Save BMZ input/output so the user could export the model to BMZ later
self.prepare_bmz_data(self.bmz_config["test_input"])
# Load best checkpoint on validation
if self.cfg.TRAIN.ENABLE:
self.start_epoch, self.checkpoint_path = load_model_checkpoint(
cfg=self.cfg,
jobname=self.job_identifier,
model_without_ddp=self.model_without_ddp,
device=self.device,
skip_unmatched_layers=self.cfg.MODEL.SKIP_UNMATCHED_LAYERS,
)
# Check possible checkpoint problems
if self.start_epoch == -1:
raise ValueError("There was a problem loading the checkpoint. Test phase aborted!")
image_counter = 0
print("###############")
print("# INFERENCE #")
print("###############")
print("Making predictions on test data . . .")
# Reactivate prints to see each rank progress
if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == "3D":
setup_for_distributed(True)
# Process all the images
for i, self.current_sample in enumerate(self.test_generator): # type: ignore
self.current_sample_metrics = {"file": self.current_sample["X_filename"]}
self.f_numbers = [i]
if "Y" not in self.current_sample:
self.current_sample["Y"] = None
# Decide whether to infer by chunks or not
discarded = False
_, file_extension = os.path.splitext(self.current_sample["X_filename"])
if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == "3D":
by_chunks = True
if not looks_like_hdf5(self.current_sample["X_filename"]) and file_extension not in [".zarr", ".n5"]:
print(
"WARNING: You are not using an image format that can extract patches without loading it entirely into memory. "
"The image formats that support this are: '.hdf5', '.hdf', '.h5', '.zarr' and '.n5'. "
)
else:
by_chunks = False
if by_chunks:
me = "[Rank {} ({})] Processing image (by chunks): {}".format(
get_rank(), os.getpid(), self.current_sample["X_filename"]
)
if "Y_filename" in self.current_sample:
me += " (GT: {})".format(self.current_sample["Y_filename"])
print(me)
self.process_test_sample_by_chunks()
else:
if is_main_process():
if self.cfg.PROBLEM.TYPE != "CLASSIFICATION":
me = "Processing image: {}".format(self.current_sample["X_filename"])
if "Y_filename" in self.current_sample:
me += " (GT: {})".format(self.current_sample["Y_filename"])
print(me)
print("Normalization used: {}".format(self.current_sample["X_norm"]))
discarded = self.process_test_sample()
# If process_test_sample() returns True means that the sample was skipped due to filter set
# up with: DATA.TEST.FILTER_SAMPLES
if discarded:
print(" Skipping image: {}".format(self.current_sample["X_filename"]))
else:
image_counter += 1
self.metrics_per_test_file.append(self.current_sample_metrics)
# Only enable print for the main rank again
if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == "3D":
setup_for_distributed(is_main_process())
self.destroy_test_data()
if is_main_process():
self.after_all_images()
print("#############")
print("# RESULTS #")
print("#############")
print("The values below represent the averages across all test samples")
if self.cfg.TRAIN.ENABLE:
print("Epoch number: {}".format(len(self.plot_values["val_loss"])))
print("Train time (s): {}".format(self.total_training_time_str))
print("Train loss: {}".format(np.min(self.plot_values["loss"])))
for i in range(len(self.train_metric_names)):
metric_name = (
"Foreground IoU" if self.train_metric_names[i] == "IoU" else self.train_metric_names[i]
)
print(
"Train {}: {}".format(
metric_name,
(
np.max(self.plot_values[self.train_metric_names[i]])
if self.train_metric_best[i] == "max"
else np.min(self.plot_values[self.train_metric_names[i]])
),
)
)
print("Validation loss: {}".format(self.val_best_loss))
for i in range(len(self.train_metric_names)):
metric_name = (
"Foreground IoU" if self.train_metric_names[i] == "IoU" else self.train_metric_names[i]
)
print(
"Validation {}: {}".format(
metric_name,
self.val_best_metric[i],
)
)
self.print_stats(image_counter)
[docs]
def predict_batches_in_test(
self, x_batch: NDArray, y_batch: Optional[NDArray], stats_name="per_crop", disable_tqdm: bool = False
) -> NDArray:
"""
Predict data for the test phase.
Parameters
----------
x_batch : NDArray
X data. Expected axes are: ``(num_patches, z, y, x, channels)`` for 3D and
``(num_patches, y, x, channels)`` for 2D.
y_batch: NDArray
Y data. Expected axes are: ``(num_patches, z, y, x, channels)`` for 3D and
``(num_patches, y, x, channels)`` for 2D.
stats_name : str, optional
Name of the statistics to save.
disable_tqdm : bool, optional
Whether to disable tqdm or not.
Returns
-------
pred : NDArray
Predicted batch.
"""
if self.cfg.TEST.AUGMENTATION:
for k in tqdm(range(x_batch.shape[0]), leave=False, disable=disable_tqdm):
if self.cfg.PROBLEM.NDIM == "2D":
p = ensemble8_2d_predictions(
x_batch[k],
axes_order_back=self.axes_order_back,
axes_order=self.axes_order,
device=self.test_device,
pred_func=self.model_call_func,
mode=self.cfg.TEST.AUGMENTATION_MODE,
)
else:
p = ensemble16_3d_predictions(
x_batch[k],
batch_size_value=self.cfg.TRAIN.BATCH_SIZE,
axes_order_back=self.axes_order_back,
axes_order=self.axes_order,
device=self.test_device,
pred_func=self.model_call_func,
mode=self.cfg.TEST.AUGMENTATION_MODE,
)
# Multi-head concatenation
if isinstance(p, dict):
if "class" in p:
p = torch.cat((p["pred"], p["class"]), dim=1)
else:
p = p["pred"]
# Calculate the metrics
if y_batch is not None:
metric_values = self.metric_calculation(output=p, targets=np.expand_dims(y_batch[k],0), train=False)
for metric in metric_values:
if str(metric).lower() not in self.stats[stats_name]:
self.stats[stats_name][str(metric).lower()] = 0
self.stats[stats_name][str(metric).lower()] += metric_values[metric]
self.stats["patch_by_batch_counter"] += 1
p = to_numpy_format(p, self.axes_order_back)
if "pred" not in locals():
pred = np.zeros((x_batch.shape[0],) + p.shape[1:], dtype=self.dtype)
pred[k] = p
else:
l = int(math.ceil(x_batch.shape[0] / self.cfg.TRAIN.BATCH_SIZE))
for k in tqdm(range(l), leave=False, disable=disable_tqdm):
top = (
(k + 1) * self.cfg.TRAIN.BATCH_SIZE
if (k + 1) * self.cfg.TRAIN.BATCH_SIZE < x_batch.shape[0]
else x_batch.shape[0]
)
p = self.model_call_func(x_batch[k * self.cfg.TRAIN.BATCH_SIZE : top])
# Multi-head concatenation
if isinstance(p, dict):
if "class" in p:
p = torch.cat((p["pred"], p["class"]), dim=1)
else:
p = p["pred"]
# Calculate the metrics
if y_batch is not None:
metric_values = self.metric_calculation(
output=p,
targets=y_batch[k * self.cfg.TRAIN.BATCH_SIZE : top],
train=False,
)
for metric in metric_values:
if str(metric).lower() not in self.stats[stats_name]:
self.stats[stats_name][str(metric).lower()] = 0
self.stats[stats_name][str(metric).lower()] += metric_values[metric]
self.stats["patch_by_batch_counter"] += 1
p = to_numpy_format(p, self.axes_order_back)
if "pred" not in locals():
pred = np.zeros((x_batch.shape[0],) + p.shape[1:], dtype=self.dtype)
pred[k * self.cfg.TRAIN.BATCH_SIZE : top] = p
return pred
[docs]
def prepare_bmz_data(self, img):
"""
Prepare required data for exporting a model into BMZ.
Parameters
----------
img : 4D/5D Numpy array
Image to save (unnormalized). The axes must be in Torch format already, i.e. ``(b,c,y,x)`` for 2D or
``(b,c,z,y,x)`` for 3D.
"""
def _prepare_bmz_sample(sample_key, img, apply_norm=True):
"""
Prepare a sample from the given ``img`` using the patch size in the configuration.
It also saves the sample in ``self.bmz_config`` using the ``sample_key``.
Parameters
----------
sample_key : str
Key to store the sample into. Must be one between: ``["test_input", "test_output"]``
img : 4D/5D Numpy array
Image to extract the sample from. The axes must be in Torch format already, i.e.
``(b,c,y,x)`` for 2D or ``(b,c,z,y,x)`` for 3D.
"""
img = img.astype(np.float32)
if len(img.shape) == 2: # Classification
self.bmz_config[sample_key] = img.copy()
else:
self.bmz_config[sample_key] = img[0].copy()
# Ensure dimensions
if self.cfg.PROBLEM.NDIM == "2D":
if self.bmz_config[sample_key].ndim == 3:
self.bmz_config[sample_key] = np.expand_dims(self.bmz_config[sample_key], 0)
else: # 3D
if self.bmz_config[sample_key].ndim == 4:
self.bmz_config[sample_key] = np.expand_dims(self.bmz_config[sample_key], 0)
# Apply normalization
if apply_norm:
if self.cfg.PROBLEM.NDIM == "2D":
# Transpose to (b,y,x,c) for normalization and back to (b,c,y,x) after
self.bmz_config[sample_key], bmz_norm_used = normalize_image(self.bmz_config[sample_key].transpose(0, 2, 3, 1), self.norm_module)
self.bmz_config[sample_key] = self.bmz_config[sample_key].astype(np.float32).transpose(0, 3, 1, 2)
else:
# Transpose to (b,z,y,x,c) for normalization and back to (b,c,z,y,x) after
self.bmz_config[sample_key], bmz_norm_used = normalize_image(self.bmz_config[sample_key].transpose(0, 2, 3, 4, 1), self.norm_module)
self.bmz_config[sample_key] = self.bmz_config[sample_key].astype(np.float32).transpose(0, 4, 1, 2, 3)
print("Normalization used in when creating BMZ data: {}".format(bmz_norm_used))
# Save test_input without the normalization if not already saved
if "test_input" not in self.bmz_config:
_prepare_bmz_sample("test_input", img, apply_norm=False)
# Save test_input with the normalization
if "test_input_norm" not in self.bmz_config:
_prepare_bmz_sample("test_input_norm", img)
# Model prediction
assert self.model and self.model_without_ddp
assert isinstance(self.bmz_config["test_input_norm"], np.ndarray)
pred = self.model(torch.from_numpy(self.bmz_config["test_input_norm"]).to(self.device))
# MAE
if isinstance(pred, dict) and "mask" in pred:
pred = self.apply_model_activations(pred)
assert isinstance(pred, dict), "The model output should be a dictionary containing 'pred' and 'mask' for the MAE pretext task."
assert "pred" in pred and "mask" in pred, "The model output should contain 'pred' and 'mask' for the MAE pretext task."
mask = pred["mask"]
pred = pred["pred"]
pred, p_mask, _ = self.model_without_ddp.save_images(
torch.from_numpy(self.bmz_config["test_input_norm"]).to(self.device),
pred,
mask,
self.dtype,
) # type: ignore
# We call MAE again with "return_just_preds" that sets a seed in the random masking to ensure that the same mask is applied here
# and in the BMZ exported model. If not BMZ check will crash reproducing the output
pred = self.model(torch.from_numpy(self.bmz_config["test_input_norm"]).to(self.device), return_just_preds=True)
else:
pred = self.apply_model_activations(pred)
# Multi-head concatenation
if isinstance(pred, dict):
if "class" in pred:
pred = torch.cat((pred["pred"], torch.argmax(pred["class"], dim=1).unsqueeze(1)), dim=1)
else:
pred = pred["pred"]
# Save output
_prepare_bmz_sample("test_output", pred.clone().cpu().detach().numpy().astype(np.float32), apply_norm=False)
if "cover_gt" not in self.bmz_config or ("cover_gt" in self.bmz_config and self.bmz_config["cover_gt"] is None):
if self.bmz_config["test_output"].ndim == 1: # Classification
self.bmz_config["cover_gt"] = self.bmz_config["test_output"].copy()
elif self.cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and self.cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK.lower() == "masking":
self.bmz_config["cover_gt"] = p_mask[0].copy()
else:
self.bmz_config["cover_gt"] = self.bmz_config["test_output"].copy().transpose(0, *range(2, self.bmz_config["test_output"].ndim), 1)[0]
if self.cfg.DATA.N_CLASSES > 2 and self.cfg.PROBLEM.TYPE == "SEMANTIC_SEG":
self.bmz_config["cover_gt"] = np.expand_dims(np.argmax(self.bmz_config["cover_gt"], -1), -1)
[docs]
def process_test_sample(self):
"""Process a sample in the inference phase."""
# Skip processing image
if "discard" in self.current_sample and self.current_sample["discard"]:
return True
#################
### PER PATCH ###
#################
if not self.cfg.TEST.FULL_IMG or self.cfg.PROBLEM.NDIM == "3D":
if not self.cfg.TEST.REUSE_PREDICTIONS:
original_data_shape = self.current_sample["X"].shape
# Crop if necessary
if self.current_sample["X"].shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]:
# Copy X to be used later in full image
if self.cfg.PROBLEM.NDIM != "3D":
X_original = self.current_sample["X"].copy()
if (
self.current_sample["Y"] is not None
and self.current_sample["X"].shape[:-1] != self.current_sample["Y"].shape[:-1]
):
raise ValueError(
"Image {} ({}) and mask {} ({}) differ in shape (without considering the channels, i.e. last dimension). "
"Please check the images.".format(
self.current_sample["X"].shape, self.current_sample['X_filename'],
self.current_sample["Y"].shape, self.current_sample['Y_filename']
)
)
if self.cfg.PROBLEM.NDIM == "2D":
obj = crop_data_with_overlap(
self.current_sample["X"],
self.cfg.DATA.PATCH_SIZE,
data_mask=self.current_sample["Y"],
overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING,
verbose=self.cfg.TEST.VERBOSE,
)
if self.current_sample["Y"] is not None:
self.current_sample["X"], self.current_sample["Y"], _ = obj # type: ignore
else:
self.current_sample["X"], _ = obj # type: ignore
del obj
else:
if self.cfg.TEST.REDUCE_MEMORY:
self.current_sample["X"], _ = crop_3D_data_with_overlap( # type: ignore
self.current_sample["X"][0],
self.cfg.DATA.PATCH_SIZE,
overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING,
verbose=self.cfg.TEST.VERBOSE,
median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING,
)
if self.current_sample["Y"] is not None:
self.current_sample["Y"], _ = crop_3D_data_with_overlap( # type: ignore
self.current_sample["Y"][0],
self.cfg.DATA.PATCH_SIZE[:-1] + (self.current_sample["Y"].shape[-1],),
overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING,
verbose=self.cfg.TEST.VERBOSE,
median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING,
)
else:
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = self.current_sample["Y"][0]
obj = crop_3D_data_with_overlap(
self.current_sample["X"][0],
self.cfg.DATA.PATCH_SIZE,
data_mask=self.current_sample["Y"],
overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING,
verbose=self.cfg.TEST.VERBOSE,
median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING,
)
if self.current_sample["Y"] is not None:
self.current_sample["X"], self.current_sample["Y"], _ = obj # type: ignore
else:
self.current_sample["X"], _ = obj # type: ignore
del obj
pred = self.predict_batches_in_test(self.current_sample["X"], self.current_sample["Y"])
# Reconstruct the predictions
if original_data_shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]:
if self.cfg.PROBLEM.NDIM == "3D":
original_data_shape = original_data_shape[1:]
f_name = merge_data_with_overlap if self.cfg.PROBLEM.NDIM == "2D" else merge_3D_data_with_overlap
if self.cfg.TEST.REDUCE_MEMORY:
pred = f_name(
pred,
original_data_shape[:-1] + (pred.shape[-1],),
padding=self.cfg.DATA.TEST.PADDING,
overlap=self.cfg.DATA.TEST.OVERLAP,
verbose=self.cfg.TEST.VERBOSE,
)
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = f_name(
self.current_sample["Y"],
original_data_shape[:-1] + (self.current_sample["Y"].shape[-1],),
padding=self.cfg.DATA.TEST.PADDING,
overlap=self.cfg.DATA.TEST.OVERLAP,
verbose=self.cfg.TEST.VERBOSE,
)
else:
obj = f_name(
pred,
original_data_shape[:-1] + (pred.shape[-1],),
data_mask=self.current_sample["Y"],
padding=self.cfg.DATA.TEST.PADDING,
overlap=self.cfg.DATA.TEST.OVERLAP,
verbose=self.cfg.TEST.VERBOSE,
)
if self.current_sample["Y"] is not None:
pred, self.current_sample["Y"] = obj
else:
pred = obj
del obj
self.current_sample["X"] = f_name(
self.current_sample["X"],
original_data_shape[:-1] + (self.current_sample["X"].shape[-1],),
padding=self.cfg.DATA.TEST.PADDING,
overlap=self.cfg.DATA.TEST.OVERLAP,
verbose=self.cfg.TEST.VERBOSE,
)
assert isinstance(pred, np.ndarray)
if self.cfg.PROBLEM.NDIM != "3D":
self.current_sample["X"] = X_original.copy()
del X_original
else:
pred = np.expand_dims(pred, 0)
self.current_sample["X"] = np.expand_dims(self.current_sample["X"], 0)
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = np.expand_dims(self.current_sample["Y"], 0)
# Resize to original shape
if self.cfg.DATA.PREPROCESS.TEST and "rescaled_shape" in self.current_sample:
rescaled_shape = (1,) + self.current_sample["rescaled_shape"][:-1]+(pred.shape[-1],)
if self.cfg.TEST.VERBOSE:
print(
"Resizing prediction from {} to {}".format(
pred.shape, rescaled_shape
)
)
pred = resize_images(
[pred],
output_shape=rescaled_shape,
order=self.cfg.DATA.PREPROCESS.RESIZE.ORDER,
mode=self.cfg.DATA.PREPROCESS.RESIZE.MODE,
cval=self.cfg.DATA.PREPROCESS.RESIZE.CVAL,
clip=self.cfg.DATA.PREPROCESS.RESIZE.CLIP,
preserve_range=self.cfg.DATA.PREPROCESS.RESIZE.PRESERVE_RANGE,
anti_aliasing=self.cfg.DATA.PREPROCESS.RESIZE.ANTI_ALIASING,
)[0]
self.current_sample["X"] = resize_images(
[self.current_sample["X"]],
output_shape=rescaled_shape,
order=self.cfg.DATA.PREPROCESS.RESIZE.ORDER,
mode=self.cfg.DATA.PREPROCESS.RESIZE.MODE,
cval=self.cfg.DATA.PREPROCESS.RESIZE.CVAL,
clip=self.cfg.DATA.PREPROCESS.RESIZE.CLIP,
preserve_range=self.cfg.DATA.PREPROCESS.RESIZE.PRESERVE_RANGE,
anti_aliasing=self.cfg.DATA.PREPROCESS.RESIZE.ANTI_ALIASING,
)[0]
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = resize_images(
[self.current_sample["Y"]],
output_shape=self.current_sample["rescaled_shape"][:-1]+(self.current_sample["Y"].shape[-1],),
order=0,
mode=self.cfg.DATA.PREPROCESS.RESIZE.MODE,
cval=self.cfg.DATA.PREPROCESS.RESIZE.CVAL,
clip=self.cfg.DATA.PREPROCESS.RESIZE.CLIP,
preserve_range=self.cfg.DATA.PREPROCESS.RESIZE.PRESERVE_RANGE,
anti_aliasing=self.cfg.DATA.PREPROCESS.RESIZE.ANTI_ALIASING,
)[0]
if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE:
reflected_orig_shape = (1,) + self.current_sample["reflected_orig_shape"]
if reflected_orig_shape != pred.shape:
if self.cfg.TEST.VERBOSE:
print(
"Cropping prediction to original shape {}".format(
self.current_sample["reflected_orig_shape"]
)
)
if self.cfg.PROBLEM.NDIM == "2D":
pred = pred[:, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :]
self.current_sample["X"] = self.current_sample["X"][
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
]
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = self.current_sample["Y"][
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
]
else:
pred = pred[
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
-reflected_orig_shape[3] :,
]
self.current_sample["X"] = self.current_sample["X"][
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
-reflected_orig_shape[3] :,
]
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = self.current_sample["Y"][
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
-reflected_orig_shape[3] :,
]
# Apply mask
if self.cfg.TEST.POST_PROCESSING.APPLY_MASK:
pred = np.expand_dims(apply_binary_mask(pred[0], self.cfg.DATA.TEST.BINARY_MASKS), 0)
if self.separated_class_channel:
class_idx = self.model_output_channel_info.index("class") if "class" in self.model_output_channel_info else -1
pred = np.concatenate(
(
pred[...,:-self.model_output_channels[class_idx]],
np.expand_dims(np.argmax(pred[..., -self.model_output_channels[class_idx]:], axis=-1), axis=-1)
), axis=-1)
# Save image
if self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "" and self.cfg.TEST.SAVE_MODEL_RAW_OUTPUT:
save_tif(
pred,
self.cfg.PATHS.RESULT_DIR.PER_IMAGE,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
# Calculate the metrics
if self.current_sample["Y"] is not None:
metric_values = self.metric_calculation(output=pred, targets=self.current_sample["Y"], train=False)
for metric in metric_values:
if str(metric).lower() not in self.stats["merge_patches"]:
self.stats["merge_patches"][str(metric).lower()] = 0
self.stats["merge_patches"][str(metric).lower()] += metric_values[metric]
self.current_sample_metrics[str(metric).lower()] = metric_values[metric]
############################
### POST-PROCESSING (3D) ###
############################
if self.post_processing["per_image"]:
pred = apply_post_processing(self.cfg, pred)
# Calculate the metrics
if self.current_sample["Y"] is not None:
metric_values = self.metric_calculation(
output=pred, targets=self.current_sample["Y"], train=False
)
for metric in metric_values:
if str(metric).lower() not in self.stats["merge_patches_post"]:
self.stats["merge_patches_post"][str(metric).lower()] = 0
self.stats["merge_patches_post"][str(metric).lower()] += metric_values[metric]
self.current_sample_metrics[str(metric).lower() + " (post-processing)"] = metric_values[metric]
save_tif(
pred,
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
else:
# Load prediction from file
folder = (
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING
if self.post_processing["per_image"]
else self.cfg.PATHS.RESULT_DIR.PER_IMAGE
)
# read file created by 'save_tif' (it always has .tif extension)
test_file = os.path.join(folder, os.path.splitext(self.current_sample["X_filename"])[0]+'.tif')
pred = read_img_as_ndarray(test_file, is_3d=self.cfg.PROBLEM.NDIM == "3D")
pred = np.expand_dims(pred, 0) # expand dimensions to include "batch"
# Calculate the metrics
if self.current_sample["Y"] is not None:
metric_values = self.metric_calculation(output=pred, targets=self.current_sample["Y"], train=False)
for metric in metric_values:
if str(metric).lower() not in self.stats["merge_patches"]:
self.stats["merge_patches"][str(metric).lower()] = 0
self.stats["merge_patches"][str(metric).lower()] += metric_values[metric]
self.current_sample_metrics[str(metric).lower()] = metric_values[metric]
self.after_merge_patches(pred)
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list)
self.all_pred.append(pred)
if self.current_sample["Y"] is not None and self.all_gt is not None:
self.all_gt.append(self.current_sample["Y"])
##################
### FULL IMAGE ###
##################
if self.cfg.TEST.FULL_IMG and self.cfg.PROBLEM.NDIM == "2D":
self.current_sample["X"], o_test_shape = check_downsample_division(
self.current_sample["X"], len(self.cfg.MODEL.FEATURE_MAPS) - 1
)
if not self.cfg.TEST.REUSE_PREDICTIONS:
if self.current_sample["Y"] is not None:
self.current_sample["Y"], _ = check_downsample_division(
self.current_sample["Y"], len(self.cfg.MODEL.FEATURE_MAPS) - 1
)
# Make the prediction
if self.cfg.TEST.AUGMENTATION:
pred = ensemble8_2d_predictions(
self.current_sample["X"][0],
axes_order_back=self.axes_order_back,
pred_func=self.model_call_func,
axes_order=self.axes_order,
device=self.test_device,
mode=self.cfg.TEST.AUGMENTATION_MODE,
)
else:
pred = self.model_call_func(self.current_sample["X"])
# Multi-head concatenation
if isinstance(pred, dict):
if "class" in pred:
pred = torch.cat((pred["pred"], torch.argmax(pred["class"], dim=1).unsqueeze(1)), dim=1)
else:
pred = pred["pred"]
pred = to_numpy_format(pred, self.axes_order_back)
del self.current_sample["X"]
# Recover original shape if padded with check_downsample_division
pred = pred[:, : o_test_shape[1], : o_test_shape[2]]
if self.current_sample["Y"] is not None:
self.current_sample["Y"] = self.current_sample["Y"][:, : o_test_shape[1], : o_test_shape[2]]
# Save image
save_tif(
pred,
self.cfg.PATHS.RESULT_DIR.FULL_IMAGE,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
if self.cfg.TEST.POST_PROCESSING.APPLY_MASK:
pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS)
else:
# load prediction from file
# read file created by 'save_tif' (it always has .tif extension)
test_file = os.path.join(self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, os.path.splitext(self.current_sample["X_filename"])[0]+'.tif')
pred = read_img_as_ndarray(test_file, is_3d=self.cfg.PROBLEM.NDIM == "3D")
pred = np.expand_dims(pred, 0) # expand dimensions to include "batch"
# Calculate the metrics
if self.current_sample["Y"] is not None:
metric_values = self.metric_calculation(output=pred, targets=self.current_sample["Y"], train=False)
for metric in metric_values:
if str(metric).lower() not in self.stats["full_image"]:
self.stats["full_image"][str(metric).lower()] = 0
self.stats["full_image"][str(metric).lower()] += metric_values[metric]
self.current_sample_metrics[str(metric).lower()] = metric_values[metric]
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list)
self.all_pred.append(pred)
if self.current_sample["Y"] is not None and self.all_gt is not None:
self.all_gt.append(self.current_sample["Y"])
self.after_full_image(pred)
[docs]
def normalize_stats(self, image_counter):
"""
Normalize statistics.
Parameters
----------
image_counter : int
Number of images to average the metrics.
"""
# Per crop
for metric in self.stats["per_crop"]:
self.stats["per_crop"][metric] = (
self.stats["per_crop"][metric] / self.stats["patch_by_batch_counter"]
if self.stats["patch_by_batch_counter"] != 0
else 0
)
# Merge patches
for metric in self.stats["merge_patches"]:
self.stats["merge_patches"][metric] = (
self.stats["merge_patches"][metric] / image_counter if image_counter != 0 else 0
)
# Full image
for metric in self.stats["full_image"]:
self.stats["full_image"][metric] = (
self.stats["full_image"][metric] / image_counter if image_counter != 0 else 0
)
if self.post_processing["per_image"]:
for metric in self.stats["merge_patches_post"]:
self.stats["merge_patches_post"][metric] = (
self.stats["merge_patches_post"][metric] / image_counter if image_counter != 0 else 0
)
[docs]
def print_stats(self, image_counter):
"""
Print statistics.
Parameters
----------
image_counter : int
Number of images to call ``normalize_stats``.
"""
self.normalize_stats(image_counter)
if self.cfg.DATA.TEST.LOAD_GT:
if not self.cfg.TEST.FULL_IMG or (len(self.stats["per_crop"]) > 0 or len(self.stats["merge_patches"]) > 0):
if len(self.stats["per_crop"]) > 0:
for metric in self.test_metric_names:
if metric.lower() in self.stats["per_crop"]:
metric_name = "Foreground IoU" if metric == "IoU" else metric
self.test_metrics_message += (
"Test {} (per patch): {}\n".format(
metric_name,
self.stats["per_crop"][metric.lower()],
)
)
if len(self.stats["merge_patches"]) > 0:
for metric in self.test_metric_names:
if metric.lower() in self.stats["merge_patches"]:
metric_name = "Foreground IoU" if metric == "IoU" else metric
self.test_metrics_message += (
"Test {} (merge patches): {}\n".format(
metric_name,
self.stats["merge_patches"][metric.lower()],
)
)
else:
if len(self.stats["full_image"]) > 0:
for metric in self.test_metric_names:
if metric.lower() in self.stats["full_image"]:
metric_name = "Foreground IoU" if metric == "IoU" else metric
self.test_metrics_message += (
"Test {} (per image): {}\n".format(
metric_name,
self.stats["full_image"][metric.lower()],
)
)
if self.post_processing["per_image"] and len(self.stats["merge_patches_post"]) > 0:
for metric in self.test_metric_names:
if metric.lower() in self.stats["merge_patches_post"]:
metric_name = "Foreground IoU" if metric == "IoU" else metric
self.test_metrics_message += (
"Test {} (merge patches - post-processing): {}\n".format(
metric_name,
self.stats["merge_patches_post"][metric.lower()],
)
)
if self.post_processing["as_3D_stack"] and len(self.stats["as_3D_stack_post"]) > 0:
for metric in self.test_metric_names:
if metric.lower() in self.stats["as_3D_stack_post"]:
metric_name = "Foreground IoU" if metric == "IoU" else metric
self.test_metrics_message += (
"Test {} (as 3D stack - post-processing): {}\n".format(
metric_name,
self.stats["as_3D_stack_post"][metric.lower()],
)
)
if self.test_metrics_message != "":
for line in self.test_metrics_message.split("\n"):
print(line)
df_metrics = pd.DataFrame(self.metrics_per_test_file)
os.makedirs(self.cfg.PATHS.RESULT_DIR.PATH, exist_ok=True)
df_metrics.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.PATH,
"test_results_metrics.csv",
),
index=False,
)
[docs]
@abstractmethod
def after_merge_patches(self, pred):
"""
Place any code that needs to be done after merging all predicted patches into the original image.
Parameters
----------
pred : Torch Tensor
Model prediction.
"""
raise NotImplementedError
[docs]
@abstractmethod
def after_full_image(self, pred: NDArray):
"""
Place here any code that must be executed after generating the prediction by supplying the entire image to the model.
To enable this, the model should be convolutional, and the image(s) should be in a 2D format. Using 3D images as
direct inputs to the model is not feasible due to their large size.
Parameters
----------
pred : NDArray
Model prediction.
"""
raise NotImplementedError
[docs]
def after_all_images(self):
"""Place here any code that must be done after predicting all images."""
############################
### POST-PROCESSING (2D) ###
############################
if self.post_processing["as_3D_stack"]:
self.all_pred = np.expand_dims(np.concatenate(self.all_pred), 0)
if self.cfg.DATA.TEST.LOAD_GT and self.all_gt is not None:
self.all_gt = np.expand_dims(np.concatenate(self.all_gt), 0)
save_tif(
self.all_pred,
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK,
verbose=self.cfg.TEST.VERBOSE,
)
save_tif(
(self.all_pred > 0.5).astype(np.uint8),
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_BIN,
verbose=self.cfg.TEST.VERBOSE,
)
self.all_pred = apply_post_processing(self.cfg, self.all_pred)
# Calculate the metrics
if self.cfg.DATA.TEST.LOAD_GT:
metric_values = self.metric_calculation(output=self.all_pred[0], targets=self.all_gt[0], train=False)
for metric in metric_values:
self.stats["as_3D_stack_post"][str(metric).lower()] = metric_values[metric]
self.current_sample_metrics[str(metric).lower() + " as 3D stack (post-processing)"] = metric_values[metric]
save_tif(
self.all_pred,
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_POST_PROCESSING,
verbose=self.cfg.TEST.VERBOSE,
)
#########################
### BY CHUNKS METHODS ###
#########################
# The order of the execution of the "by chunks" methods is the following:
# * 'process_test_sample_by_chunks': process a sample in the inference phase in "by chunks" setting, this is the main method
# that calls the other three methods below
# 1. 'after_one_chunk_raw_prediction': after predicting one chunk
# Once the predictions for all the chunks are generated each workflow will process the generated zarr in an specific way. The
# process is the following:
# 2. 'after_all_chunk_prediction_workflow_process': process to be done after predicting all the chunks in all ranks
# 2.1 'after_one_chunk_workflow_process': process a list of chunks
# 3. 'after_all_chunk_prediction_workflow_process_master_rank': process to be done after predicting all the chunks
# but only on the master rank.
#
[docs]
def process_test_sample_by_chunks(self):
"""
Process a sample in the inference phase.
A final H5/Zarr file is created in "TZCYX" or "TZYXC" order
depending on ``DATA.TEST.INPUT_IMG_AXES_ORDER`` ('T' is always included).
"""
if not self.cfg.TEST.REUSE_PREDICTIONS and "prediction" in self.cfg.TEST.BY_CHUNKS.PHASES:
# Create the generator
self.test_generator = create_chunked_test_generator(
self.cfg,
system_dict=self.system_dict,
current_sample=self.current_sample,
norm_module=self.norm_module,
out_dir=self.cfg.PATHS.RESULT_DIR.PER_IMAGE,
dtype_str=self.dtype_str,
)
tgen: chunked_test_pair_data_generator = self.test_generator.dataset # type: ignore
# Get parallel data shape is ZYX
_, z_dim, _, y_dim, x_dim = order_dimensions(
tgen.X_parallel_data.shape, self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER
)
self.parallel_data_shape = [z_dim, y_dim, x_dim]
samples_visited = {}
for obj_list in self.test_generator:
sampler_ids, img, mask, patch_in_data, added_pad, norm_extra_info = obj_list
if self.cfg.TEST.VERBOSE:
print(
"[Rank {} ({})] Patch number {} processing patches {} from {}".format(
get_rank(), os.getpid(), sampler_ids, patch_in_data, self.parallel_data_shape
)
)
# Pass the batch through the model
pred = self.predict_batches_in_test(img, mask, disable_tqdm=True)
if self.separated_class_channel:
class_idx = self.model_output_channel_info.index("class") if "class" in self.model_output_channel_info else -1
pred = np.concatenate(
(
pred[...,:-self.model_output_channels[class_idx]],
np.expand_dims(np.argmax(pred[..., -self.model_output_channels[class_idx]:], axis=-1), axis=-1)
), axis=-1)
lbreaked = False
for i in range(pred.shape[0]):
# Break the loop as those samples were created just to complete the last batch
if sampler_ids[i] < sampler_ids[0] or sampler_ids[i] in samples_visited:
print(
"[Rank {} ({})] Patch {} discarded".format(
get_rank(),
os.getpid(),
sampler_ids[i],
)
)
lbreaked = True
break
single_pred = pred[i]
single_pred_pad = added_pad[i]
single_patch_in_data = patch_in_data[i]
self.after_one_chunk_raw_prediction(
sampler_ids[i], single_pred, single_patch_in_data, single_pred_pad
)
# Remove padding if added
single_pred = single_pred[
single_pred_pad[0][0] : single_pred.shape[0] - single_pred_pad[0][1],
single_pred_pad[1][0] : single_pred.shape[1] - single_pred_pad[1][1],
single_pred_pad[2][0] : single_pred.shape[2] - single_pred_pad[2][1],
]
# Insert into the data
tgen.insert_patch_in_file(single_pred, single_patch_in_data)
samples_visited[sampler_ids[i]] = True
if lbreaked and sampler_ids[i] in samples_visited:
print(
"[Rank {} ({})] Finishing the loop. Seems that the patches are starting to repeat".format(
get_rank(),
os.getpid(),
)
)
break
# Wait until all threads are done so the main thread can create the full size image
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
if self.cfg.TEST.VERBOSE:
print(
f"[Rank {get_rank()} ({os.getpid()})] Finished predicting patches. Waiting for all ranks . . ."
)
dist.barrier()
tgen.close_open_files()
# Only after everyone finished writing, optionally convert to TIF on rank0
if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF:
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
dist.barrier()
if is_main_process():
tgen.save_parallel_data_as_tif()
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
dist.barrier()
self.after_all_chunk_prediction_workflow_process()
if is_main_process():
self.after_all_chunk_prediction_workflow_process_master_rank()
# Wait until all threads are done so the main thread can create the full size image
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
if self.cfg.TEST.VERBOSE:
print(f"[Rank {get_rank()} ({os.getpid()})] Finished predicting sample. Waiting for all ranks . . .")
dist.barrier()
[docs]
def after_one_chunk_raw_prediction(
self, chunk_id: int, chunk: NDArray, chunk_in_data: PatchCoords, added_pad: List[List[int]]
):
"""
Place any code that needs to be done after predicting one chunk of data in "by chunks" setting.
Parameters
----------
chunk_id: int
Chunk identifier.
chunk : NDArray
Predicted chunk
patch_in_data : PatchCoords
Global coordinates of the chunk.
added_pad: List of list of ints
Padding added to the chunk in each dimension. The order of dimensions is the same as the input
image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], ...].
"""
raise NotImplementedError
[docs]
def after_one_chunk_workflow_process(self, chunks: List[NDArray], patch_in_data: List) -> Optional[List[NDArray]]:
"""
Process a list of chunks during inference in "by chunks" setting. Each workflow should have
its own implementation of this method.
Parameters
----------
chunks : List[NDArray]
List of chunks. Expected axes are: ``(z, y, x, channels)`` for 3D and
``(y, x, channels)`` for 2D.
patch_in_data : List[PatchCoords]
Spatial coordinates of each chunk in the full volume.
Returns
-------
chunks : Optional[List[NDArray]]
Processed chunks.
"""
raise NotImplementedError
[docs]
def after_all_chunk_prediction_workflow_process(self):
"""
Place any code that needs to be done after predicting all patches in "by chunks" setting.
This function is called on all ranks.
"""
print("Processing generated predictions . . .")
# Create the generator
fpath = os.path.join(
self.cfg.PATHS.RESULT_DIR.PER_IMAGE, os.path.splitext(self.current_sample["X_filename"])[0] + ".zarr"
)
self.test_generator = create_chunked_workflow_process_generator(
self.cfg,
system_dict=self.system_dict,
model_predictions=fpath,
out_dir=self.test_chunked_workflow_process_vars["out_dir"],
dtype_str=self.test_chunked_workflow_process_vars["dtype_str"],
)
tgen: chunked_workflow_process_generator = self.test_generator.dataset # type: ignore
# Get parallel data shape is ZYX
_, z_dim, _, y_dim, x_dim = order_dimensions(
tgen.X_parallel_data.shape, self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER
)
self.parallel_data_shape = [z_dim, y_dim, x_dim]
samples_visited = {}
for obj_list in self.test_generator:
sampler_ids, chunks, patch_in_data = obj_list
if self.cfg.TEST.VERBOSE:
print(
"[Rank {} ({})] Patch number {} processing patches {} from {}".format(
get_rank(), os.getpid(), sampler_ids, patch_in_data, self.parallel_data_shape
)
)
processed_chunks = self.after_one_chunk_workflow_process(chunks, patch_in_data)
assert processed_chunks is not None, "The after_one_chunk_workflow_process() method must return a value."
lbreaked = False
for i in range(len(processed_chunks)):
# Break the loop as those samples were created just to complete the last batch
if sampler_ids[i] < sampler_ids[0] or sampler_ids[i] in samples_visited:
print(
"[Rank {} ({})] Patch {} discarded".format(
get_rank(),
os.getpid(),
sampler_ids[i],
)
)
lbreaked = True
break
tgen.insert_patch_in_file(processed_chunks[i], patch_in_data[i])
samples_visited[sampler_ids[i]] = True
if lbreaked and sampler_ids[i] in samples_visited:
print(
"[Rank {} ({})] Finishing the loop. Seems that the patches are starting to repeat".format(
get_rank(),
os.getpid(),
)
)
break
# Wait until all threads are done so the main thread can create the full size image
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
if self.cfg.TEST.VERBOSE:
print(
f"[Rank {get_rank()} ({os.getpid()})] Finished predicting patches. Waiting for all ranks . . ."
)
dist.barrier()
tgen.close_open_files()
# Only after everyone finished writing, optionally convert to TIF on rank0
if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF:
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
dist.barrier()
if is_main_process():
tgen.save_parallel_data_as_tif()
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
dist.barrier()
[docs]
def after_all_chunk_prediction_workflow_process_master_rank(self):
"""
Place any code that needs to be done after predicting all the patches in the "by chunks" setting.
This function is only called on the master rank.
"""
raise NotImplementedError