Source code for biapy.engine.denoising

import math
import torch
import numpy as np
import numpy.ma as ma
from tqdm import tqdm

from biapy.data.data_2D_manipulation import crop_data_with_overlap, merge_data_with_overlap
from biapy.data.data_3D_manipulation import crop_3D_data_with_overlap, merge_3D_data_with_overlap
from biapy.data.post_processing.post_processing import ensemble8_2d_predictions, ensemble16_3d_predictions
from biapy.engine.base_workflow import Base_Workflow
from biapy.utils.util import save_tif, pad_and_reflect
from biapy.utils.misc import to_pytorch_format, to_numpy_format, is_main_process
from biapy.data.pre_processing import denormalize, undo_norm_range01
from biapy.engine.metrics import n2v_loss_mse

[docs]class Denoising_Workflow(Base_Workflow): """ Denoising workflow where the goal is to remove noise from an image. More details in `our documentation <https://biapy.readthedocs.io/en/latest/workflows/denoising.html>`_. Parameters ---------- cfg : YACS configuration Running configuration. Job_identifier : str Complete name of the running job. device : Torch device Device used. args : argpase class Arguments used in BiaPy's call. """ def __init__(self, cfg, job_identifier, device, args, **kwargs): super(Denoising_Workflow, self).__init__(cfg, job_identifier, device, args, **kwargs) # From now on, no modification of the cfg will be allowed self.cfg.freeze() # Activations for each output channel: # channel number : 'activation' self.activations = [{':': 'Linear'}] # Workflow specific training variables self.mask_path = None self.load_Y_val = False
[docs] def define_metrics(self): """ Definition of self.metrics, self.metric_names and self.loss variables. """ print("Overriding 'LOSS.TYPE' to set it to N2V loss (masked MSE)") self.metrics = [torch.nn.MSELoss()] self.metric_names = ["MSE"] self.loss = n2v_loss_mse
[docs] def metric_calculation(self, output, targets, metric_logger=None): """ Execution of the metrics defined in :func:`~define_metrics` function. Parameters ---------- output : Torch Tensor Prediction of the model. targets : Torch Tensor Ground truth to compare the prediction with. metric_logger : MetricLogger, optional Class to be updated with the new metric(s) value(s) calculated. Returns ------- value : float Value of the metric for the given prediction. """ with torch.no_grad(): train_mse = self.metrics[0](output.squeeze(), targets[:,0].squeeze()) train_mse = train_mse.item() if not torch.isnan(train_mse) else 0 if metric_logger is not None: metric_logger.meters[self.metric_names[0]].update(train_mse) else: return train_mse
[docs] def process_sample(self, norm): """ Function to process a sample in the inference phase. Parameters ---------- norm : List of dicts Normalization used during training. Required to denormalize the predictions of the model. """ # Reflect data to complete the needed shape if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE: reflected_orig_shape = self._X.shape self._X = np.expand_dims(pad_and_reflect(self._X[0], self.cfg.DATA.PATCH_SIZE, verbose=self.cfg.TEST.VERBOSE),0) if self.cfg.DATA.TEST.LOAD_GT: self._Y = np.expand_dims(pad_and_reflect(self._Y[0], self.cfg.DATA.PATCH_SIZE, verbose=self.cfg.TEST.VERBOSE),0) original_data_shape = self._X.shape # Crop if necessary if self._X.shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]: if self.cfg.PROBLEM.NDIM == '2D': obj = crop_data_with_overlap(self._X, self.cfg.DATA.PATCH_SIZE, data_mask=self._Y, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE) if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: self._X, self._Y = obj else: self._X = obj del obj else: if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: self._Y = self._Y[0] if self.cfg.TEST.REDUCE_MEMORY: self._X = crop_3D_data_with_overlap(self._X[0], self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING) self._Y = crop_3D_data_with_overlap(self._Y, self.cfg.DATA.PATCH_SIZE, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING) else: obj = crop_3D_data_with_overlap(self._X[0], self.cfg.DATA.PATCH_SIZE, data_mask=self._Y, overlap=self.cfg.DATA.TEST.OVERLAP, padding=self.cfg.DATA.TEST.PADDING, verbose=self.cfg.TEST.VERBOSE, median_padding=self.cfg.DATA.TEST.MEDIAN_PADDING) if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: self._X, self._Y = obj else: self._X = obj del obj # Predict each patch if self.cfg.TEST.AUGMENTATION: for k in tqdm(range(self._X.shape[0]), leave=False, disable=not is_main_process()): if self.cfg.PROBLEM.NDIM == '2D': p = ensemble8_2d_predictions(self._X[k], axis_order_back=self.axis_order_back, pred_func=self.model_call_func, axis_order=self.axis_order, device=self.device) else: p = ensemble16_3d_predictions(self._X[k], batch_size_value=self.cfg.TRAIN.BATCH_SIZE, axis_order_back=self.axis_order_back, pred_func=self.model_call_func, axis_order=self.axis_order, device=self.device) p = self.apply_model_activations(p) p = to_numpy_format(p, self.axis_order_back) if 'pred' not in locals(): pred = np.zeros((self._X.shape[0],)+p.shape[1:], dtype=self.dtype) pred[k] = p else: self._X = to_pytorch_format(self._X, self.axis_order, self.device) l = int(math.ceil(self._X.shape[0]/self.cfg.TRAIN.BATCH_SIZE)) for k in tqdm(range(l), leave=False, disable=not is_main_process()): top = (k+1)*self.cfg.TRAIN.BATCH_SIZE if (k+1)*self.cfg.TRAIN.BATCH_SIZE < self._X.shape[0] else self._X.shape[0] with torch.cuda.amp.autocast(): p = self.model(self._X[k*self.cfg.TRAIN.BATCH_SIZE:top]) p = to_numpy_format(self.apply_model_activations(p), self.axis_order_back) if 'pred' not in locals(): pred = np.zeros((self._X.shape[0],)+p.shape[1:], dtype=self.dtype) pred[k*self.cfg.TRAIN.BATCH_SIZE:top] = p del self._X, p # Reconstruct the predictions if original_data_shape[1:-1] != self.cfg.DATA.PATCH_SIZE[:-1]: if self.cfg.PROBLEM.NDIM == '3D': original_data_shape = original_data_shape[1:] f_name = merge_data_with_overlap if self.cfg.PROBLEM.NDIM == '2D' else merge_3D_data_with_overlap if self.cfg.TEST.REDUCE_MEMORY: pred = f_name(pred, original_data_shape[:-1]+(pred.shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE) self._Y = f_name(self._Y, original_data_shape[:-1]+(self._Y.shape[-1],), padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE) else: obj = f_name(pred, original_data_shape[:-1]+(pred.shape[-1],), data_mask=self._Y, padding=self.cfg.DATA.TEST.PADDING, overlap=self.cfg.DATA.TEST.OVERLAP, verbose=self.cfg.TEST.VERBOSE) if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST: pred, self._Y = obj else: pred = obj del obj if self.cfg.PROBLEM.NDIM == '3D': pred = np.expand_dims(pred,0) if self._Y is not None: self._Y = np.expand_dims(self._Y,0) if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE: if self.cfg.PROBLEM.NDIM == '2D': pred = pred[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:] if self._Y is not None: self._Y = self._Y[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:] else: pred = pred[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:,-reflected_orig_shape[3]:] if self._Y is not None: self._Y = self._Y[:,-reflected_orig_shape[1]:,-reflected_orig_shape[2]:,-reflected_orig_shape[3]:] # Undo normalization x_norm = norm[0] if x_norm['type'] == 'div': pred = undo_norm_range01(pred, x_norm) else: pred = denormalize(pred, x_norm['mean'], x_norm['std']) if x_norm['orig_dtype'] not in [np.dtype('float64'), np.dtype('float32'), np.dtype('float16')]: pred = np.round(pred) minpred = np.min(pred) pred = pred+abs(minpred) pred = pred.astype(x_norm['orig_dtype']) # Save image if self.cfg.PATHS.RESULT_DIR.PER_IMAGE != "": save_tif(pred, self.cfg.PATHS.RESULT_DIR.PER_IMAGE, self.processing_filenames, verbose=self.cfg.TEST.VERBOSE)
[docs] def torchvision_model_call(self, in_img, is_train=False): """ Call a regular Pytorch model. Parameters ---------- in_img : Tensor Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. Returns ------- prediction : Tensor Image prediction. """ pass
[docs] def after_merge_patches(self, pred): """ Steps need to be done after merging all predicted patches into the original image. Parameters ---------- pred : Torch Tensor Model prediction. """ pass
[docs] def after_merge_patches_by_chunks_proccess_patch(self, filename): """ Place any code that needs to be done after merging all predicted patches into the original image but in the process made chunk by chunk. This function will operate patch by patch defined by ``DATA.PATCH_SIZE``. Parameters ---------- filename : List of str Filename of the predicted image H5/Zarr. """ pass
[docs] def after_full_image(self, pred): """ Steps that must be executed after generating the prediction by supplying the entire image to the model. Parameters ---------- pred : Torch Tensor Model prediction. """ pass
[docs] def after_all_images(self): """ Steps that must be done after predicting all images. """ super().after_all_images()
[docs] def normalize_stats(self, image_counter): """ Normalize statistics. Parameters ---------- image_counter : int Number of images to average the metrics. """ pass
[docs] def print_stats(self, image_counter): """ Print statistics. Parameters ---------- image_counter : int Number of images to call ``normalize_stats``. """ self.normalize_stats(image_counter)
#################################### # Adapted from N2V code: # # https://github.com/juglab/n2v # ####################################
[docs]def get_subpatch(patch, coord, local_sub_patch_radius, crop_patch=True): crop_neg, crop_pos = 0, 0 if crop_patch: start = np.array(coord) - local_sub_patch_radius end = start + local_sub_patch_radius * 2 + 1 # compute offsets left/up ... crop_neg = np.minimum(start, 0) # and right/down crop_pos = np.maximum(0, end-patch.shape) # correct for offsets, patch size shrinks if crop_*!=0 start -= crop_neg end -= crop_pos else: start = np.maximum(0, np.array(coord) - local_sub_patch_radius) end = start + local_sub_patch_radius * 2 + 1 shift = np.minimum(0, patch.shape - end) start += shift end += shift slices = [slice(s, e) for s, e in zip(start, end)] # return crop vectors for deriving correct center pixel locations later return patch[tuple(slices)], crop_neg, crop_pos
[docs]def random_neighbor(shape, coord): rand_coords = sample_coords(shape, coord) while np.any(rand_coords == coord): rand_coords = sample_coords(shape, coord) return rand_coords
[docs]def sample_coords(shape, coord, sigma=4): return [normal_int(c, sigma, s) for c, s in zip(coord, shape)]
[docs]def normal_int(mean, sigma, w): return int(np.clip(np.round(np.random.normal(mean, sigma)), 0, w - 1))
[docs]def mask_center(local_sub_patch_radius, ndims=2): size = local_sub_patch_radius*2 + 1 patch_wo_center = np.ones((size, ) * ndims) if ndims == 2: patch_wo_center[local_sub_patch_radius, local_sub_patch_radius] = 0 elif ndims == 3: patch_wo_center[local_sub_patch_radius, local_sub_patch_radius, local_sub_patch_radius] = 0 else: raise NotImplementedError() return ma.make_mask(patch_wo_center)
[docs]def pm_normal_withoutCP(local_sub_patch_radius): def normal_withoutCP(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): rand_coords = random_neighbor(patch.shape, coord) vals.append(patch[tuple(rand_coords)]) return vals return normal_withoutCP
[docs]def pm_mean(local_sub_patch_radius): def patch_mean(patch, coords, dims, structN2Vmask=None): patch_wo_center = mask_center(local_sub_patch_radius, ndims=dims) vals = [] for coord in zip(*coords): sub_patch, crop_neg, crop_pos = get_subpatch(patch, coord, local_sub_patch_radius) slices = [slice(-n, s-p) for n, p, s in zip(crop_neg, crop_pos, patch_wo_center.shape)] sub_patch_mask = (structN2Vmask or patch_wo_center)[tuple(slices)] vals.append(np.mean(sub_patch[sub_patch_mask])) return vals return patch_mean
[docs]def pm_median(local_sub_patch_radius): def patch_median(patch, coords, dims, structN2Vmask=None): patch_wo_center = mask_center(local_sub_patch_radius, ndims=dims) vals = [] for coord in zip(*coords): sub_patch, crop_neg, crop_pos = get_subpatch(patch, coord, local_sub_patch_radius) slices = [slice(-n, s-p) for n, p, s in zip(crop_neg, crop_pos, patch_wo_center.shape)] sub_patch_mask = (structN2Vmask or patch_wo_center)[tuple(slices)] vals.append(np.median(sub_patch[sub_patch_mask])) return vals return patch_median
[docs]def pm_uniform_withCP(local_sub_patch_radius): def random_neighbor_withCP_uniform(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): sub_patch, _, _ = get_subpatch(patch, coord, local_sub_patch_radius) rand_coords = [np.random.randint(0, s) for s in sub_patch.shape[0:dims]] vals.append(sub_patch[tuple(rand_coords)]) return vals return random_neighbor_withCP_uniform
[docs]def pm_uniform_withoutCP(local_sub_patch_radius): def random_neighbor_withoutCP_uniform(patch, coords, dims, structN2Vmask=None): patch_wo_center = mask_center(local_sub_patch_radius, ndims=dims) vals = [] for coord in zip(*coords): sub_patch, crop_neg, crop_pos = get_subpatch(patch, coord, local_sub_patch_radius) slices = [slice(-n, s-p) for n, p, s in zip(crop_neg, crop_pos, patch_wo_center.shape)] sub_patch_mask = (structN2Vmask or patch_wo_center)[tuple(slices)] vals.append(np.random.permutation(sub_patch[sub_patch_mask])[0]) return vals return random_neighbor_withoutCP_uniform
[docs]def pm_normal_additive(pixel_gauss_sigma): def pixel_gauss(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): vals.append(np.random.normal(patch[tuple(coord)], pixel_gauss_sigma)) return vals return pixel_gauss
[docs]def pm_normal_fitted(local_sub_patch_radius): def local_gaussian(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): sub_patch, _, _ = get_subpatch(patch, coord, local_sub_patch_radius) axis = tuple(range(dims)) vals.append(np.random.normal(np.mean(sub_patch, axis=axis), np.std(sub_patch, axis=axis))) return vals return local_gaussian
[docs]def pm_identity(local_sub_patch_radius): def identity(patch, coords, dims, structN2Vmask=None): vals = [] for coord in zip(*coords): vals.append(patch[coord]) return vals return identity
[docs]def get_stratified_coords2D(box_size, shape): box_count_Y = int(np.ceil(shape[0] / box_size)) box_count_X = int(np.ceil(shape[1] / box_size)) x_coords = [] y_coords = [] for i in range(box_count_Y): for j in range(box_count_X): y, x = np.random.rand() * box_size, np.random.rand() * box_size y = int(i * box_size + y) x = int(j * box_size + x) if (y < shape[0] and x < shape[1]): y_coords.append(y) x_coords.append(x) return (y_coords, x_coords)
[docs]def get_stratified_coords3D(box_size, shape): box_count_z = int(np.ceil(shape[0] / box_size)) box_count_Y = int(np.ceil(shape[1] / box_size)) box_count_X = int(np.ceil(shape[2] / box_size)) x_coords = [] y_coords = [] z_coords = [] for i in range(box_count_z): for j in range(box_count_Y): for k in range(box_count_X): z, y, x = np.random.rand() * box_size, np.random.rand() * box_size, np.random.rand() * box_size z = int(i * box_size + z) y = int(j * box_size + y) x = int(k * box_size + x) if (z < shape[0] and y < shape[1] and x < shape[2]): z_coords.append(z) y_coords.append(y) x_coords.append(x) return (z_coords, y_coords, x_coords)
[docs]def apply_structN2Vmask(patch, coords, mask): """ each point in coords corresponds to the center of the mask. then for point in the mask with value=1 we assign a random value """ coords = np.array(coords, dtype=int) ndim = mask.ndim center = np.array(mask.shape)//2 ## leave the center value alone mask[tuple(center.T)] = 0 ## displacements from center dx = np.indices(mask.shape)[:,mask==1] - center[:,None] ## combine all coords (ndim, npts,) with all displacements (ncoords,ndim,) mix = (dx.T[...,None] + coords[None]) mix = mix.transpose([1,0,2]).reshape([ndim,-1]).T ## stay within patch boundary mix = mix.clip(min=np.zeros(ndim),max=np.array(patch.shape)-1).astype(np.uint) ## replace neighbouring pixels with random values from flat dist patch[tuple(mix.T)] = np.random.rand(mix.shape[0])*4 - 2
[docs]def apply_structN2Vmask3D(patch, coords, mask): """ each point in coords corresponds to the center of the mask. then for point in the mask with value=1 we assign a random value """ z_coords = coords[0] coords = coords[1:] for z in z_coords: coords = np.array(coords, dtype=int) ndim = mask.ndim center = np.array(mask.shape)//2 ## leave the center value alone mask[tuple(center.T)] = 0 ## displacements from center dx = np.indices(mask.shape)[:,mask==1] - center[:,None] ## combine all coords (ndim, npts,) with all displacements (ncoords,ndim,) mix = (dx.T[...,None] + coords[None]) mix = mix.transpose([1,0,2]).reshape([ndim,-1]).T ## stay within patch boundary mix = mix.clip(min=np.zeros(ndim),max=np.array(patch.shape[1:])-1).astype(np.uint) ## replace neighbouring pixels with random values from flat dist patch[z][tuple(mix.T)] = np.random.rand(mix.shape[0])*4 - 2
[docs]def manipulate_val_data(X_val, Y_val, perc_pix=0.198, shape=(64, 64), value_manipulation=pm_uniform_withCP(5)): dims = len(shape) if dims == 2: box_size = np.round(np.sqrt(100 / perc_pix), dtype=int) get_stratified_coords = get_stratified_coords2D elif dims == 3: box_size = np.round(np.sqrt(100 / perc_pix), dtype=int) get_stratified_coords = get_stratified_coords3D n_chan = X_val.shape[-1] Y_val *= 0 for j in tqdm(range(X_val.shape[0]), desc='Preparing validation data: ', disable=not is_main_process()): coords = get_stratified_coords(box_size=box_size, shape=np.array(X_val.shape)[1:-1]) for c in range(n_chan): indexing = (j,) + coords + (c,) indexing_mask = (j,) + coords + (c + n_chan,) y_val = X_val[indexing] x_val = value_manipulation(X_val[j, ..., c], coords, dims) Y_val[indexing] = y_val Y_val[indexing_mask] = 1 X_val[indexing] = x_val
[docs]def get_value_manipulation(n2v_manipulator, n2v_neighborhood_radius): return eval('pm_{0}({1})'.format(n2v_manipulator, str(n2v_neighborhood_radius)))