"""
Single image data generator for BiaPy.
This module provides the SingleBaseDataGenerator class, which supports flexible
data loading, augmentation, and normalization for single image data in deep learning
workflows. It includes a wide range of augmentation options for both 2D and 3D data,
and is designed to work with BiaPyDataset objects and normalization modules.
"""
from abc import ABCMeta, abstractmethod
import torch
from torch.utils.data import Dataset
import numpy as np
import random
import h5py
from tqdm import tqdm
from typing import (
Tuple,
Literal,
Dict,
)
from numpy.typing import NDArray
from biapy.data.generators.augmentors import *
from biapy.utils.misc import is_main_process
from biapy.data.data_manipulation import load_img_data
from biapy.data.data_3D_manipulation import extract_patch_from_efficient_file
from biapy.data.dataset import BiaPyDataset
from biapy.data.norm import normalize_image
[docs]
class SingleBaseDataGenerator(Dataset, metaclass=ABCMeta):
"""
Custom BaseDataGenerator to transform single image data.
Parameters
----------
ndim : int
Dimensions of the data (``2`` for 2D and ``3`` for 3D).
X : BiaPyDataset
X data.
norm_module : Dict
Normalization module that defines the normalization steps to apply.
seed : int, optional
Seed for random functions.
da : bool, optional
To activate the data augmentation.
da_prob : float, optional
Probability of doing each transformation.
rotation90 : bool, optional
To make square (90, 180,270) degree rotations.
rand_rot : bool, optional
To make random degree range rotations.
rnd_rot_range : tuple of float, optional
Range of random rotations. E. g. ``(-180, 180)``.
shear : bool, optional
To make shear transformations.
shear_range : tuple of int, optional
Degree range to make shear. E. g. ``(-20, 20)``.
zoom : bool, optional
To make zoom on images.
zoom_range : tuple of floats, optional
Zoom range to apply. E. g. ``(0.8, 1.2)``.
zoom_in_z: bool, optional
Whether to apply or not zoom in Z axis.
shift : float, optional
To make shifts.
shift_range : tuple of float, optional
Range to make a shift. E. g. ``(0.1, 0.2)``.
affine_mode: str, optional
Method to use when filling in newly created pixels. Same meaning as in `skimage` (and `numpy.pad()`).
E.g. ``constant``, ``reflect`` etc.
vflip : bool, optional
To activate vertical flips.
hflip : bool, optional
To activate horizontal flips.
elastic : bool, optional
To make elastic deformations.
e_alpha : tuple of ints, optional
Strength of the distortion field. E. g. ``(240, 250)``.
e_sigma : int, optional
Standard deviation of the gaussian kernel used to smooth the distortion fields.
e_mode : str, optional
Parameter that defines the handling of newly created pixels with the elastic transformation.
g_blur : bool, optional
To insert gaussian blur on the images.
g_sigma : tuple of floats, optional
Standard deviation of the gaussian kernel. E. g. ``(1.0, 2.0)``.
median_blur : bool, optional
To blur an image by computing median values over neighbourhoods.
mb_kernel : tuple of ints, optional
Median blur kernel size. E. g. ``(3, 7)``.
motion_blur : bool, optional
Blur images in a way that fakes camera or object movements.
motb_k_range : int, optional
Kernel size to use in motion blur.
gamma_contrast : bool, optional
To insert gamma constrast changes on images.
gc_gamma : tuple of floats, optional
Exponent for the contrast adjustment. Higher values darken the image. E. g. ``(1.25, 1.75)``.
dropout : bool, optional
To set a certain fraction of pixels in images to zero.
drop_range : tuple of floats, optional
Range to take a probability ``p`` to drop pixels. E.g. ``(0, 0.2)`` will take a ``p`` folowing ``0<=p<=0.2``
and then drop ``p`` percent of all pixels in the image (i.e. convert them to black pixels).
val : bool, optional
Advise the generator that the images will be to validate the model to not make random crops (as the val.
data must be the same on each epoch).
resize_shape : tuple of ints, optional
If defined the input samples will be scaled into that shape.
convert_to_rgb : bool, optional
In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are
converted into RGB.
preprocess_f : function, optional
The preprocessing function, is necessary in case you want to apply any preprocessing.
preprocess_cfg : dict, optional
Configuration parameters for preprocessing, is necessary in case you want to apply any preprocessing.
"""
def __init__(
self,
ndim: int,
X: BiaPyDataset,
norm_module: Dict,
seed: int = 0,
da: bool = True,
da_prob: float = 0.5,
rotation90: bool = False,
rand_rot: bool = False,
rnd_rot_range=(-180, 180),
shear: bool = False,
shear_range=(-20, 20),
zoom: bool = False,
zoom_range=(0.8, 1.2),
zoom_in_z: bool = False,
shift: bool = False,
shift_range=(0.1, 0.2),
affine_mode: Literal["constant", "edge", "symmetric", "reflect", "wrap"] = "constant",
vflip: bool = False,
hflip: bool = False,
elastic: bool = False,
e_alpha=(240, 250),
e_sigma: int = 25,
e_mode: Literal["constant", "edge", "symmetric", "reflect", "wrap"] = "constant",
g_blur: bool = False,
g_sigma: Tuple[float, float] = (1.0, 2.0),
median_blur: bool = False,
mb_kernel: Tuple[int, int] = (3, 7),
motion_blur: bool = False,
motb_k_range: Tuple[int, int] = (3, 8),
gamma_contrast: bool = False,
gc_gamma: Tuple[float, float] = (1.25, 1.75),
dropout: bool = False,
drop_range: Tuple[float, float] = (0, 0.2),
val: bool = False,
resize_shape: Tuple[int, ...] = (256, 256, 1),
convert_to_rgb: bool = False,
preprocess_f=None,
preprocess_cfg=None,
):
"""
Initialize the SingleBaseDataGenerator.
Sets up data sources, normalization, augmentation options, and preprocessing
for single image data.
Parameters
----------
See class docstring for full parameter list.
"""
if preprocess_f and preprocess_cfg is None:
raise ValueError("'preprocess_cfg' needs to be provided with 'preprocess_f'")
sshape = X.sample_list[0].get_shape()
if sshape and len(sshape) != ndim:
raise ValueError("Samples in X must be have {} dimensions. Provided: {}".format(ndim, sshape))
self.ndim = ndim
self.z_size = -1
self.convert_to_rgb = convert_to_rgb
self.norm_module = norm_module
self.X = X
self.length = len(self.X.sample_list)
self.shape = resize_shape
self.random_crop_func = random_3D_crop_single if ndim == 3 else random_crop_single
self.val = val
self.preprocess_f = preprocess_f
self.preprocess_cfg = preprocess_cfg
# X data analysis
img, _ = self.load_sample(0, first_load=True)
if resize_shape[-1] != img.shape[-1]:
raise ValueError(
"Channel of the patch size given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(resize_shape[-1], img.shape[-1])
)
print("Normalization config used for X: {}".format(self.norm_module))
self.shape = resize_shape if resize_shape else img.shape
self.o_indexes = np.arange(self.length)
self.da = da
self.da_prob = da_prob
self.zoom = zoom
self.zoom_range = zoom_range
self.zoom_in_z = zoom_in_z
self.rand_rot = rand_rot
self.rnd_rot_range = rnd_rot_range
self.rotation90 = rotation90
self.affine_mode = affine_mode
self.gamma_contrast = gamma_contrast
self.gc_gamma = gc_gamma
self.seed = seed
self.indexes = self.o_indexes.copy()
self.elastic = elastic
self.shear = shear
self.shift = shift
self.vflip = vflip
self.hflip = hflip
self.g_blur = g_blur
self.median_blur = median_blur
self.motion_blur = motion_blur
self.dropout = dropout
self.drop_range = drop_range
self.e_alpha = e_alpha
self.e_sigma = e_sigma
self.e_mode = e_mode
self.shear_range = shear_range
self.shift_range = shift_range
self.affine_mode = affine_mode
self.g_sigma = g_sigma
self.mb_kernel = mb_kernel
self.motb_k_range = motb_k_range
self.da_options = []
self.trans_made = ""
if rotation90:
self.trans_made += "_rot[90,180,270]"
if rand_rot:
self.trans_made += "_rrot" + str(rnd_rot_range)
if shear:
self.trans_made += "_shear" + str(shear_range)
if zoom:
self.trans_made += "_zoom" + str(zoom_range) + "+" + str(zoom_in_z)
if shift:
self.trans_made += "_shift" + str(shift_range)
if vflip:
self.trans_made += "_vflip"
if hflip:
self.trans_made += "_hflip"
if elastic:
self.trans_made += "_elastic" + str(e_alpha) + "+" + str(e_sigma) + "+" + str(e_mode)
if g_blur:
self.trans_made += "_gblur" + str(g_sigma)
if median_blur:
self.trans_made += "_mblur" + str(mb_kernel)
if motion_blur:
self.trans_made += "_motb" + str(motb_k_range)
if gamma_contrast:
self.trans_made += "_gcontrast" + str(gc_gamma)
if dropout:
self.trans_made += "_drop" + str(drop_range)
self.trans_made = self.trans_made.replace(" ", "")
random.seed(seed)
[docs]
@abstractmethod
def save_aug_samples(
self,
img: NDArray,
orig_images: Dict,
i: int,
pos: int,
out_dir: str,
draw_grid: bool,
):
"""
Save transformed samples in order to check the generator.
Parameters
----------
img : 3D/4D Numpy array
Image to use as sample. E.g. ``(y, x, channels)`` for ``2D`` and ``(z, y, x, channels)`` for ``3D``.
orig_images : dict
Dict where the original image and mask are saved in "o_x" and "o_y", respectively.
i : int
Number of the sample within the transformed ones.
pos : int
Number of the sample within the dataset.
out_dir : str
Directory to save the images.
draw_grid : bool
Whether to draw a grid or not.
"""
raise NotImplementedError
def __len__(self):
"""Define the number of samples per epoch."""
return self.length
[docs]
def load_sample(self, idx: int, first_load: bool = False) -> Tuple[NDArray, int]:
"""
Load one data sample given its corresponding index.
Parameters
----------
idx : int
Sample index counter.
first_load : bool, optional
Whether its the first time a sample is loaded to prevent normalizing it.
Returns
-------
img : 3D/4D Numpy array
X element. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``.
img_class : int
Class of the image.
"""
sample = self.X.sample_list[idx]
if sample.img_is_loaded():
img = sample.img.copy()
else:
img, img_file = load_img_data(
self.X.dataset_info[sample.fid].path,
is_3d=(self.ndim == 3),
data_within_zarr_path=sample.get_path_in_zarr(),
)
if not self.X.dataset_info[sample.fid].is_parallel():
# Apply preprocessing
if self.preprocess_f:
img = self.preprocess_f(self.preprocess_cfg, x_data=[img], is_2d=(self.ndim == 2))[0]
else:
coords = sample.coords
data_axes_order = self.X.dataset_info[sample.fid].get_input_axes()
assert coords is not None and data_axes_order is not None
img = extract_patch_from_efficient_file(img, coords, data_axes_order=data_axes_order)
# Apply preprocessing after extract sample
if self.preprocess_f:
img = self.preprocess_f(self.preprocess_cfg, x_data=[img], is_2d=(self.ndim == 2))[0]
if isinstance(img_file, h5py.File):
img_file.close()
img_class = self.X.dataset_info[sample.fid].get_class_num()
if img.shape[:-1] != self.shape[:-1]:
img = self.random_crop_func(img, self.shape[:-1], self.val)
img = resize_img(img, self.shape[:-1]) # type: ignore
# X normalization
if not first_load:
xnorm_info = self.X.dataset_info[sample.fid].norm_info
if xnorm_info is None:
xnorm_info = self.norm_module
img, _ = normalize_image(img, norm_module=xnorm_info)
assert isinstance(img, np.ndarray)
if self.convert_to_rgb and img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
return img, img_class
[docs]
def getitem(self, index: int) -> Tuple[torch.Tensor, int]:
"""
Generate one sample of data.
Parameters
----------
index : int
Index counter.
Returns
-------
item : tuple of 3D/4D Numpy arrays
X and Y (if avail) elements. X is ``(z, y, x, channels)`` if ``3D`` or
``(y, x, channels)`` if ``2D``. Y is an integer.
"""
return self.__getitem__(index)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
"""
Generate one sample data.
Parameters
----------
index : int
Sample index counter.
Returns
-------
img : 3D/4D Numpy array
X element, for instance, an image. E.g. ``(y, x, channels)`` in ``2D`` or
``(z, y, x, channels)`` in ``3D``.
"""
img, img_class = self.load_sample(index)
# Apply transformations
if self.da:
img = self.apply_transform(img)
# If no normalization was applied, as is done with torchvision models, it can be an image of uint16
# so we need to convert it to
if img.dtype == np.uint16:
img = torch.from_numpy(img.copy().astype(np.float32))
else:
img = torch.from_numpy(img.copy())
return img, img_class
[docs]
def draw_grid(self, im: NDArray, grid_width: Optional[int] = None) -> NDArray:
"""
Draw grid of the specified size on an image.
Parameters
----------
im : 3D/4D Numpy array
Image to be modified. E.g. ``(y, x, channels)`` in ``2D`` or ``(z, y, x, channels)`` in ``3D``.
grid_width : int, optional
Grid's width.
"""
vmax = []
for c in range(im.shape[-1]):
vmax.append(np.max(im[...,c]))
if grid_width is not None and grid_width > 0:
grid_y = grid_width
grid_x = grid_width
else:
grid_y = im.shape[self.ndim - 2] // 5
grid_x = im.shape[self.ndim - 2] // 5
if self.ndim == 2:
for i in range(0, im.shape[0], grid_y):
im[i] = vmax
for j in range(0, im.shape[1], grid_x):
im[:, j] = vmax
else:
for k in range(0, im.shape[0]):
for i in range(0, im.shape[2], grid_x):
im[k, :, i] = vmax
for j in range(0, im.shape[1], grid_y):
im[k, j] = vmax
return im