"""
Augmentation utilities for image and mask data in deep learning workflows.
This module provides a variety of data augmentation functions for images and masks,
including cutout, cutblur, cutmix, cutnoise, misalignment, cropping, flipping,
rotation, zoom, gamma/contrast adjustment, blurring, dropout, elastic deformation,
shear, shift, and more. These augmentations are designed to improve model robustness
and generalization for both 2D and 3D data formats.
"""
import cv2
import random
import math
import numpy as np
from PIL import Image
from skimage.transform import resize
from skimage.draw import line
from skimage.exposure import adjust_gamma
from skimage.filters import gaussian
from scipy.ndimage import binary_dilation as binary_dilation_scipy
from scipy.ndimage import rotate
from typing import Tuple, Union, Optional, List
from numpy.typing import NDArray
from scipy.ndimage import median_filter, shift as shift_nd
from skimage.transform import AffineTransform, ProjectiveTransform, warp
[docs]
def cutout(
img: NDArray,
mask: NDArray,
z_size: int,
nb_iterations: Tuple[int, int] = (1, 3),
size: Tuple[float, float] = (0.2, 0.4),
cval: int = 0,
res_relation: Tuple[float, ...] = (1.0, 1.0),
apply_to_mask: bool = False,
) -> Tuple[NDArray, NDArray]:
"""
Apply augmentation using Cutout technique.
Cutout data augmentation presented in `Improved Regularization of Convolutional Neural Networks with Cutout <https://arxiv.org/pdf/1708.04552.pdf>`_.
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : Numpy array
Mask to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
z_size : int
Size of z dimension. Used for 3D images as the z axis has been merged with the channels. Set to -1 to when
do not want to be applied.
nb_iterations : tuple of ints, optional
Number of areas to fill the image with. E.g. ``(1, 3)``.
size : tuple of floats, optional
Range to choose the size of the areas to create.
cval : int, optional
Value to fill the area with.
res_relation: tuple of floats, optional
Relation between axis resolution in ``(x,y,z)``. E.g. ``(1,1,0.27)`` for anisotropic data of
8umx8umx30um resolution.
apply_to_mask : boolean, optional
To apply cutout to the mask.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : Numpy array
Transformed mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``nb_iterations=(1,3)``, ``size=(0.05,0.3)``,
``apply_to_mask=False`` may result in:
+----------------------------------------------+----------------------------------------------+
| .. figure:: ../../../img/orig_cutout.png | .. figure:: ../../../img/orig_cutout_mask.png|
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Corresponding mask |
+----------------------------------------------+----------------------------------------------+
| .. figure:: ../../../img/cutout.png | .. figure:: ../../../img/cutout_mask.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Augmented image | Augmented mask |
+----------------------------------------------+----------------------------------------------+
The grid is painted for visualization purposes.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
assert mask.ndim in [3, 4], f"Mask must be 3D or 4D, got shape {mask.shape}"
assert len(nb_iterations) == 2 and nb_iterations[0] <= nb_iterations[1]
assert len(size) == 2 and 0.0 < size[0] <= size[1] <= 1.0
# ensure (x,y,z) factors available
rx = float(res_relation[0]) if len(res_relation) >= 1 else 1.0
ry = float(res_relation[1]) if len(res_relation) >= 2 else 1.0
rz = float(res_relation[2]) if len(res_relation) >= 3 else 1.0
out = img.copy()
m_out = mask.copy()
# spatial dims
if img.ndim == 3: # (y, x, c)
H, W = img.shape[:2]
Z = None
else: # (z, y, x, c)
Z, H, W = img.shape[0], img.shape[1], img.shape[2]
# how many cutouts
it = int(np.random.randint(nb_iterations[0], nb_iterations[1] + 1))
# helper to clamp sizes to [1, max]
def _clamp_size(n, mx):
return max(1, min(int(n), int(mx)))
# fill values cast to dtype
fill_img = np.array(cval, dtype=img.dtype)
fill_mask = np.array(0, dtype=mask.dtype)
for _ in range(it):
frac = random.uniform(size[0], size[1])
# rectangle size in (y, x)
y_size = _clamp_size(round(H * frac * ry), H)
x_size = _clamp_size(round(W * frac * rx), W)
# random top-left in-bounds (inclusive)
cy = 0 if H == y_size else np.random.randint(0, H - y_size + 1)
cx = 0 if W == x_size else np.random.randint(0, W - x_size + 1)
if img.ndim == 4:
# decide z extent
if z_size != -1:
# sample a z block size scaled by rz
assert Z is not None, "Z dimension not found in 4D image"
z_block = _clamp_size(round(Z * frac * rz), Z)
z0 = 0 if Z == z_block else np.random.randint(0, Z - z_block + 1)
z_slice = slice(z0, z0 + z_block)
else:
z_slice = slice(None)
# apply to image & (optionally) mask
out[z_slice, cy:cy + y_size, cx:cx + x_size, :] = fill_img
if apply_to_mask:
m_out[z_slice, cy:cy + y_size, cx:cx + x_size, :] = fill_mask
else:
# 2D: apply across all channels
out[cy:cy + y_size, cx:cx + x_size, :] = fill_img
if apply_to_mask:
m_out[cy:cy + y_size, cx:cx + x_size, :] = fill_mask
return out, m_out
[docs]
def cutblur(
img: NDArray,
size: Tuple[float, float] = (0.2, 0.4),
down_ratio_range: Tuple[int, int] = (2, 8),
only_inside: bool = True,
) -> NDArray:
"""
Apply CutBlur data augmentation.
CutBlur data augmentation introduced in `Rethinking Data Augmentation for Image Super-resolution:
A Comprehensive Analysis and a New Strategy <https://arxiv.org/pdf/2004.00448.pdf>`_ and adapted
from https://github.com/clovaai/cutblur .
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
size : float, optional
Size of the region to transform.
down_ratio_range : tuple of ints, optional
Downsampling ratio range to be applied. E.g. ``(2, 8)``.
only_inside : bool, optional
If ``True`` only the region inside will be modified (cut LR into HR image). If ``False`` the ``50%`` of the
times the region inside will be modified (cut LR into HR image) and the other ``50%`` the inverse will be
done (cut HR into LR image). See Figure 1 of the official `paper <https://arxiv.org/pdf/2004.00448.pdf>`_.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``size=(0.2,0.4)``, ``down_ratio_range=(2,8)``,
``only_inside=True`` may result in:
+--------------------------------------------+--------------------------------------------+
| .. figure:: ../../../img/orig_cutblur.png | .. figure:: ../../../img/cutblur.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+--------------------------------------------+--------------------------------------------+
| .. figure:: ../../../img/orig_cutblur2.png | .. figure:: ../../../img/cutblur2.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+--------------------------------------------+--------------------------------------------+
The grid and the red square are painted for visualization purposes.
"""
assert img.ndim in (3, 4), f"Image must be 3D or 4D, got shape {img.shape}"
assert len(size) == 2 and 0.0 < size[0] <= size[1] <= 1.0, f"Invalid size range: {size}"
assert len(down_ratio_range) == 2 and down_ratio_range[0] >= 1 and down_ratio_range[0] <= down_ratio_range[1], \
f"Invalid down_ratio_range: {down_ratio_range}"
if img.size == 0:
return img
# Spatial dims (H, W) and channel count
if img.ndim == 3: # (y, x, c)
H, W, C = img.shape
Z = 1
else: # (z, y, x, c)
Z, H, W, C = img.shape
# Sample patch fraction and clamp to at least 1 pixel
frac = float(random.uniform(size[0], size[1]))
y_size = max(1, int(round(H * frac)))
x_size = max(1, int(round(W * frac)))
# Random top-left (inclusive) so the patch stays inside bounds
cy = 0 if H == y_size else np.random.randint(0, H - y_size + 1)
cx = 0 if W == x_size else np.random.randint(0, W - x_size + 1)
# Downsample ratio and shapes
down_ratio = int(np.random.randint(down_ratio_range[0], down_ratio_range[1] + 1))
dsH_full, dsW_full = max(1, H // down_ratio), max(1, W // down_ratio)
dsH_patch, dsW_patch = max(1, y_size // down_ratio), max(1, x_size // down_ratio)
# inside flag
inside = True if only_inside else (random.uniform(0, 1) < 0.5)
out = img.copy()
orig_dtype = img.dtype
def _resize(arr: NDArray, out_shape_hw_c: Tuple[int, int, int], order: int, aa: bool) -> NDArray:
# skimage.transform.resize expects (H, W, C)
res = resize(
arr, out_shape_hw_c, order=order, mode="reflect",
clip=True, preserve_range=True, anti_aliasing=aa
)
# cast back to original dtype when needed
if orig_dtype.kind != "f":
res = res.astype(orig_dtype, copy=False)
return res
if img.ndim == 3:
if inside:
# LR->HR only for the selected patch
patch = img[cy:cy + y_size, cx:cx + x_size, :]
down = _resize(patch, (dsH_patch, dsW_patch, C), order=1, aa=True)
up = _resize(down, (y_size, x_size, C), order=0, aa=False)
out[cy:cy + y_size, cx:cx + x_size, :] = up
else:
# Whole image to LR->HR, then paste original HR patch back
down = _resize(img, (dsH_full, dsW_full, C), order=1, aa=True)
up = _resize(down, (H, W, C), order=0, aa=False)
out = up
out[cy:cy + y_size, cx:cx + x_size, :] = img[cy:cy + y_size, cx:cx + x_size, :]
return out
# 4D: apply per z-slice with shared region and ratio
for z in range(Z):
if inside:
patch = img[z, cy:cy + y_size, cx:cx + x_size, :]
down = _resize(patch, (dsH_patch, dsW_patch, C), order=1, aa=True)
up = _resize(down, (y_size, x_size, C), order=0, aa=False)
out[z, cy:cy + y_size, cx:cx + x_size, :] = up
else:
full = img[z]
down = _resize(full, (dsH_full, dsW_full, C), order=1, aa=True)
up = _resize(down, (H, W, C), order=0, aa=False)
out[z] = up
out[z, cy:cy + y_size, cx:cx + x_size, :] = img[z, cy:cy + y_size, cx:cx + x_size, :]
return out
[docs]
def cutmix(
im1: NDArray,
im2: NDArray,
mask1: NDArray,
mask2: NDArray,
heat1: NDArray | None,
heat2: NDArray | None,
size: Tuple[float, float] = (0.2, 0.4),
) -> Tuple[NDArray, NDArray, NDArray | None]:
"""
Apply Cutmix data augmentation.
Cutmix augmentation introduced in `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features <https://arxiv.org/abs/1905.04899>`_.
With this augmentation a region of the image sample is filled with a given second image. This implementation is used for semantic segmentation so the masks of
the images are also needed. It assumes that the images are of the same shape.
Parameters
----------
im1 : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
im2 : Numpy array
Image to paste into the region of ``im1``. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask1 : Numpy array
Mask to transform (belongs to ``im1``). E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask2 : Numpy array
Mask to paste into the region of ``mask1``. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat1 : Numpy array or None
Heatmap to transform (belongs to ``im1``). E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
If ``None``, no heatmap is used.
heat2 : Numpy array or None
Heatmap to paste into the region of ``heat1``. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
If ``None``, no heatmap is used.
size : tuple of floats, optional
Range to choose the size of the areas to transform. E.g. ``(0.2, 0.4)``.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
m_out : Numpy array
Transformed mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
h_out : Numpy array or None
Transformed heatmap. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``size=(0.2,0.4)`` may result in:
+----------------------------------------------+----------------------------------------------+
| .. figure:: ../../../img/orig_cutmix.png | .. figure:: ../../../img/orig_cutmix_mask.png|
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Corresponding mask |
+----------------------------------------------+----------------------------------------------+
| .. figure:: ../../../img/cutmix.png | .. figure:: ../../../img/cutmix_mask.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Augmented image | Augmented mask |
+----------------------------------------------+----------------------------------------------+
The grid is painted for visualization purposes.
"""
assert im1.ndim in (3, 4) and im2.ndim == im1.ndim, f"Shape mismatch: {im1.shape} vs {im2.shape}"
assert mask1.ndim == im1.ndim and mask2.ndim == im1.ndim, "Mask dims must match image dims"
assert im1.shape[:-1] == im2.shape[:-1] == mask1.shape[:-1] == mask2.shape[:-1], "All inputs must share shape"
assert len(size) == 2 and 0.0 < size[0] <= size[1] <= 1.0, f"Invalid size range: {size}"
if im1.size == 0:
return im1, mask1, heat1
# Spatial dims (H, W)
if im1.ndim == 3: # (y, x, c)
H, W, C = im1.shape
Z = 1
else: # (z, y, x, c)
Z, H, W, C = im1.shape
# Sample patch size (at least 1×1)
frac = float(random.uniform(size[0], size[1]))
y_size = max(1, int(round(H * frac)))
x_size = max(1, int(round(W * frac)))
# Random top-left corners (inclusive bounds so the patch fits)
im1cy = 0 if H == y_size else np.random.randint(0, H - y_size + 1)
im1cx = 0 if W == x_size else np.random.randint(0, W - x_size + 1)
im2cy = 0 if H == y_size else np.random.randint(0, H - y_size + 1)
im2cx = 0 if W == x_size else np.random.randint(0, W - x_size + 1)
out = im1.copy()
m_out = mask1.copy()
h_out = heat1.copy() if heat1 is not None else None
if im1.ndim == 3:
# Vectorized over channels
out[im1cy:im1cy + y_size, im1cx:im1cx + x_size, :] = \
im2[im2cy:im2cy + y_size, im2cx:im2cx + x_size, :]
m_out[im1cy:im1cy + y_size, im1cx:im1cx + x_size, :] = \
mask2[im2cy:im2cy + y_size, im2cx:im2cx + x_size, :]
if h_out is not None and heat2 is not None:
h_out[im1cy:im1cy + y_size, im1cx:im1cx + x_size, :] = \
heat2[im2cy:im2cy + y_size, im2cx:im2cx + x_size, :]
else:
# Apply the same (y,x) patch across all z-slices
out[:, im1cy:im1cy + y_size, im1cx:im1cx + x_size, :] = \
im2[:, im2cy:im2cy + y_size, im2cx:im2cx + x_size, :]
m_out[:, im1cy:im1cy + y_size, im1cx:im1cx + x_size, :] = \
mask2[:, im2cy:im2cy + y_size, im2cx:im2cx + x_size, :]
if h_out is not None and heat2 is not None:
h_out[:, im1cy:im1cy + y_size, im1cx:im1cx + x_size, :] = \
heat2[:, im2cy:im2cy + y_size, im2cx:im2cx + x_size, :]
return out, m_out, h_out
[docs]
def cutnoise(
img: NDArray,
scale: Tuple[float, float] = (0.1, 0.2),
nb_iterations: Tuple[int, int] = (1, 3),
size: Tuple[float, float] = (0.2, 0.4),
) -> NDArray:
"""
Apply Cutnoise data augmentation.
Cutnoise data augmentation. Randomly add noise to a cuboid region in the image to force the model to learn
denoising when making predictions.
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
scale : tuple of floats, optional
Scale of the random noise. E.g. ``(0.1, 0.2)``.
nb_iterations : tuple of ints, optional
Number of areas with noise to create. E.g. ``(1, 3)``.
size : boolean, optional
Range to choose the size of the areas to transform. E.g. ``(0.2, 0.4)``.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``scale=(0.1,0.2)``, ``nb_iterations=(1,3)`` and
``size=(0.2,0.4)`` may result in:
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_cutnoise.png | .. figure:: ../../../img/cutnoise.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_cutnoise2.png | .. figure:: ../../../img/cutnoise2.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
The grid and the red squares are painted for visualization purposes.
"""
assert img.ndim in (3, 4), f"Image must be 3D or 4D, got {img.shape}"
assert len(scale) == 2 and scale[0] <= scale[1], f"Invalid scale range: {scale}"
assert len(size) == 2 and 0.0 < size[0] <= size[1] <= 1.0, f"Invalid size range: {size}"
if img.size == 0:
return img
out = img.copy()
orig_dtype = img.dtype
is_float = np.issubdtype(orig_dtype, np.floating)
# Spatial dims
if img.ndim == 3: # (H, W, C)
H, W, C = img.shape
Z = None
else: # (Z, H, W, C)
Z, H, W, C = img.shape
# how many patches (inclusive upper bound)
it = int(np.random.randint(nb_iterations[0], nb_iterations[1] + 1))
# helpers
def _clamp_int(n, lo, hi):
return max(lo, min(int(n), hi))
# dtype-safe add with clipping for integer arrays
def _add_noise_inplace(target_view: NDArray, noise_arr: NDArray):
# noise_arr shapes: (H,W) or (Z,H,W), we’ll broadcast over C with [..., None]
if is_float:
target_view += noise_arr[..., None]
else:
info = np.iinfo(orig_dtype)
tmp = target_view.astype(np.float32, copy=False) + noise_arr[..., None].astype(np.float32, copy=False)
np.clip(tmp, info.min, info.max, out=tmp)
target_view[...] = tmp.astype(orig_dtype, copy=False)
for _ in range(it):
frac = float(random.uniform(size[0], size[1]))
# patch size in (y, x)
y_size = _clamp_int(round(H * frac), 1, H)
x_size = _clamp_int(round(W * frac), 1, W)
# top-left (inclusive bounds)
cy = 0 if H == y_size else np.random.randint(0, H - y_size + 1)
cx = 0 if W == x_size else np.random.randint(0, W - x_size + 1)
# amplitude (keep same semantics as your original: scale * img.max())
amp = float(random.uniform(scale[0], scale[1])) * float(img.max())
if img.ndim == 3:
# noise shape (y, x)
noise = np.random.normal(loc=0.0, scale=amp, size=(y_size, x_size)).astype(np.float32)
view = out[cy:cy + y_size, cx:cx + x_size, :]
_add_noise_inplace(view, noise)
else:
# z-block proportional to frac (at least 1)
assert Z is not None, "Z dimension not found in 4D image"
z_size = _clamp_int(round((Z if Z is not None else 1) * frac), 1, Z)
z0 = 0 if Z == z_size else np.random.randint(0, Z - z_size + 1)
# noise shape (z, y, x)
noise = np.random.normal(loc=0.0, scale=amp, size=(z_size, y_size, x_size)).astype(np.float32)
view = out[z0:z0 + z_size, cy:cy + y_size, cx:cx + x_size, :]
_add_noise_inplace(view, noise)
return out
[docs]
def misalignment(
img: NDArray,
mask: NDArray,
displacement: int = 16,
rotate_ratio: float = 0.0,
) -> Tuple[NDArray, NDArray]:
"""
Apply mis-alignment data augmentation.
Mis-alignment data augmentation of image stacks. This augmentation is applied to both images and masks.
Implementation based on `PyTorch Connectomics' misalign.py
<https://github.com/zudi-lin/pytorch_connectomics/blob/master/connectomics/data/augmentation/misalign.py>`_.
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : Numpy array
Mask to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
displacement : int, optional
Maximum pixel displacement in ``xy``-plane.
rotate_ratio : float, optional
Ratio of rotation-based mis-alignment.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
m_out : Numpy array
Transformed mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``displacement=16`` and ``rotate_ratio=0.5`` may result in:
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_miss.png | .. figure:: ../../../img/orig_miss_mask.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Corresponding mask |
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/miss.png | .. figure:: ../../../img/miss_mask.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Augmented image | Augmented mask |
+---------------------------------------------+---------------------------------------------+
The grid is painted for visualization purposes.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
assert mask.ndim in [3, 4], "Mask is supposed to be 3 or 4 dimensions but provided {} mask shape instead".format(
mask.shape
)
out = np.zeros(img.shape, img.dtype)
m_out = np.zeros(mask.shape, mask.dtype)
def _randomrotate_matrix(height: int, displacement: int):
""" Generate random rotation matrix."""
x = displacement / 2.0
y = ((height - displacement) / 2.0) * 1.42
angle = math.asin(x / y) * 2.0 * 57.2958 # convert radians to degrees
rand_angle = (random.uniform(0, 1) - 0.5) * 2.0 * angle
M = cv2.getRotationMatrix2D((height / 2, height / 2), rand_angle, 1)
return M
# 2D
if img.ndim == 3:
oy = np.random.randint(1, img.shape[0] - 1)
d = np.random.randint(0, displacement)
if random.uniform(0, 1) < rotate_ratio:
# Apply misalignment to all channels
for i in range(img.shape[-1]):
out[:oy, :, i] = img[:oy, :, i]
out[oy:, : img.shape[1] - d, i] = img[oy:, d:, i]
for i in range(mask.shape[-1]):
m_out[:oy, :, i] = mask[:oy, :, i]
m_out[oy:, : mask.shape[1] - d, i] = mask[oy:, d:, i]
else:
H, W = img.shape[:2]
M = _randomrotate_matrix(H, displacement)
H = H - oy
# Apply misalignment to all channels
for i in range(img.shape[-1]):
out[:oy, :, i] = img[:oy, :, i]
out[oy:, :, i] = cv2.warpAffine(
img[oy:, :, i],
M,
(W, H),
1.0, # type: ignore
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
)
for i in range(mask.shape[-1]):
m_out[:oy, :, i] = mask[:oy, :, i]
m_out[oy:, :, i] = cv2.warpAffine(
mask[oy:, :, i],
M,
(W, H),
1.0, # type: ignore
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
)
# 3D
else:
# img, mask: (z, y, x, c)
Z, H, W, C_img = img.shape
C_mask = mask.shape[-1]
# spatial crop size after displacement
out_h = H - displacement
out_w = W - displacement
mode = "slip" if random.uniform(0, 1) < 0.5 else "translation"
# pick a z-plane (used as the “affected slice” for slip, or split point for translation)
z_idx = np.random.randint(1, Z - 1) if Z >= 3 else 0
if random.uniform(0, 1) < rotate_ratio:
# start from a copy; we’ll overwrite affected slices
out = img.copy()
m_out = mask.copy()
# rotate in the (y, x) plane
M = _randomrotate_matrix(H, displacement)
if mode == "slip":
# only transform the selected z-slice, across ALL channels
for c in range(C_img):
out[z_idx, :, :, c] = 0
out[z_idx, :, :, c] = cv2.warpAffine(
img[z_idx, :, :, c],
M,
(W, H),
1.0, # type: ignore
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
)
for c in range(C_mask):
m_out[z_idx, :, :, c] = 0
m_out[z_idx, :, :, c] = cv2.warpAffine(
mask[z_idx, :, :, c],
M,
(W, H),
1.0, # type: ignore
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
)
else:
# transform all slices from z_idx onwards, across ALL channels
for z in range(z_idx, Z):
for c in range(C_img):
out[z, :, :, c] = 0
out[z, :, :, c] = cv2.warpAffine(
img[z, :, :, c],
M,
(W, H),
1.0, # type: ignore
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
)
for z in range(z_idx, Z):
for c in range(C_mask):
m_out[z, :, :, c] = 0
m_out[z, :, :, c] = cv2.warpAffine(
mask[z, :, :, c],
M,
(W, H),
1.0, # type: ignore
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
)
else:
# xy translations via cropping/paste
rng = np.random.RandomState()
x0 = rng.randint(displacement)
y0 = rng.randint(displacement)
x1 = rng.randint(displacement)
y1 = rng.randint(displacement)
if mode == "slip":
# copy whole volume once
out[:, y0:y0 + out_h, x0:x0 + out_w, :] = img[
:, y0:y0 + out_h, x0:x0 + out_w, :
]
m_out[:, y0:y0 + out_h, x0:x0 + out_w, :] = mask[
:, y0:y0 + out_h, x0:x0 + out_w, :
]
# then overwrite only the chosen z-slice (all channels) with a different offset
out[z_idx, :, :, :] = 0
out[z_idx, y1:y1 + out_h, x1:x1 + out_w, :] = img[
z_idx, y1:y1 + out_h, x1:x1 + out_w, :
]
m_out[z_idx, :, :, :] = 0
m_out[z_idx, y1:y1 + out_h, x1:x1 + out_w, :] = mask[
z_idx, y1:y1 + out_h, x1:x1 + out_w, :
]
else:
# split volume along z at z_idx (all channels)
out[:z_idx, y0:y0 + out_h, x0:x0 + out_w, :] = img[
:z_idx, y0:y0 + out_h, x0:x0 + out_w, :
]
out[z_idx:, y1:y1 + out_h, x1:x1 + out_w, :] = img[
z_idx:, y1:y1 + out_h, x1:x1 + out_w, :
]
m_out[:z_idx, y0:y0 + out_h, x0:x0 + out_w, :] = mask[
:z_idx, y0:y0 + out_h, x0:x0 + out_w, :
]
m_out[z_idx:, y1:y1 + out_h, x1:x1 + out_w, :] = mask[
z_idx:, y1:y1 + out_h, x1:x1 + out_w, :
]
return out, m_out
[docs]
def brightness(
image: NDArray,
brightness_factor: Tuple[float, float] = (0, 0),
) -> NDArray:
"""
Randomly adjust brightness between a range.
Parameters
----------
image : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
brightness_factor : tuple of 2 floats
Range of brightness' intensity. E.g. ``(0.1, 0.3)``.
Returns
-------
image : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``brightness_factor=(0.1,0.3)``, ``mode='mix'``, ``invert=False`` and ``invert_p=0``
may result in:
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_bright.png | .. figure:: ../../../img/bright.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_bright2.png | .. figure:: ../../../img/bright2.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
The grid is painted for visualization purposes.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
lo, hi = float(brightness_factor[0]), float(brightness_factor[1])
if lo == 0.0 and hi == 0.0:
return image
if lo > hi:
lo, hi = hi, lo
if image.size == 0:
return image
delta = float(np.random.uniform(lo, hi))
out = image.copy()
if np.issubdtype(out.dtype, np.floating):
out += delta
return out
# integer dtype: add in float and clip back
info = np.iinfo(out.dtype)
tmp = out.astype(np.float32, copy=False) + delta
np.clip(tmp, info.min, info.max, out=tmp)
return tmp.astype(out.dtype, copy=False)
[docs]
def contrast(image: NDArray, contrast_factor: Tuple[float, float] = (0, 0)) -> NDArray:
"""
Contrast augmentation.
Parameters
----------
image : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
contrast_factor : tuple of 2 floats
Range of contrast's intensity. E.g. ``(0.1, 0.3)``.
Returns
-------
image : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``contrast_factor=(0.1,0.3)``, ``mode='mix'``, ``invert=False`` and ``invert_p=0``
may result in:
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_contrast.png | .. figure:: ../../../img/contrast.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_contrast2.png | .. figure:: ../../../img/contrast2.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
The grid is painted for visualization purposes.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
lo, hi = float(contrast_factor[0]), float(contrast_factor[1])
if lo == 0.0 and hi == 0.0:
return image
if lo > hi:
lo, hi = hi, lo
if image.size == 0:
return image
scale = 1.0 + float(np.random.uniform(lo, hi))
out = image.copy()
if np.issubdtype(out.dtype, np.floating):
out *= scale
return out
# integer dtype: multiply in float and clip back
info = np.iinfo(out.dtype)
tmp = out.astype(np.float32, copy=False) * scale
np.clip(tmp, info.min, info.max, out=tmp)
return tmp.astype(out.dtype, copy=False)
[docs]
def missing_sections(img: NDArray, iterations: Tuple[int, int] = (30, 40), channel_prob: float = 0.5) -> NDArray:
"""
Augment the image by creating a black line in a random position.
Implementation based on `PyTorch Connectomics' missing_parts.py
<https://github.com/zudi-lin/pytorch_connectomics/blob/master/connectomics/data/augmentation/missing_parts.py>`_.
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
iterations : tuple of 2 ints, optional
Iterations to dilate the missing line with. E.g. ``(30, 40)``.
channel_prob : float, optional
Probability of applying a missing section to each channel individually.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with ``iterations=(30,40)`` may result in:
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_missing.png | .. figure:: ../../../img/missing.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
| .. figure:: ../../../img/orig_missing2.png | .. figure:: ../../../img/missing2.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+---------------------------------------------+---------------------------------------------+
The grid is painted for visualization purposes.
"""
assert img.ndim in (3, 4), f"Image must be 3D or 4D, got shape {img.shape}"
it = int(np.random.randint(iterations[0], iterations[1]))
out = img.copy()
def _prepare_deform_slice(slice_shape: Tuple[int, int], iterations: int) -> NDArray:
"""Build a boolean mask (H,W) for a dilated random line to 'remove'."""
H, W = slice_shape
if H < 3 or W < 3:
# too small to draw a line with the given sampling; return empty mask
return np.zeros((H, W), dtype=bool)
# randomly choose fixed x or fixed y with p=1/2
fixed_x = (np.random.rand() < 0.5)
if fixed_x:
x0, y0 = 0, np.random.randint(1, W - 1)
x1, y1 = H - 1, np.random.randint(1, W - 1)
else:
x0, y0 = np.random.randint(1, H - 1), 0
x1, y1 = np.random.randint(1, H - 1), W - 1
# base line mask
line_mask = np.zeros((H, W), dtype=bool)
rr, cc = line(x0, y0, x1, y1)
rr = np.clip(rr, 0, H - 1)
cc = np.clip(cc, 0, W - 1)
line_mask[rr, cc] = True
# (legacy leftover: normal/labels not used for the final effect)
# dilate to thicken the missing section
line_mask = binary_dilation_scipy(line_mask, iterations=iterations) # type: ignore
return line_mask
if img.ndim == 3:
# (y, x, c) -> operate per channel
H, W, C = img.shape
slice_shape = (H, W)
transforms = {}
i = 0
while i < C:
if np.random.rand() < channel_prob:
transforms[i] = _prepare_deform_slice(slice_shape, it)
i += 2 # enforce gap: at most one mod in any consecutive 3
i += 1
for c in transforms.keys():
line_mask = transforms[c]
sl = out[..., c]
mean_val = sl.mean()
sl[line_mask] = mean_val
out[..., c] = sl
else:
# (z, y, x, c) -> operate along z for each channel independently
Z, H, W, C = img.shape
slice_shape = (H, W)
for c in range(C):
transforms = {}
i = 0
while i < Z:
if np.random.rand() < channel_prob:
transforms[i] = _prepare_deform_slice(slice_shape, it)
i += 2 # enforce gap along z for this channel
i += 1
for z, line_mask in transforms.items():
sl = out[z, :, :, c]
mean_val = sl.mean()
sl[line_mask] = mean_val
out[z, :, :, c] = sl
return out
[docs]
def shuffle_channels(img: NDArray) -> NDArray:
"""
Augment the image by shuffling its channels.
Parameters
----------
img : 3D/4D Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Returns
-------
out : 3D/4D Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
+-----------------------------------------------+-----------------------------------------------+
| .. figure:: ../../../img/orig_chshuffle.png | .. figure:: ../../../img/chshuffle.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+-----------------------------------------------+-----------------------------------------------+
The grid is painted for visualization purposes.
"""
assert img.ndim in (3, 4), f"Image must be 3D or 4D, got {img.shape}"
new_channel_order = np.random.permutation(img.shape[-1])
return img[..., new_channel_order]
[docs]
def grayscale(img: NDArray) -> NDArray:
"""
Augment the image by converting it into grayscale.
Parameters
----------
img : 3D/4D Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Returns
-------
out : 3D/4D Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
+-----------------------------------------------+-----------------------------------------------+
| .. figure:: ../../../img/orig_grayscale.png | .. figure:: ../../../img/grayscale.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+-----------------------------------------------+-----------------------------------------------+
The grid is painted for visualization purposes.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
if img.shape[-1] != 3:
raise ValueError(
"Image is supposed to have 3 channels (RGB). Provided {} image shape instead".format(img.shape)
)
return np.tile(np.expand_dims(np.mean(img, -1), -1), 3)
[docs]
def GridMask(
img: NDArray,
z_size: int,
ratio: float = 0.6,
d_range: Tuple[float, ...] = (30.0, 60.0),
rotate: int = 1,
invert: bool = False,
) -> NDArray:
"""
Apply GridMask data augmentation presented in `GridMask Data Augmentation <https://arxiv.org/abs/2001.04086v1>`_.
GridMask is a data augmentation technique that randomly masks out grid-like regions in the image, which helps the model to learn more robust features by forcing it to focus on different parts of the image Code adapted from `<https://github.com/dvlab-research/GridMask/blob/master/imagenet_grid/utils/grid.py>`_.
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
z_size : int
Size of z dimension. Used for 3D images as the z axis has been merged with the channels. Set to -1 to when
do not want to be applied.
ratio : tuple of floats, optional
Range to choose the size of the areas to create.
d_range : tuple of floats, optional
Range to choose the ``d`` value in the original paper.
rotate : float, optional
Rotation of the mask in GridMask. Needs to be between ``[0,1]`` where 1 is 360 degrees.
invert : bool, optional
Whether to invert the mask.
Returns
-------
out : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Examples
--------
Calling this function with the default settings may result in:
+----------------------------------------------+----------------------------------------------+
| .. figure:: ../../../img/orig_GridMask.png | .. figure:: ../../../img/GridMask.png |
| :width: 80% | :width: 80% |
| :align: center | :align: center |
| | |
| Input image | Augmented image |
+----------------------------------------------+----------------------------------------------+
The grid is painted for visualization purposes.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
assert 0 <= rotate <= 1, "Rotate should be between 0 and 1. Provided {}".format(rotate)
assert 0 < ratio < 1, "Ratio should be between 0 and 1. Provided {}".format(ratio)
# Get spatial dims (h, w) regardless of 2D or 3D input
if img.ndim == 3:
h, w = img.shape[0], img.shape[1]
else: # (z, y, x, c)
h, w = img.shape[1], img.shape[2]
# Minimum square that fully covers the image after rotation
hh = int(math.ceil(math.sqrt(h * h + w * w)))
# Grid parameters
d = np.random.randint(int(d_range[0]), int(d_range[1])) # grid period
l = int(math.ceil(d * ratio)) # mask square size per period
# Build base mask
mask = np.ones((hh, hh), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
# Horizontal stripes
for i in range(-1, hh // d + 1):
s = d * i + st_h
t = s + l
s = max(min(s, hh), 0)
t = max(min(t, hh), 0)
if s < t:
mask[s:t, :] *= 0
# Vertical stripes
for i in range(-1, hh // d + 1):
s = d * i + st_w
t = s + l
s = max(min(s, hh), 0)
t = max(min(t, hh), 0)
if s < t:
mask[:, s:t] *= 0
# Rotation: interpret rotate in [0,1] as a fraction of 360°
max_deg = max(1, int(round(rotate * 360)))
r = np.random.randint(max_deg)
mask = Image.fromarray(np.uint8(mask * 255))
mask = mask.rotate(r)
mask = np.asarray(mask, dtype=np.float32) / 255.0
# Center-crop back to (h, w)
y0 = (hh - h) // 2
x0 = (hh - w) // 2
mask = mask[y0:y0 + h, x0:x0 + w]
if not invert:
mask = 1.0 - mask # keep same semantics as your original
# Apply
if img.ndim == 3:
# (y, x, c): broadcast mask over channel dim
return img * mask[:, :, None]
else:
# (z, y, x, c)
Z, H, W, C = img.shape
mask_4d = mask[None, :, :, None] # (1, H, W, 1)
out = img.copy()
# If z_size != -1, apply to a random contiguous z-block (size sampled if possible)
if z_size != -1:
# If provided, use d_range[2:4] as the z-block size range; else default to [1, Z]
if len(d_range) >= 4:
z_min = max(1, int(d_range[2]))
z_max = max(z_min, int(d_range[3]))
z_max = min(z_max, Z)
else:
z_min, z_max = 1, Z
block = np.random.randint(z_min, z_max + 1) # inclusive upper bound
if block > Z:
block = Z
start = np.random.randint(0, max(1, Z - block + 1))
end = start + block
out[start:end, :, :, :] = out[start:end, :, :, :] * mask_4d
return out
# Else: apply to all z-slices
out = out * mask_4d
return out
[docs]
def random_crop_pair(
image: NDArray,
mask: NDArray,
random_crop_size: Tuple[int, ...],
val: bool = False,
draw_prob_map_points: bool = False,
img_prob: Optional[NDArray] = None,
weight_map: Optional[NDArray] = None,
scale: Tuple[int, ...] = (1, 1),
) -> Union[
Tuple[NDArray, NDArray],
Tuple[NDArray, NDArray, NDArray],
Tuple[NDArray, NDArray, int, int, int, int],
]:
"""
Apply random crop for an image and its mask.
No crop is done in those dimensions that ``random_crop_size`` is greater than
the input image shape in those dimensions. For instance, if an input image is ``400x150`` and ``random_crop_size`` is ``224x224`` the resulting image will be ``224x150``.
Parameters
----------
image : Numpy 3D array
Image. E.g. ``(y, x, channels)``.
mask : Numpy 3D array
Image mask. E.g. ``(y, x, channels)``.
random_crop_size : 2 int tuple
Size of the crop. E.g. ``(height, width)``.
val : bool, optional
If the image provided is going to be used in the validation data. This forces to crop from the origin,
e. g. ``(0, 0)`` point.
draw_prob_map_points : bool, optional
To return the pixel chosen to be the center of the crop.
img_prob : Numpy 3D array, optional
Probability of each pixel to be chosen as the center of the crop. E. .g. ``(y, x, channels)``.
weight_map : bool, optional
Weight map of the given image. E.g. ``(y, x, channels)``.
scale : tuple of 2 ints, optional
Scale factor the second image given. E.g. ``(2,2)``.
Returns
-------
img : 2D Numpy array
Crop of the given image. E.g. ``(y, x, channels)``.
weight_map : 2D Numpy array, optional
Crop of the given image's weigth map. E.g. ``(y, x, channels)``.
ox : int, optional
X coordinate in the complete image of the chose central pixel to make the crop.
oy : int, optional
Y coordinate in the complete image of the chose central pixel to make the crop.
x : int, optional
X coordinate in the complete image where the crop starts.
y : int, optional
Y coordinate in the complete image where the crop starts.
"""
assert image.ndim == 3, f"Image must be 3D, got {image.shape}"
assert mask.ndim == 3, f"Mask must be 3D, got {mask.shape}"
assert len(random_crop_size) == 2, f"Random crop size must have 2 elements, got {random_crop_size}"
if weight_map is not None:
img, we = image
else:
img = image
height, width = img.shape[0], img.shape[1]
dy, dx = random_crop_size[0], random_crop_size[1]
if val:
y, x, oy, ox = 0, 0, 0, 0
else:
if img_prob is not None:
prob = img_prob.ravel()
# Generate the random coordinates based on the distribution
choices = np.prod(img_prob.shape)
index = np.random.choice(choices, size=1, p=prob)
coordinates = np.unravel_index(index, img_prob.shape)
x = int(coordinates[1][0])
y = int(coordinates[0][0])
ox = int(coordinates[1][0])
oy = int(coordinates[0][0])
# Adjust the coordinates to be the origin of the crop and control to
# not be out of the image
if y < int(random_crop_size[0] / 2):
y = 0
elif y > img.shape[0] - int(random_crop_size[0] / 2):
y = img.shape[0] - random_crop_size[0]
else:
y -= int(random_crop_size[0] / 2)
if x < int(random_crop_size[1] / 2):
x = 0
elif x > img.shape[1] - int(random_crop_size[1] / 2):
x = img.shape[1] - random_crop_size[1]
else:
x -= int(random_crop_size[1] / 2)
else:
oy, ox = 0, 0
x = np.random.randint(0, width - dx + 1) if width - dx + 1 > 0 else 0
y = np.random.randint(0, height - dy + 1) if height - dy + 1 > 0 else 0
# Super-resolution check
if any([x != 1 for x in scale]):
img_out_shape = img[y : (y + dy), x : (x + dx)].shape
mask_out_shape = mask[y * scale[0] : (y + dy) * scale[0], x * scale[1] : (x + dx) * scale[1]].shape
s = [img_out_shape[0] * scale[0], img_out_shape[1] * scale[1]]
if all(x != y for x, y in zip(s, mask_out_shape)):
raise ValueError(
"Images can not be cropped to a PATCH_SIZE of {}. Inputs: LR image shape={} "
"and HR image shape={}. When cropping the output shapes are {} and {}, for LR and HR images respectively. "
"Try to reduce DATA.PATCH_SIZE".format(
random_crop_size,
img.shape,
mask.shape,
img_out_shape,
mask_out_shape,
)
)
if draw_prob_map_points:
return (
img[y : (y + dy), x : (x + dx)],
mask[y * scale[0] : (y + dy) * scale[0], x * scale[1] : (x + dx) * scale[1]],
oy,
ox,
y,
x,
)
else:
if weight_map is not None:
return (
img[y : (y + dy), x : (x + dx)],
mask[
y * scale[0] : (y + dy) * scale[0],
x * scale[1] : (x + dx) * scale[1],
],
weight_map[y : (y + dy), x : (x + dx)],
)
else:
return (
img[y : (y + dy), x : (x + dx)],
mask[
y * scale[0] : (y + dy) * scale[0],
x * scale[1] : (x + dx) * scale[1],
],
)
[docs]
def random_3D_crop_pair(
image: NDArray,
mask: NDArray,
random_crop_size: Tuple[int, ...],
val: bool = False,
img_prob: Optional[NDArray] = None,
weight_map: Optional[NDArray] = None,
draw_prob_map_points: bool = False,
scale: Tuple[int, ...] = (1, 1, 1),
) -> Union[
Tuple[NDArray, NDArray],
Tuple[NDArray, NDArray, NDArray],
Tuple[NDArray, NDArray, int, int, int, int, int, int],
]:
"""
Extract a random 3D patch from the given image and mask.
No crop is done in those dimensions that ``random_crop_size`` is
greater than the input image shape in those dimensions. For instance, if an input image is ``10x400x150`` and ``random_crop_size``
is ``10x224x224`` the resulting image will be ``10x224x150``.
Parameters
----------
image : 4D Numpy array
Data to extract the patch from. E.g. ``(z, y, x, channels)``.
mask : 4D Numpy array
Data mask to extract the patch from. E.g. ``(z, y, x, channels)``.
random_crop_size : 3D int tuple
Shape of the patches to create. E.g. ``(z, y, x)``.
val : bool, optional
If the image provided is going to be used in the validation data. This forces to crop from the origin, e.g.
``(0, 0)`` point.
img_prob : Numpy 4D array, optional
Probability of each pixel to be chosen as the center of the crop. E. g. ``(z, y, x, channels)``.
weight_map : bool, optional
Weight map of the given image. E.g. ``(z, y, x, channels)``.
draw_prob_map_points : bool, optional
To return the voxel chosen to be the center of the crop.
scale : tuple of 3 ints, optional
Scale factor the second image given. E.g. ``(2,4,4)``.
Returns
-------
img : 4D Numpy array
Crop of the given image. E.g. ``(z, y, x, channels)``.
weight_map : 4D Numpy array, optional
Crop of the given image's weigth map. E.g. ``(z, y, x, channels)``.
oz : int, optional
Z coordinate in the complete image of the chose central pixel to
make the crop.
oy : int, optional
Y coordinate in the complete image of the chose central pixel to
make the crop.
ox : int, optional
X coordinate in the complete image of the chose central pixel to
make the crop.
z : int, optional
Z coordinate in the complete image where the crop starts.
y : int, optional
Y coordinate in the complete image where the crop starts.
x : int, optional
X coordinate in the complete image where the crop starts.
"""
assert image.ndim == 4, f"Image must be 4D, got {image.shape}"
assert mask.ndim == 4, f"Mask must be 4D, got {mask.shape}"
assert len(random_crop_size) == 3, f"Random crop size must have 3 elements, got {random_crop_size}"
if weight_map is not None:
vol, we = image
else:
vol = image
deep, cols, rows = vol.shape[0], vol.shape[1], vol.shape[2]
dz, dy, dx = random_crop_size
if val:
x, y, z, ox, oy, oz = 0, 0, 0, 0, 0, 0
else:
if img_prob is not None:
prob = img_prob.ravel()
# Generate the random coordinates based on the distribution
choices = np.prod(img_prob.shape)
index = np.random.choice(choices, size=1, p=prob)
coordinates = np.unravel_index(index, shape=img_prob.shape)
x = int(coordinates[2])
y = int(coordinates[1])
z = int(coordinates[0])
ox = int(coordinates[2])
oy = int(coordinates[1])
oz = int(coordinates[0])
# Adjust the coordinates to be the origin of the crop and control to
# not be out of the volume
if z < int(random_crop_size[0] / 2):
z = 0
elif z > vol.shape[0] - int(random_crop_size[0] / 2):
z = vol.shape[0] - random_crop_size[0]
else:
z -= int(random_crop_size[0] / 2)
if y < int(random_crop_size[1] / 2):
y = 0
elif y > vol.shape[1] - int(random_crop_size[1] / 2):
y = vol.shape[1] - random_crop_size[1]
else:
y -= int(random_crop_size[1] / 2)
if x < int(random_crop_size[2] / 2):
x = 0
elif x > vol.shape[2] - int(random_crop_size[2] / 2):
x = vol.shape[2] - random_crop_size[2]
else:
x -= int(random_crop_size[2] / 2)
else:
ox = 0
oy = 0
oz = 0
z = np.random.randint(0, deep - dz + 1) if deep - dz + 1 > 0 else 0
y = np.random.randint(0, cols - dy + 1) if cols - dy + 1 > 0 else 0
x = np.random.randint(0, rows - dx + 1) if rows - dx + 1 > 0 else 0
# Super-resolution check
if any([x != 1 for x in scale]):
img_out_shape = vol[z : (z + dz), y : (y + dy), x : (x + dx)].shape
mask_out_shape = mask[
z * scale[0] : (z + dz) * scale[0],
y * scale[1] : (y + dy) * scale[1],
x * scale[2] : (x + dx) * scale[2],
].shape
s = [
img_out_shape[0] * scale[0],
img_out_shape[1] * scale[1],
img_out_shape[2] * scale[2],
]
if all(x != y for x, y in zip(s, mask_out_shape)):
raise ValueError(
"Images can not be cropped to a PATCH_SIZE of {}. Inputs: LR image shape={} "
"and HR image shape={}. When cropping the output shapes are {} and {}, for LR and HR images respectively. "
"Try to reduce DATA.PATCH_SIZE".format(
random_crop_size,
vol.shape,
mask.shape,
img_out_shape,
mask_out_shape,
)
)
if draw_prob_map_points:
return (
vol[z : (z + dz), y : (y + dy), x : (x + dx)],
mask[
z * scale[0] : (z + dz) * scale[0],
y * scale[1] : (y + dy) * scale[1],
x * scale[2] : (x + dx) * scale[2],
],
oz,
oy,
ox,
z,
y,
x,
)
else:
if weight_map is not None:
return (
vol[z : (z + dz), y : (y + dy), x : (x + dx)],
mask[
z * scale[0] : (z + dz) * scale[0],
y * scale[1] : (y + dy) * scale[1],
x * scale[2] : (x + dx) * scale[2],
],
weight_map[z : (z + dz), y : (y + dy), x : (x + dx)],
)
else:
return (
vol[z : (z + dz), y : (y + dy), x : (x + dx)],
mask[z : (z + dz), y : (y + dy), x : (x + dx)],
)
[docs]
def random_crop_single(
image: NDArray,
random_crop_size: Tuple[int, ...],
val: bool = False,
draw_prob_map_points: bool = False,
weight_map: Optional[NDArray] = None,
) -> Union[
NDArray,
Tuple[NDArray, NDArray],
Tuple[NDArray, int, int, int, int],
]:
"""
Random crop for a single image.
No crop is done in those dimensions that ``random_crop_size`` is greater than
the input image shape in those dimensions. For instance, if an input image is ``400x150`` and ``random_crop_size`` is ``224x224`` the resulting image will be ``224x150``.
Parameters
----------
image : Numpy 3D array
Image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
random_crop_size : 2 int tuple
Size of the crop. E.g. ``(y, x)``.
val : bool, optional
If the image provided is going to be used in the validation data. This forces to crop from the origin,
e. g. ``(0, 0)`` point.
draw_prob_map_points : bool, optional
To return the pixel chosen to be the center of the crop.
weight_map : bool, optional
Weight map of the given image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Returns
-------
img : 2D Numpy array
Crop of the given image. E.g. ``(y, x)``.
weight_map : 2D Numpy array, optional
Crop of the given image's weigth map. E.g. ``(y, x)``.
oy : int, optional
Y coordinate in the complete image of the chose central pixel to make the crop.
ox : int, optional
X coordinate in the complete image of the chose central pixel to make the crop.
y : int, optional
Y coordinate in the complete image where the crop starts.
y : int, optional
X coordinate in the complete image where the crop starts.
"""
assert image.ndim == 3, f"Image must be 3D, got {image.shape}"
if weight_map is not None:
img, we = image
else:
img = image
height, width = img.shape[0], img.shape[1]
dy, dx = random_crop_size
if val:
x, y, z, ox, oy, oz = 0, 0, 0, 0, 0, 0
else:
oy, ox = 0, 0
x = np.random.randint(0, width - dx + 1) if width - dx + 1 > 0 else 0
y = np.random.randint(0, height - dy + 1) if height - dy + 1 > 0 else 0
if draw_prob_map_points:
return img[y : (y + dy), x : (x + dx)], ox, oy, x, y
else:
if weight_map is not None:
return (
img[y : (y + dy), x : (x + dx)],
weight_map[y : (y + dy), x : (x + dx)],
)
else:
return img[y : (y + dy), x : (x + dx)]
[docs]
def random_3D_crop_single(
image: NDArray,
random_crop_size: Tuple[int, ...],
val: bool = False,
draw_prob_map_points: bool = False,
weight_map: Optional[NDArray] = None,
) -> Union[
NDArray,
Tuple[NDArray, NDArray],
Tuple[NDArray, int, int, int, int, int, int],
]:
"""
Random crop for a single image.
No crop is done in those dimensions that ``random_crop_size`` is greater than
the input image shape in those dimensions. For instance, if an input image is ``50x400x150`` and ``random_crop_size`` is ``30x224x224`` the resulting image will be ``30x224x150``.
Parameters
----------
image : Numpy 3D array
Image. E.g. ``(z, y, x, channels)``.
random_crop_size : 2 int tuple
Size of the crop. E.g. ``(z, y, x)``.
val : bool, optional
If the image provided is going to be used in the validation data. This forces to crop from the origin,
e. g. ``(0, 0)`` point.
draw_prob_map_points : bool, optional
To return the pixel chosen to be the center of the crop.
weight_map : bool, optional
Weight map of the given image. E.g. ``(z, y, x, channels)``.
Returns
-------
img : 2D Numpy array
Crop of the given image. E.g. ``(z, y, x)``.
weight_map : 2D Numpy array, optional
Crop of the given image's weigth map. E.g. ``(z, y, x)``.
ox : int, optional
Z coordinate in the complete image of the chose central pixel to make the crop.
oy : int, optional
Y coordinate in the complete image of the chose central pixel to make the crop.
ox : int, optional
X coordinate in the complete image of the chose central pixel to make the crop.
z : int, optional
Z coordinate in the complete image where the crop starts.
y : int, optional
Y coordinate in the complete image where the crop starts.
x : int, optional
X coordinate in the complete image where the crop starts.
"""
assert image.ndim == 3, f"Image must be 3D, got {image.shape}"
if weight_map is not None:
img, we = image
else:
img = image
deep, cols, rows = img.shape[0], img.shape[1], img.shape[2]
dz, dy, dx = random_crop_size
if val:
x, y, z, ox, oy, oz = 0, 0, 0, 0, 0, 0
else:
ox = 0
oy = 0
oz = 0
z = np.random.randint(0, deep - dz + 1) if deep - dz + 1 > 0 else 0
y = np.random.randint(0, cols - dy + 1) if cols - dy + 1 > 0 else 0
x = np.random.randint(0, rows - dx + 1) if rows - dx + 1 > 0 else 0
if draw_prob_map_points:
return img[z : (z + dz), y : (y + dy), x : (x + dx)], oz, oy, ox, z, y, x
else:
if weight_map is not None:
return (
img[z : (z + dz), y : (y + dy), x : (x + dx)],
weight_map[z : (z + dz), y : (y + dy), x : (x + dx)],
)
else:
return img[z : (z + dz), y : (y + dy), x : (x + dx)]
[docs]
def center_crop_single(
img: NDArray,
crop_shape: Tuple[int, ...],
) -> NDArray:
"""
Extract the central patch from a single image.
Parameters
----------
img : 3D/4D array
Image. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``.
crop_shape : 2/3 int tuple
Size of the crop. E.g. ``(y, x)`` or ``(z, y, x)``.
Returns
-------
img : 3D/4D Numpy array
Center crop of the given image. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
if img.ndim == 4:
z, y, x, c = img.shape
startz = max(z // 2 - crop_shape[0] // 2, 0)
starty = max(y // 2 - crop_shape[1] // 2, 0)
startx = max(x // 2 - crop_shape[2] // 2, 0)
return img[
startz : startz + crop_shape[0],
starty : starty + crop_shape[1],
startx : startx + crop_shape[2],
]
else:
y, x, c = img.shape
starty = max(y // 2 - crop_shape[0] // 2, 0)
startx = max(x // 2 - crop_shape[1] // 2, 0)
return img[starty : starty + crop_shape[0], startx : startx + crop_shape[1]]
[docs]
def resize_img(img: NDArray, shape: Tuple[int, ...]) -> NDArray:
"""
Resize input image to given shape.
Parameters
----------
img : 3D/4D Numpy array
Data to extract the patch from. E.g. ``(y, x, channels)`` for ``2D`` or ``(z, y, x, channels)`` for ``3D``.
shape : 2D/3D int tuple
Shape to resize the image to. E.g. ``(y, x)`` for ``2D`` ``(z, y, x)`` for ``3D``.
Returns
-------
img : 3D/4D Numpy array
Resized image. E.g. ``(y, x, channels)`` for ``2D`` or ``(z, y, x, channels)`` for ``3D``.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
assert (len(shape) == 2 and img.ndim == 3) or (len(shape) == 3 and img.ndim == 4), (
"Shape is supposed to have 2 elements for 2D images and 3 elements for 3D images. "
"Provided {} shape for {} image instead".format(shape, img.shape)
)
return resize(
img,
shape,
order=1,
mode="reflect",
clip=True,
preserve_range=True,
anti_aliasing=True,
)
[docs]
def rotation(
img: NDArray,
mask: Optional[NDArray] = None,
heat: Optional[NDArray] = None,
angles: Union[Tuple[int, int], List[int]] = [],
mode: str = "reflect",
mask_type: str = "as_mask",
) -> Union[
NDArray,
Tuple[NDArray, Optional[NDArray], Optional[NDArray]],
]:
"""
Apply a rotation to input ``image`` and ``mask`` (if provided).
Parameters
----------
img : 3D/4D Numpy array
Image to rotate. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Mask to rotate. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to rotate. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
angles : List of ints, optional
List of angles to choose the rotation to be made. E.g. [90,180,360].
mode : str, optional
How to fill up the new values created. Options: ``constant``, ``reflect``, ``wrap``, ``symmetric``.
mask_type : str, optional
How to treat the mask during interpolation. Either as "as_mask" (order 0) or "as_image" (order 1).
Returns
-------
img : 3D/4D Numpy array
Rotated image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Rotated mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Rotated heatmap. Returned if ``mask`` is provided. E.g. ``(y, x, channels)`` for ``2D`` or
``(y, x, z, channels)`` for ``3D``.
"""
assert img.ndim in (3, 4), f"Image must be 3D or 4D, got shape {img.shape}"
if mask is not None:
assert mask.ndim in (3, 4), f"Mask must be 3D or 4D, got shape {mask.shape}"
if heat is not None:
assert heat.ndim in (3, 4), f"Heat must be 3D or 4D, got shape {heat.shape}"
# --- pick angle ---
if not angles:
angle = float(np.random.uniform(0.0, 360.0))
elif isinstance(angles, tuple):
assert len(angles) == 2, "If a tuple is provided it must have length 2"
lo, hi = float(angles[0]), float(angles[1])
if lo > hi:
lo, hi = hi, lo
angle = float(np.random.uniform(lo, hi))
elif isinstance(angles, list):
angle = float(random.choice(angles))
else:
raise ValueError("angles must be a list or a tuple")
# Map "symmetric" to SciPy's "mirror"
_mode = "mirror" if mode == "symmetric" else mode
# axes for (y, x) rotation
axes_img = (1, 0) if img.ndim == 3 else (2, 1)
def _rotate(arr: NDArray, axes: Tuple[int, int], order: int) -> NDArray:
arr_mins, arr_maxes = np.min(arr, axis=tuple(range(arr.ndim - 1))), np.max(arr, axis=tuple(range(arr.ndim - 1)))
orig_dtype = arr.dtype
out = rotate(
arr.astype(np.float32, copy=False),
angle=angle,
axes=axes,
reshape=False,
order=order,
mode=_mode,
)
# Cast back to original dtype
if np.issubdtype(orig_dtype, np.floating):
return out.astype(orig_dtype, copy=False)
np.clip(out, arr_mins, arr_maxes, out=out)
return out.astype(orig_dtype, copy=False)
# Image (bilinear)
img_out = _rotate(img, axes_img, order=1)
# Mask
mask_out = None
if mask is not None:
order_mask = 0 if mask_type == "as_mask" else 1
mask_out = _rotate(mask, axes_img, order=order_mask)
# Heat
heat_out = None
if heat is not None:
heat_out = _rotate(heat, axes_img, order=1)
return img_out if mask is None and heat is None else (img_out, mask_out, heat_out)
[docs]
def zoom(
img: NDArray,
zoom_range: Tuple[float, ...],
mask: Optional[NDArray] = None,
heat: Optional[NDArray] = None,
zoom_in_z: bool = False,
mode: str = "reflect",
mask_type: str = "as_mask",
) -> Union[
NDArray,
Tuple[NDArray, Optional[NDArray], Optional[NDArray]],
]:
"""
Apply zoom to input ``image`` and ``mask`` (if provided).
Parameters
----------
img : 3D/4D Numpy array
Image to rotate. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
zoom_range : tuple of floats
Defines minimum and maximum factors to scale the images. E.g. (0.8, 1.2).
mask : 3D/4D Numpy array, optional
Mask to rotate. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to rotate. E.g. ``(y, x, channels)`` for ``2D`` or
``(y, x, z, channels)`` for ``3D``.
zoom_in_z: bool, optional
Whether to apply or not zoom in Z axis.
mode : str, optional
How to fill up the new values created. Options: ``constant``, ``reflect``, ``wrap``, ``symmetric``.
mask_type : str, optional
How to treat the mask during interpolation. Either as "as_mask" (order 0) or "as_image" (order 1).
Returns
-------
img : 3D/4D Numpy array
Zoomed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Zoomed mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Zoomed heatmap. Returned if ``mask`` is provided. E.g. ``(y, x, channels)`` for ``2D`` or
``(y, x, z, channels)`` for ``3D``.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
if mask is not None:
assert mask.ndim in [3, 4], f"Mask must be 3D or 4D, got shape {mask.shape}"
if heat is not None:
assert heat.ndim in [3, 4], f"Heatmap must be 3D or 4D, got shape {heat.shape}"
assert len(zoom_range) == 2, f"Zoom range is supposed to have 2 elements but provided {zoom_range} instead"
assert zoom_range[0] <= zoom_range[1], "First element of zoom range must be lower than the second one"
assert zoom_range[0] > 0, "Zoom range values must be greater than 0"
zoom_selected = random.uniform(zoom_range[0], zoom_range[1])
mask_order = 0 if mask_type == "as_mask" else 1
if img.ndim == 4:
z_zoom = zoom_selected if zoom_in_z else 1
img_shape = [
int(img.shape[0] * zoom_selected),
int(img.shape[1] * zoom_selected),
int(img.shape[2] * z_zoom),
]
if mask is not None:
mask_shape = [
int(mask.shape[0] * zoom_selected),
int(mask.shape[1] * zoom_selected),
int(mask.shape[2] * z_zoom),
]
else:
img_shape = [
int(img.shape[0] * zoom_selected),
int(img.shape[1] * zoom_selected),
]
if mask is not None:
mask_shape = [
int(mask.shape[0] * zoom_selected),
int(mask.shape[1] * zoom_selected),
]
img_shape += [img.shape[-1],]
if mask is not None:
mask_shape += [mask.shape[-1],] # type: ignore
if img_shape != img.shape:
img_orig_shape = img.shape
img = resize(
img,
img_shape,
order=1,
mode=mode,
clip=True,
preserve_range=True,
anti_aliasing=True,
)
if mask is not None:
mask_orig_shape = mask.shape
mask = resize(
mask,
mask_shape,
order=mask_order,
mode=mode,
clip=True,
preserve_range=True,
anti_aliasing=True,
)
if heat is not None:
heat = resize(
heat,
img_shape[:-1],
order=1,
mode=mode,
clip=True,
preserve_range=True,
anti_aliasing=True,
)
if zoom_selected >= 1:
img = center_crop_single(img, img_orig_shape)
if mask is not None:
mask = center_crop_single(mask, mask_orig_shape)
if heat is not None:
heat = center_crop_single(heat, img_orig_shape[:-1])
else:
if img.ndim == 4:
img_pad_tup = (
(
int((img_orig_shape[0] - img_shape[0]) // 2),
math.ceil((img_orig_shape[0] - img_shape[0]) / 2),
),
(
int((img_orig_shape[1] - img_shape[1]) // 2),
math.ceil((img_orig_shape[1] - img_shape[1]) / 2),
),
(
int((img_orig_shape[2] - img_shape[2]) // 2),
math.ceil((img_orig_shape[2] - img_shape[2]) / 2),
),
(0, 0),
)
if mask is not None:
mask_pad_tup = (
(
int((mask_orig_shape[0] - mask_shape[0]) // 2),
math.ceil((mask_orig_shape[0] - mask_shape[0]) / 2),
),
(
int((mask_orig_shape[1] - mask_shape[1]) // 2),
math.ceil((mask_orig_shape[1] - mask_shape[1]) / 2),
),
(
int((mask_orig_shape[2] - mask_shape[2]) // 2),
math.ceil((mask_orig_shape[2] - mask_shape[2]) / 2),
),
(0, 0),
)
else:
img_pad_tup = (
(
int((img_orig_shape[0] - img_shape[0]) // 2),
math.ceil((img_orig_shape[0] - img_shape[0]) / 2),
),
(
int((img_orig_shape[1] - img_shape[1]) // 2),
math.ceil((img_orig_shape[1] - img_shape[1]) / 2),
),
(0, 0),
)
if mask is not None:
mask_pad_tup = (
(
int((mask_orig_shape[0] - mask_shape[0]) // 2),
math.ceil((mask_orig_shape[0] - mask_shape[0]) / 2),
),
(
int((mask_orig_shape[1] - mask_shape[1]) // 2),
math.ceil((mask_orig_shape[1] - mask_shape[1]) / 2),
),
(0, 0),
)
img = np.pad(img, img_pad_tup, mode) # type: ignore
if mask is not None:
mask = np.pad(mask, mask_pad_tup, mode) # type: ignore
if heat is not None:
heat = np.pad(heat, img_pad_tup, mode) # type: ignore
if mask is None:
return img
else:
return img, mask, heat
[docs]
def gamma_contrast(img: NDArray, gamma: Tuple[float, float] = (0, 1)) -> NDArray:
"""
Apply gamma contrast to input ``image``.
Parameters
----------
img : Numpy array
Image to transform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
gamma : tuple of 2 floats, optional
Range of gamma intensity. E.g. ``(0.8, 1.3)``.
Returns
-------
img : Numpy array
Transformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
assert img.ndim in [3, 4], f"Image must be 3D or 4D, got shape {img.shape}"
assert len(gamma) == 2, "Gamma is supposed to have 2 elements but provided {} instead".format(gamma)
assert gamma[0] <= gamma[1], "First element of gamma must be lower than the second one"
assert gamma[0] > 0, "Gamma values must be greater than 0"
_gamma = random.uniform(gamma[0], gamma[1])
return adjust_gamma(np.clip(img, 0, 1), gamma=_gamma) # type: ignore
[docs]
def shear(
image: NDArray,
shear: tuple,
mask: Optional[NDArray] = None,
heat: Optional[NDArray] = None,
cval: float = 0,
mask_type: str = "as_mask",
mode: str = "constant",
):
"""
Apply a shear transformation to an image (and optional mask/heatmap).
Parameters
----------
image : 3D/4D Numpy array
Image to shear. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Mask to shear. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to shear. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
shear : tuple
Shear range (min, max) in degrees for both x and y directions.
cval : float
Value used for points outside the boundaries.
mask_type : str
How to treat the mask during interpolation. Either as "as_mask" (order 0) or "as_image" (order 1).
mode : str
Points outside boundaries are filled according to this mode.
Returns
-------
img : 3D/4D Numpy array
Sheared image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Sheared mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Sheared heatmap. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
# Random shear (deg)
shear_x = random.randint(shear[0], shear[1])
shear_y = random.randint(shear[0], shear[1])
def _restore_channels(original, warped):
# If a single-channel input comes back as (H,W), expand to (H,W,1)
if original is not None and original.ndim >= 3 and original.shape[-1] == 1 and warped.ndim == 2:
return warped[..., np.newaxis]
return warped
def _warp_hwc(arr_hwc: NDArray, tform, order: int, cval: float, mode: str) -> NDArray:
H, W = arr_hwc.shape[:2]
orig_dtype = arr_hwc.dtype
out = warp(
arr_hwc,
inverse_map=tform,
cval=cval,
mode=mode,
order=order,
output_shape=(H, W),
preserve_range=True, # keep original value range
)
# For masks (nearest/bilinear), cast back to original dtype
if orig_dtype.kind != "f":
out = out.astype(orig_dtype, copy=False)
# Ensure 3D with channel axis if a single channel was reduced
out = _restore_channels(arr_hwc, out)
return out
# Get spatial size and build transform on (H,W)
if image.ndim == 3: # (y, x, c)
H_img, W_img = image.shape[:2]
tform = _build_shear_matrix_skimage((H_img, W_img), np.deg2rad(shear_x), np.deg2rad(shear_y))
img = _warp_hwc(image, tform, order=3, cval=cval, mode=mode)
m = None
if mask is not None:
H_mask, W_mask = mask.shape[:2]
# same transform (center/size differs? -> rebuild with mask size)
tform_m = _build_shear_matrix_skimage((H_mask, W_mask), np.deg2rad(shear_x), np.deg2rad(shear_y))
mask_order = 0 if mask_type == "as_mask" else 1
m = _warp_hwc(mask, tform_m, order=mask_order, cval=cval, mode=mode)
h = None
if heat is not None:
H_heat, W_heat = heat.shape[:2]
heat_mins, heat_maxes = np.min(heat, axis=tuple(range(heat.ndim - 1))), np.max(heat, axis=tuple(range(heat.ndim - 1)))
tform_h = _build_shear_matrix_skimage((H_heat, W_heat), np.deg2rad(shear_x), np.deg2rad(shear_y))
h = _warp_hwc(heat, tform_h, order=3, cval=cval, mode=mode)
np.clip(h, heat_mins, heat_maxes, out=h)
return img, m, h
elif image.ndim == 4: # (z, y, x, c)
Z, H, W, C = image.shape
tform = _build_shear_matrix_skimage((H, W), np.deg2rad(shear_x), np.deg2rad(shear_y))
# Image
img_out = np.empty_like(image)
for z in range(Z):
img_out[z] = _warp_hwc(image[z], tform, order=3, cval=cval, mode=mode)
# Mask
m_out = None
if mask is not None:
Zm, Hm, Wm, Cm = mask.shape
assert (Zm, Hm, Wm) == (Z, H, W), "mask shape must match image (z,y,x)"
tform_m = _build_shear_matrix_skimage((Hm, Wm), np.deg2rad(shear_x), np.deg2rad(shear_y))
mask_order = 0 if mask_type == "as_mask" else 1
m_out = np.empty_like(mask)
for z in range(Z):
m_out[z] = _warp_hwc(mask[z], tform_m, order=mask_order, cval=cval, mode=mode)
# Heat
h_out = None
if heat is not None:
Zh, Hh, Wh, Ch = heat.shape
assert (Zh, Hh, Wh) == (Z, H, W), "heat shape must match image (z,y,x)"
tform_h = _build_shear_matrix_skimage((Hh, Wh), np.deg2rad(shear_x), np.deg2rad(shear_y))
h_out = np.empty_like(heat)
for z in range(Z):
h_out[z] = _warp_hwc(heat[z], tform_h, order=3, cval=cval, mode=mode)
return img_out, m_out, h_out
else:
raise ValueError(f"Unsupported image ndim: {image.ndim} (expected 3 or 4)")
[docs]
def shift(
image: NDArray,
mask: Optional[NDArray] = None,
heat: Optional[NDArray] = None,
shift_range: Optional[tuple] = None,
cval: float = 0,
mask_type: str = "as_mask",
mode: str = "constant",
):
"""
Shift an image (and optional mask/heatmap) by a random amount within a range.
Parameters
----------
image : 3D/4D Numpy array
Image to shift. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Mask to shift. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to shift. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
shift_range : Optional[tuple]
Range (min, max) for random shift in both x and y directions.
cval : float
Value used for points outside the boundaries.
mask_type : str
How to treat the mask during interpolation. Either as "as_mask" (order 0) or "as_image" (order 1).
mode : str
Points outside boundaries are filled according to this mode.
Returns
-------
img : 3D/4D Numpy array
Shifted image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Shifted mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Shifted heatmap. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
if mask is not None:
assert mask.ndim in (3, 4), f"Mask must be 3D or 4D, got {mask.shape}"
if heat is not None:
assert heat.ndim in (3, 4), f"Heat must be 3D or 4D, got {heat.shape}"
assert shift_range is not None and len(shift_range) == 2, \
f"shift_range must be (min, max); got {shift_range}"
# Get spatial size for (y, x)
if image.ndim == 3: # (y, x, c)
h, w = image.shape[:2]
else: # (z, y, x, c)
h, w = image.shape[1:3]
# Sample a percentage and convert to pixel shifts
shift_perc = random.uniform(shift_range[0], shift_range[1])
x_pix = int(round(shift_perc * w))
y_pix = int(round(shift_perc * h))
# Build per-array shift tuples (keep z and c fixed)
def get_shift_tuple(arr, x, y):
if arr.ndim == 3: # (y, x, c)
return (y, x, 0)
elif arr.ndim == 4: # (z, y, x, c)
return (0, y, x, 0)
else:
raise ValueError(f"Unsupported ndim: {arr.ndim}")
# Shift image
img = shift_nd(image, get_shift_tuple(image, x_pix, y_pix),
order=3, mode=mode, cval=cval)
# Shift mask
if mask is not None:
order_mask = 0 if mask_type == "as_mask" else 1
mask = shift_nd(mask, get_shift_tuple(mask, x_pix, y_pix),
order=order_mask, mode=mode, cval=cval)
# Shift heatmap
if heat is not None:
heat_mins, heat_maxes = np.min(heat, axis=tuple(range(heat.ndim - 1))), np.max(heat, axis=tuple(range(heat.ndim - 1)))
heat = shift_nd(heat, get_shift_tuple(heat, x_pix, y_pix),
order=3, mode=mode, cval=cval)
np.clip(heat, heat_mins, heat_maxes, out=heat)
return img, mask, heat
[docs]
def flip_horizontal(image: NDArray, mask: Optional[NDArray] = None, heat: Optional[NDArray] = None):
"""
Flip an image (and optional mask/heatmap) horizontally (left-right).
Parameters
----------
image : 3D/4D Numpy array
Image to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Mask to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Returns
-------
img : 3D/4D Numpy array
Flipped image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array or None
Flipped mask if provided, else None.
heat : 3D/4D Numpy array or None
Flipped heatmap if provided, else None.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
if mask is not None:
assert mask.ndim in (3, 4), f"Mask must be 3D or 4D, got {mask.shape}"
if heat is not None:
assert heat.ndim in (3, 4), f"Heatmap must be 3D or 4D, got {heat.shape}"
if image.ndim == 3:
img = image[::-1]
mask = mask[::-1] if mask is not None else None
heat = heat[::-1] if heat is not None else None
else:
img = image[:, ::-1]
mask = mask[:, ::-1] if mask is not None else None
heat = heat[:, ::-1] if heat is not None else None
return img, mask, heat
[docs]
def flip_vertical(image: NDArray, mask: Optional[NDArray] = None, heat: Optional[NDArray] = None):
"""
Flip an image (and optional mask/heatmap) vertically (up-down).
Parameters
----------
image : 3D/4D Numpy array
Image to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Mask to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
Returns
-------
img : 3D/4D Numpy array
Flipped image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Flipped mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Flipped heatmap. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
if mask is not None:
assert mask.ndim in (3, 4), f"Mask must be 3D or 4D, got {mask.shape}"
if heat is not None:
assert heat.ndim in (3, 4), f"Heatmap must be 3D or 4D, got {heat.shape}"
if image.ndim == 3:
img = image[:, ::-1]
mask = mask[:, ::-1] if mask is not None else None
heat = heat[:, ::-1] if heat is not None else None
else:
img = image[:, :, ::-1]
mask = mask[:, :, ::-1] if mask is not None else None
heat = heat[:, :, ::-1] if heat is not None else None
return img, mask, heat
[docs]
def gaussian_blur(image: NDArray, sigma: float | tuple = (0.5, 1.5)):
"""
Apply Gaussian blur to an image.
Parameters
----------
image : Numpy array
Image to Blur. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
sigma : float or tuple
Standard deviation for Gaussian kernel. If tuple, a random value is chosen from the range.
Returns
-------
img : NDArray
Blurred image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
# Needed for elastic as integer
if isinstance(sigma, tuple):
sigma = random.uniform(sigma[0], sigma[1])
return gaussian(image, sigma=sigma)
[docs]
def motion_blur(image: NDArray, k_range: Optional[tuple] = None):
"""
Apply motion blur to an image.
Parameters
----------
image : 3D/4D Numpy array
Image to flip. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
k_range : Optional[tuple]
Range (min, max) for random kernel size (must be odd).
Returns
-------
img : NDArray
Blurred image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
if k_range is None or len(k_range) != 2:
raise ValueError("k_range must be provided and have length 2")
if image.size == 0:
return image
# Sample kernel size (must be odd)
k = int(random.randint(k_range[0], k_range[1]))
if k % 2 == 0:
k += 1
# Sample motion direction (angle) and intensity direction
angle = int(random.randint(0, 359))
direction = float(random.uniform(-1.0, 1.0))
# Build a vertical line kernel, then rotate it
base = np.zeros((k, k), dtype=np.float32)
base[:, k // 2] = np.linspace(direction, 1.0 - direction, num=k).astype(np.float32)
# Rotate to the sampled angle; normalize
kernel_rot = rotate(base, angle=angle)
kernel = kernel_rot.astype(np.float32)
s = kernel.sum()
if s != 0.0:
kernel /= s
else:
kernel[k // 2, k // 2] = 1.0 # degenerate fallback
# Apply blur
if image.ndim == 3: # (y, x, c)
H, W, C = image.shape
out = np.empty_like(image)
for c in range(C):
out[..., c] = cv2.filter2D(image[..., c], ddepth=-1, kernel=kernel)
return out
else: # (z, y, x, c)
Z, H, W, C = image.shape
out = np.empty_like(image)
for z in range(Z):
for c in range(C):
out[z, :, :, c] = cv2.filter2D(image[z, :, :, c], ddepth=-1, kernel=kernel)
return out
[docs]
def dropout(
image: NDArray, drop_range: tuple = (0.1, 0.2), random_state: Optional[np.random.RandomState] = None
) -> NDArray:
"""
Randomly set a fraction of pixels in the image to zero (dropout).
Parameters
----------
image : 3D/4D Numpy array
Image to apply dropout. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
drop_range : tuple
Range for dropout probability. A value is randomly chosen from this range.
random_state : Optional[np.random.RandomState]
Random state for reproducibility.
Returns
-------
img : 3D/4D Numpy array
Image after dropout. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
assert len(drop_range) == 2, f"Drop range must have 2 elements, got {drop_range}"
rng = np.random if random_state is None else random_state
p = rng.uniform(*drop_range)
if p == 0:
return image
if image.ndim == 2:
mask_shape = image.shape
else:
mask_shape = image.shape[:-1]
keep_mask = rng.binomial(1, 1.0 - p, size=mask_shape).astype(image.dtype)
if image.ndim > len(mask_shape):
keep_mask = np.expand_dims(keep_mask, axis=-1)
image = image * keep_mask
return image
[docs]
def elastic(
image: NDArray,
mask: Optional[NDArray] = None,
heat: Optional[NDArray] = None,
alpha: float | tuple = 14,
sigma: float = 4,
mask_type: str = "as_mask",
cval: float = 0,
mode: str = "constant",
random_seed=None,
) -> Tuple[NDArray, Optional[NDArray], Optional[NDArray]]:
"""
Apply elastic deformation to an image (and optional mask/heatmap).
Parameters
----------
image : 3D/4D Numpy array
Image to deform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Mask to deform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Heatmap (float mask) to deform. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
alpha : float, optional
Scaling factor for deformation intensity.
sigma : float, optional
Standard deviation for Gaussian filter.
cval : float, optional
Value used for points outside the boundaries.
mode : str, optional
Points outside boundaries are filled according to this mode.
mask_type : str, optional
How to treat the mask during interpolation. Either as "as_mask" (order 0) or "as_image" (order 1).
random_seed : int, optional
Random seed for reproducibility.
Returns
-------
img : 3D/4D Numpy array
Deformed image. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
mask : 3D/4D Numpy array, optional
Deformed mask. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
heat : 3D/4D Numpy array, optional
Deformed heatmap. E.g. ``(y, x, channels)`` for ``2D`` or ``(y, x, z, channels)`` for ``3D``.
"""
assert image.ndim in (3, 4), f"Image must be 3D or 4D, got {image.shape}"
if mask is not None:
assert mask.ndim in (3, 4), f"Mask must be 3D or 4D, got {mask.shape}"
if heat is not None:
assert heat.ndim in (3, 4), f"Heatmap must be 3D or 4D, got {heat.shape}"
if random_seed is not None:
np.random.seed(random_seed)
def warp_with_new_displacement(
tensor: NDArray,
alpha: float,
sigma: float,
order: int,
cval: float,
mode: str
) -> NDArray:
"""
Apply elastic deformation to a tensor.
Works for (H, W, C) and (Z, H, W, C). The displacement is generated in (H, W)
and broadcast to all channels (and all z-slices if 4D).
"""
assert tensor.ndim in (3, 4), f"Expected 3D or 4D tensor, got {tensor.shape}"
# Pick spatial (H, W) depending on layout
if tensor.ndim == 3: # (H, W, C)
H, W = tensor.shape[:2]
else: # (Z, H, W, C)
H, W = tensor.shape[1:3]
# Choose padding kernel size based on sigma (same logic as before)
if sigma < 3.0:
ksize = 3.3 * sigma # ~99% weight
elif sigma < 5.0:
ksize = 2.9 * sigma # ~97% weight
else:
ksize = 2.6 * sigma # ~95% weight
ksize = int(max(ksize, 5))
ksize = ksize + 1 if (ksize % 2 == 0) else ksize
padding = ksize
# Build padded random fields, smooth them, then crop back to (H, W)
H_pad = H + 2 * padding
W_pad = W + 2 * padding
rng = np.random.rand(2 * H_pad, W_pad).astype(np.float32) * 2 - 1
dx_unsmoothed = rng[:H_pad, :]
dy_unsmoothed = rng[H_pad:, :]
# Use skimage.filters.gaussian (already imported as `gaussian`)
dx = gaussian(dx_unsmoothed, sigma=sigma).astype(np.float32) * alpha
dy = gaussian(dy_unsmoothed, sigma=sigma).astype(np.float32) * alpha
if padding > 0:
dx = dx[padding:-padding, padding:-padding]
dy = dy[padding:-padding, padding:-padding]
# Let _map_coordinates handle both 3D and 4D layouts
return _map_coordinates(tensor, dx, dy, order=order, cval=cval, mode=mode)
alphas, sigmas = _draw_samples(alpha, sigma, nb_images=1)
alpha_val = alphas[0]
sigma_val = sigmas[0]
img = warp_with_new_displacement(image, alpha_val, sigma_val, 3, cval, mode)
if mask is not None:
mask_order = 0 if mask_type == "as_mask" else 1
mask = warp_with_new_displacement(mask, alpha_val, sigma_val, mask_order, cval, mode)
if heat is not None:
heat_mins, heat_maxes = np.min(heat, axis=tuple(range(heat.ndim - 1))), np.max(heat, axis=tuple(range(heat.ndim - 1)))
heat = warp_with_new_displacement(heat, alpha_val, sigma_val, 3, cval, mode)
np.clip(heat, heat_mins, heat_maxes, out=heat)
return img, mask, heat
## Helpers
def _build_shear_matrix_skimage(image_shape: tuple, shear_x_rad: float, shear_y_rad: float, shift_add: tuple = (0.5, 0.5)) -> ProjectiveTransform:
"""
Build an affine transformation matrix for shear augmentation using skimage.
Parameters
----------
image_shape : tuple
Shape of the image (height, width, ...).
shear_x_rad : float
Shear angle in radians for the x direction.
shear_y_rad : float
Shear angle in radians for the y direction.
shift_add : tuple, optional
Additional shift to apply when centering the transformation.
Returns
-------
matrix : AffineTransform
Affine transformation matrix for shear.
"""
h, w = image_shape[:2]
if h == 0 or w == 0:
return AffineTransform()
shift_y = h / 2.0 - shift_add[0]
shift_x = w / 2.0 - shift_add[1]
matrix_to_topleft = AffineTransform(translation=[-shift_x, -shift_y])
matrix_to_center = AffineTransform(translation=[shift_x, shift_y])
matrix_shear_x = AffineTransform(shear=shear_x_rad)
matrix_shear_y_rot = AffineTransform(rotation=-np.pi / 2)
matrix_shear_y = AffineTransform(shear=shear_y_rad)
matrix_shear_y_rot_inv = AffineTransform(rotation=np.pi / 2)
# Correct order: shear_x then shear_y (via rotated frame)
matrix = (
matrix_to_topleft
+ matrix_shear_x
+ matrix_shear_y_rot
+ matrix_shear_y
+ matrix_shear_y_rot_inv
+ matrix_to_center
)
return matrix
def _normalize_cv2_input_arr_(arr: NDArray) -> NDArray:
"""
Ensure array is contiguous and owns its data for cv2 functions.
Parameters
----------
arr : NDArray
Input array.
"""
flags = arr.flags
if not flags["OWNDATA"]:
arr = np.copy(arr)
flags = arr.flags
if not flags["C_CONTIGUOUS"]:
arr = np.ascontiguousarray(arr)
return arr
def _draw_samples(alpha: float | tuple, sigma: float | tuple, nb_images: int) -> tuple:
"""
Draw samples for alpha and sigma parameters.
Parameters
----------
alpha : float or tuple
Alpha parameter or range (min, max).
sigma : float or tuple
Sigma parameter or range (min, max).
nb_images : int
Number of samples to draw.
Returns
-------
alphas : NDArray
Array of drawn alpha values.
sigmas : NDArray
Array of drawn sigma values.
"""
# Use np.random for all randomness
def draw_param(param: float | tuple, size: tuple) -> NDArray:
if isinstance(param, (int, float)):
out = np.full(size, param)
elif isinstance(param, str):
out = np.array([param] * size[0], dtype=object)
elif isinstance(param, tuple):
out = np.random.uniform(param[0], param[1], size=size) if len(param) == 2 else np.full(size, param[0])
else:
out = np.full(size, param)
return out
alphas = draw_param(alpha, (nb_images,))
sigmas = draw_param(sigma, (nb_images,))
return alphas, sigmas
_MAPPING_MODE_SCIPY_CV2 = {
"constant": cv2.BORDER_CONSTANT,
"edge": cv2.BORDER_REPLICATE,
"symmetric": cv2.BORDER_REFLECT,
"reflect": cv2.BORDER_REFLECT_101,
"wrap": cv2.BORDER_WRAP,
"nearest": cv2.BORDER_REPLICATE,
}
_MAPPING_ORDER_SCIPY_CV2 = {
0: cv2.INTER_NEAREST,
1: cv2.INTER_LINEAR,
2: cv2.INTER_CUBIC,
3: cv2.INTER_CUBIC,
4: cv2.INTER_CUBIC,
5: cv2.INTER_CUBIC,
}
def _map_coordinates(image: NDArray, dx: NDArray, dy: NDArray, order: int = 1, cval: float = 0, mode: str = "constant") -> NDArray:
"""
Map input image to new coordinates defined by displacement fields dx and dy.
Parameters
----------
image : NDArray
Input image array.
dx : NDArray
Displacement field in x direction.
dy : NDArray
Displacement field in y direction.
order : int
Interpolation order.
cval : float
Value used for points outside the boundaries.
mode : str
Points outside boundaries are filled according to this mode.
Returns
-------
result : NDArray
Transformed image array.
"""
if image.size == 0:
return np.copy(image)
dx = dx.astype(np.float32)
dy = dy.astype(np.float32)
if order == 0 and image.dtype.name in ["uint64", "int64"]:
raise Exception(
"dtypes uint64 and int64 are only supported in "
"ElasticTransformation for order=0, got order=%d with "
"dtype=%s." % (order, image.dtype.name)
)
assert image.ndim in (3, 4), f"Expected 3D or 4D image, got {image.ndim}D with shape {image.shape}"
# cv2 params
border_mode = _MAPPING_MODE_SCIPY_CV2[mode]
interpolation = _MAPPING_ORDER_SCIPY_CV2[order]
if image.dtype.kind == "f":
cval_cast = float(cval)
else:
cval_cast = int(cval)
def _make_maps(h: int, w: int, dx2: NDArray, dy2: NDArray):
"""Build OpenCV remap maps for a single 2D field."""
y, x = np.meshgrid(
np.arange(h, dtype=np.float32),
np.arange(w, dtype=np.float32),
indexing="ij",
)
x_shifted = x - dx2
y_shifted = y - dy2
if interpolation == cv2.INTER_NEAREST:
return x_shifted, y_shifted
else:
# returns (map1, map2) as optimized fixed-point/float maps
return cv2.convertMaps(x_shifted, y_shifted, cv2.CV_32FC1, nninterpolation=False)
def _remap_hwcn(arr_hwc: NDArray, map1: NDArray, map2: NDArray) -> NDArray:
"""
Apply cv2.remap to (H, W, C) with any C (remap supports up to 4 channels at once).
"""
H, W, C = arr_hwc.shape
if C <= 4:
border_val = (cval_cast,) * min(max(C, 1), 4)
res = cv2.remap(
_normalize_cv2_input_arr_(arr_hwc),
map1,
map2,
interpolation=interpolation,
borderMode=border_mode,
borderValue=border_val,
)
if res.ndim == 2:
res = res[..., np.newaxis]
return res
# chunk channels in groups of up to 4
chunks = []
for i in range(0, C, 4):
sub = arr_hwc[:, :, i : i + 4]
border_val = (cval_cast,) * (sub.shape[-1])
res = cv2.remap(
_normalize_cv2_input_arr_(sub),
map1,
map2,
interpolation=interpolation,
borderMode=border_mode,
borderValue=border_val,
)
if res.ndim == 2:
res = res[..., np.newaxis]
chunks.append(res)
return np.concatenate(chunks, axis=2)
if image.ndim == 3:
# (H, W, C)
H, W, C = image.shape
# accept dx/dy as (H,W)
assert dx.shape == (H, W) and dy.shape == (H, W), \
f"For 3D image (H,W,C), dx/dy must be (H,W); got dx {dx.shape}, dy {dy.shape}"
map1, map2 = _make_maps(H, W, dx, dy)
return _remap_hwcn(np.copy(image), map1, map2)
else:
# (Z, H, W, C)
Z, H, W, C = image.shape
result = np.empty_like(image)
# dx/dy: either (H,W) or (Z,H,W)
per_slice = dx.ndim == 3 and dy.ndim == 3
if per_slice:
assert dx.shape == (Z, H, W) and dy.shape == (Z, H, W), \
f"For per-slice fields, dx/dy must be (Z,H,W); got dx {dx.shape}, dy {dy.shape}"
else:
assert dx.shape == (H, W) and dy.shape == (H, W), \
f"For broadcast fields, dx/dy must be (H,W); got dx {dx.shape}, dy {dy.shape}"
# precompute shared maps once
shared_map1, shared_map2 = _make_maps(H, W, dx, dy)
for z in range(Z):
if per_slice:
map1, map2 = _make_maps(H, W, dx[z], dy[z])
else:
map1, map2 = shared_map1, shared_map2
slice_res = _remap_hwcn(image[z], map1, map2)
result[z] = slice_res
return result