"""
BiaPy data generators package.
This package provides data generator classes and utility functions for loading,
augmenting, and batching image and mask data for deep learning workflows in BiaPy.
It supports 2D and 3D data, chunked loading, distributed training, and advanced
augmentation pipelines.
"""
import torch
from typing import List, Dict, Any, Tuple, Optional
from torch.utils.data import (
DistributedSampler,
DataLoader,
SequentialSampler,
)
import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm
from yacs.config import CfgNode as CN
from biapy.data.pre_processing import calculate_volume_prob_map
from biapy.data.generators.pair_data_2D_generator import Pair2DImageDataGenerator
from biapy.data.generators.pair_data_3D_generator import Pair3DImageDataGenerator
from biapy.data.generators.single_data_2D_generator import Single2DImageDataGenerator
from biapy.data.generators.single_data_3D_generator import Single3DImageDataGenerator
from biapy.data.generators.test_pair_data_generators import test_pair_data_generator
from biapy.data.generators.test_single_data_generator import test_single_data_generator
from biapy.data.generators.chunked_test_pair_data_generator import chunked_test_pair_data_generator
from biapy.data.generators.chunked_workflow_process_generator import chunked_workflow_process_generator
from biapy.data.pre_processing import preprocess_data
from biapy.data.data_manipulation import save_tif
from biapy.data.dataset import BiaPyDataset
from biapy.utils.misc import get_rank, get_world_size, is_dist_avail_and_initialized, os_walk_clean, is_main_process
from biapy.models.bmz_utils import extract_BMZ_sample_and_cover
[docs]
def create_train_val_augmentors(
cfg: CN,
system_dict: Dict[str, Any],
X_train: BiaPyDataset,
X_val: BiaPyDataset,
norm_module: Dict,
Y_train: Optional[BiaPyDataset] = None,
Y_val: Optional[BiaPyDataset] = None,
) -> Tuple[DataLoader, DataLoader, int, NDArray, NDArray, NDArray]:
"""
Create training and validation generators.
Parameters
----------
cfg : Config
BiaPy configuration.
system_dict : dict
System dictionary containing:
* 'cpu_budget': int, Total CPU budget.
* 'cpu_per_rank': int, CPU budget per rank.
* 'main_threads': int, Number of main threads.
* 'num_workers_hint': int, Hint for the number of workers.
X_train : BiaPyDataset
Loaded train X data.
X_val : BiaPyDataset
Loaded train Y data.
norm_module : Dict
Normalization module that defines the normalization steps to apply.
Y_train : BiaPyDataset, optional
Loaded train Y data.
Y_val : BiaPyDataset, optional
Loaded validation Y data.
Returns
-------
train_generator : DataLoader
Training data generator.
val_generator : DataLoader
Validation data generator.
num_training_steps_per_epoch: int
Number of training steps per epoch.
bmz_input_sample : 4D Numpy array
Sample of the input data to be used for exporting the model to BMZ. Shape is ``(1, y, x, channels)`` for ``2D`` or ``(1, z, y, x, channels)`` for ``3D``.
cover_raw : 4D Numpy array
Sample of the raw cover data to be used for exporting the model to BMZ. Shape is ``(1, y, x, channels)`` for ``2D`` or ``(1, z, y, x, channels)`` for ``3D``.
cover_gt : 4D Numpy array
Sample of the GT cover data to be used for exporting the model to BMZ. Shape is ``(1, y, x, channels)`` for ``2D`` or ``(1, z, y, x, channels)`` for ``3D``.
"""
if cfg.PROBLEM.NDIM == "2D":
if cfg.PROBLEM.TYPE == "CLASSIFICATION" or (
cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK == "masking"
):
f_name = Single2DImageDataGenerator
else:
f_name = Pair2DImageDataGenerator
else:
if cfg.PROBLEM.TYPE == "CLASSIFICATION" or (
cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK == "masking"
):
f_name = Single3DImageDataGenerator
else:
f_name = Pair3DImageDataGenerator
ndim = 3 if cfg.PROBLEM.NDIM == "3D" else 2
if cfg.PROBLEM.TYPE == "CLASSIFICATION" or (
cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK == "masking"
):
r_shape = cfg.DATA.PATCH_SIZE
if cfg.MODEL.ARCHITECTURE == "efficientnet_b0" and cfg.DATA.PATCH_SIZE[:-1] != (
224,
224,
):
r_shape = (224, 224) + (cfg.DATA.PATCH_SIZE[-1],)
print("Changing patch size from {} to {} to use efficientnet_b0".format(cfg.DATA.PATCH_SIZE[:-1], r_shape))
dic = dict(
ndim=ndim,
X=X_train,
seed=cfg.SYSTEM.SEED,
da=cfg.AUGMENTOR.ENABLE,
da_prob=cfg.AUGMENTOR.DA_PROB,
rotation90=cfg.AUGMENTOR.ROT90,
rand_rot=cfg.AUGMENTOR.RANDOM_ROT,
rnd_rot_range=cfg.AUGMENTOR.RANDOM_ROT_RANGE,
shear=cfg.AUGMENTOR.SHEAR,
shear_range=cfg.AUGMENTOR.SHEAR_RANGE,
zoom=cfg.AUGMENTOR.ZOOM,
zoom_range=cfg.AUGMENTOR.ZOOM_RANGE,
zoom_in_z=cfg.AUGMENTOR.ZOOM_IN_Z,
shift=cfg.AUGMENTOR.SHIFT,
shift_range=cfg.AUGMENTOR.SHIFT_RANGE,
affine_mode=cfg.AUGMENTOR.AFFINE_MODE,
vflip=cfg.AUGMENTOR.VFLIP,
hflip=cfg.AUGMENTOR.HFLIP,
elastic=cfg.AUGMENTOR.ELASTIC,
e_alpha=cfg.AUGMENTOR.E_ALPHA,
e_sigma=cfg.AUGMENTOR.E_SIGMA,
e_mode=cfg.AUGMENTOR.E_MODE,
g_blur=cfg.AUGMENTOR.G_BLUR,
g_sigma=cfg.AUGMENTOR.G_SIGMA,
median_blur=cfg.AUGMENTOR.MEDIAN_BLUR,
mb_kernel=cfg.AUGMENTOR.MB_KERNEL,
motion_blur=cfg.AUGMENTOR.MOTION_BLUR,
motb_k_range=cfg.AUGMENTOR.MOTB_K_RANGE,
gamma_contrast=cfg.AUGMENTOR.GAMMA_CONTRAST,
gc_gamma=cfg.AUGMENTOR.GC_GAMMA,
dropout=cfg.AUGMENTOR.DROPOUT,
drop_range=cfg.AUGMENTOR.DROP_RANGE,
resize_shape=r_shape,
norm_module=norm_module,
convert_to_rgb=cfg.DATA.FORCE_RGB,
preprocess_f=preprocess_data if cfg.DATA.PREPROCESS.TRAIN else None,
preprocess_cfg=cfg.DATA.PREPROCESS if cfg.DATA.PREPROCESS.TRAIN else None,
)
else:
dic = dict(
ndim=ndim,
X=X_train,
Y=Y_train,
seed=cfg.SYSTEM.SEED,
da=cfg.AUGMENTOR.ENABLE,
da_prob=cfg.AUGMENTOR.DA_PROB,
rotation90=cfg.AUGMENTOR.ROT90,
rand_rot=cfg.AUGMENTOR.RANDOM_ROT,
rnd_rot_range=cfg.AUGMENTOR.RANDOM_ROT_RANGE,
shear=cfg.AUGMENTOR.SHEAR,
shear_range=cfg.AUGMENTOR.SHEAR_RANGE,
zoom=cfg.AUGMENTOR.ZOOM,
zoom_range=cfg.AUGMENTOR.ZOOM_RANGE,
zoom_in_z=cfg.AUGMENTOR.ZOOM_IN_Z,
shift=cfg.AUGMENTOR.SHIFT,
affine_mode=cfg.AUGMENTOR.AFFINE_MODE,
shift_range=cfg.AUGMENTOR.SHIFT_RANGE,
vflip=cfg.AUGMENTOR.VFLIP,
hflip=cfg.AUGMENTOR.HFLIP,
elastic=cfg.AUGMENTOR.ELASTIC,
e_alpha=cfg.AUGMENTOR.E_ALPHA,
e_sigma=cfg.AUGMENTOR.E_SIGMA,
e_mode=cfg.AUGMENTOR.E_MODE,
g_blur=cfg.AUGMENTOR.G_BLUR,
g_sigma=cfg.AUGMENTOR.G_SIGMA,
median_blur=cfg.AUGMENTOR.MEDIAN_BLUR,
mb_kernel=cfg.AUGMENTOR.MB_KERNEL,
motion_blur=cfg.AUGMENTOR.MOTION_BLUR,
motb_k_range=cfg.AUGMENTOR.MOTB_K_RANGE,
gamma_contrast=cfg.AUGMENTOR.GAMMA_CONTRAST,
gc_gamma=cfg.AUGMENTOR.GC_GAMMA,
brightness=cfg.AUGMENTOR.BRIGHTNESS,
brightness_factor=cfg.AUGMENTOR.BRIGHTNESS_FACTOR,
contrast=cfg.AUGMENTOR.CONTRAST,
contrast_factor=cfg.AUGMENTOR.CONTRAST_FACTOR,
dropout=cfg.AUGMENTOR.DROPOUT,
drop_range=cfg.AUGMENTOR.DROP_RANGE,
cutout=cfg.AUGMENTOR.CUTOUT,
cout_nb_iterations=cfg.AUGMENTOR.COUT_NB_ITERATIONS,
cout_size=cfg.AUGMENTOR.COUT_SIZE,
cout_cval=cfg.AUGMENTOR.COUT_CVAL,
cout_apply_to_mask=cfg.AUGMENTOR.COUT_APPLY_TO_MASK,
cutblur=cfg.AUGMENTOR.CUTBLUR,
cblur_size=cfg.AUGMENTOR.CBLUR_SIZE,
cblur_down_range=cfg.AUGMENTOR.CBLUR_DOWN_RANGE,
cblur_inside=cfg.AUGMENTOR.CBLUR_INSIDE,
cutmix=cfg.AUGMENTOR.CUTMIX,
cmix_size=cfg.AUGMENTOR.CMIX_SIZE,
cutnoise=cfg.AUGMENTOR.CUTNOISE,
cnoise_size=cfg.AUGMENTOR.CNOISE_SIZE,
cnoise_nb_iterations=cfg.AUGMENTOR.CNOISE_NB_ITERATIONS,
cnoise_scale=cfg.AUGMENTOR.CNOISE_SCALE,
misalignment=cfg.AUGMENTOR.MISALIGNMENT,
ms_displacement=cfg.AUGMENTOR.MS_DISPLACEMENT,
ms_rotate_ratio=cfg.AUGMENTOR.MS_ROTATE_RATIO,
missing_sections=cfg.AUGMENTOR.MISSING_SECTIONS,
missp_iterations=cfg.AUGMENTOR.MISSP_ITERATIONS,
missp_channel_pb=cfg.AUGMENTOR.MISSP_CHANNEL_PB,
grayscale=cfg.AUGMENTOR.GRAYSCALE,
channel_shuffle=cfg.AUGMENTOR.CHANNEL_SHUFFLE,
gridmask=cfg.AUGMENTOR.GRIDMASK,
grid_ratio=cfg.AUGMENTOR.GRID_RATIO,
grid_d_range=cfg.AUGMENTOR.GRID_D_RANGE,
grid_rotate=cfg.AUGMENTOR.GRID_ROTATE,
grid_invert=cfg.AUGMENTOR.GRID_INVERT,
gaussian_noise=cfg.AUGMENTOR.GAUSSIAN_NOISE,
gaussian_noise_mean=cfg.AUGMENTOR.GAUSSIAN_NOISE_MEAN,
gaussian_noise_var=cfg.AUGMENTOR.GAUSSIAN_NOISE_VAR,
gaussian_noise_use_input_img_mean_and_var=cfg.AUGMENTOR.GAUSSIAN_NOISE_USE_INPUT_IMG_MEAN_AND_VAR,
poisson_noise=cfg.AUGMENTOR.POISSON_NOISE,
salt=cfg.AUGMENTOR.SALT,
salt_amount=cfg.AUGMENTOR.SALT_AMOUNT,
pepper=cfg.AUGMENTOR.PEPPER,
pepper_amount=cfg.AUGMENTOR.PEPPER_AMOUNT,
salt_and_pepper=cfg.AUGMENTOR.SALT_AND_PEPPER,
salt_pep_amount=cfg.AUGMENTOR.SALT_AND_PEPPER_AMOUNT,
salt_pep_proportion=cfg.AUGMENTOR.SALT_AND_PEPPER_PROP,
shape=cfg.DATA.PATCH_SIZE,
resolution=cfg.DATA.TRAIN.RESOLUTION,
n_classes=cfg.DATA.N_CLASSES,
ignore_index=cfg.LOSS.IGNORE_INDEX,
norm_module=norm_module,
random_crop_scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING,
convert_to_rgb=cfg.DATA.FORCE_RGB,
preprocess_f=preprocess_data if cfg.DATA.PREPROCESS.TRAIN else None,
preprocess_cfg=cfg.DATA.PREPROCESS if cfg.DATA.PREPROCESS.TRAIN else None,
)
if cfg.PROBLEM.NDIM == "3D":
dic["zflip"] = cfg.AUGMENTOR.ZFLIP
if cfg.PROBLEM.TYPE == "INSTANCE_SEG":
dic["instance_problem"] = True
elif cfg.PROBLEM.TYPE == "DENOISING":
dic["n2v"] = True
dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX
dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR
dic["n2v_neighborhood_radius"] = cfg.PROBLEM.DENOISING.N2V_NEIGHBORHOOD_RADIUS
dic["n2v_structMask"] = (
np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]) if cfg.PROBLEM.DENOISING.N2V_STRUCTMASK else None
)
dic["n2v_load_gt"] = cfg.PROBLEM.DENOISING.LOAD_GT_DATA
print("Initializing train data generator . . .")
train_generator = f_name(**dic) # type: ignore
print("Initializing val data generator . . .")
if cfg.PROBLEM.TYPE == "CLASSIFICATION" or (
cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK == "masking"
):
val_generator = f_name(
ndim=ndim,
X=X_val,
seed=cfg.SYSTEM.SEED,
da=False,
resize_shape=r_shape,
norm_module=norm_module,
preprocess_f=preprocess_data if cfg.DATA.PREPROCESS.VAL else None,
preprocess_cfg=cfg.DATA.PREPROCESS if cfg.DATA.PREPROCESS.VAL else None,
)
else:
dic = dict(
ndim=ndim,
X=X_val,
Y=Y_val,
da=False,
shape=cfg.DATA.PATCH_SIZE,
val=True,
n_classes=cfg.DATA.N_CLASSES,
ignore_index=cfg.LOSS.IGNORE_INDEX,
seed=cfg.SYSTEM.SEED,
norm_module=norm_module,
resolution=cfg.DATA.VAL.RESOLUTION,
random_crop_scale=cfg.PROBLEM.SUPER_RESOLUTION.UPSCALING,
preprocess_f=preprocess_data if cfg.DATA.PREPROCESS.VAL else None,
preprocess_cfg=cfg.DATA.PREPROCESS if cfg.DATA.PREPROCESS.VAL else None,
)
if cfg.PROBLEM.TYPE == "INSTANCE_SEG":
dic["instance_problem"] = True
elif cfg.PROBLEM.TYPE == "DENOISING":
dic["n2v"] = True
dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX
dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR
dic["n2v_neighborhood_radius"] = cfg.PROBLEM.DENOISING.N2V_NEIGHBORHOOD_RADIUS
val_generator = f_name(**dic) # type: ignore
# Generate examples of data augmentation
if cfg.AUGMENTOR.AUG_SAMPLES and cfg.AUGMENTOR.ENABLE:
print("Creating generator samples . . .")
train_generator.get_transformed_samples(
cfg.AUGMENTOR.AUG_NUM_SAMPLES,
save_to_dir=True,
train=False,
out_dir=cfg.PATHS.DA_SAMPLES,
draw_grid=cfg.AUGMENTOR.DRAW_GRID,
)
# Training dataset
total_batch_size = cfg.TRAIN.BATCH_SIZE * get_world_size() * cfg.TRAIN.ACCUM_ITER
training_samples = len(train_generator)
# ---- Choose num_workers for this DataLoader ----
# Priority:
# 1) If user explicitly set SYSTEM.NUM_WORKERS != -1 => respect it
# 2) Else use the precomputed hint from startup
if cfg.SYSTEM.NUM_WORKERS != -1:
num_workers = max(0, int(cfg.SYSTEM.NUM_WORKERS))
else:
# Use the value computed earlier at startup
num_workers = int(system_dict.get("num_workers_hint", 0))
# Don't spawn more workers than samples (helps tiny datasets / edge cases)
num_workers = min(num_workers, training_samples) if training_samples > 0 else 0
# Ensure DataLoader workers don't each spawn many threads
def worker_init_fn(worker_id):
torch.set_num_threads(1)
# Set num_workers
if is_dist_avail_and_initialized() and cfg.SYSTEM.NUM_GPUS >= 1:
sampler_train = DistributedSampler(
train_generator,
num_replicas=get_world_size(),
rank=get_rank(),
shuffle=True
)
DataLoader_shuffle = False # IMPORTANT: shuffle must be False when sampler is used
else:
sampler_train = None
DataLoader_shuffle = True
num_training_steps_per_epoch = training_samples // total_batch_size
print(f"Train/val generators with {num_workers} workers")
print("Accumulate grad iterations: %d" % cfg.TRAIN.ACCUM_ITER)
print("Effective batch size: %d" % total_batch_size)
print("Sampler_train = %s" % str(sampler_train))
train_dataset = DataLoader(
train_generator,
shuffle=DataLoader_shuffle,
sampler=sampler_train,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=num_workers,
pin_memory=(cfg.SYSTEM.PIN_MEM and cfg.SYSTEM.NUM_GPUS > 0),
drop_last=False,
worker_init_fn=worker_init_fn,
persistent_workers=(num_workers > 0),
prefetch_factor=2 if num_workers > 0 else None,
)
# Save a sample to export the model to BMZ
bmz_input_sample = None
bmz_input_sample, mask_sample = train_generator.load_sample(0, first_load=True)
bmz_input_sample, cover_raw, cover_gt = extract_BMZ_sample_and_cover(
img=bmz_input_sample,
img_gt=mask_sample if not isinstance(mask_sample, int) else None,
patch_size=cfg.DATA.PATCH_SIZE,
is_3d=cfg.PROBLEM.NDIM == "3D",
input_axis_order=cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER,
)
bmz_input_sample = bmz_input_sample.astype(np.float32)
# Ensure dimensions
if cfg.PROBLEM.NDIM == "2D":
if bmz_input_sample.ndim == 3:
bmz_input_sample = np.expand_dims(bmz_input_sample, 0)
bmz_input_sample = bmz_input_sample.transpose(0, 3, 1, 2) # Numpy -> Torch
else: # 3D
if bmz_input_sample.ndim == 4:
bmz_input_sample = np.expand_dims(bmz_input_sample, 0)
bmz_input_sample = bmz_input_sample.transpose(0, 4, 1, 2, 3) # Numpy -> Torch
# Validation dataset
sampler_val = None
if cfg.DATA.VAL.DIST_EVAL:
if len(val_generator) % get_world_size() != 0:
print(
"Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. "
"This will slightly alter validation results as extra duplicate entries are added to achieve "
"equal num of samples per-process."
)
sampler_val = DistributedSampler(
val_generator,
num_replicas=get_world_size(),
rank=get_rank(),
shuffle=False
)
else:
sampler_val = SequentialSampler(val_generator)
# Don't spawn more workers than validation samples
val_samples = len(val_generator)
num_workers_val = min(num_workers, val_samples) if val_samples > 0 else 0
val_dataset = DataLoader(
val_generator,
sampler=sampler_val,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=num_workers_val,
pin_memory=(cfg.SYSTEM.PIN_MEM and cfg.SYSTEM.NUM_GPUS > 0),
drop_last=False,
shuffle=False,
worker_init_fn=worker_init_fn,
persistent_workers=(num_workers_val > 0),
prefetch_factor=2 if num_workers_val > 0 else None,
)
return train_dataset, val_dataset, num_training_steps_per_epoch, bmz_input_sample, cover_raw, cover_gt
[docs]
def create_test_generator(
cfg: CN,
X_test: Any,
Y_test: Any,
norm_module: Dict,
) -> Tuple[test_pair_data_generator | test_single_data_generator, NDArray, NDArray, NDArray]:
"""
Create test data generator.
Parameters
----------
cfg : Config
BiaPy configuration.
X_test : 4D Numpy array
Test data. E.g. ``(num_of_images, y, x, channels)`` for ``2D`` or ``(num_of_images, z, y, x, channels)`` for ``3D``.
Y_test : 4D Numpy array
Test data mask/class. E.g. ``(num_of_images, y, x, channels)`` for ``2D`` or ``(num_of_images, z, y, x, channels)`` for ``3D``
in all the workflows except classification. For this last the shape is ``(num_of_images, class)`` for both ``2D`` and ``3D``.
norm_module : Dict
Normalization module that defines the normalization steps to apply.
Returns
-------
test_generator : test_pair_data_generator/test_single_data_generator
Test data generator.
bmz_input_sample : 4D Numpy array
Sample of the input data to be used for exporting the model to BMZ. Shape is ``(1, y, x, channels)`` for ``2D`` or ``(1, z, y, x, channels)`` for ``3D``.
cover_raw : 4D Numpy array
Sample of the raw cover data to be used for exporting the model to BMZ. Shape is ``(1, y, x, channels)`` for ``2D`` or ``(1, z, y, x, channels)`` for ``3D``.
cover_gt : 4D Numpy array
Sample of the GT cover data to be used for exporting the model to BMZ. Shape is ``(1, y, x, channels)`` for ``2D`` or ``(1, z, y, x, channels)`` for ``3D``.
"""
if cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK == "masking":
provide_Y = False
else:
provide_Y = cfg.DATA.TEST.LOAD_GT or cfg.DATA.TEST.USE_VAL_AS_TEST
ndim: int = 3 if cfg.PROBLEM.NDIM == "3D" else 2
dic = dict(
X=X_test,
provide_Y=provide_Y,
ndim=ndim,
seed=cfg.SYSTEM.SEED,
norm_module=norm_module,
convert_to_rgb=cfg.DATA.FORCE_RGB,
filter_props=cfg.DATA.TEST.FILTER_SAMPLES.PROPS,
filter_vals=cfg.DATA.TEST.FILTER_SAMPLES.VALUES,
filter_signs=cfg.DATA.TEST.FILTER_SAMPLES.SIGNS,
preprocess_data=preprocess_data if cfg.DATA.PREPROCESS.TEST else None,
preprocess_cfg=cfg.DATA.PREPROCESS if cfg.DATA.PREPROCESS.TEST else None,
reflect_to_complete_shape=cfg.DATA.REFLECT_TO_COMPLETE_SHAPE,
data_shape=cfg.DATA.PATCH_SIZE,
)
if cfg.PROBLEM.TYPE == "CLASSIFICATION" or (
cfg.PROBLEM.TYPE == "SELF_SUPERVISED" and cfg.PROBLEM.SELF_SUPERVISED.PRETEXT_TASK == "masking"
):
gen_name = test_single_data_generator
r_shape = cfg.DATA.PATCH_SIZE
if cfg.MODEL.ARCHITECTURE == "efficientnet_b0" and cfg.DATA.PATCH_SIZE[:-1] != (
224,
224,
):
r_shape = (224, 224) + (cfg.DATA.PATCH_SIZE[-1],)
print("Changing patch size from {} to {} to use efficientnet_b0".format(cfg.DATA.PATCH_SIZE[:-1], r_shape))
if cfg.PROBLEM.TYPE == "CLASSIFICATION":
dic["crop_center"] = True
dic["data_shape"] = r_shape
else:
gen_name = test_pair_data_generator
dic["Y"] = Y_test
dic["test_by_chunks"] = cfg.TEST.BY_CHUNKS.ENABLE
dic["instance_problem"] = cfg.PROBLEM.TYPE == "INSTANCE_SEG"
dic["ignore_index"] = cfg.LOSS.IGNORE_INDEX
dic["n_classes"] = cfg.DATA.N_CLASSES
test_generator = gen_name(**dic)
# Save a sample to export the model to BMZ
bmz_input_sample = None
if gen_name == test_single_data_generator:
bmz_input_sample, _ , _, _, _ = test_generator.load_sample(0, first_load=True) # type: ignore
mask_sample = None
else:
bmz_input_sample, mask_sample, _, _, _, _ = test_generator.load_sample(0, first_load=True) # type: ignore
bmz_input_sample, cover_raw, cover_gt = extract_BMZ_sample_and_cover(
img=bmz_input_sample[0] if (isinstance(bmz_input_sample, np.ndarray) and not cfg.TEST.BY_CHUNKS.ENABLE) else bmz_input_sample,
img_gt=mask_sample[0] if (isinstance(mask_sample, np.ndarray) and not cfg.TEST.BY_CHUNKS.ENABLE) else mask_sample,
patch_size=cfg.DATA.PATCH_SIZE,
is_3d=cfg.PROBLEM.NDIM == "3D",
input_axis_order=cfg.DATA.TEST.INPUT_IMG_AXES_ORDER,
)
# Ensure dimensions
if cfg.PROBLEM.NDIM == "2D":
if bmz_input_sample.ndim == 3:
bmz_input_sample = np.expand_dims(bmz_input_sample, 0)
bmz_input_sample = bmz_input_sample.transpose(0, 3, 1, 2) # Numpy -> Torch
else: # 3D
if bmz_input_sample.ndim == 4:
bmz_input_sample = np.expand_dims(bmz_input_sample, 0)
bmz_input_sample = bmz_input_sample.transpose(0, 4, 1, 2, 3) # Numpy -> Torch
return test_generator, bmz_input_sample, cover_raw, cover_gt
[docs]
def by_chunks_collate_fn(data):
"""
Collate function to avoid the default one with type checking. It does nothing speciall but stack the images.
Parameters
----------
data : tuple
Data tuple.
Returns
-------
data : tuple
Stacked data in batches.
"""
return (
# torch.cat([torch.from_numpy(x[0]) for x in data]),
[x[0] for x in data],
np.stack([x[1] for x in data]),
np.stack([x[2] for x in data if x is not None]) if len(data) > 0 and data[0][2] is not None else None,
[x[3] for x in data],
[x[4] for x in data],
[x[5] for x in data],
)
[docs]
def create_chunked_test_generator(
cfg: CN,
system_dict: Dict[str, Any],
current_sample: Dict,
norm_module: Dict,
out_dir: str,
dtype_str: str,
) -> DataLoader:
"""
Create a DataLoader for chunked test data using chunked_test_pair_data_generator.
This function sets up a generator for efficient inference on large volumetric datasets
by processing data in chunks. It configures the generator with the appropriate axes,
patch size, padding, and normalization, and wraps it in a PyTorch DataLoader with
optimal worker settings for distributed or single-GPU environments.
Parameters
----------
cfg : CN
BiaPy configuration node.
system_dict : dict
System dictionary containing:
* 'cpu_budget': int, Total CPU budget.
* 'cpu_per_rank': int, CPU budget per rank.
* 'main_threads': int, Number of main threads.
* 'num_workers_hint': int, Hint for the number of workers.
current_sample : dict
Dictionary containing the sample to process (e.g., file pointers, data arrays).
norm_module : Dict
Normalization module to apply to the data.
out_dir : str
Output directory to save results.
dtype_str : str
Data type string for output files.
Returns
-------
test_dataset : DataLoader
PyTorch DataLoader wrapping the chunked test data generator.
"""
chunked_generator = chunked_test_pair_data_generator(
sample_to_process=current_sample,
norm_module=norm_module,
input_axes=cfg.DATA.TEST.INPUT_IMG_AXES_ORDER,
mask_input_axes=cfg.DATA.TEST.INPUT_MASK_AXES_ORDER,
crop_shape=cfg.DATA.PATCH_SIZE,
padding=cfg.DATA.TEST.PADDING,
out_dir=out_dir,
dtype_str=dtype_str,
n_classes=cfg.DATA.N_CLASSES,
ignore_index=cfg.LOSS.IGNORE_INDEX,
instance_problem=cfg.PROBLEM.TYPE == "INSTANCE_SEG",
z_start=cfg.TEST.BY_CHUNKS.Z_START,
z_end=cfg.TEST.BY_CHUNKS.Z_END,
)
# ---- Choose num_workers for this DataLoader ----
# Priority:
# 1) Respect explicit SYSTEM.NUM_WORKERS if set
# 2) Else reuse the precomputed hint from startup (system_dict["num_workers_hint"])
if cfg.SYSTEM.NUM_WORKERS != -1:
num_workers = max(0, int(cfg.SYSTEM.NUM_WORKERS))
else:
num_workers = int(system_dict.get("num_workers_hint", 0))
# Cap by dataset length if the generator supports __len__
try:
n_chunks = len(chunked_generator) # may raise TypeError if __len__ not implemented
if n_chunks > 0:
num_workers = min(num_workers, n_chunks)
else:
num_workers = 0
except TypeError:
# length unknown -> keep computed num_workers
pass
# Ensure DataLoader workers don't each spawn many threads
def worker_init_fn(worker_id):
torch.set_num_threads(1)
if is_main_process():
print(f"Chunked test generator with {num_workers} workers")
test_dataset = DataLoader(
chunked_generator,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=num_workers,
collate_fn=by_chunks_collate_fn,
pin_memory=(cfg.SYSTEM.PIN_MEM and cfg.SYSTEM.NUM_GPUS > 0),
drop_last=False,
worker_init_fn=worker_init_fn,
persistent_workers=(num_workers > 0),
prefetch_factor=2 if num_workers > 0 else None,
)
return test_dataset
[docs]
def by_chunks_workflow_collate_fn(data):
"""
Collate function to avoid the default one with type checking. It does nothing speciall but stack the images.
Parameters
----------
data : tuple
Data tuple.
Returns
-------
data : tuple
Stacked data in batches.
"""
return (
[x[0] for x in data],
[x[1] for x in data],
[x[2] for x in data],
)
[docs]
def create_chunked_workflow_process_generator(
cfg: CN,
system_dict: Dict[str, Any],
model_predictions: str,
out_dir: str,
dtype_str: str,
) -> DataLoader:
"""
Create a DataLoader for chunked test data using chunked_workflow_process_generator.
This function sets up a generator for efficient inference on large volumetric datasets
by processing data in chunks. It configures the generator with the appropriate axes,
patch size, padding, and normalization, and wraps it in a PyTorch DataLoader with
optimal worker settings for distributed or single-GPU environments.
Parameters
----------
cfg : CN
BiaPy configuration node.
system_dict : dict
System dictionary containing:
* 'cpu_budget': int, Total CPU budget.
* 'cpu_per_rank': int, CPU budget per rank.
* 'main_threads': int, Number of main threads.
* 'num_workers_hint': int, Hint for the number of workers.
model_predictions : str
Path to the model predictions to process.
out_dir : str
Output directory to save results.
dtype_str : str
Data type string for output files.
Returns
-------
test_dataset : DataLoader
PyTorch DataLoader wrapping the chunked test data generator.
"""
if "C" not in cfg.DATA.TEST.INPUT_IMG_AXES_ORDER:
out_data_order = cfg.DATA.TEST.INPUT_IMG_AXES_ORDER + "C"
else:
out_data_order = cfg.DATA.TEST.INPUT_IMG_AXES_ORDER
chunked_generator = chunked_workflow_process_generator(
model_predictions=model_predictions,
input_axes=out_data_order,
crop_shape=cfg.DATA.PATCH_SIZE,
out_dir=out_dir,
dtype_str=dtype_str,
z_start=cfg.TEST.BY_CHUNKS.Z_START,
z_end=cfg.TEST.BY_CHUNKS.Z_END,
)
# ---- Choose num_workers for this DataLoader ----
# Priority:
# 1) Respect explicit SYSTEM.NUM_WORKERS if set
# 2) Else reuse the precomputed hint from startup (system_dict["num_workers_hint"])
if cfg.SYSTEM.NUM_WORKERS != -1:
num_workers = max(0, int(cfg.SYSTEM.NUM_WORKERS))
else:
num_workers = int(system_dict.get("num_workers_hint", 0))
# Cap by dataset length if the generator supports __len__
try:
n_chunks = len(chunked_generator) # may raise TypeError if __len__ not implemented
if n_chunks > 0:
num_workers = min(num_workers, n_chunks)
else:
num_workers = 0
except TypeError:
# length unknown -> keep computed num_workers
pass
# Ensure DataLoader workers don't each spawn many threads
def worker_init_fn(worker_id):
torch.set_num_threads(1)
print(f"Chunked test generator with {num_workers} workers")
test_dataset = DataLoader(
chunked_generator,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=num_workers,
collate_fn=by_chunks_workflow_collate_fn,
pin_memory=(cfg.SYSTEM.PIN_MEM and cfg.SYSTEM.NUM_GPUS > 0),
drop_last=False,
worker_init_fn=worker_init_fn,
persistent_workers=(num_workers > 0),
prefetch_factor=2 if num_workers > 0 else None,
)
return test_dataset
[docs]
def check_generator_consistence(
gen: DataLoader, data_out_dir: str, mask_out_dir: str, filenames: List[str] | None = None
):
"""
Save all data of a generator in the given path.
Parameters
----------
gen : Pair2DImageDataGenerator/Single2DImageDataGenerator (2D) or Pair3DImageDataGenerator/Single3DImageDataGenerator (3D)
Generator to extract the data from.
data_out_dir : str
Path to store the generator data samples.
mask_out_dir : str
Path to store the generator data mask samples.
Filenames : List, optional
Filenames that should be used when saving each image.
"""
print("Check generator . . .")
it = iter(gen)
c = 0
for i in tqdm(range(len(gen))):
sample = next(it)
X_test, Y_test = sample
for k in range(len(X_test)):
fil = [filenames[c]] if filenames else ["sample_" + str(c) + ".tif"]
save_tif(np.expand_dims(X_test[k], 0), data_out_dir, fil, verbose=False)
save_tif(np.expand_dims(Y_test[k], 0), mask_out_dir, fil, verbose=False)
c += 1