"""
Base data generator for paired image and mask data in BiaPy.
This module provides the PairBaseDataGenerator class, which supports flexible
data loading, augmentation, and normalization for deep learning workflows.
It includes a wide range of augmentation options for both 2D and 3D data,
and is designed to work with BiaPyDataset objects and normalization modules.
"""
from typing import (
Tuple,
Literal,
Dict,
)
import numpy as np
import random
import torch
import os
import h5py
from tqdm import tqdm
from skimage.util import random_noise
from abc import ABCMeta, abstractmethod
from torch.utils.data import Dataset
from numpy.typing import NDArray
from biapy.data.generators.augmentors import *
from biapy.utils.misc import is_main_process, os_walk_clean
from biapy.data.data_manipulation import pad_and_reflect, load_img_data, extract_patch_within_image
from biapy.data.data_3D_manipulation import extract_patch_from_efficient_file
from biapy.data.dataset import BiaPyDataset
from biapy.data.norm import normalize_image, normalize_mask, update_mask_norm_info
[docs]
class PairBaseDataGenerator(Dataset, metaclass=ABCMeta):
"""
Custom BaseDataGenerator to transform paired image and mask data.
Parameters
----------
ndim : int
Dimensions of the data (``2`` for ``2D`` and ``3`` for 3D).
X : BiaPyDataset
X dataset.
Y : BiaPyDataset
Y dataset.
norm_module : Dict
Normalization module that defines the normalization steps to apply.
seed : int, optional
Seed for random functions.
da : bool, optional
To activate the data augmentation.
da_prob : float, optional
Probability of doing each transformation.
rotation90 : bool, optional
To make square (90, 180,270) degree rotations.
rand_rot : bool, optional
To make random degree range rotations.
rnd_rot_range : tuple of float, optional
Range of random rotations. E. g. ``(-180, 180)``.
shear : bool, optional
To make shear transformations.
shear_range : tuple of int, optional
Degree range to make shear. E. g. ``(-20, 20)``.
zoom : bool, optional
To make zoom on images.
zoom_range : tuple of floats, optional
Zoom range to apply. E. g. ``(0.8, 1.2)``.
zoom_in_z: bool, optional
Whether to apply or not zoom in Z axis.
shift : float, optional
To make shifts.
shift_range : tuple of float, optional
Range to make a shift. E. g. ``(0.1, 0.2)``.
affine_mode: str, optional
Method to use when filling in newly created pixels. Same meaning as in `skimage` (and `numpy.pad()`).
E.g. ``constant``, ``reflect`` etc.
vflip : bool, optional
To activate vertical flips.
hflip : bool, optional
To activate horizontal flips.
elastic : bool, optional
To make elastic deformations.
e_alpha : tuple of ints, optional
Strength of the distortion field. E. g. ``(240, 250)``.
e_sigma : int, optional
Standard deviation of the gaussian kernel used to smooth the distortion fields.
e_mode : str, optional
Parameter that defines the handling of newly created pixels with the elastic transformation.
g_blur : bool, optional
To insert gaussian blur on the images.
g_sigma : tuple of floats, optional
Standard deviation of the gaussian kernel. E. g. ``(1.0, 2.0)``.
median_blur : bool, optional
To blur an image by computing median values over neighbourhoods.
mb_kernel : tuple of ints, optional
Median blur kernel size. E. g. ``(3, 7)``.
motion_blur : bool, optional
Blur images in a way that fakes camera or object movements.
motb_k_range : int, optional
Kernel size to use in motion blur.
gamma_contrast : bool, optional
To insert gamma constrast changes on images.
gc_gamma : tuple of floats, optional
Exponent for the contrast adjustment. Higher values darken the image. E. g. ``(1.25, 1.75)``.
brightness : bool, optional
To aply brightness to the images as `PyTorch Connectomics
<https://github.com/zudi-lin/pytorch_connectomics/blob/master/connectomics/data/augmentation/grayscale.py>`_.
brightness_factor : tuple of 2 floats, optional
Strength of the brightness range, with valid values being ``0 <= brightness_factor <= 1``. E.g. ``(0.1, 0.3)``.
contrast : boolen, optional
To apply contrast changes to the images as `PyTorch Connectomics
<https://github.com/zudi-lin/pytorch_connectomics/blob/master/connectomics/data/augmentation/grayscale.py>`_.
contrast_factor : tuple of 2 floats, optional
Strength of the contrast change range, with valid values being ``0 <= contrast_factor <= 1``.
E.g. ``(0.1, 0.3)``.
dropout : bool, optional
To set a certain fraction of pixels in images to zero.
drop_range : tuple of floats, optional
Range to take a probability ``p`` to drop pixels. E.g. ``(0, 0.2)`` will take a ``p`` folowing ``0<=p<=0.2``
and then drop ``p`` percent of all pixels in the image (i.e. convert them to black pixels).
cutout : bool, optional
To fill one or more rectangular areas in an image using a fill mode.
cout_nb_iterations : tuple of ints, optional
Range of number of areas to fill the image with. E. g. ``(1, 3)``.
cout_size : tuple of floats, optional
Range to select the size of the areas in % of the corresponding image size. Values between ``0`` and ``1``.
E. g. ``(0.2, 0.4)``.
cout_cval : int, optional
Value to fill the area of cutout with.
cout_apply_to_mask : boolen, optional
Whether to apply cutout to the mask.
cutblur : boolean, optional
Blur a rectangular area of the image by downsampling and upsampling it again.
cblur_size : tuple of floats, optional
Range to select the size of the area to apply cutblur on. E. g. ``(0.2, 0.4)``.
cblur_inside : boolean, optional
If ``True`` only the region inside will be modified (cut LR into HR image). If ``False`` the ``50%`` of the
times the region inside will be modified (cut LR into HR image) and the other ``50%`` the inverse will be
done (cut HR into LR image). See Figure 1 of the official `paper <https://arxiv.org/pdf/2004.00448.pdf>`__.
cutmix : boolean, optional
Combine two images pasting a region of one image to another.
cmix_size : tuple of floats, optional
Range to select the size of the area to paste one image into another. E. g. ``(0.2, 0.4)``.
cnoise : boolean, optional
Randomly add noise to a cuboid region in the image.
cnoise_scale : tuple of floats, optional
Range to choose a value that will represent the % of the maximum value of the image that will be used as the
std of the Gaussian Noise distribution. E.g. ``(0.1, 0.2)``.
cnoise_nb_iterations : tuple of ints, optional
Number of areas with noise to create. E.g. ``(1, 3)``.
cnoise_size : tuple of floats, optional
Range to choose the size of the areas to transform. E.g. ``(0.2, 0.4)``.
misalignment : boolean, optional
To add miss-aligment augmentation.
ms_displacement : int, optional
Maximum pixel displacement in `xy`-plane for misalignment.
ms_rotate_ratio : float, optional
Ratio of rotation-based mis-alignment
missing_sections : boolean, optional
Augment the image by creating a black line in a random position.
missp_iterations : tuple of 2 ints, optional
Iterations to dilate the missing line with. E.g. ``(30, 40)``.
missp_channel_pb : float, optional
Probability of applying missing section to each channel. E.g. ``0.5``.
grayscale : bool, optional
Whether to augment images converting partially in grayscale.
gridmask : bool, optional
Whether to apply gridmask to the image. See the official `paper <https://arxiv.org/abs/2001.04086v1>`__ for
more information about it and its parameters.
grid_ratio : float, optional
Determines the keep ratio of an input image (``r`` in the original paper).
grid_d_range : tuple of floats, optional
Range to choose a ``d`` value. It represents the % of the image size. E.g. ``(0.4,1)``.
grid_rotate : float, optional
Rotation of the mask in GridMask. Needs to be between ``[0,1]`` where 1 is 360 degrees.
grid_invert : bool, optional
Whether to invert the mask of GridMask.
channel_shuffle : bool, optional
Whether to shuflle the channels of the images.
gaussian_noise : bool, optional
To apply Gaussian noise to the images.
gaussian_noise_mean : tuple of ints, optional
Mean of the Gaussian noise.
gaussian_noise_var : tuple of ints, optional
Variance of the Gaussian noise.
gaussian_noise_use_input_img_mean_and_var : bool, optional
Whether to use the mean and variance of the input image instead of ``gaussian_noise_mean``
and ``gaussian_noise_var``.
poisson_noise : bool, optional
To apply Poisson noise to the images.
salt : tuple of ints, optional
Mean of the gaussian noise.
salt_amount : tuple of ints, optional
Variance of the gaussian noise.
pepper : bool, optional
To apply poisson noise to the images.
pepper_amount : tuple of ints, optional
Mean of the gaussian noise.
salt_and_pepper : bool, optional
To apply poisson noise to the images.
salt_pep_amount : tuple of ints, optional
Mean of the gaussian noise.
salt_pep_proportion : bool, optional
To apply poisson noise to the images.
resolution : 2D tuple of floats, optional
Resolution of the given data ``(y,x)``. E.g. ``(8,8)``.
val : bool, optional
Advise the generator that the images will be to validate the model to not make random crops (as the val.
data must be the same on each epoch).
n_classes : int, optional
Number of classes.
ignore_index : int, optional
Value to ignore in the loss/metrics.
n2v : bool, optional
Whether to create `Noise2Void <https://openaccess.thecvf.com/content_CVPR_2019/papers/Krull_Noise2Void_-_Learning_Denoising_From_Single_Noisy_Images_CVPR_2019_paper.pdf>`__
mask. Used in DENOISING problem type.
n2v_perc_pix : float, optional
Input image pixels to be manipulated.
n2v_manipulator : str, optional
How to manipulate the input pixels. Most pixel manipulators will compute the replacement value based on a neighborhood.
Possible options: `normal_withoutCP`: samples the neighborhood according to a normal gaussian distribution, but without
the center pixel; `normal_additive`: adds a random number to the original pixel value. The random number is sampled from
a gaussian distribution with zero-mean and sigma = `n2v_neighborhood_radius` ; `normal_fitted`: uses a random value from
a gaussian normal distribution with mean equal to the mean of the neighborhood and standard deviation equal to the
standard deviation of the neighborhood ; `identity`: performs no pixel manipulation.
n2v_neighborhood_radius : int, optional
Neighborhood size to use when manipulating the values.
n2v_structMask : Array of ints, optional
Masking kernel for StructN2V to hide pixels adjacent to main blind spot. Value 1 = 'hidden', Value 0 = 'non hidden'.
Nested lists equivalent to ndarray. Must have odd length in each dimension (center pixel is blind spot). ``None``
implies normal N2V masking.
instance_problem : bool, optional
Advice the class that the workflow is of instance segmentation to divide the labels by channels.
random_crop_scale : tuple of ints, optional
Scale factor the mask used in super-resolution workflow. E.g. ``(2,2)``.
convert_to_rgb : bool, optional
In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are
converted into RGB.
preprocess_f : function, optional
The preprocessing function, is necessary in case you want to apply any preprocessing.
preprocess_cfg : dict, optional
Configuration parameters for preprocessing, is necessary in case you want to apply any preprocessing.
"""
def __init__(
self,
ndim: int,
X: BiaPyDataset,
Y: BiaPyDataset,
norm_module: Dict,
seed: int = 0,
da: bool = True,
da_prob: float = 0.5,
rotation90: bool = False,
rand_rot: bool = False,
rnd_rot_range: Tuple[int, int] = (-180, 180),
shear: bool = False,
shear_range: Tuple[int, int] = (-20, 20),
zoom: bool = False,
zoom_range: Tuple[float, float] = (0.8, 1.2),
zoom_in_z: bool = False,
shift: bool = False,
shift_range: Tuple[float, float] = (0.1, 0.2),
affine_mode: Literal["constant", "edge", "symmetric", "reflect", "wrap"] = "constant",
vflip: bool = False,
hflip: bool = False,
elastic: bool = False,
e_alpha: Tuple[int, int] = (240, 250),
e_sigma: int = 25,
e_mode: Literal["constant", "edge", "symmetric", "reflect", "wrap"] = "constant",
g_blur: bool = False,
g_sigma: Tuple[float, float] = (1.0, 2.0),
median_blur: bool = False,
mb_kernel: Tuple[int, int] = (3, 7),
motion_blur: bool = False,
motb_k_range: Tuple[int, int] = (3, 8),
gamma_contrast: bool = False,
gc_gamma: Tuple[float, float] = (1.25, 1.75),
brightness: bool = False,
brightness_factor: Tuple[int, int] = (1, 3),
brightness_mode: str = "2D",
contrast: bool = False,
contrast_factor: Tuple[int, int] = (1, 3),
contrast_mode: str = "2D",
dropout: bool = False,
drop_range: Tuple[float, float] = (0.0, 0.2),
cutout: bool = False,
cout_nb_iterations: Tuple[int, int] = (1, 3),
cout_size: Tuple[float, float] = (0.2, 0.4),
cout_cval: int = 0,
cout_apply_to_mask: bool = False,
cutblur: bool = False,
cblur_size: Tuple[float, float] = (0.1, 0.5),
cblur_down_range: Tuple[int, int] = (2, 8),
cblur_inside: bool = True,
cutmix: bool = False,
cmix_size: Tuple[float, float] = (0.2, 0.4),
cutnoise: bool = False,
cnoise_scale: Tuple[float, float] = (0.1, 0.2),
cnoise_nb_iterations: Tuple[int, int] = (1, 3),
cnoise_size: Tuple[float, float] = (0.2, 0.4),
misalignment: bool = False,
ms_displacement: int = 16,
ms_rotate_ratio: float = 0.0,
missing_sections: bool = False,
missp_iterations: Tuple[int, int] = (30, 40),
missp_channel_pb: float = 0.5,
grayscale: bool = False,
channel_shuffle: bool = False,
gridmask: bool = False,
grid_ratio: float = 0.6,
grid_d_range: Tuple[float, float] = (0.4, 1.0),
grid_rotate: int = 1,
grid_invert: bool = False,
gaussian_noise: bool = False,
gaussian_noise_mean: int = 0,
gaussian_noise_var: float = 0.01,
gaussian_noise_use_input_img_mean_and_var: bool = False,
poisson_noise: bool = False,
salt: bool = False,
salt_amount: float = 0.05,
pepper: bool = False,
pepper_amount: float = 0.05,
salt_and_pepper: bool = False,
salt_pep_amount: float = 0.05,
salt_pep_proportion: float = 0.5,
shape: Tuple[int, int, int] = (256, 256, 1),
resolution: Tuple[int, ...] = (-1,),
val: bool = False,
n_classes: int = 1,
ignore_index: Optional[int] = None,
n2v: bool = False,
n2v_perc_pix: float = 0.198,
n2v_manipulator="uniform_withCP",
n2v_neighborhood_radius: int = 5,
n2v_structMask=np.array([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]),
n2v_load_gt: bool = False,
instance_problem: bool = False,
random_crop_scale: Tuple[int, ...] = (1, 1),
convert_to_rgb: bool = False,
preprocess_f=None,
preprocess_cfg=None,
):
"""
Initialize the PairBaseDataGenerator.
Sets up data sources, normalization, augmentation options, and preprocessing
for paired image and mask data.
Parameters
----------
See class docstring for full parameter list.
"""
if preprocess_f and preprocess_cfg is None:
raise ValueError("'preprocess_cfg' needs to be provided with 'preprocess_f'")
self.ndim = ndim
self.z_size = -1
self.val = val
self.convert_to_rgb = convert_to_rgb
self.norm_module = norm_module.copy()
self.preprocess_f = preprocess_f
self.preprocess_cfg = preprocess_cfg
self.random_crop_func = random_3D_crop_pair if ndim == 3 else random_crop_pair
# Super-resolution options
self.random_crop_scale = random_crop_scale
sshape = X.sample_list[0].get_shape()
if sshape and len(sshape) != ndim:
raise ValueError(
"Samples in X must be have {} dimensions. Provided: {}".format(ndim, X.sample_list[0].get_shape())
)
sshape = Y.sample_list[0].get_shape()
if sshape and len(sshape) != ndim:
raise ValueError(
"Samples in Y must be have {} dimensions. Provided: {}".format(ndim, Y.sample_list[0].get_shape())
)
if rotation90 and rand_rot:
print("Warning: you selected double rotation type. Maybe you should set only 'rand_rot'?")
self.X = X
self.Y = Y
self.length = len(self.X.sample_list)
self.real_length = self.length
self.no_bin_channel_found = False
self.shape = shape
# X data analysis
img, _ = self.load_sample(0, first_load=True)
if norm_module["type"] in ["div", "scale_range"]:
if shape[-1] != img.shape[-1]:
raise ValueError(
"Channel of the patch size given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(shape[-1], img.shape[-1])
)
self.Y_channels = img.shape[-1]
xnorm_info = self.X.dataset_info[self.X.sample_list[0].fid].norm_info
if xnorm_info is None:
xnorm_info = self.norm_module
img, xnorm_example = normalize_image(img, norm_module=xnorm_info)
del img
# Y data analysis
# Loop over a few masks to ensure foreground class is present to decide normalization
self.mask_norm = None
if self.norm_module["mask_norm"] == "as_mask":
print("Checking which channel of the mask needs normalization . . .")
n_samples = min(50, len(self.X.sample_list))
if instance_problem:
n_samples = 1
for i in tqdm(range(n_samples), total=n_samples):
_, mask = self.load_sample(i, first_load=True)
# Store which channels are binary or not (e.g. distance transform channel is not binary)
mask, new_mask_norm = normalize_mask(
mask,
norm_module=self.norm_module,
n_classes=n_classes,
ignore_index=ignore_index,
instance_problem=instance_problem,
apply_norm=False
)
if self.mask_norm is None:
self.mask_norm = new_mask_norm
else:
self.mask_norm = update_mask_norm_info(self.mask_norm, new_mask_norm)
# Check if any channel is not binary to set no_bin_channel_found to True
self.no_bin_channel_found = any(
self.mask_norm["per_channel_info"][j]["type"] == "no_bin"
for j in range(len(self.mask_norm["per_channel_info"]))
)
else:
self.mask_norm = self.norm_module
_, mask = self.load_sample(0)
self.Y_channels = mask.shape[-1]
self.Y_dtype = mask.dtype
del mask
print("Normalization config used for X (first sample): {}".format(xnorm_info))
print("Normalization config used for Y: {}".format(self.mask_norm))
if self.ndim == 2:
resolution = tuple(resolution[i] for i in [1, 0]) # y, x -> x, y
self.res_relation = (1.0, resolution[0] / resolution[1])
else:
resolution = tuple(resolution[i] for i in [2, 1, 0]) # z, y, x -> x, y, z
self.res_relation = (
1.0,
resolution[0] / resolution[1],
resolution[0] / resolution[2],
)
self.resolution = resolution
self.o_indexes = np.arange(self.length)
self.n_classes = n_classes
self.da = da
self.da_prob = da_prob
self.cutout = cutout
self.cout_nb_iterations = cout_nb_iterations
self.cout_size = cout_size
self.cout_cval = cout_cval
self.cout_apply_to_mask = cout_apply_to_mask
self.cutblur = cutblur
self.cblur_size = cblur_size
self.cblur_down_range = cblur_down_range
self.cblur_inside = cblur_inside
self.cutmix = cutmix
self.cmix_size = cmix_size
self.cutnoise = cutnoise
self.cnoise_scale = cnoise_scale
self.cnoise_nb_iterations = cnoise_nb_iterations
self.cnoise_size = cnoise_size
self.misalignment = misalignment
self.ms_displacement = ms_displacement
self.ms_rotate_ratio = ms_rotate_ratio
self.brightness = brightness
self.contrast = contrast
self.missing_sections = missing_sections
self.missp_iterations = missp_iterations
self.missp_channel_pb = missp_channel_pb
self.grayscale = grayscale
self.gridmask = gridmask
self.grid_ratio = grid_ratio
self.grid_d_range = grid_d_range
self.grid_rotate = grid_rotate
self.grid_invert = grid_invert
self.grid_d_size = (
self.shape[0] * grid_d_range[0],
self.shape[1] * grid_d_range[1],
)
self.channel_shuffle = channel_shuffle
self.gaussian_noise = gaussian_noise
self.gaussian_noise_mean = gaussian_noise_mean
self.gaussian_noise_var = gaussian_noise_var
self.gaussian_noise_use_input_img_mean_and_var = gaussian_noise_use_input_img_mean_and_var
self.poisson_noise = poisson_noise
self.salt = salt
self.salt_amount = salt_amount
self.pepper = pepper
self.pepper_amount = pepper_amount
self.salt_and_pepper = salt_and_pepper
self.salt_pep_amount = salt_pep_amount
self.salt_pep_proportion = salt_pep_proportion
self.rand_rot = rand_rot
self.rnd_rot_range = rnd_rot_range
self.rotation90 = rotation90
self.affine_mode = affine_mode
self.zoom = zoom
self.zoom_range = zoom_range
self.zoom_in_z = zoom_in_z
self.gamma_contrast = gamma_contrast
self.gc_gamma = gc_gamma
self.elastic = elastic
self.shear = shear
self.shift = shift
self.vflip = vflip
self.hflip = hflip
self.g_blur = g_blur
self.median_blur = median_blur
self.motion_blur = motion_blur
self.dropout = dropout
self.drop_range = drop_range
self.e_alpha = e_alpha
self.e_sigma = e_sigma
self.e_mode = e_mode
self.shear_range = shear_range
self.shift_range = shift_range
self.affine_mode = affine_mode
self.g_sigma = g_sigma
self.mb_kernel = mb_kernel
self.motb_k_range = motb_k_range
self.brightness_factor = brightness_factor
self.contrast_factor = contrast_factor
# Instance segmentation options
self.instance_problem = instance_problem
# Denoising options
self.n2v = n2v
self.val = val
if self.n2v:
from biapy.engine.denoising import (
get_stratified_coords2D,
get_stratified_coords3D,
get_value_manipulation,
apply_structN2Vmask,
apply_structN2Vmask3D,
)
self.box_size = int(np.round(np.sqrt(100 / n2v_perc_pix)))
self.get_stratified_coords = get_stratified_coords2D if self.ndim == 2 else get_stratified_coords3D
self.value_manipulation = get_value_manipulation(n2v_manipulator, n2v_neighborhood_radius)
self.n2v_structMask = n2v_structMask
self.n2v_load_gt = n2v_load_gt
self.apply_structN2Vmask_func = apply_structN2Vmask if self.ndim == 2 else apply_structN2Vmask3D
self.da_options = []
self.trans_made = ""
if rotation90:
self.trans_made += "_rot[90,180,270]"
if rand_rot:
self.trans_made += "_rrot" + str(rnd_rot_range)
if shear:
self.trans_made += "_shear" + str(shear_range)
if zoom:
self.trans_made += "_zoom" + str(zoom_range) + "+" + str(zoom_in_z)
if shift:
self.trans_made += "_shift" + str(shift_range)
if vflip:
self.trans_made += "_vflip"
if hflip:
self.trans_made += "_hflip"
if elastic:
self.trans_made += "_elastic" + str(e_alpha) + "+" + str(e_sigma) + "+" + str(e_mode)
if g_blur:
self.trans_made += "_gblur" + str(g_sigma)
if median_blur:
self.trans_made += "_mblur" + str(mb_kernel)
if motion_blur:
self.trans_made += "_motb" + str(motb_k_range)
if gamma_contrast:
self.trans_made += "_gcontrast" + str(gc_gamma)
if brightness:
self.trans_made += "_brightness" + str(brightness_factor)
if contrast:
self.trans_made += "_contrast" + str(contrast_factor)
if dropout:
self.trans_made += "_drop" + str(drop_range)
if grayscale:
self.trans_made += "_gray"
if gridmask:
self.trans_made += (
"_gridmask"
+ str(self.grid_ratio)
+ "+"
+ str(self.grid_d_range)
+ "+"
+ str(self.grid_rotate)
+ "+"
+ str(self.grid_invert)
)
if channel_shuffle:
self.trans_made += "_chshuffle"
if cutout:
self.trans_made += (
"_cout"
+ str(cout_nb_iterations)
+ "+"
+ str(cout_size)
+ "+"
+ str(cout_cval)
+ "+"
+ str(cout_apply_to_mask)
)
if cutblur:
self.trans_made += "_cblur" + str(cblur_size) + "+" + str(cblur_down_range) + "+" + str(cblur_inside)
if cutmix:
self.trans_made += "_cmix" + str(cmix_size)
if cutnoise:
self.trans_made += "_cnoi" + str(cnoise_scale) + "+" + str(cnoise_nb_iterations) + "+" + str(cnoise_size)
if misalignment:
self.trans_made += "_msalg" + str(ms_displacement) + "+" + str(ms_rotate_ratio)
if missing_sections:
self.trans_made += "_missp" + "+" + str(missp_iterations)
if gaussian_noise:
self.trans_made += "_gausnoise" + "+" + str(gaussian_noise_mean) + "+" + str(gaussian_noise_var)
if poisson_noise:
self.trans_made += "_poisnoise"
if salt:
self.trans_made += "_salt" + "+" + str(salt_amount)
if pepper:
self.trans_made += "_pepper" + "+" + str(pepper_amount)
if salt_and_pepper:
self.trans_made += "_salt_and_pepper" + "+" + str(salt_pep_amount) + "+" + str(salt_pep_proportion)
self.trans_made = self.trans_made.replace(" ", "")
self.seed = seed
random.seed(seed)
self.indexes = self.o_indexes.copy()
[docs]
@abstractmethod
def save_aug_samples(
self,
img: NDArray,
mask: NDArray,
orig_images: Dict,
i: int,
pos: int,
out_dir: str,
):
"""
Save transformed samples in order to check the generator.
Parameters
----------
img : 3D/4D Numpy array
Image to use as sample. E.g. ``(y, x, channels)`` for ``2D`` and ``(z, y, x, channels)`` for ``3D``.
mask : 3D/4D Numpy array
Mask to use as sample. E.g. ``(y, x, channels)`` for ``2D`` and ``(z, y, x, channels)`` for ``3D``.
orig_images : dict
Dict where the original image and mask are saved in "o_x" and "o_y", respectively.
i : int
Number of the sample within the transformed ones.
pos : int
Number of the sample within the dataset.
out_dir : str
Directory to save the images.
"""
raise NotImplementedError
def __len__(self):
"""Define the number of samples per epoch."""
return self.length
[docs]
def load_sample(self, _idx: int, first_load: bool = False) -> Tuple[NDArray, NDArray]:
"""
Load one data sample given its corresponding index.
Parameters
----------
_idx : int
Sample index counter.
first_load : bool, optional
Whether its the first time a sample is loaded to prevent normalizing it.
Returns
-------
img : 3D/4D Numpy array
X element. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``.
mask : 3D/4D Numpy array
Y element. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``.
"""
idx = _idx % self.real_length
sample = self.X.sample_list[idx]
# X data
if sample.img_is_loaded():
img = sample.img.copy()
else:
img, img_file = load_img_data(
self.X.dataset_info[sample.fid].path,
is_3d=(self.ndim == 3),
data_within_zarr_path=sample.get_path_in_zarr(),
)
if not self.X.dataset_info[sample.fid].is_parallel():
# Apply preprocessing
if self.preprocess_f:
img = self.preprocess_f(self.preprocess_cfg, x_data=[img], is_2d=(self.ndim == 2))[0]
img = pad_and_reflect(img, self.shape, verbose=False)
# Extract the sample within the image
if sample.coords:
img = extract_patch_within_image(img, sample.coords, is_3d=(self.ndim == 3))
else:
coords = sample.coords
data_axes_order = self.X.dataset_info[sample.fid].get_input_axes()
assert coords is not None and data_axes_order is not None
img = extract_patch_from_efficient_file(img, coords, data_axes_order=data_axes_order)
# Apply preprocessing after extract sample
if self.preprocess_f:
img = self.preprocess_f(self.preprocess_cfg, x_data=[img], is_2d=(self.ndim == 2))[0]
if isinstance(img_file, h5py.File):
img_file.close()
# Y data
# "gt_associated_id" available only in PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER
gt_id = sample.get_gt_associated_id()
if gt_id is not None:
msample = self.Y.sample_list[gt_id]
mask, _ = load_img_data(
self.Y.dataset_info[msample.fid].path,
is_3d=(self.ndim == 3),
)
# Extract the sample within the image
if msample.coords:
mask = extract_patch_within_image(mask, msample.coords, is_3d=(self.ndim == 3))
else:
msample = self.Y.sample_list[idx]
if msample.img_is_loaded():
mask = msample.img.copy()
else:
fid = msample.fid
mask, mask_file = load_img_data(
self.Y.dataset_info[fid].path,
is_3d=(self.ndim == 3),
data_within_zarr_path=msample.get_path_in_zarr(),
)
if not self.Y.dataset_info[msample.fid].is_parallel():
# Apply preprocessing
if self.preprocess_f:
mask = self.preprocess_f(
self.preprocess_cfg, y_data=[mask], is_2d=(self.ndim == 2), is_y_mask=True
)[0]
mask = pad_and_reflect(mask, self.shape, verbose=False)
# Extract the sample within the image
if msample.coords:
coords = msample.coords
assert coords is not None
mask = extract_patch_within_image(mask, coords, is_3d=(self.ndim == 3))
else:
coords = msample.coords
data_axes_order = self.Y.dataset_info[msample.fid].get_input_axes()
assert coords is not None and data_axes_order is not None
mask = extract_patch_from_efficient_file(
mask,
coords,
data_axes_order=data_axes_order,
)
# Apply preprocessing after extract sample
if self.preprocess_f:
mask = self.preprocess_f(
self.preprocess_cfg, y_data=[mask], is_2d=(self.ndim == 2), is_y_mask=True
)[0]
if self.Y.dataset_info[msample.fid].is_parallel():
if mask_file and isinstance(mask_file, h5py.File):
mask_file.close()
# Apply random crops if it is selected
if sample.coords is None:
img, mask = self.random_crop_func( # type: ignore
img,
mask,
self.shape[: self.ndim],
self.val,
scale=self.random_crop_scale,
)
if not first_load:
xnorm_info = self.X.dataset_info[sample.fid].norm_info
if xnorm_info is None:
xnorm_info = self.norm_module
img, _ = normalize_image(img, norm_module=xnorm_info)
mask, _ = normalize_mask(mask, norm_module=self.mask_norm)
assert isinstance(img, np.ndarray) and isinstance(mask, np.ndarray)
if self.convert_to_rgb:
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
if self.norm_module["mask_norm"] == "as_image" and mask.shape[-1] == 1:
mask = np.repeat(mask, 3, axis=-1)
return img, mask
[docs]
def getitem(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate one pair of data.
Parameters
----------
index : int
Index counter.
Returns
-------
item : 3D/4D Torch tensors
X and Y (if avail) elements. Each one shape is ``(z, y, x, channels)`` if ``2D`` or ``(y, x, channels)``
if ``3D``.
"""
return self.__getitem__(index)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate one pair of data.
Parameters
----------
index : int
Index counter.
Returns
-------
img : 3D/4D Torch tensor
X element, for instance, an image. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``.
mask : 3D/4D Torch tensor
Y element, for instance, a mask. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``.
"""
img, mask = self.load_sample(index)
# Apply transformations
if self.da:
e_img, e_mask = None, None
if self.cutmix:
extra_img = np.random.randint(0, self.length - 1) if self.length > 2 else 0
e_img, e_mask = self.load_sample(extra_img)
img, mask = self.apply_transform(img, mask, e_im=e_img, e_mask=e_mask)
# Prepare mask when denoising with Noise2Void
if self.n2v:
img, mask = self.prepare_n2v(img, mask)
# If no normalization was applied, as is done with torchvision models, it can be an image of uint16
# so we need to convert it to
if img.dtype == np.uint16:
img = torch.from_numpy(img.astype(np.float32))
else:
img = torch.from_numpy(img.copy())
mask = torch.from_numpy(mask.copy())
return img, mask
[docs]
def draw_grid(self, im: NDArray, grid_width: Optional[int] = None) -> NDArray:
"""
Draw grid of the specified size on an image.
Parameters
----------
im : 3D/4D Numpy array
Image to draw the grid into. E.g. ``(y, x, channels)`` in ``2D`` or ``(z, y, x, channels)`` in ``3D``.
grid_width : int, optional
Grid's width.
"""
vmax = []
for c in range(im.shape[-1]):
vmax.append(np.max(im[...,c]))
if grid_width is not None and grid_width > 0:
grid_y = grid_width
grid_x = grid_width
else:
grid_y = im.shape[self.ndim - 2] // 5
grid_x = im.shape[self.ndim - 2] // 5
if self.ndim == 2:
for i in range(0, im.shape[0], grid_y):
im[i] = vmax
for j in range(0, im.shape[1], grid_x):
im[:, j] = vmax
else:
for k in range(0, im.shape[0]):
for i in range(0, im.shape[2], grid_x):
im[k, :, i] = vmax
for j in range(0, im.shape[1], grid_y):
im[k, j] = vmax
return im
[docs]
def prepare_n2v(self, _img: NDArray, _mask: NDArray) -> Tuple[NDArray, NDArray]:
"""
Create Noise2Void mask.
Parameters
----------
_img : 3D/4D Numpy array
Image to wipe some pixels from. E.g. ``(y, x, channels)`` in ``2D`` or ``(z, y, x, channels)`` in ``3D``.
_mask : 3D/4D Numpy array
Mask to use values from. Only used when the ground truth is loaded with ``self.n2v_load_gt``.
E.g. ``(y, x, channels)`` in ``2D`` or ``(z, y, x, channels)`` in ``3D``.
Returns
-------
img : 3D/4D Numpy array
Input image modified removing some pixels. E.g. ``(y, x, channels)`` in ``2D`` or ``(y, x, z, channels)`` in ``3D``.
mask : 3D/4D Numpy array
Noise2Void mask created. E.g. ``(y, x, channels)`` in ``2D`` or ``(y, x, z, channels)`` in ``3D``.
"""
img = _img.copy()
mask = np.zeros(img.shape[:-1] + (img.shape[-1] * 2,), dtype=np.float32)
if self.val:
np.random.seed(0)
for c in range(self.Y_channels):
coords = self.get_stratified_coords(box_size=self.box_size, shape=self.shape)
indexing = coords + (c,)
indexing_mask = coords + (c + self.Y_channels,)
if self.n2v_load_gt:
y_val = _mask[indexing]
else:
y_val = img[indexing]
x_val = self.value_manipulation(img[..., c], coords, self.ndim, self.n2v_structMask)
mask[indexing] = y_val
mask[indexing_mask] = 1
img[indexing] = x_val
if self.n2v_structMask is not None:
self.apply_structN2Vmask_func(img[..., c], coords, self.n2v_structMask)
return img, mask