import math
import os
import datetime
import time
import json
import torch
import h5py
import zarr
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from sklearn.model_selection import StratifiedKFold
import torch.multiprocessing as mp
import torch.distributed as dist
from biapy.models import build_model, build_torchvision_model
from biapy.engine import prepare_optimizer, build_callbacks
from biapy.data.generators import create_train_val_augmentors, create_test_augmentor, check_generator_consistence
from biapy.utils.misc import (get_world_size, get_rank, is_main_process, save_model, time_text, load_model_checkpoint, TensorboardLogger,
to_pytorch_format, to_numpy_format, is_dist_avail_and_initialized, setup_for_distributed)
from biapy.utils.util import (load_data_from_dir, load_3d_images_from_dir, create_plots, pad_and_reflect, save_tif, check_downsample_division,
read_chunked_data, order_dimensions)
from biapy.engine.train_engine import train_one_epoch, evaluate
from biapy.data.data_2D_manipulation import crop_data_with_overlap, merge_data_with_overlap, load_and_prepare_2D_train_data
from biapy.data.data_3D_manipulation import (crop_3D_data_with_overlap, merge_3D_data_with_overlap, load_and_prepare_3D_data,
load_and_prepare_3D_efficient_format_data, load_3D_efficient_files, extract_3D_patch_with_overlap_yield)
from biapy.data.post_processing.post_processing import ensemble8_2d_predictions, ensemble16_3d_predictions, apply_binary_mask
from biapy.engine.metrics import jaccard_index_numpy, voc_calculation
from biapy.data.post_processing import apply_post_processing
from biapy.data.pre_processing import preprocess_data
[docs]class Base_Workflow(metaclass=ABCMeta):
"""
Base workflow class. A new workflow should extend this class.
Parameters
----------
cfg : YACS configuration
Running configuration.
Job_identifier : str
Complete name of the running job.
device : Torch device
Device used.
args : argpase class
Arguments used in BiaPy's call.
"""
def __init__(self, cfg, job_identifier, device, args):
self.cfg = cfg
self.args = args
self.job_identifier = job_identifier
self.device = device
self.original_test_path = None
self.original_test_mask_path = None
self.test_mask_filenames = None
self.cross_val_samples_ids = None
self.post_processing = {}
self.post_processing['per_image'] = False
self.post_processing['as_3D_stack'] = False
self.test_filenames = None
self.metrics = []
self.data_norm = None
self.model = None
self.optimizer = None
self.loss_scaler = None
self.model_prepared = False
self.dtype = np.float32 if not self.cfg.TEST.REDUCE_MEMORY else np.float16
self.dtype_str = "float32" if not self.cfg.TEST.REDUCE_MEMORY else "float16"
self.loss_dtype = torch.float32
# Save paths in case we need them in a future
self.orig_train_path = self.cfg.DATA.TRAIN.PATH
self.orig_train_mask_path = self.cfg.DATA.TRAIN.GT_PATH
self.orig_val_path = self.cfg.DATA.VAL.PATH
self.orig_val_mask_path = self.cfg.DATA.VAL.GT_PATH
self.all_pred = []
self.all_gt = []
self.stats = {}
# Per crop
self.stats['loss_per_crop'] = 0
self.stats['iou_per_crop'] = 0
self.stats['patch_counter'] = 0
# Merging the image
self.stats['iou_merge_patches'] = 0
self.stats['ov_iou_merge_patches'] = 0
self.stats['iou_merge_patches_post'] = 0
self.stats['ov_iou_merge_patches_post'] = 0
# As 3D stack
self.stats['iou_as_3D_stack_post'] = 0
self.stats['ov_iou_as_3D_stack_post'] = 0
# Full image
self.stats['loss'] = 0
self.stats['iou'] = 0
self.stats['ov_iou'] = 0
self.world_size = get_world_size()
self.global_rank = get_rank()
if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == '3D':
maxsize = min(10,self.cfg.SYSTEM.NUM_GPUS*10)
self.output_queue = mp.Queue(maxsize=maxsize)
self.input_queue = mp.Queue(maxsize=maxsize)
self.extract_info_queue = mp.Queue()
# Test variables
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK and self.cfg.PROBLEM.NDIM == "2D":
if self.cfg.TEST.POST_PROCESSING.YZ_FILTERING or self.cfg.TEST.POST_PROCESSING.Z_FILTERING:
self.post_processing['as_3D_stack'] = True
elif self.cfg.PROBLEM.NDIM == "3D":
if self.cfg.TEST.POST_PROCESSING.YZ_FILTERING or self.cfg.TEST.POST_PROCESSING.Z_FILTERING:
self.post_processing['per_image'] = True
# Define permute shapes to pass from Numpy axis order (Y,X,C) to Pytorch's (C,Y,X)
self.axis_order = (0,3,1,2) if self.cfg.PROBLEM.NDIM == "2D" else (0,4,1,2,3)
self.axis_order_back = (0,2,3,1) if self.cfg.PROBLEM.NDIM == "2D" else (0,2,3,4,1)
# Define metrics
self.define_metrics()
# Load Bioimage model Zoo pretrained model information
self.torchvision_preprocessing = None
self.bmz_test_input = None
self.bmz_test_output = None
self.bmz_model_resource = None
if self.cfg.MODEL.SOURCE == "bmz":
import bioimageio.core
import xarray as xr
print("Loading Bioimage Model Zoo pretrained model . . .")
self.bmz_model_resource = bioimageio.core.load_resource_description(self.cfg.MODEL.BMZ.SOURCE_MODEL_DOI)
# Change PATCH_SIZE with the one stored in the RDF
input_image = np.load(self.bmz_model_resource.test_inputs[0])
opts = ["DATA.PATCH_SIZE", input_image.shape[2:]+(input_image.shape[1],)]
print("[BMZ] Changed 'DATA.PATCH_SIZE' from {} to {} as defined in the RDF"
.format(self.cfg.DATA.PATCH_SIZE,opts[1]))
self.cfg.merge_from_list(opts)
[docs] @abstractmethod
def define_metrics(self):
"""
This function must define the following variables:
self.metrics : List of functions
Metrics to be calculated during model's training and inference.
self.metric_names : List of str
Names of the metrics calculated.
self.loss : Function
Loss function used during training.
"""
NotImplementedError
[docs] @abstractmethod
def metric_calculation(self, output, targets, metric_logger=None):
"""
Execution of the metrics defined in :func:`~define_metrics` function.
Parameters
----------
output : Torch Tensor
Prediction of the model.
targets : Torch Tensor
Ground truth to compare the prediction with.
metric_logger : MetricLogger, optional
Class to be updated with the new metric(s) value(s) calculated.
Returns
-------
value : float
Value of the metric for the given prediction.
"""
NotImplementedError
[docs] def prepare_targets(self, targets, batch):
"""
Location to perform any necessary data transformations to ``targets``
before calculating the loss.
Parameters
----------
targets : Torch Tensor
Ground truth to compare the prediction with.
batch : Torch Tensor
Prediction of the model. Only used in SSL workflow.
Returns
-------
targets : Torch tensor
Resulting targets.
"""
# We do not use 'batch' input but in SSL workflow
return to_pytorch_format(targets, self.axis_order, self.device)
[docs] def load_train_data(self):
"""
Load training and validation data.
"""
if self.cfg.TRAIN.ENABLE:
print("##########################")
print("# LOAD TRAINING DATA #")
print("##########################")
self.X_val, self.Y_val = None, None
if self.cfg.DATA.TRAIN.IN_MEMORY:
val_split = self.cfg.DATA.VAL.SPLIT_TRAIN if self.cfg.DATA.VAL.FROM_TRAIN else 0.
f_name = load_and_prepare_2D_train_data if self.cfg.PROBLEM.NDIM == '2D' else load_and_prepare_3D_data
preprocess_cfg = self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.TRAIN else None
preprocess_fn = preprocess_data if self.cfg.DATA.PREPROCESS.TRAIN else None
is_y_mask = self.cfg.PROBLEM.TYPE in ['SEMANTIC_SEG', 'INSTANCE_SEG']
objs = f_name(self.cfg.DATA.TRAIN.PATH, self.mask_path, cross_val=self.cfg.DATA.VAL.CROSS_VAL,
cross_val_nsplits=self.cfg.DATA.VAL.CROSS_VAL_NFOLD, cross_val_fold=self.cfg.DATA.VAL.CROSS_VAL_FOLD,
val_split=val_split, seed=self.cfg.SYSTEM.SEED, shuffle_val=self.cfg.DATA.VAL.RANDOM,
random_crops_in_DA=self.cfg.DATA.EXTRACT_RANDOM_PATCH, crop_shape=self.cfg.DATA.PATCH_SIZE,
y_upscaling=self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING, ov=self.cfg.DATA.TRAIN.OVERLAP,
padding=self.cfg.DATA.TRAIN.PADDING, minimum_foreground_perc=self.cfg.DATA.TRAIN.MINIMUM_FOREGROUND_PER,
reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE, convert_to_rgb=self.cfg.DATA.FORCE_RGB,
preprocess_cfg = preprocess_cfg, is_y_mask = is_y_mask, preprocess_f=preprocess_fn)
if self.cfg.DATA.VAL.FROM_TRAIN:
if self.cfg.DATA.VAL.CROSS_VAL:
self.X_train, self.Y_train, self.X_val, self.Y_val, self.train_filenames, self.cross_val_samples_ids = objs
else:
self.X_train, self.Y_train, self.X_val, self.Y_val, self.train_filenames = objs
else:
self.X_train, self.Y_train, self.train_filenames = objs
del objs
else:
# Checking if the user inputted Zarr/H5 files
zarr_files = sorted(next(os.walk(self.cfg.DATA.TRAIN.PATH))[1])
h5_files = sorted(next(os.walk(self.cfg.DATA.TRAIN.PATH))[2])
if self.cfg.PROBLEM.NDIM == '3D' and (len(zarr_files) > 0 and '.zarr' in zarr_files[0]) or \
(len(h5_files) > 0 and '.h5' in h5_files[0]):
val_split = self.cfg.DATA.VAL.SPLIT_TRAIN if self.cfg.DATA.VAL.FROM_TRAIN else 0.
if len(zarr_files) > 0 and '.zarr' in zarr_files[0]:
print("Working with Zarr files . . .")
img_files = [os.path.join(self.cfg.DATA.TRAIN.PATH, x) for x in zarr_files]
mask_files = [os.path.join(self.mask_path, x) for x in sorted(next(os.walk(self.mask_path))[1])]
elif len(h5_files) > 0 and '.h5' in h5_files[0]:
print("Working with H5 files . . .")
img_files = [os.path.join(self.cfg.DATA.TRAIN.PATH, x) for x in h5_files]
mask_files = [os.path.join(self.mask_path, x) for x in sorted(next(os.walk(self.mask_path))[2])]
del zarr_files, h5_files
if self.cfg.DATA.EXTRACT_RANDOM_PATCH:
print("WARNING: 'DATA.EXTRACT_RANDOM_PATCH' not taken into account when working with Zarr/H5 images")
if self.cfg.DATA.FORCE_RGB:
print("WARNING: 'DATA.FORCE_RGB' not taken into account when working with Zarr/H5 images")
objs = load_and_prepare_3D_efficient_format_data(
img_files, mask_files, input_img_axes=self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER,
input_mask_axes=self.cfg.DATA.TRAIN.INPUT_MASK_AXES_ORDER,
cross_val=self.cfg.DATA.VAL.CROSS_VAL, cross_val_nsplits=self.cfg.DATA.VAL.CROSS_VAL_NFOLD,
cross_val_fold=self.cfg.DATA.VAL.CROSS_VAL_FOLD, val_split=val_split, seed=self.cfg.SYSTEM.SEED,
shuffle_val=self.cfg.DATA.VAL.RANDOM, crop_shape=self.cfg.DATA.PATCH_SIZE,
y_upscaling=self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING,
ov=self.cfg.DATA.TRAIN.OVERLAP, padding=self.cfg.DATA.TRAIN.PADDING,
minimum_foreground_perc=self.cfg.DATA.TRAIN.MINIMUM_FOREGROUND_PER)
if self.cfg.DATA.VAL.FROM_TRAIN:
if self.cfg.DATA.VAL.CROSS_VAL:
self.X_train, self.Y_train, self.X_val, self.Y_val, self.cross_val_samples_ids = objs
else:
self.X_train, self.Y_train, self.X_val, self.Y_val = objs
else:
self.X_train, self.Y_train = objs
del objs
else:
self.X_train, self.Y_train = None, None
##################
### VALIDATION ###
##################
if not self.cfg.DATA.VAL.FROM_TRAIN:
if self.cfg.DATA.VAL.IN_MEMORY:
print("1) Loading validation images . . .")
f_name = load_data_from_dir if self.cfg.PROBLEM.NDIM == '2D' else load_3d_images_from_dir
preprocess_cfg = self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.VAL else None
preprocess_fn = preprocess_data if self.cfg.DATA.PREPROCESS.VAL else None
is_y_mask = self.cfg.PROBLEM.TYPE in ['SEMANTIC_SEG', 'INSTANCE_SEG']
self.X_val, _, _ = f_name(self.cfg.DATA.VAL.PATH, crop=True, crop_shape=self.cfg.DATA.PATCH_SIZE,
overlap=self.cfg.DATA.VAL.OVERLAP, padding=self.cfg.DATA.VAL.PADDING,
reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE,
convert_to_rgb=self.cfg.DATA.FORCE_RGB, preprocess_cfg = preprocess_cfg,
is_mask = False, preprocess_f=preprocess_fn)
if self.cfg.PROBLEM.NDIM == '2D':
crop_shape = (self.cfg.DATA.PATCH_SIZE[0]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[0],
self.cfg.DATA.PATCH_SIZE[1]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[1], self.cfg.DATA.PATCH_SIZE[2])
else:
crop_shape = (self.cfg.DATA.PATCH_SIZE[0], self.cfg.DATA.PATCH_SIZE[1]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[0],
self.cfg.DATA.PATCH_SIZE[2]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[1],
self.cfg.DATA.PATCH_SIZE[3]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[2])
if self.load_Y_val:
self.Y_val, _, _ = f_name(self.cfg.DATA.VAL.GT_PATH, crop=True, crop_shape=crop_shape,
overlap=self.cfg.DATA.VAL.OVERLAP, padding=self.cfg.DATA.VAL.PADDING,
reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE,
check_channel=False, check_drange=False,convert_to_rgb=self.cfg.DATA.FORCE_RGB,
preprocess_cfg = preprocess_cfg, is_mask = is_y_mask, preprocess_f=preprocess_fn)
if self.Y_val is not None and len(self.X_val) != len(self.Y_val):
raise ValueError("Different number of raw and ground truth items ({} vs {}). "
"Please check the data!".format(len(self.X_val), len(self.Y_val)))
else:
# Checking if the user inputted Zarr/H5 files
zarr_files = sorted(next(os.walk(self.cfg.DATA.VAL.PATH))[1])
h5_files = sorted(next(os.walk(self.cfg.DATA.VAL.PATH))[2])
if self.cfg.PROBLEM.NDIM == '3D' and (len(zarr_files) > 0 and '.zarr' in zarr_files[0]) or \
(len(h5_files) > 0 and '.h5' in h5_files[0]):
print("1) Loading validation image information . . .")
if len(zarr_files) > 0 and '.zarr' in zarr_files[0]:
print("Working with Zarr files . . .")
img_files = [os.path.join(self.cfg.DATA.VAL.PATH, x) for x in zarr_files]
mask_files = [os.path.join(self.mask_path, x) for x in sorted(next(os.walk(self.mask_path))[1])]
elif len(h5_files) > 0 and '.h5' in h5_files[0]:
print("Working with H5 files . . .")
img_files = [os.path.join(self.cfg.DATA.VAL.PATH, x) for x in h5_files]
mask_files = [os.path.join(self.mask_path, x) for x in sorted(next(os.walk(self.mask_path))[2])]
del zarr_files, h5_files
if self.cfg.DATA.FORCE_RGB:
print("WARNING: 'DATA.FORCE_RGB' not taken into account when working with Zarr/H5 images")
self.X_val, _ = load_3D_efficient_files(data_path=img_files, input_axes=self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER,
crop_shape=self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.VAL.OVERLAP, padding=self.cfg.DATA.VAL.PADDING)
if self.cfg.PROBLEM.NDIM == '2D':
crop_shape = (self.cfg.DATA.PATCH_SIZE[0]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[0],
self.cfg.DATA.PATCH_SIZE[1]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[1], self.cfg.DATA.PATCH_SIZE[2])
else:
crop_shape = (self.cfg.DATA.PATCH_SIZE[0], self.cfg.DATA.PATCH_SIZE[1]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[0],
self.cfg.DATA.PATCH_SIZE[2]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[1],
self.cfg.DATA.PATCH_SIZE[3]*self.cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING[2])
if self.load_Y_val:
print("1) Loading validation GT information . . .")
self.Y_val, _ = load_3D_efficient_files(data_path=img_files, input_axes=self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER,
crop_shape=crop_shape, overlap=self.cfg.DATA.VAL.OVERLAP, padding=self.cfg.DATA.VAL.PADDING, check_channel=False)
else:
self.Y_val = None
if self.Y_val is not None and len(self.X_val) != len(self.Y_val):
raise ValueError("Different number of raw and ground truth items ({} vs {}). "
"Please check the data!".format(len(self.X_val), len(self.Y_val)))
else:
self.X_val, self.Y_val = None, None
# Ensure all the processes have read the data
if is_dist_avail_and_initialized():
print("Waiting until all processes have read the data . . .")
dist.barrier()
[docs] def destroy_train_data(self):
"""
Delete training variable to release memory.
"""
print("Releasing memory . . .")
if 'X_train' in locals() or 'X_train' in globals():
del self.X_train
if 'Y_train' in locals() or 'Y_train' in globals():
del self.Y_train
if 'X_val' in locals() or 'X_val' in globals():
del self.X_val
if 'Y_val' in locals() or 'Y_val' in globals():
del self.Y_val
if 'train_generator' in locals() or 'train_generator' in globals():
del self.train_generator
if 'val_generator' in locals() or 'val_generator' in globals():
del self.val_generator
[docs] def prepare_train_generators(self):
"""
Build train and val generators.
"""
if self.cfg.TRAIN.ENABLE:
print("##############################")
print("# PREPARE TRAIN GENERATORS #")
print("##############################")
self.train_generator, \
self.val_generator, \
self.data_norm, \
self.num_training_steps_per_epoch = create_train_val_augmentors(self.cfg, self.X_train, self.Y_train,
self.X_val, self.Y_val, self.world_size, self.global_rank, self.args.distributed)
if self.cfg.DATA.CHECK_GENERATORS and self.cfg.PROBLEM.TYPE != 'CLASSIFICATION':
check_generator_consistence(
self.train_generator, self.cfg.PATHS.GEN_CHECKS+"_train", self.cfg.PATHS.GEN_MASK_CHECKS+"_train")
check_generator_consistence(
self.val_generator, self.cfg.PATHS.GEN_CHECKS+"_val", self.cfg.PATHS.GEN_MASK_CHECKS+"_val")
[docs] def bmz_model_call(self, in_img, is_train=False):
"""
Call Bioimage model zoo model.
Parameters
----------
in_img : Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : Tensor
Image prediction.
"""
# Convert from Numpy to xarray.DataArray
if self.cfg.PROBLEM.NDIM == '2D':
self.bmz_axes = ('b', 'c', 'y', 'x')
else:
self.bmz_axes = ('b', 'c', 'z', 'y', 'x')
in_img = xr.DataArray(in_img.cpu().numpy(), dims=tuple(self.bmz_axes))
# Apply pre-processing
in_img = dict(zip([ipt.name for ipt in self.model.input_specs], (in_img,)))
self.bmz_computed_measures = {}
self.model.apply_preprocessing(in_img, self.bmz_computed_measures)
# Predict
prediction_tensors = self.model.predict(*list(in_img.values()))
# Apply post-processing
prediction = dict(zip([out.name for out in self.model.output_specs], prediction_tensors))
self.model.apply_postprocessing(prediction, self.bmz_computed_measures)
# Convert back to Tensor
prediction = torch.from_numpy(prediction['output0'].to_numpy())
return prediction
[docs] @abstractmethod
def torchvision_model_call(self, in_img, is_train=False):
"""
Call a regular Pytorch model.
Parameters
----------
in_img : Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : Tensor
Image prediction.
"""
raise NotImplementedError
[docs] def model_call_func(self, in_img, to_pytorch=True, is_train=False):
"""
Call a regular Pytorch model.
Parameters
----------
in_img : Tensor
Input image to pass through the model.
to_pytorch : bool, optional
Whether if the input image needs to be converted into pytorch format or not.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : Tensor
Image prediction.
"""
if to_pytorch:
in_img = to_pytorch_format(in_img, self.axis_order, self.device)
if self.cfg.MODEL.SOURCE == "biapy":
p = self.model(in_img)
elif self.cfg.MODEL.SOURCE == "bmz":
p = self.bmz_model_call(in_img, is_train)
elif self.cfg.MODEL.SOURCE == "torchvision":
p = self.torchvision_model_call(in_img, is_train)
return p
[docs] def prepare_model(self):
"""
Build the model.
"""
if self.model_prepared:
print("Model already prepared!")
return
print("###############")
print("# Build model #")
print("###############")
if self.cfg.MODEL.SOURCE == "biapy":
self.model = build_model(self.cfg, self.job_identifier, self.device)
elif self.cfg.MODEL.SOURCE == "torchvision":
self.model, self.torchvision_preprocessing = build_torchvision_model(self.cfg, self.device)
# Bioimage Model Zoo pretrained models
elif self.cfg.MODEL.SOURCE == "bmz":
# Create a bioimage pipeline to create predictions
try:
self.model = bioimageio.core.create_prediction_pipeline(
self.bmz_model_resource, devices=None,
weight_format="torchscript",
)
except Exception as e:
print(f"The error thrown during the BMZ model load was:\n{e}")
raise ValueError("An error ocurred when creating the BMZ model (see above). "
"BiaPy only supports models prepared with Torchscript.")
if self.args.distributed:
raise ValueError("DDP can not be activated when loading a BMZ pretrained model")
self.model_without_ddp = self.model
if self.args.distributed:
find_unused_parameters = True if self.cfg.MODEL.ARCHITECTURE.lower() == "unetr" else False
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.args.gpu],
find_unused_parameters=find_unused_parameters)
self.model_without_ddp = self.model.module
self.model_prepared = True
# Load checkpoint if necessary
if self.cfg.MODEL.SOURCE == "biapy" and self.cfg.MODEL.LOAD_CHECKPOINT:
self.start_epoch = load_model_checkpoint(cfg=self.cfg, jobname=self.job_identifier, model_without_ddp=self.model_without_ddp,
device=self.device, optimizer=self.optimizer, loss_scaler=self.loss_scaler)
else:
self.start_epoch = 0
[docs] def train(self):
"""
Training phase.
"""
self.load_train_data()
if not self.model_prepared:
self.prepare_model()
self.prepare_train_generators()
self.prepare_logging_tool()
self.early_stopping = build_callbacks(self.cfg)
self.optimizer, self.lr_scheduler, self.loss_scaler = prepare_optimizer(self.cfg, self.model_without_ddp,
len(self.train_generator))
print("#####################")
print("# TRAIN THE MODEL #")
print("#####################")
print(f"Start training in epoch {self.start_epoch+1} - Total: {self.cfg.TRAIN.EPOCHS}")
start_time = time.time()
val_best_metric = np.zeros(len(self.metric_names), dtype=np.float32)
val_best_loss = np.Inf
for epoch in range(self.start_epoch, self.cfg.TRAIN.EPOCHS):
print("~~~ Epoch {}/{} ~~~\n".format(epoch+1, self.cfg.TRAIN.EPOCHS))
e_start = time.time()
if self.args.distributed:
self.train_generator.sampler.set_epoch(epoch)
if self.log_writer is not None:
self.log_writer.set_step(epoch * self.num_training_steps_per_epoch)
# Train
train_stats = train_one_epoch(self.cfg, model=self.model, model_call_func=self.model_call_func, loss_function=self.loss,
activations=self.apply_model_activations, metric_function=self.metric_calculation, prepare_targets=self.prepare_targets,
data_loader=self.train_generator, optimizer=self.optimizer, device=self.device, loss_scaler=self.loss_scaler, epoch=epoch,
log_writer=self.log_writer, lr_scheduler=self.lr_scheduler, start_steps=epoch * self.num_training_steps_per_epoch,
verbose=self.cfg.TRAIN.VERBOSE)
# Save checkpoint
if self.cfg.MODEL.SAVE_CKPT_FREQ != -1:
if (epoch + 1) % self.cfg.MODEL.SAVE_CKPT_FREQ == 0 or epoch + 1 == self.cfg.TRAIN.EPOCHS and is_main_process():
save_model(cfg=self.cfg, jobname=self.job_identifier, model=self.model, model_without_ddp=self.model_without_ddp,
optimizer=self.optimizer, loss_scaler=self.loss_scaler, epoch=epoch+1)
# Validation
if self.val_generator is not None:
test_stats = evaluate(self.cfg, model=self.model, model_call_func=self.model_call_func, loss_function=self.loss,
activations=self.apply_model_activations, metric_function=self.metric_calculation, prepare_targets=self.prepare_targets,
epoch=epoch, data_loader=self.val_generator, lr_scheduler=self.lr_scheduler)
# Save checkpoint is val loss improved
if test_stats['loss'] < val_best_loss:
f = os.path.join(self.cfg.PATHS.CHECKPOINT,"{}-checkpoint-best.pth".format(self.job_identifier))
print("Val loss improved from {} to {}, saving model to {}".format(val_best_loss, test_stats['loss'], f))
m = " "
for i in range(len(val_best_metric)):
val_best_metric[i] = test_stats[self.metric_names[i]]
m += f"{self.metric_names[i]}: {val_best_metric[i]:.4f} "
val_best_loss = test_stats['loss']
if is_main_process():
save_model(cfg=self.cfg, jobname=self.job_identifier, model=self.model, model_without_ddp=self.model_without_ddp,
optimizer=self.optimizer, loss_scaler=self.loss_scaler, epoch="best")
print(f'[Val] best loss: {val_best_loss:.4f} best '+m)
# Store validation stats
if self.log_writer is not None:
self.log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
for i in range(len(self.metric_names)):
self.log_writer.update(test_iou=test_stats[self.metric_names[i]], head="perf", step=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch}
else:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch}
# Write statistics in the logging file
if is_main_process():
# Log epoch stats
if self.log_writer is not None:
self.log_writer.flush()
with open(self.log_file, mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
# Create training plot
self.plot_values['loss'].append(train_stats['loss'])
if self.val_generator is not None:
self.plot_values['val_loss'].append(test_stats['loss'])
for i in range(len(self.metric_names)):
self.plot_values[self.metric_names[i]].append(train_stats[self.metric_names[i]])
if self.val_generator is not None:
self.plot_values['val_'+self.metric_names[i]].append(test_stats[self.metric_names[i]])
if (epoch+1) % self.cfg.LOG.CHART_CREATION_FREQ == 0:
create_plots(self.plot_values, self.metric_names, self.job_identifier, self.cfg.PATHS.CHARTS)
if self.val_generator is not None and self.early_stopping is not None:
self.early_stopping(test_stats['loss'])
if self.early_stopping.early_stop:
print("Early stopping")
break
e_end = time.time()
t_epoch = e_end - e_start
print("[Time] {} {}/{}\n".format(time_text(t_epoch), time_text(e_end - start_time),
time_text((e_end - start_time)+(t_epoch*(self.cfg.TRAIN.EPOCHS-epoch)))))
total_time = time.time() - start_time
self.total_training_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(self.total_training_time_str))
print("Train loss: {}".format(train_stats['loss']))
for i in range(len(self.metric_names)):
print("Train {}: {}".format(self.metric_names[i], train_stats[self.metric_names[i]]))
if self.val_generator is not None:
print("Val loss: {}".format(val_best_loss))
for i in range(len(self.metric_names)):
print("Val {}: {}".format(self.metric_names[i], val_best_metric[i]))
print('Finished Training')
# Save two samples to export the model to BMZ
if self.bmz_test_input is None:
sample = next(enumerate(self.train_generator))
self.bmz_test_input = sample[1][0][0]
self.bmz_test_output = sample[1][1]
if not isinstance(self.bmz_test_output, int):
self.bmz_test_output = self.bmz_test_output[0]
self.destroy_train_data()
[docs] def load_test_data(self):
"""
Load test data.
"""
if self.cfg.TEST.ENABLE:
print("######################")
print("# LOAD TEST DATA #")
print("######################")
if not self.cfg.DATA.TEST.USE_VAL_AS_TEST:
if self.cfg.DATA.TEST.IN_MEMORY:
print("2) Loading test images . . .")
f_name = load_data_from_dir if self.cfg.PROBLEM.NDIM == '2D' else load_3d_images_from_dir
preprocess_cfg = self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.TEST else None
preprocess_fn = preprocess_data if self.cfg.DATA.PREPROCESS.TEST else None
is_y_mask = self.cfg.PROBLEM.TYPE in ['SEMANTIC_SEG', 'INSTANCE_SEG']
self.X_test, _, _ = f_name(self.cfg.DATA.TEST.PATH, convert_to_rgb=self.cfg.DATA.FORCE_RGB,
preprocess_cfg=preprocess_cfg, is_mask=False, preprocess_f=preprocess_fn)
if self.cfg.DATA.TEST.LOAD_GT:
print("3) Loading test masks . . .")
self.Y_test, _, _ = f_name(self.cfg.DATA.TEST.GT_PATH, check_channel=False, check_drange=False,
preprocess_cfg=preprocess_cfg, is_mask=is_y_mask, preprocess_f=preprocess_fn)
if len(self.X_test) != len(self.Y_test):
raise ValueError("Different number of raw and ground truth items ({} vs {}). "
"Please check the data!".format(len(self.X_test), len(self.Y_test)))
else:
self.Y_test = None
else:
self.X_test, self.Y_test = None, None
if self.original_test_path is None:
self.test_filenames = sorted(next(os.walk(self.cfg.DATA.TEST.PATH))[2])
else:
self.test_filenames = sorted(next(os.walk(self.original_test_path))[2])
else:
# The test is the validation, and as it is only available when validation is obtained from train and when
# cross validation is enabled, the test set files reside in the train folder
self.test_filenames = sorted(next(os.walk(self.cfg.DATA.TRAIN.PATH))[2])
self.X_test, self.Y_test = None, None
if self.cross_val_samples_ids is None:
# Split the test as it was the validation when train is not enabled
skf = StratifiedKFold(n_splits=self.cfg.DATA.VAL.CROSS_VAL_NFOLD, shuffle=self.cfg.DATA.VAL.RANDOM,
random_state=self.cfg.SYSTEM.SEED)
fold = 1
test_index = None
A = B = np.zeros(len(self.test_filenames))
for _, te_index in skf.split(A, B):
if self.cfg.DATA.VAL.CROSS_VAL_FOLD == fold:
self.cross_val_samples_ids = te_index.copy()
break
fold += 1
if len(self.cross_val_samples_ids) > 5:
print("Fold number {} used for test data. Printing the first 5 ids: {}".format(fold, self.cross_val_samples_ids[:5]))
else:
print("Fold number {}. Indexes used in cross validation: {}".format(fold, self.cross_val_samples_ids))
self.test_filenames = [x for i, x in enumerate(self.test_filenames) if i in self.cross_val_samples_ids]
self.original_test_path = self.orig_train_path
self.original_test_mask_path = self.orig_train_mask_path
[docs] def destroy_test_data(self):
"""
Delete test variable to release memory.
"""
print("Releasing memory . . .")
if 'X_test' in locals() or 'X_test' in globals():
del self.X_test
if 'Y_test' in locals() or 'Y_test' in globals():
del self.Y_test
if 'test_generator' in locals() or 'test_generator' in globals():
del self.test_generator
if '_X' in locals() or '_X' in globals():
del self._X
if '_Y' in locals() or '_Y' in globals():
del self._Y
[docs] def prepare_test_generators(self):
"""
Prepare test data generator.
"""
if self.cfg.TEST.ENABLE:
print("############################")
print("# PREPARE TEST GENERATOR #")
print("############################")
self.test_generator, self.data_norm = create_test_augmentor(self.cfg, self.X_test, self.Y_test, self.cross_val_samples_ids)
[docs] def apply_model_activations(self, pred, training=False):
"""
Function that apply the last activation (if any) to the model's output.
Parameters
----------
pred : Torch Tensor
Predictions of the model.
training : bool, optional
To advice the function if this is being applied during training of inference. During training,
``CE_Sigmoid`` activations will NOT be applied, as ``torch.nn.BCEWithLogitsLoss`` will apply
``Sigmoid`` automatically in a way that is more stable numerically
(`ref <https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html>`_).
Returns
-------
pred : Torch tensor
Resulting predictions after applying last activation(s).
"""
if not isinstance(pred, list):
multiple_heads = False
pred = [pred]
else:
multiple_heads = True
assert len(pred) == len(self.activations), "Activations length need to match prediction list length in multiple heads setting"
for out_heads in range(len(pred)):
for key, value in self.activations[out_heads].items():
# Ignore CE_Sigmoid as torch.nn.BCEWithLogitsLoss will apply Sigmoid automatically in a way
# that is more stable numerically (ref: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)
if (training and value not in ["Linear", "CE_Sigmoid"]) or (not training and value != "Linear"):
value = "Sigmoid" if value == "CE_Sigmoid" else value
act = getattr(torch.nn, value)()
if key == ':':
pred[out_heads] = act(pred[out_heads])
else:
pred[out_heads][:,int(key),...] = act(pred[out_heads][:,int(key),...])
if not multiple_heads:
return pred[0]
else:
return pred
[docs] @torch.no_grad()
def test(self):
"""
Test/Inference step.
"""
self.load_test_data()
if not self.model_prepared:
self.prepare_model()
self.prepare_test_generators()
# Switch to evaluation mode
if self.cfg.MODEL.SOURCE != "bmz":
self.model_without_ddp.eval()
# Check possible checkpoint problems
if self.start_epoch == -1:
raise ValueError("There was a problem loading the checkpoint. Test phase aborted!")
image_counter = 0
print("###############")
print("# INFERENCE #")
print("###############")
print("Making predictions on test data . . .")
# Reactivate prints to see each rank progress
if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == '3D':
setup_for_distributed(True)
# Process all the images
for i, gen_obj in tqdm(enumerate(self.test_generator), total=len(self.test_generator), disable=not is_main_process()):
self._X, X_norm, self._Y, Y_norm = None, None, None, None
if 'X' in gen_obj: self._X = gen_obj['X']
if 'X_norm' in gen_obj: X_norm = gen_obj['X_norm']
if 'Y' in gen_obj: self._Y = gen_obj['Y']
if 'Y_norm' in gen_obj: Y_norm = gen_obj['Y_norm']
self.processing_filenames = self.test_filenames[gen_obj['file']] if isinstance(gen_obj['file'], int) else gen_obj['file']
self.processing_filenames = [os.path.basename(self.processing_filenames)]
self.f_numbers = [i]
del gen_obj
if self.cfg.TEST.BY_CHUNKS.ENABLE and self.cfg.PROBLEM.NDIM == '3D':
print(f"[Rank {get_rank()} ({os.getpid()})] Processing image(s): {self.processing_filenames[0]}")
self.process_sample_by_chunks(self.processing_filenames[0])
else:
if is_main_process():
print("Processing image: {}".format(self.processing_filenames[0]))
self.process_sample(norm=(X_norm, Y_norm))
image_counter += 1
self.destroy_test_data()
if is_main_process():
self.after_all_images()
print("#############")
print("# RESULTS #")
print("#############")
if self.cfg.TRAIN.ENABLE:
print("Epoch number: {}".format(len(self.plot_values['val_loss'])))
print("Train time (s): {}".format(self.total_training_time_str))
print("Train loss: {}".format(np.min(self.plot_values['loss'])))
for i in range(len(self.metric_names)):
if self.metric_names[i] == "IoU":
print("Train Foreground {}: {}".format(self.metric_names[i], np.max(self.plot_values[self.metric_names[i]])))
else:
print("Train {}: {}".format(self.metric_names[i], np.max(self.plot_values[self.metric_names[i]])))
print("Validation loss: {}".format(np.min(self.plot_values['val_loss'])))
for i in range(len(self.metric_names)):
if self.metric_names[i] == "IoU":
print("Validation Foreground {}: {}".format(self.metric_names[i], np.max(self.plot_values['val_'+self.metric_names[i]])))
else:
print("Validation {}: {}".format(self.metric_names[i], np.max(self.plot_values['val_'+self.metric_names[i]])))
self.print_stats(image_counter)
[docs] def process_sample_by_chunks(self, filenames):
"""
Function to process a sample in the inference phase. A final H5/Zarr file is created in "TZCYX" or "TZYXC" order
depending on ``TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER`` ('T' is always included).
Parameters
----------
filenames : List of str
Filenames fo the samples to process.
"""
filename, file_extension = os.path.splitext(filenames)
if file_extension not in ['.hdf5', '.h5', ".zarr"]:
print("WARNING: you could have saved more memory by converting input test images into H5 file format (.h5) "
"or Zarr (.zarr) as with 'TEST.BY_CHUNKS.ENABLE' option enabled H5/Zarr files will be processed by chunks")
# Load data
if file_extension in ['.hdf5', '.h5', ".zarr"]:
self._X_file, self._X = read_chunked_data(self._X)
else: # Numpy array
if self._X.ndim == 3:
c_pos = -1 if self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER[-1] == 'C' else 1
self._X = np.expand_dims(self._X, c_pos)
if is_main_process():
print(f"Loaded image shape is {self._X.shape}")
data_shape = self._X.shape
if self._X.ndim < 3:
raise ValueError("Loaded image need to have at least 3 dimensions: {} (ndim: {})".format(self._X.shape, self._X.ndim))
if len(self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER) != self._X.ndim:
raise ValueError("'TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER' value {} does not match the number of dimensions of the loaded H5/Zarr "
"file {} (ndim: {})".format(self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER, self._X.shape, self._X.ndim))
# Data paths
os.makedirs(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, exist_ok=True)
ext = ".h5" if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5" else ".zarr"
if self.cfg.SYSTEM.NUM_GPUS > 1:
out_data_filename = os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+"_part"+str(get_rank())+ext)
out_data_mask_filename = os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+"_part"+str(get_rank())+"_mask"+ext)
else:
out_data_filename = os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+"_nodiv"+ext)
out_data_mask_filename = os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+"_mask"+ext)
out_data_div_filename = os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+ext)
in_data = self._X
# Process in charge of processing one predicted patch
output_handle_proc = mp.Process(target=insert_patch_into_dataset, args=(out_data_filename, out_data_mask_filename,
data_shape, self.output_queue, self.extract_info_queue, self.cfg, self.dtype_str, self.dtype,
self.cfg.TEST.BY_CHUNKS.FORMAT, self.cfg.TEST.VERBOSE))
output_handle_proc.daemon=True
output_handle_proc.start()
# Process in charge of loading part of the data
load_data_process = mp.Process(target=extract_patch_from_dataset, args=(in_data, self.cfg, self.input_queue,
self.extract_info_queue, self.cfg.TEST.VERBOSE))
load_data_process.daemon=True
load_data_process.start()
if '_X_file' in locals() and isinstance(self._X_file, h5py.File):
self._X_file.close()
del self._X, in_data
# Lock the thread inferring until no more patches
if self.cfg.TEST.VERBOSE and self.cfg.SYSTEM.NUM_GPUS > 1:
print(f"[Rank {get_rank()} ({os.getpid()})] Doing inference ")
while True:
obj = self.input_queue.get(timeout=60)
if obj == None: break
img, patch_coords = obj
img, _ = self.test_generator.norm_X(img)
if self.cfg.TEST.AUGMENTATION:
p = ensemble16_3d_predictions(img[0], batch_size_value=self.cfg.TRAIN.BATCH_SIZE,
axis_order_back=self.axis_order_back, pred_func=self.model_call_func,
axis_order=self.axis_order, device=self.device)
else:
with torch.cuda.amp.autocast():
p = self.model_call_func(img)
p = self.apply_model_activations(p)
# Multi-head concatenation
if isinstance(p, list):
p = torch.cat((p[0], torch.argmax(p[1], axis=1).unsqueeze(1)), dim=1)
p = to_numpy_format(p, self.axis_order_back)
# Create a mask with the overlap. Calculate the exact part of the patch that will be inserted in the
# final H5/Zarr file
p = p[0, self.cfg.DATA.TEST.PADDING[0]:p.shape[1]-self.cfg.DATA.TEST.PADDING[0],
self.cfg.DATA.TEST.PADDING[1]:p.shape[2]-self.cfg.DATA.TEST.PADDING[1],
self.cfg.DATA.TEST.PADDING[2]:p.shape[3]-self.cfg.DATA.TEST.PADDING[2]]
m = np.ones(p.shape, dtype=np.uint8)
# Put the prediction into queue
self.output_queue.put([p, m, patch_coords])
# Get some auxiliar variables
self.stats['patch_counter'] = self.extract_info_queue.get(timeout=60)
if is_main_process():
z_vol_info = self.extract_info_queue.get(timeout=60)
list_of_vols_in_z = self.extract_info_queue.get(timeout=60)
load_data_process.join()
output_handle_proc.join()
# Wait until all threads are done so the main thread can create the full size image
if self.cfg.SYSTEM.NUM_GPUS > 1 :
if self.cfg.TEST.VERBOSE:
print(f"[Rank {get_rank()} ({os.getpid()})] Finish sample inference ")
if is_dist_avail_and_initialized():
dist.barrier()
# Create the final H5/Zarr file that contains all the individual parts
if is_main_process():
if "C" not in self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER:
out_data_order = self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER + "C"
c_index = -1
else:
out_data_order = self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER
c_index = out_data_order.index("C")
if self.cfg.SYSTEM.NUM_GPUS > 1:
# Obtain parts of the data created by all GPUs
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
data_parts_filenames = sorted(next(os.walk(self.cfg.PATHS.RESULT_DIR.PER_IMAGE))[2])
else:
data_parts_filenames = sorted(next(os.walk(self.cfg.PATHS.RESULT_DIR.PER_IMAGE))[1])
parts = []
mask_parts = []
for x in data_parts_filenames:
if filename+"_part" in x and x.endswith(self.cfg.TEST.BY_CHUNKS.FORMAT):
if "_mask" not in x:
parts.append(x)
else:
mask_parts.append(x)
data_parts_filenames = parts
data_parts_mask_filenames = mask_parts
del parts, mask_parts
if max(1,self.cfg.SYSTEM.NUM_GPUS) != len(data_parts_filenames) != len(list_of_vols_in_z):
raise ValueError("Number of data parts is not the same as number of GPUs")
# Compose the large image
for i, data_part_fname in enumerate(data_parts_filenames):
print("Reading {}".format(os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, data_part_fname)))
data_part_file, data_part = read_chunked_data(os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, data_part_fname))
data_mask_part_file, data_mask_part = read_chunked_data(os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, data_parts_mask_filenames[i]))
if 'data' not in locals():
all_data_filename = os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+ext)
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
allfile = h5py.File(all_data_filename,'w')
data = allfile.create_dataset("data", data_part.shape, dtype=self.dtype_str, compression="gzip")
else:
allfile = zarr.open_group(all_data_filename, mode="w")
data = allfile.create_dataset("data", shape=data_part.shape, dtype=self.dtype_str, compression="gzip")
for j, k in enumerate(list_of_vols_in_z[i]):
slices = (
slice(z_vol_info[k][0],z_vol_info[k][1]), # z (only z axis is distributed across GPUs)
slice(None), # y
slice(None), # x
slice(None), # Channel
)
data_ordered_slices = order_dimensions(
slices,
input_order="ZYXC",
output_order=out_data_order,
default_value=0)
if self.cfg.TEST.VERBOSE:
print(f"Filling {k} [{z_vol_info[k][0]}:{z_vol_info[k][1]}]")
data[data_ordered_slices] = data_part[data_ordered_slices] / data_mask_part[data_ordered_slices]
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
allfile.flush()
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
data_part_file.close()
data_mask_part_file.close()
# Save image
if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF and self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "":
current_order = np.array(range(len(data.shape)))
transpose_order = order_dimensions(current_order, input_order=out_data_order,
output_order="TZYXC", default_value=np.nan)
transpose_order = [x for x in transpose_order if not np.isnan(x)]
data = np.array(data, dtype=self.dtype).transpose(transpose_order)
if "T" not in out_data_order:
data = np.expand_dims(data,0)
save_tif(data, self.cfg.PATHS.RESULT_DIR.PER_IMAGE, [filename+".tif"], verbose=self.cfg.TEST.VERBOSE)
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
allfile.close()
# Just make the division with the overlap
else:
# Load predictions and overlapping mask
pred_file, pred = read_chunked_data(out_data_filename)
mask_file, mask = read_chunked_data(out_data_mask_filename)
# Create new file
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
fid_div = h5py.File(out_data_div_filename,'w')
pred_div = fid_div.create_dataset("data", pred.shape, dtype=pred.dtype, compression="gzip")
else:
fid_div = zarr.open_group(out_data_div_filename, mode="w")
pred_div = fid_div.create_dataset("data", shape=pred.shape, dtype=pred.dtype)
t_dim, z_dim, c_dim, y_dim, x_dim = order_dimensions(
data_shape, self.cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER)
# Fill the new data
z_vols = math.ceil(z_dim/self.cfg.DATA.PATCH_SIZE[0])
y_vols = math.ceil(y_dim/self.cfg.DATA.PATCH_SIZE[1])
x_vols = math.ceil(x_dim/self.cfg.DATA.PATCH_SIZE[2])
for z in tqdm(range(z_vols), disable=not is_main_process()):
for y in range(y_vols):
for x in range(x_vols):
slices = (
slice(z*self.cfg.DATA.PATCH_SIZE[0], min(z_dim,self.cfg.DATA.PATCH_SIZE[0]*(z+1))),
slice(y*self.cfg.DATA.PATCH_SIZE[1], min(y_dim,self.cfg.DATA.PATCH_SIZE[1]*(y+1))),
slice(x*self.cfg.DATA.PATCH_SIZE[2], min(x_dim,self.cfg.DATA.PATCH_SIZE[2]*(x+1))),
slice(0,pred.shape[c_index]), # Channel
)
data_ordered_slices = order_dimensions(
slices,
input_order = "ZYXC",
output_order = out_data_order,
default_value = 0,
)
pred_div[data_ordered_slices] = pred[data_ordered_slices] / mask[data_ordered_slices]
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
fid_div.flush()
# Save image
if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF and self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "":
current_order = np.array(range(len(pred_div.shape)))
transpose_order = order_dimensions(current_order, input_order=out_data_order,
output_order="TZYXC", default_value=np.nan)
transpose_order = [x for x in transpose_order if not np.isnan(x)]
pred_div = np.array(pred_div, dtype=self.dtype).transpose(transpose_order)
if "T" not in out_data_order:
pred_div = np.expand_dims(pred_div,0)
save_tif(pred_div, self.cfg.PATHS.RESULT_DIR.PER_IMAGE, [os.path.join(self.cfg.PATHS.RESULT_DIR.PER_IMAGE, filename+".tif")],
verbose=self.cfg.TEST.VERBOSE)
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
pred_file.close()
mask_file.close()
fid_div.close()
if self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS:
if self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk":
self.after_merge_patches_by_chunks_proccess_patch(out_data_div_filename)
else:
self.after_merge_patches_by_chunks_proccess_entire_pred(out_data_div_filename)
# Wait until the main thread is done to predict the next sample
if self.cfg.SYSTEM.NUM_GPUS > 1 :
if self.cfg.TEST.VERBOSE:
print(f"[Rank {get_rank()} ({os.getpid()})] Process waiting . . . ")
if is_dist_avail_and_initialized():
dist.barrier()
if self.cfg.TEST.VERBOSE:
print(f"[Rank {get_rank()} ({os.getpid()})] Synched with main thread. Go for the next sample")
[docs] def process_sample(self, norm):
"""
Function to process a sample in the inference phase.
Parameters
----------
norm : List of dicts
Normalization used during training. Required to denormalize the predictions of the model.
"""
# Data channel check
if self.cfg.DATA.PATCH_SIZE[-1] != self._X.shape[-1]:
raise ValueError("Channel of the DATA.PATCH_SIZE given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(self.cfg.DATA.PATCH_SIZE[-1], self._X.shape[-1]))
#################
### PER PATCH ###
#################
if not self.cfg.TEST.FULL_IMG or self.cfg.PROBLEM.NDIM == '3D':
if not self.cfg.TEST.REUSE_PREDICTIONS:
# Reflect data to complete the needed shape
if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE:
reflected_orig_shape = self._X.shape
self._X = np.expand_dims(pad_and_reflect(self._X[0], self.cfg.DATA.PATCH_SIZE, verbose=self.cfg.TEST.VERBOSE),0)
if self.cfg.DATA.TEST.LOAD_GT:
self._Y = np.expand_dims(pad_and_reflect(self._Y[0], self.cfg.DATA.PATCH_SIZE, verbose=self.cfg.TEST.VERBOSE),0)
original_data_shape = self._X.shape
# Crop if necessary
if self._X.shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]:
# Copy X to be used later in full image
if self.cfg.PROBLEM.NDIM != '3D':
X_original = self._X.copy()
if self.cfg.DATA.TEST.LOAD_GT and self._X.shape[:-1] != self._Y.shape[:-1]:
raise ValueError("Image {} and mask {} differ in shape (without considering the channels, i.e. last dimension)"
.format(self._X.shape,self._Y.shape))
if self.cfg.PROBLEM.NDIM == '2D':
obj = crop_data_with_overlap(self._X, self.cfg.DATA.PATCH_SIZE, data_mask=self._Y, overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE)
if self.cfg.DATA.TEST.LOAD_GT:
self._X, self._Y = obj
else:
self._X = obj
del obj
else:
if self.cfg.TEST.REDUCE_MEMORY:
self._X = crop_3D_data_with_overlap(self._X[0], self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE,
median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING)
if self.cfg.DATA.TEST.LOAD_GT:
self._Y = crop_3D_data_with_overlap(self._Y[0], self.cfg.DATA.PATCH_SIZE[:-1]+(self._Y.shape[-1],), overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE,
median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING)
else:
if self.cfg.DATA.TEST.LOAD_GT: self._Y = self._Y[0]
obj = crop_3D_data_with_overlap(self._X[0], self.cfg.DATA.PATCH_SIZE, data_mask=self._Y, overlap=self.cfg.DATA.TEST.OVERLAP,
padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE,
median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING)
if self.cfg.DATA.TEST.LOAD_GT:
self._X, self._Y = obj
else:
self._X = obj
del obj
# Evaluate each patch
if self.cfg.DATA.TEST.LOAD_GT and self.cfg.TEST.EVALUATE:
l = int(math.ceil(self._X.shape[0]/self.cfg.TRAIN.BATCH_SIZE))
for k in tqdm(range(l), leave=False):
top = (k+1)*self.cfg.TRAIN.BATCH_SIZE if (k+1)*self.cfg.TRAIN.BATCH_SIZE < self._X.shape[0] else self._X.shape[0]
with torch.cuda.amp.autocast():
output = self.apply_model_activations(self.model_call_func(self._X[k*self.cfg.TRAIN.BATCH_SIZE:top]))
loss = self.loss(output, to_pytorch_format(self._Y[k*self.cfg.TRAIN.BATCH_SIZE:top], self.axis_order, self.device, dtype=self.loss_dtype))
# Calculate the metrics
train_iou = self.metric_calculation(output, to_pytorch_format(self._Y[k*self.cfg.TRAIN.BATCH_SIZE:top], self.axis_order, self.device, dtype=self.loss_dtype))
self.stats['loss_per_crop'] += loss.item()
self.stats['iou_per_crop'] += train_iou
del output
self.stats['patch_counter'] += self._X.shape[0]
# Predict each patch
if self.cfg.TEST.AUGMENTATION:
for k in tqdm(range(self._X.shape[0]), leave=False):
if self.cfg.PROBLEM.NDIM == '2D':
p = ensemble8_2d_predictions(self._X[k], axis_order_back=self.axis_order_back,
pred_func=self.model_call_func, axis_order=self.axis_order, device=self.device)
else:
p = ensemble16_3d_predictions(self._X[k], batch_size_value=self.cfg.TRAIN.BATCH_SIZE,
axis_order_back=self.axis_order_back, pred_func=self.model_call_func,
axis_order=self.axis_order, device=self.device)
p = self.apply_model_activations(p)
# Multi-head concatenation
if isinstance(p, list):
p = torch.cat((p[0], p[1]), dim=1)
p = to_numpy_format(p, self.axis_order_back)
if 'pred' not in locals():
pred = np.zeros((self._X.shape[0],)+p.shape[1:], dtype=self.dtype)
pred[k] = p
else:
l = int(math.ceil(self._X.shape[0]/self.cfg.TRAIN.BATCH_SIZE))
for k in tqdm(range(l), leave=False):
top = (k+1)*self.cfg.TRAIN.BATCH_SIZE if (k+1)*self.cfg.TRAIN.BATCH_SIZE < self._X.shape[0] else self._X.shape[0]
with torch.cuda.amp.autocast():
p = self.apply_model_activations(self.model_call_func(self._X[k*self.cfg.TRAIN.BATCH_SIZE:top]))
# Multi-head concatenation
if isinstance(p, list):
p = torch.cat((p[0], p[1]), dim=1)
p = to_numpy_format(p, self.axis_order_back)
if 'pred' not in locals():
pred = np.zeros((self._X.shape[0],)+p.shape[1:], dtype=self.dtype)
pred[k*self.cfg.TRAIN.BATCH_SIZE:top] = p
# Delete self._X as in 3D there is no full image
if self.cfg.PROBLEM.NDIM == '3D':
del self._X, p
# Reconstruct the predictions
if original_data_shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]:
if self.cfg.PROBLEM.NDIM == '3D': original_data_shape = original_data_shape[1:]
f_name = merge_data_with_overlap if self.cfg.PROBLEM.NDIM == '2D' else merge_3D_data_with_overlap
if self.cfg.TEST.REDUCE_MEMORY:
pred = f_name(pred, original_data_shape[:-1]+(pred.shape[-1],), padding=self.cfg.DATA.TEST.PADDING,
overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE)
if self.cfg.DATA.TEST.LOAD_GT:
self._Y = f_name(self._Y, original_data_shape[:-1]+(self._Y.shape[-1],), padding=self.cfg.DATA.TEST.PADDING,
overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE)
else:
obj = f_name(pred, original_data_shape[:-1]+(pred.shape[-1],), data_mask=self._Y,
padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP,
verbose=self.cfg.TEST.VERBOSE)
if self.cfg.DATA.TEST.LOAD_GT:
pred, self._Y = obj
else:
pred = obj
del obj
if self.cfg.PROBLEM.NDIM != '3D':
self._X = X_original.copy()
del X_original
else:
pred = np.expand_dims(pred,0)
if self._Y is not None: self._Y = np.expand_dims(self._Y,0)
if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE:
if self.cfg.PROBLEM.NDIM == '2D':
pred = pred[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:]
if self._Y is not None:
self._Y = self._Y[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:]
else:
pred = pred[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:,-reflected_orig_shape[3]:]
if self._Y is not None:
self._Y = self._Y[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:,-reflected_orig_shape[3]:]
# Argmax if needed
if self.cfg.MODEL.N_CLASSES > 2 and self.cfg.DATA.TEST.ARGMAX_TO_OUTPUT:
# Multi-head case of instance segmentation
if pred.shape[-1] > self.cfg.MODEL.N_CLASSES:
pred = np.concatenate([pred[...,:-(self.cfg.MODEL.N_CLASSES)],
np.expand_dims(np.argmax(pred[...,-(self.cfg.MODEL.N_CLASSES):],-1), -1)], axis=-1)
else:
pred = np.expand_dims(np.argmax(pred,-1), -1)
if self.cfg.DATA.TEST.LOAD_GT: self._Y = np.expand_dims(np.argmax(self._Y,-1), -1)
# Apply mask
if self.cfg.TEST.POST_PROCESSING.APPLY_MASK:
pred = np.expand_dims(apply_binary_mask(pred[0], self.cfg.DATA.TEST.BINARY_MASKS),0)
# Save image
if self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "":
save_tif(pred, self.cfg.PATHS.RESULT_DIR.PER_IMAGE, self.processing_filenames,
verbose=self.cfg.TEST.VERBOSE)
if self.cfg.DATA.TEST.LOAD_GT and self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS != "Dv2":
if self.cfg.LOSS.TYPE != 'MASKED_BCE':
_iou_merge_patches = jaccard_index_numpy((self._Y>0.5).astype(np.uint8), (pred>0.5).astype(np.uint8))
_ov_iou_merge_patches = voc_calculation((self._Y>0.5).astype(np.uint8), (pred>0.5).astype(np.uint8),
_iou_merge_patches)
else:
exclusion_mask = self._Y < 2
binY = self._Y * exclusion_mask.astype( float )
_iou_merge_patches = jaccard_index_numpy((binY>0.5).astype(np.uint8), (pred>0.5).astype(np.uint8))
_ov_iou_merge_patches = voc_calculation((binY>0.5).astype(np.uint8), (pred>0.5).astype(np.uint8),
_iou_merge_patches)
self.stats['iou_merge_patches'] += _iou_merge_patches
self.stats['ov_iou_merge_patches'] += _ov_iou_merge_patches
############################
### POST-PROCESSING (3D) ###
############################
if self.post_processing['per_image']:
pred, _iou_post, _ov_iou_post = apply_post_processing(self.cfg, pred, self._Y)
self.stats['iou_merge_patches_post'] += _iou_post
self.stats['ov_iou_merge_patches_post'] += _ov_iou_post
save_tif(pred, self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING, self.processing_filenames,
verbose=self.cfg.TEST.VERBOSE)
else:
# Load predictions from file
f = self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING if self.post_processing['per_image'] else self.cfg.PATHS.RESULT_DIR.PER_IMAGE
f_name = load_data_from_dir if self.cfg.PROBLEM.NDIM == '2D' else load_3d_images_from_dir
pred, _, _ = f_name(f)
self.after_merge_patches(pred)
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
self.all_pred.append(pred)
if self.cfg.DATA.TEST.LOAD_GT: self.all_gt.append(self._Y)
##################
### FULL IMAGE ###
##################
if self.cfg.TEST.FULL_IMG and self.cfg.PROBLEM.NDIM == '2D':
self._X, o_test_shape = check_downsample_division(self._X, len(self.cfg.MODEL.FEATURE_MAPS)-1)
if not self.cfg.TEST.REUSE_PREDICTIONS:
if self.cfg.DATA.TEST.LOAD_GT:
self._Y, _ = check_downsample_division(self._Y, len(self.cfg.MODEL.FEATURE_MAPS)-1)
# Evaluate each img
if self.cfg.DATA.TEST.LOAD_GT:
with torch.cuda.amp.autocast():
output = self.model_call_func(self._X)
loss = self.loss(output, to_pytorch_format(self._Y, self.axis_order, self.device, dtype=self.loss_dtype))
self.stats['loss'] += loss.item()
del output
# Make the prediction
if self.cfg.TEST.AUGMENTATION:
pred = ensemble8_2d_predictions(self._X[0], axis_order_back=self.axis_order_back,
pred_func=self.model_call_func, axis_order=self.axis_order, device=self.device)
else:
with torch.cuda.amp.autocast():
pred = self.model_call_func(self._X)
pred = self.apply_model_activations(pred)
# Multi-head concatenation
if isinstance(pred, list):
pred = torch.cat((pred[0], torch.argmax(pred[1], axis=1).unsqueeze(1)), dim=1)
pred = to_numpy_format(pred, self.axis_order_back)
del self._X
# Recover original shape if padded with check_downsample_division
pred = pred[:,:o_test_shape[1],:o_test_shape[2]]
if self.cfg.DATA.TEST.LOAD_GT: self._Y = self._Y[:,:o_test_shape[1],:o_test_shape[2]]
# Save image
if pred.ndim == 4 and self.cfg.PROBLEM.NDIM == '3D':
save_tif(np.expand_dims(pred,0), self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, self.processing_filenames,
verbose=self.cfg.TEST.VERBOSE)
else:
save_tif(pred, self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, self.processing_filenames, verbose=self.cfg.TEST.VERBOSE)
# Argmax if needed
if self.cfg.MODEL.N_CLASSES > 2 and self.cfg.DATA.TEST.ARGMAX_TO_OUTPUT:
pred = np.expand_dims(np.argmax(pred,-1), -1)
if self.cfg.DATA.TEST.LOAD_GT: self._Y = np.expand_dims(np.argmax(self._Y,-1), -1)
if self.cfg.TEST.POST_PROCESSING.APPLY_MASK:
pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS)
if self.cfg.DATA.TEST.LOAD_GT:
score = jaccard_index_numpy((self._Y>0.5).astype(np.uint8), (pred>0.5).astype(np.uint8))
self.stats['iou'] += score
self.stats['ov_iou'] += voc_calculation((self._Y>0.5).astype(np.uint8), (pred>0.5).astype(np.uint8), score)
else:
# load predictions from file
pred, _, _ = load_data_from_dir( self.cfg.PATHS.RESULT_DIR.FULL_IMAGE )
if pred.ndim == 5:
pred = np.squeeze( pred )
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
self.all_pred.append(pred)
if self.cfg.DATA.TEST.LOAD_GT: self.all_gt.append(self._Y)
self.after_full_image(pred)
[docs] def normalize_stats(self, image_counter):
"""
Normalize statistics.
Parameters
----------
image_counter : int
Number of images to average the metrics.
"""
# Per crop
self.stats['loss_per_crop'] = self.stats['loss_per_crop'] / self.stats['patch_counter'] if self.stats['patch_counter'] != 0 else 0
self.stats['iou_per_crop'] = self.stats['iou_per_crop'] / self.stats['patch_counter'] if self.stats['patch_counter'] != 0 else 0
# Merge patches
self.stats['iou_merge_patches'] = self.stats['iou_merge_patches'] / image_counter
self.stats['ov_iou_merge_patches'] = self.stats['ov_iou_merge_patches'] / image_counter
# Full image
self.stats['iou'] = self.stats['iou'] / image_counter
self.stats['loss'] = self.stats['loss'] / image_counter
self.stats['ov_iou'] = self.stats['ov_iou'] / image_counter
if self.post_processing['per_image']:
self.stats['iou_merge_patches_post'] = self.stats['iou_merge_patches_post'] / image_counter
self.stats['ov_iou_merge_patches_post'] = self.stats['ov_iou_merge_patches_post'] / image_counter
[docs] def print_stats(self, image_counter):
"""
Print statistics.
Parameters
----------
image_counter : int
Number of images to call ``normalize_stats``.
"""
self.normalize_stats(image_counter)
if self.cfg.DATA.TEST.LOAD_GT:
if not self.cfg.TEST.FULL_IMG:
print("Loss (per patch): {}".format(self.stats['loss_per_crop']))
print("Test Foreground IoU (per patch): {}".format(self.stats['iou_per_crop']))
print(" ")
print("Test Foreground IoU (merge patches): {}".format(self.stats['iou_merge_patches']))
print("Test Overall IoU (merge patches): {}".format(self.stats['ov_iou_merge_patches']))
print(" ")
else:
print("Loss (per image): {}".format(self.stats['loss']))
print("Test Foreground IoU (per image): {}".format(self.stats['iou']))
print("Test Overall IoU (per image): {}".format(self.stats['ov_iou']))
print(" ")
[docs] def print_post_processing_stats(self):
"""
Print post-processing statistics.
"""
if self.post_processing['per_image']:
print("Test Foreground IoU (merge patches - post-processing): {}".format(self.stats['iou_merge_patches_post']))
print("Test Overall IoU (merge patches - post-processing): {}".format(self.stats['ov_iou_merge_patches_post']))
print(" ")
if self.post_processing['as_3D_stack']:
print("Test Foreground IoU (as 3D stack - post-processing): {}".format(self.stats['iou_as_3D_stack_post']))
print("Test Overall IoU (as 3D stack - post-processing): {}".format(self.stats['ov_iou_as_3D_stack_post']))
print(" ")
[docs] @abstractmethod
def after_merge_patches(self, pred):
"""
Place any code that needs to be done after merging all predicted patches into the original image.
Parameters
----------
pred : Torch Tensor
Model prediction.
"""
raise NotImplementedError
[docs] def after_merge_patches_by_chunks_proccess_entire_pred(self, filename):
"""
Place any code that needs to be done after merging all predicted patches into the original image
but in the process made chunk by chunk. This function will operate over the entire predicted
image.
Parameters
----------
filename : List of str
Filename of the predicted image H5/Zarr.
"""
# Load H5/Zarr and convert it into numpy array
pred_file, pred = read_chunked_data(filename)
pred = np.squeeze(np.array(pred, dtype=self.dtype))
if self.cfg.TEST.BY_CHUNKS.FORMAT == "h5":
pred_file.close()
# Adjust shape
if pred.ndim < 3:
raise ValueError("Read image seems to be 2D: {}. Path: {}".format(pred.shape, filename))
if pred.ndim == 3:
pred = np.expand_dims(pred, -1)
else:
min_val = min(pred.shape)
channel_pos = pred.shape.index(min_val)
if channel_pos != 3 and pred.shape[channel_pos] <= 4:
new_pos = [x for x in range(4) if x != channel_pos]+[channel_pos,]
pred = pred.transpose(new_pos)
fname, file_extension = os.path.splitext(os.path.basename(filename))
self.processing_filenames = [fname+".tif"]
self.after_merge_patches(pred)
[docs] @abstractmethod
def after_merge_patches_by_chunks_proccess_patch(self, filename):
"""
Place any code that needs to be done after merging all predicted patches into the original image
but in the process made chunk by chunk. This function will operate patch by patch defined by
``DATA.PATCH_SIZE``.
Parameters
----------
filename : List of str
Filename of the predicted image H5/Zarr.
"""
raise NotImplementedError
[docs] @abstractmethod
def after_full_image(self, pred):
"""
Place here any code that must be executed after generating the prediction by supplying the entire image to the model.
To enable this, the model should be convolutional, and the image(s) should be in a 2D format. Using 3D images as
direct inputs to the model is not feasible due to their large size.
Parameters
----------
pred : Torch Tensor
Model prediction.
"""
raise NotImplementedError
[docs] def after_all_images(self):
"""
Place here any code that must be done after predicting all images.
"""
############################
### POST-PROCESSING (2D) ###
############################
if self.post_processing['as_3D_stack']:
self.all_pred = np.concatenate(self.all_pred)
self.all_gt = np.concatenate(self.all_gt) if self.cfg.DATA.TEST.LOAD_GT else None
save_tif(np.expand_dims(self.all_pred,0), self.cfg.PATHS.RESULT_DIR.AS_3D_STACK, verbose=self.cfg.TEST.VERBOSE)
save_tif(np.expand_dims((self.all_pred>0.5).astype(np.uint8),0), self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_BIN, verbose=self.cfg.TEST.VERBOSE)
self.all_pred, self.stats['iou_as_3D_stack_post'], self.stats['ov_iou_as_3D_stack_post'] = apply_post_processing(self.cfg, self.all_pred, self.all_gt)
save_tif(np.expand_dims(self.all_pred,0), self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_POST_PROCESSING, verbose=self.cfg.TEST.VERBOSE)
[docs]def insert_patch_into_dataset(data_filename, data_filename_mask, data_shape, output_queue, extract_info_queue, cfg,
dtype_str, dtype, file_type, verbose=False):
"""
Insert predicted patches (in ``output_queue``) in its original position in a H5/Zarr file. Each GPU will create
a file containing the part it has processed (as we can not write the same H5/Zarr file ar the same time). Then,
the main rank will create the final image. This function will be run by a child process created for every
test sample.
Parameters
----------
data_filename : Str or Numpy array
If str it will be consider a path to load a H5/Zarr file. If not, it will be considered as the
data to extract patches from.
data_shape : YACS configuration
Shape of the H5/Zarr file dataset to create.
output_queue : Multiprocessing queue
Queue to get each prediction from.
extract_info_queue : Multiprocessing queue
Auxiliary queue to pass information between processes.
cfg : YACS configuration
Running configuration.
dtype_str : str
Type of the H5/Zarr dataset to create.
dtype : Numpy dtype
Type of the H5/Zarr dataset to create. Only used if a TIF file is created by selected to do so
with ``TEST.BY_CHUNKS.SAVE_OUT_TIF`` variable.
verbose : bool, optional
To print useful information for debugging.
"""
if verbose and cfg.SYSTEM.NUM_GPUS > 1:
print(f"[Rank {get_rank()} ({os.getpid()})] In charge of inserting patches into data . . .")
if file_type == "h5":
fid = h5py.File(data_filename, "w")
fid_mask = h5py.File(data_filename_mask, "w")
else:
fid = zarr.open_group(data_filename, mode="w")
fid_mask = zarr.open_group(data_filename_mask, mode="w")
filename, file_extension = os.path.splitext(os.path.basename(data_filename))
# Obtain the total patches so we can display it for the user
total_patches = extract_info_queue.get(timeout=60)
for i in tqdm(range(total_patches), disable=not is_main_process()):
p, m, patch_coords = output_queue.get(timeout=60)
if 'data' not in locals():
# Channel dimension should be equal to the number of channel of the prediction
out_data_shape = np.array(data_shape)
if "C" not in cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER:
out_data_shape = tuple(out_data_shape) + (p.shape[-1],)
out_data_order = cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER + "C"
else:
out_data_shape[cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER.index("C")] = 1
out_data_shape = tuple(out_data_shape)
out_data_order = cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER
if file_type == "h5":
data = fid.create_dataset("data", out_data_shape, dtype=dtype_str, compression="gzip")
mask = fid_mask.create_dataset("data", out_data_shape, dtype=dtype_str, compression="gzip")
else:
data = fid.create_dataset("data", shape=out_data_shape, dtype=dtype_str)
mask = fid_mask.create_dataset("data", shape=out_data_shape, dtype=dtype_str)
# Adjust slices to calculate where to insert the predicted patch. This slice does not have into account the
# channel so any of them can be inserted
slices = (slice(patch_coords[0][0],patch_coords[0][1]),slice(patch_coords[1][0],patch_coords[1][1]),
slice(patch_coords[2][0],patch_coords[2][1]), slice(None))
data_ordered_slices = tuple(order_dimensions(slices, input_order="ZYXC", output_order=cfg.TEST.BY_CHUNKS.INPUT_IMG_AXES_ORDER,
default_value=0))
# Adjust patch slice to transpose it before inserting intop the final data
current_order = np.array(range(len(p.shape)))
transpose_order = order_dimensions(current_order, input_order="ZYXC", output_order=out_data_order,
default_value=np.nan)
transpose_order = [x for x in transpose_order if not np.isnan(x)]
data[data_ordered_slices] += p.transpose(transpose_order)
mask[data_ordered_slices] += m.transpose(transpose_order)
# Force flush after some iterations
if i % cfg.TEST.BY_CHUNKS.FLUSH_EACH == 0 and file_type == "h5":
fid.flush()
fid_mask.flush()
# Save image
if cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF and cfg.PATHS.RESULT_DIR.PER_IMAGE != "":
current_order = np.array(range(len(data.shape)))
transpose_order = order_dimensions(current_order, input_order=out_data_order,
output_order="TZYXC", default_value=np.nan)
transpose_order = [x for x in transpose_order if not np.isnan(x)]
data = np.array(data, dtype=dtype).transpose(transpose_order)
mask = np.array(mask, dtype=dtype).transpose(transpose_order)
if "T" not in out_data_order:
data = np.expand_dims(data,0)
mask = np.expand_dims(mask,0)
save_tif(data, cfg.PATHS.RESULT_DIR.PER_IMAGE, [filename+".tif"], verbose=verbose)
save_tif(mask, cfg.PATHS.RESULT_DIR.PER_IMAGE, [filename+"_mask.tif"], verbose=verbose)
if file_type == "h5":
fid.close()
fid_mask.close()
if verbose and cfg.SYSTEM.NUM_GPUS > 1:
print(f"[Rank {get_rank()} ({os.getpid()})] Finish inserting patches into data . . .")