Source code for biapy.engine.denoising

"""
Denoising workflow and utilities for BiaPy.

This module provides the Denoising_Workflow class for training and inference on image denoising tasks,
as well as utility functions for patch manipulation, stratified coordinate sampling, and structN2V masking.
It supports both 2D and 3D data, and includes implementations of various pixel manipulation strategies
used in self-supervised denoising approaches such as Noise2Void (N2V).
"""
import math
import torch
import numpy as np
import numpy.ma as ma
from tqdm import tqdm
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
from typing import Tuple, Callable, Dict, Optional
from numpy.typing import NDArray

from biapy.data.data_2D_manipulation import (
    crop_data_with_overlap,
    merge_data_with_overlap,
)
from biapy.data.data_3D_manipulation import (
    crop_3D_data_with_overlap,
    merge_3D_data_with_overlap,
)
from biapy.engine.base_workflow import Base_Workflow
from biapy.data.data_manipulation import save_tif
from biapy.utils.misc import to_pytorch_format, is_main_process, MetricLogger
from biapy.engine.metrics import n2v_loss_mse, loss_encapsulation
from biapy.data.norm import undo_image_norm

[docs] class Denoising_Workflow(Base_Workflow): """ Denoising workflow where the goal is to remove noise from an image. More details in `our documentation <https://biapy.readthedocs.io/en/latest/workflows/denoising.html>`_. Parameters ---------- cfg : YACS configuration Running configuration. job_identifier : str Complete name of the running job. device : torch.device Device used. args : argparse.Namespace Arguments used in BiaPy's call. """ def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs): """ Initialize the Denoising_Workflow. Sets up configuration, device, job identifier, and initializes workflow-specific attributes for denoising tasks. Parameters ---------- cfg : YACS configuration Running configuration. job_identifier : str Complete name of the running job. device : torch.device Device used. args : argparse.Namespace Arguments used in BiaPy's call. **kwargs : dict Additional keyword arguments. """ super(Denoising_Workflow, self).__init__(cfg, job_identifier, device, system_dict, args, **kwargs) # From now on, no modification of the cfg will be allowed self.cfg.freeze() # Workflow specific training variables self.mask_path = cfg.DATA.TRAIN.GT_PATH if cfg.PROBLEM.DENOISING.LOAD_GT_DATA else None self.is_y_mask = False self.load_Y_val = cfg.PROBLEM.DENOISING.LOAD_GT_DATA self.norm_module["mask_norm"] = "as_image" self.test_norm_module["mask_norm"] = "as_image"
[docs] def define_activations_and_channels(self): """ Define the activations to be applied to the model output and the channels that the model will output. This function must define the following variables: self.model_output_channels : List of int Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels, [1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc. self.model_output_channel_info : List of str Information about the output channels. A value per output head of the model must be defined. self.separated_class_channel : bool Whether if we should expect a separated output channel for classification. self.head_activations : List of str Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined. "linear" and "ce_sigmoid" will not be applied. E.g. ["linear"] for a model with one channel, ["linear", "sigmoid"] for a model with two channels, etc. Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would be correct:: self.model_output_channels = [1, 3] self.model_output_channel_info = ["mask", "class"] self.separated_class_channel = True self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"] """ self.model_output_channels = [self.cfg.DATA.PATCH_SIZE[-1]] self.gt_channels_expected = self.model_output_channels[0] self.separated_class_channel = False self.head_activations = ["linear"] * self.model_output_channels[0] self.model_output_channel_info = ["pred{}".format(i) for i in range(len(self.model_output_channels))] super().define_activations_and_channels()
[docs] def define_metrics(self): """ Define the metrics to be used during training and test/inference. This function must define the following variables: self.train_metrics : List of functions Metrics to be calculated during model's training. self.train_metric_names : List of str Names of the metrics calculated during training. self.train_metric_best : List of str To know which value should be considered as the best one. Options must be: "max" or "min". self.test_metrics : List of functions Metrics to be calculated during model's test/inference. self.test_metric_names : List of str Names of the metrics calculated during test/inference. self.loss : Function Loss function used during training and test. """ self.train_metrics = [] self.train_metric_names = [] self.train_metric_best = [] for metric in list(set(self.cfg.TRAIN.METRICS)): if metric in ["mse"]: self.train_metrics.append( MeanSquaredError().to(self.device), ) self.train_metric_names.append("MSE") self.train_metric_best.append("min") elif metric == "mae": self.train_metrics.append( MeanAbsoluteError().to(self.device), ) self.train_metric_names.append("MAE") self.train_metric_best.append("min") self.test_metrics = [] self.test_metric_names = [] for metric in list(set(self.cfg.TEST.METRICS)): if metric in ["mse"]: self.test_metrics.append( MeanSquaredError().to(self.test_device), ) self.test_metric_names.append("MSE") elif metric == "mae": self.test_metrics.append( MeanAbsoluteError().to(self.test_device), ) self.test_metric_names.append("MAE") # print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)") if self.cfg.LOSS.TYPE == "MSE": self.loss = loss_encapsulation(n2v_loss_mse) super().define_metrics()
[docs] def metric_calculation( self, output: NDArray | torch.Tensor, targets: NDArray | torch.Tensor, train: bool = True, metric_logger: Optional[MetricLogger] = None, ) -> Dict: """ Execute the calculation of metrics defined in :func:`~define_metrics` function. Parameters ---------- output : Torch Tensor Prediction of the model. targets : Torch Tensor Ground truth to compare the prediction with. train : bool, optional Whether to calculate train or test metrics. metric_logger : MetricLogger, optional Class to be updated with the new metric(s) value(s) calculated. Returns ------- out_metrics : dict Value of the metrics for the given prediction. """ if isinstance(output, dict): output = output["pred"] if isinstance(output, np.ndarray): _output = to_pytorch_format( output.copy(), self.axes_order, self.device if train else self.test_device, dtype=self.loss_dtype, ) else: # torch.Tensor if not train: _output = output.clone() else: _output = output if isinstance(targets, np.ndarray): _targets = to_pytorch_format( targets.copy(), self.axes_order, self.device if train else self.test_device, dtype=self.loss_dtype, ) else: # torch.Tensor if not train: _targets = targets.clone() else: _targets = targets out_metrics = {} list_to_use = self.train_metrics if train else self.test_metrics list_names_to_use = self.train_metric_names if train else self.test_metric_names with torch.no_grad(): for i, metric in enumerate(list_to_use): val = metric(_output.contiguous(), _targets[:, _output.shape[1]:].contiguous()) val = val.item() if not torch.isnan(val) else 0 out_metrics[list_names_to_use[i]] = val if metric_logger: metric_logger.meters[list_names_to_use[i]].update(val) return out_metrics
[docs] def process_test_sample(self): """Process a sample in the test/inference phase.""" assert self.model is not None # Skip processing image if "discard" in self.current_sample and self.current_sample["discard"]: return True original_data_shape = self.current_sample["X"].shape # Crop if necessary if self.current_sample["X"].shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]: if self.cfg.PROBLEM.NDIM == "2D": self.current_sample["X"], _ = crop_data_with_overlap( # type: ignore self.current_sample["X"], self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, ) else: self.current_sample["X"], _ = crop_3D_data_with_overlap( # type: ignore self.current_sample["X"][0], self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING, ) pred = self.predict_batches_in_test(self.current_sample["X"], None) del self.current_sample["X"] # Reconstruct the predictions if original_data_shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]: if self.cfg.PROBLEM.NDIM == "3D": original_data_shape = original_data_shape[1:] f_name = merge_data_with_overlap if self.cfg.PROBLEM.NDIM == "2D" else merge_3D_data_with_overlap if self.cfg.TEST.REDUCE_MEMORY: pred = f_name( pred, original_data_shape[:-1] + (pred.shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE, ) else: obj = f_name( pred, original_data_shape[:-1] + (pred.shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE, ) pred = obj del obj if self.cfg.PROBLEM.NDIM == "3D": assert isinstance(pred, np.ndarray) pred = np.expand_dims(pred, 0) if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE: reflected_orig_shape = (1,) + self.current_sample["reflected_orig_shape"] if reflected_orig_shape != pred.shape: if self.cfg.PROBLEM.NDIM == "2D": pred = pred[:, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :] # type: ignore else: pred = pred[ :, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :, -reflected_orig_shape[3] :, ] # type: ignore # Undo normalization pred = undo_image_norm(pred, self.current_sample["X_norm"]) assert isinstance(pred, np.ndarray) # Save image if self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "" and self.cfg.TEST.SAVE_MODEL_RAW_OUTPUT: assert isinstance(pred, np.ndarray) save_tif( pred, self.cfg.PATHS.RESULT_DIR.PER_IMAGE, [self.current_sample["X_filename"]], verbose=self.cfg.TEST.VERBOSE, )
[docs] def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None: """ Call a regular Pytorch model. Parameters ---------- in_img : torch.Tensor Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. Returns ------- prediction : torch.Tensor Image prediction. """ pass
[docs] def after_merge_patches(self, pred: torch.Tensor): """ Execute steps needed after merging all predicted patches into the original image. Parameters ---------- pred : Torch Tensor Model prediction. """ pass
[docs] def after_full_image(self, pred: NDArray): """ Execute steps needed after generating the prediction by supplying the entire image to the model. Parameters ---------- pred : NDArray Model prediction. """ pass
[docs] def after_all_images(self): """Excute steps that must be done after predicting all images.""" super().after_all_images()
[docs] def after_all_chunk_prediction_workflow_process(self): """ Place any code that needs to be done after predicting all patches in "by chunks" setting. This function is called on all ranks. """ pass
[docs] def after_all_chunk_prediction_workflow_process_master_rank(self): """ Place any code that needs to be done after predicting all patches in "by chunks" setting, but only on the master rank. This function is called only on the master rank. """ pass
#################################### # Adapted from N2V code: # # https://github.com/juglab/n2v # ####################################
[docs] def get_subpatch(patch, coord, local_sub_patch_radius, crop_patch=True): """ Extract a subpatch centered at a given coordinate, handling border cropping. Parameters ---------- patch : np.ndarray Input patch. coord : tuple of int Center coordinate for the subpatch. local_sub_patch_radius : int Radius of the subpatch to extract. crop_patch : bool, optional Whether to crop the patch at the borders (default: True). Returns ------- subpatch : np.ndarray Extracted subpatch. crop_neg : int Negative crop offset. crop_pos : int Positive crop offset. """ crop_neg, crop_pos = 0, 0 if crop_patch: start = np.array(coord) - local_sub_patch_radius end = start + local_sub_patch_radius * 2 + 1 # compute offsets left/up ... crop_neg = np.minimum(start, 0) # and right/down crop_pos = np.maximum(0, end - patch.shape) # correct for offsets, patch size shrinks if crop_*!=0 start -= crop_neg end -= crop_pos else: start = np.maximum(0, np.array(coord) - local_sub_patch_radius) end = start + local_sub_patch_radius * 2 + 1 shift = np.minimum(0, patch.shape - end) start += shift end += shift slices = [slice(s, e) for s, e in zip(start, end)] # return crop vectors for deriving correct center pixel locations later return patch[tuple(slices)], crop_neg, crop_pos
[docs] def random_neighbor(shape, coord): """ Sample a random neighbor coordinate different from the given coordinate. Parameters ---------- shape : tuple of int Shape of the patch. coord : tuple of int Center coordinate. Returns ------- rand_coords : list of int Random neighbor coordinate. """ rand_coords = sample_coords(shape, coord) while np.any(rand_coords == coord): rand_coords = sample_coords(shape, coord) return rand_coords
[docs] def sample_coords(shape, coord, sigma=4): """ Sample random coordinates from a normal distribution centered at coord. Parameters ---------- shape : tuple of int Shape of the patch. coord : tuple of int Center coordinate. sigma : float, optional Standard deviation for the normal distribution (default: 4). Returns ------- coords : list of int Sampled coordinates. """ return [normal_int(c, sigma, s) for c, s in zip(coord, shape)]
[docs] def normal_int(mean, sigma, w): """ Sample an integer from a normal distribution and clip to valid range. Parameters ---------- mean : float Mean of the normal distribution. sigma : float Standard deviation. w : int Maximum allowed value (exclusive). Returns ------- int Sampled and clipped integer. """ return int(np.clip(np.round(np.random.normal(mean, sigma)), 0, w - 1))
[docs] def mask_center(local_sub_patch_radius, ndims=2): """ Create a mask with the center pixel set to zero. Parameters ---------- local_sub_patch_radius : int Radius of the patch. ndims : int, optional Number of dimensions (default: 2). Returns ------- mask : np.ndarray Boolean mask with center pixel set to zero. """ size = local_sub_patch_radius * 2 + 1 patch_wo_center = np.ones((size,) * ndims) if ndims == 2: patch_wo_center[local_sub_patch_radius, local_sub_patch_radius] = 0 elif ndims == 3: patch_wo_center[local_sub_patch_radius, local_sub_patch_radius, local_sub_patch_radius] = 0 else: raise NotImplementedError() return ma.make_mask(patch_wo_center)
[docs] def pm_normal_withoutCP(local_sub_patch_radius): """ Return a function that samples a random neighbor from a normal distribution (without center pixel). Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns values from random neighbors. """ def normal_withoutCP(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): rand_coords = random_neighbor(patch.shape, coord) vals.append(patch[tuple(rand_coords)]) return vals return normal_withoutCP
[docs] def pm_mean(local_sub_patch_radius): """ Return a function that computes the mean of the local neighborhood (excluding center pixel). Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns mean values. """ def patch_mean(patch, coords, dims, structN2Vmask=None): patch_wo_center = mask_center(local_sub_patch_radius, ndims=dims) vals = [] for coord in zip(*coords): sub_patch, crop_neg, crop_pos = get_subpatch(patch, coord, local_sub_patch_radius) slices = [slice(-n, s - p) for n, p, s in zip(crop_neg, crop_pos, patch_wo_center.shape)] # type: ignore sub_patch_mask = (structN2Vmask or patch_wo_center)[tuple(slices)] vals.append(np.mean(sub_patch[sub_patch_mask])) return vals return patch_mean
[docs] def pm_median(local_sub_patch_radius): """ Return a function that computes the median of the local neighborhood (excluding center pixel). Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns median values. """ def patch_median(patch, coords, dims, structN2Vmask=None): patch_wo_center = mask_center(local_sub_patch_radius, ndims=dims) vals = [] for coord in zip(*coords): sub_patch, crop_neg, crop_pos = get_subpatch(patch, coord, local_sub_patch_radius) slices = [slice(-n, s - p) for n, p, s in zip(crop_neg, crop_pos, patch_wo_center.shape)] # type: ignore sub_patch_mask = (structN2Vmask or patch_wo_center)[tuple(slices)] vals.append(np.median(sub_patch[sub_patch_mask])) return vals return patch_median
[docs] def pm_uniform_withCP(local_sub_patch_radius): """ Return a function that samples a random value from the local neighborhood (including center pixel). Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns random values. """ def random_neighbor_withCP_uniform(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): sub_patch, _, _ = get_subpatch(patch, coord, local_sub_patch_radius) rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]] vals.append(sub_patch[tuple(rand_coords)]) return vals return random_neighbor_withCP_uniform
[docs] def pm_uniform_withoutCP(local_sub_patch_radius): """ Return a function that samples a random value from the local neighborhood (excluding center pixel). Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns random values. """ def random_neighbor_withoutCP_uniform(patch, coords, dims, structN2Vmask=None): patch_wo_center = mask_center(local_sub_patch_radius, ndims=dims) vals = [] for coord in zip(*coords): sub_patch, crop_neg, crop_pos = get_subpatch(patch, coord, local_sub_patch_radius) slices = [slice(-n, s - p) for n, p, s in zip(crop_neg, crop_pos, patch_wo_center.shape)] # type: ignore sub_patch_mask = (structN2Vmask or patch_wo_center)[tuple(slices)] vals.append(np.random.permutation(sub_patch[sub_patch_mask])[0]) return vals return random_neighbor_withoutCP_uniform
[docs] def pm_normal_additive(pixel_gauss_sigma): """ Return a function that adds Gaussian noise to the center pixel. Parameters ---------- pixel_gauss_sigma : float Standard deviation of the Gaussian noise. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns noisy values. """ def pixel_gauss(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): vals.append(np.random.normal(patch[tuple(coord)], pixel_gauss_sigma)) return vals return pixel_gauss
[docs] def pm_normal_fitted(local_sub_patch_radius): """ Return a function that samples from a Gaussian fitted to the local neighborhood. Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch. Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns sampled values. """ def local_gaussian(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): sub_patch, _, _ = get_subpatch(patch, coord, local_sub_patch_radius) axis = tuple(range(dims)) vals.append(np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis))) return vals return local_gaussian
[docs] def pm_identity(local_sub_patch_radius): """ Return a function that simply returns the center pixel value (identity). Parameters ---------- local_sub_patch_radius : int Radius of the local subpatch (unused). Returns ------- Callable Function that takes (patch, coords, dims, structN2Vmask) and returns the center pixel value. """ def identity(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): vals.append(patch[coord]) return vals return identity
[docs] def get_stratified_coords2D(box_size, shape): """ Generate stratified random coordinates for 2D patches. Parameters ---------- box_size : int Size of the box for stratification. shape : tuple of int Shape of the 2D image. Returns ------- tuple of lists (y_coords, x_coords) for sampled points. """ box_count_Y = int(np.ceil(shape[0] / box_size)) box_count_X = int(np.ceil(shape[1] / box_size)) x_coords = [] y_coords = [] for i in range(box_count_Y): for j in range(box_count_X): y, x = np.random.rand() * box_size, np.random.rand() * box_size y = int(i * box_size + y) x = int(j * box_size + x) if y < shape[0] and x < shape[1]: y_coords.append(y) x_coords.append(x) return (y_coords, x_coords)
[docs] def get_stratified_coords3D(box_size, shape): """ Generate stratified random coordinates for 3D patches. Parameters ---------- box_size : int Size of the box for stratification. shape : tuple of int Shape of the 3D image. Returns ------- tuple of lists (z_coords, y_coords, x_coords) for sampled points. """ box_count_z = int(np.ceil(shape[0] / box_size)) box_count_Y = int(np.ceil(shape[1] / box_size)) box_count_X = int(np.ceil(shape[2] / box_size)) x_coords = [] y_coords = [] z_coords = [] for i in range(box_count_z): for j in range(box_count_Y): for k in range(box_count_X): z, y, x = ( np.random.rand() * box_size, np.random.rand() * box_size, np.random.rand() * box_size, ) z = int(i * box_size + z) y = int(j * box_size + y) x = int(k * box_size + x) if z < shape[0] and y < shape[1] and x < shape[2]: z_coords.append(z) y_coords.append(y) x_coords.append(x) return (z_coords, y_coords, x_coords)
[docs] def apply_structN2Vmask(patch, coords, mask): """ Apply a structN2V mask to a 2D patch. Each point in coords corresponds to the center of the mask. For each point in the mask with value=1, assign a random value. Parameters ---------- patch : np.ndarray Input patch to modify. coords : np.ndarray or list Coordinates of mask centers. mask : np.ndarray Binary mask to apply. """ coords = np.array(coords, dtype=int) ndim = mask.ndim center = np.array(mask.shape) // 2 ## leave the center value alone mask[tuple(center.T)] = 0 ## displacements from center dx = np.indices(mask.shape)[:, mask == 1] - center[:, None] ## combine all coords (ndim, npts,) with all displacements (ncoords,ndim,) mix = dx.T[..., None] + coords[None] mix = mix.transpose([1, 0, 2]).reshape([ndim, -1]).T ## stay within patch boundary mix = mix.clip(min=np.zeros(ndim), max=np.array(patch.shape) - 1).astype(np.uint) ## replace neighbouring pixels with random values from flat dist patch[tuple(mix.T)] = np.random.rand(mix.shape[0]) * 4 - 2
[docs] def apply_structN2Vmask3D(patch, coords, mask): """ Apply a structN2V mask to a 3D patch. Each point in coords corresponds to the center of the mask. For each point in the mask with value=1, assign a random value. Parameters ---------- patch : np.ndarray Input 3D patch to modify. coords : np.ndarray or list Coordinates of mask centers (z, y, x). mask : np.ndarray Binary mask to apply. """ z_coords = coords[0] coords = coords[1:] for z in z_coords: coords = np.array(coords, dtype=int) ndim = mask.ndim center = np.array(mask.shape) // 2 ## leave the center value alone mask[tuple(center.T)] = 0 ## displacements from center dx = np.indices(mask.shape)[:, mask == 1] - center[:, None] ## combine all coords (ndim, npts,) with all displacements (ncoords,ndim,) mix = dx.T[..., None] + coords[None] mix = mix.transpose([1, 0, 2]).reshape([ndim, -1]).T ## stay within patch boundary mix = mix.clip(min=np.zeros(ndim), max=np.array(patch.shape[1:]) - 1).astype(np.uint) ## replace neighbouring pixels with random values from flat dist patch[z][tuple(mix.T)] = np.random.rand(mix.shape[0]) * 4 - 2
[docs] def manipulate_val_data( X_val: NDArray, Y_val: NDArray, perc_pix: float = 0.198, shape: Tuple[int, ...] = (64, 64), value_manipulation: Callable = pm_uniform_withCP(5), ): """ Manipulate validation data for self-supervised denoising. Applies a value manipulation strategy (e.g., uniform, mean, median) to a percentage of pixels in the validation set, as used in Noise2Void/structN2V validation. Parameters ---------- X_val : NDArray Validation input data. Y_val : NDArray Validation target data (will be overwritten). perc_pix : float, optional Percentage of pixels to manipulate (default: 0.198). shape : tuple of int, optional Shape of the patch (default: (64, 64)). value_manipulation : Callable, optional Function to manipulate pixel values (default: pm_uniform_withCP(5)). """ dims = len(shape) if dims == 2: box_size = np.round(np.sqrt(100 / perc_pix), dtype=int) # type: ignore get_stratified_coords = get_stratified_coords2D elif dims == 3: box_size = np.round(np.sqrt(100 / perc_pix), dtype=int) # type: ignore get_stratified_coords = get_stratified_coords3D n_chan = X_val.shape[-1] Y_val *= 0 for j in tqdm( range(X_val.shape[0]), desc="Preparing validation data: ", disable=not is_main_process(), ): coords = get_stratified_coords(box_size=box_size, shape=np.array(X_val.shape)[1:-1]) for c in range(n_chan): indexing = (j,) + coords + (c,) indexing_mask = (j,) + coords + (c + n_chan,) y_val = X_val[indexing] x_val = value_manipulation(X_val[j, ..., c], coords, dims) Y_val[indexing] = y_val Y_val[indexing_mask] = 1 X_val[indexing] = x_val
[docs] def get_value_manipulation(n2v_manipulator, n2v_neighborhood_radius): """ Return a value manipulation function for N2V/structN2V based on the given strategy. Parameters ---------- n2v_manipulator : str Name of the manipulation strategy (e.g., 'uniform_withCP'). n2v_neighborhood_radius : int Neighborhood radius for the manipulation. Returns ------- Callable Value manipulation function. """ return eval("pm_{0}({1})".format(n2v_manipulator, str(n2v_neighborhood_radius)))