Source code for biapy.data.data_manipulation

"""
Data Manipulation Module for BiaPy.

This module provides a collection of functions for loading, processing, and manipulating
biological image data for deep learning applications. It supports both 2D and 3D data
formats, including common file types like TIFF, HDF5, Zarr, and NumPy arrays.

Key Functionalities:

- Loading training, validation, and test data from various formats
- Data preprocessing and normalization
- Image cropping and patching with overlap
- Data filtering based on various properties
- Cross-validation and train-test splitting
- Data augmentation and shape manipulation
- Format conversion (e.g., to one-hot encoding)
- Data saving in multiple formats

The module supports:

- Both 2D and 3D image data
- Multiple input formats (TIFF, HDF5, Zarr, NumPy arrays)
- Classification and segmentation workflows
- Memory-efficient loading of large datasets
- Parallel processing capabilities
- Data validation and consistency checks

Main Classes and Functions:

- load_and_prepare_train_data(): Main function for loading training data
- load_and_prepare_test_data(): Function for loading test data
- load_and_prepare_cls_test_data(): For classification test data
- samples_from_image_list(): Creates dataset from image list
- samples_from_zarr(): Handles Zarr/HDF5 datasets
- filter_samples_by_properties(): Filters data based on conditions
- img_to_onehot_encoding(): Converts masks to one-hot format
- save_tif(), save_npy_files(): Data saving utilities

Typical Workflow:

1. Load data using one of the load_and_prepare_* functions
2. Apply preprocessing/normalization
3. Filter or augment data as needed
4. Use in training or save processed data
"""

import os
import h5py
import torch
import tifffile
import imageio
import numpy as np
from typing import (
    List,
    Tuple,
    Dict,
    Optional,
    Callable,
    Any,
)
from numpy.typing import NDArray
from yacs.config import CfgNode as CN
from tqdm import tqdm
from sklearn.model_selection import train_test_split, StratifiedKFold
import torch.nn.functional as F
from skimage.transform import resize as sk_resize
import nibabel as nib

from biapy.data.dataset import BiaPyDataset, DatasetFile, DataSample, PatchCoords
from biapy.data.norm import normalize_image, normalize_mask
from biapy.utils.misc import is_main_process, os_walk_clean
from biapy.data.data_2D_manipulation import crop_data_with_overlap, ensure_2d_shape
from biapy.data.data_3D_manipulation import (
    crop_3D_data_with_overlap,
    extract_3D_patch_with_overlap_and_padding_yield,
    order_dimensions,
    ensure_3d_shape,
    looks_like_hdf5,
)


[docs] def load_and_prepare_train_data( train_path: str, train_mask_path: str, train_in_memory: str, train_ov: Tuple[float, ...], train_padding: Tuple[int, ...], val_path: str, val_mask_path: str, val_in_memory: bool, val_ov: Tuple[float, ...], val_padding: Tuple[int, ...], norm_module: Dict, crop_shape: Tuple[int, ...], cross_val: bool = False, cross_val_nsplits: int = 5, cross_val_fold: int = 1, val_split: float = 0.1, seed: int = 0, shuffle_val: bool = True, train_preprocess_f: Optional[Callable] = None, train_preprocess_cfg: Optional[CN] = None, train_filter_props: List[List[str]] = [], train_filter_vals: List[List[float]] = [], train_filter_signs: List[List[str]] = [], val_preprocess_f: Optional[Callable] = None, val_preprocess_cfg: Optional[CN] = None, val_filter_props: List[List[str]] = [], val_filter_vals: List[List[float]] = [], val_filter_signs: List[List[str]] = [], filter_by_entire_image: bool = True, norm_before_filter: bool = False, random_crops_in_DA: bool = False, y_upscaling: Tuple[int, ...] = (1, 1), gt_channels_expected: int = 1, reflect_to_complete_shape: bool = False, convert_to_rgb: bool = False, is_y_mask: bool = False, is_3d: bool = False, train_zarr_data_information: Optional[Dict] = None, val_zarr_data_information: Optional[Dict] = None, multiple_raw_images: bool = False, save_filtered_images: bool = True, save_filtered_images_dir: Optional[str] = None, save_filtered_images_num: int = 3, ) -> Tuple[BiaPyDataset, BiaPyDataset, BiaPyDataset, BiaPyDataset]: """ Load training and validation data. Parameters ---------- train_path : str Path to the training data. train_mask_path : str Path to the training data masks. train_in_memory : str Whether the training data must be loaded in memory or not. train_ov : 2D/3D float tuple, optional Amount of minimum overlap on x and y dimensions for train data. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. train_padding : 2D/3D int tuple, optional Size of padding to be added on each axis to the train data. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. val_path : str Path to the validation data. val_mask_path : str Path to the validation data masks. val_in_memory : str Whether the validation data must be loaded in memory or not. val_ov : 2D/3D float tuple, optional Amount of minimum overlap on x and y dimensions for val data. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. val_padding : 2D/3D int tuple, optional Size of padding to be added on each axis to the val data. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. norm_module : Dict Information about the normalization. crop_shape : 3D/4D int tuple, optional Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. cross_val : bool, optional Whether to use cross validation or not. cross_val_nsplits : int, optional Number of folds for the cross validation. cross_val_fold : int, optional Number of the fold to be used as validation. val_split : float, optional % of the train data used as validation (value between ``0`` and ``1``). seed : int, optional Seed value. shuffle_val : bool, optional Take random training examples to create validation data. train_preprocess_f : function, optional The train preprocessing function, is necessary in case you want to apply any preprocessing. train_preprocess_cfg : dict, optional Configuration parameters for train preprocessing, is necessary in case you want to apply any preprocessing. train_filter_props : list of lists of str Filter conditions to be applied to the train data. The three variables, ``filter_props``, ``filter_vals`` and ``filter_vals`` will compose a list of conditions to remove the samples from the list. They are list of list of conditions. For instance, the conditions can be like this: ``[['A'], ['B','C']]``. Then, if the sample satisfies the first list of conditions, only 'A' in this first case (from ['A'] list), or satisfy 'B' and 'C' (from ['B','C'] list) it will be removed. In each sublist all the conditions must be satisfied. Available properties are: [``'foreground'``, ``'mean'``, ``'min'``, ``'max'``]. Each property descrition: * ``'foreground'`` is defined as the mask foreground percentage. * ``'mean'`` is defined as the mean value. * ``'min'`` is defined as the min value. * ``'max'`` is defined as the max value. * ``'diff'`` is defined as the difference between ground truth and raw images. Require ``y_dataset`` to be provided. * ``'diff_by_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between raw image max and min. * ``'target_mean'`` is defined as the mean intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_min'`` is defined as the min intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_max'`` is defined as the max intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'diff_by_target_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between ground truth image max and min. train_filter_vals : list of int/float Represent the values of the properties listed in ``train_filter_props`` that the images need to satisfy to not be dropped. train_filter_signs : list of list of str Signs to do the comparison for train data filtering. Options: [``'gt'``, ``'ge'``, ``'lt'``, ``'le'``] that corresponds to "greather than", e.g. ">", "greather equal", e.g. ">=", "less than", e.g. "<", and "less equal" e.g. "<=" comparisons. val_preprocess_f : function, optional The validation preprocessing function, is necessary in case you want to apply any preprocessing. val_preprocess_cfg : dict, optional Configuration parameters for validation preprocessing, is necessary in case you want to apply any preprocessing. val_filter_props : list of lists of str Filter conditions to be applied to the validation data. The three variables, ``filter_props``, ``filter_vals`` and ``filter_vals`` will compose a list of conditions to remove the images from the list. They are list of list of conditions. For instance, the conditions can be like this: ``[['A'], ['B','C']]``. Then, if the sample satisfies the first list of conditions, only 'A' in this first case (from ['A'] list), or satisfy 'B' and 'C' (from ['B','C'] list) it will be removed. In each sublist all the conditions must be satisfied. Available properties are: [``'foreground'``, ``'mean'``, ``'min'``, ``'max'``]. Each property descrition: * ``'foreground'`` is defined as the mask foreground percentage. * ``'mean'`` is defined as the mean value. * ``'min'`` is defined as the min value. * ``'max'`` is defined as the max value. * ``'diff'`` is defined as the difference between ground truth and raw images. Require ``y_dataset`` to be provided. * ``'diff_by_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between raw image max and min. * ``'target_mean'`` is defined as the mean intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_min'`` is defined as the min intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_max'`` is defined as the max intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'diff_by_target_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between ground truth image max and min. val_filter_vals : list of int/float Represent the values of the properties listed in ``val_filter_props`` that the images need to satisfy to not be dropped. val_filter_signs : list of list of str Signs to do the comparison for validation data filtering. Options: [``'gt'``, ``'ge'``, ``'lt'``, ``'le'``] that corresponds to "greather than", e.g. ">", "greather equal", e.g. ">=", "less than", e.g. "<", and "less equal" e.g. "<=" comparisons. filter_by_entire_image : bool, optional If filtering is done this will decide how the filtering will be done: * ``True``: apply filter image by image. * ``False``: apply filtering sample by sample. Each sample represents a patch within an image. norm_before_filter : bool, optional Whether to apply normalization before filtering. Be aware then that the values for filtering may change. random_crops_in_DA : bool, optional To advice the method that not preparation of the data must be done, as random subvolumes will be created on DA, and the whole volume will be used for that. y_upscaling : 2D/3D int tuple, optional Upscaling to be done when loading Y data. User for super-resolution workflow. gt_channels_expected : int, optional Expected number of channels in the GT. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. is_y_mask : bool, optional Whether the data are masks. It is used to control the preprocessing of the data. is_3d : bool, optional Whether if the expected images to read are 3D or not. train_zarr_data_information : dict, optional Additional information when using Zarr/H5 files for training. The following keys are expected: * ``"raw_path"``, str: path where the raw images reside within the zarr (used when ``multiple_data_within_zarr`` is ``True``). * ``"gt_path"``, str: path where the mask images reside within the zarr (used when ``multiple_data_within_zarr`` is ``True``). * ``"use_gt_path"``, bool: whether the GT that should be used or not. * ``"multiple_data_within_zarr"``, bool: whether if your input Zarr contains the raw images and labels together or not. * ``"input_img_axes"``, tuple of int: order of the axes of the images. * ``"input_mask_axes"``, tuple of int: order of the axes of the masks. val_zarr_data_information : dict, optional Additional information when using Zarr/H5 files for validation. Same keys as ``train_zarr_data_information`` are expected. multiple_raw_images : bool, optional When a folder of folders for each image is expected. In each of those subfolder different versions of the same image are placed. Visit the following tutorial for a real use case and a more detailed description: `Light My Cells <https://biapy.readthedocs.io/en/latest/tutorials/image-to-image/lightmycells.html>`_. This is used when ``PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER`` is selected. save_filtered_images : bool, optional Whether to save or not filtered images. save_filtered_images_dir : str, optional Directory to save filtered images. save_filtered_images_num : int, optional Number of filtered images to save. Only work when ``save_filtered_images`` is ``True``. Returns ------- X_train : BiaPyDataset Loaded train X dataset. Y_train : BiaPyDataset Loaded train Y dataset. X_val : list of dict Loaded validation X dataset. Y_val : list of dict Loaded validation Y dataset. """ train_shape_will_change = False if train_preprocess_f: if train_preprocess_cfg is None: raise ValueError("'train_preprocess_cfg' needs to be provided with 'train_preprocess_f'") if train_preprocess_cfg.RESIZE.ENABLE: train_shape_will_change = True val_shape_will_change = False if val_preprocess_f: if val_preprocess_cfg is None: raise ValueError("'val_preprocess_cfg' needs to be provided with 'val_preprocess_f'") if val_preprocess_cfg.RESIZE.ENABLE: val_shape_will_change = True print("### LOAD ###") # Disable crops when random_crops_in_DA is selected crop = False if random_crops_in_DA else True # Check validation if val_split > 0 or cross_val: create_val_from_train = True else: create_val_from_train = False X_train, Y_train, X_val, Y_val = None, None, None, None # Create X_train and Y_train train_using_zarr = False if not multiple_raw_images: ids = next(os_walk_clean(train_path))[2] fids = next(os_walk_clean(train_path))[1] print("Gathering raw images for training data . . .") if len(ids) == 0 or (len(ids) > 0 and looks_like_hdf5(ids[0])): # Zarr if len(ids) == 0 and len(fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(train_path)) # Working with Zarr if not is_3d: raise ValueError("Please check you data as only folders where found. In this case BiaPy expects Zarr " "format images, but using these is only available for 3D problems") train_using_zarr = True assert train_zarr_data_information X_train = samples_from_zarr( list_of_data=fids if len(ids) == 0 else ids, data_path=train_path, zarr_data_info=train_zarr_data_information, crop_shape=crop_shape, ov=train_ov, padding=train_padding, is_mask=False, is_3d=is_3d, ) else: X_train = samples_from_image_list( list_of_data=ids, data_path=train_path, crop=crop, crop_shape=crop_shape, ov=train_ov, padding=train_padding, norm_module=norm_module, is_mask=False, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_f=train_preprocess_f if train_shape_will_change else None, preprocess_cfg=train_preprocess_cfg if train_shape_will_change else None, ) # Extract a list of all training gt images if train_mask_path: print("Gathering labels for training data . . .") ids = next(os_walk_clean(train_mask_path))[2] fids = next(os_walk_clean(train_mask_path))[1] if len(ids) == 0 or (len(ids) > 0 and looks_like_hdf5(ids[0])): # Zarr if len(ids) == 0 and len(fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(train_mask_path)) assert train_zarr_data_information Y_train = samples_from_zarr( list_of_data=fids if len(ids) == 0 else ids, data_path=train_mask_path, zarr_data_info=train_zarr_data_information, crop_shape=crop_shape, ov=train_ov, padding=train_padding, is_mask=True, is_3d=is_3d, ) else: # Calculate shape with upsampling if is_3d: assert len(crop_shape) == 4 assert len(y_upscaling) == 3 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], crop_shape[2] * y_upscaling[2], crop_shape[3], ) else: real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], crop_shape[2], ) Y_train = samples_from_image_list( list_of_data=ids, data_path=train_mask_path, crop=crop, crop_shape=real_shape, ov=train_ov, padding=train_padding, norm_module=norm_module, is_mask=True, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_f=train_preprocess_f if train_shape_will_change else None, preprocess_cfg=train_preprocess_cfg if train_shape_will_change else None, ) else: if train_mask_path is None: raise ValueError("Implementation error. Contact BiaPy team") print("Gathering raw and label images train information . . .") X_train, Y_train = samples_from_image_list_multiple_raw_one_gt( data_path=train_path, gt_path=train_mask_path, crop_shape=crop_shape, ov=train_ov, padding=train_padding, norm_module=norm_module, crop=crop, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_f=train_preprocess_f if train_shape_will_change else None, preprocess_cfg=train_preprocess_cfg if train_shape_will_change else None, ) # Check that the shape of all images match if train_mask_path and Y_train: print("Checking training raw and label images' shapes . . .") if not multiple_raw_images and len(X_train.sample_list) != len(Y_train.sample_list): mistmatch_message = shape_mismatch_message(X_train, Y_train) m = ( "Mistmatch between number of raw samples ({}) and number of corresponding masks ({}). Please, check that the raw" "format and labels have same shape. {}".format( len(X_train.sample_list), len(Y_train.sample_list), mistmatch_message ) ) raise ValueError(m) for i in range(len(X_train.sample_list)): xshape = X_train.sample_list[i].get_shape() gt_associated_id = X_train.sample_list[i].get_gt_associated_id() if gt_associated_id is not None: yshape = Y_train.sample_list[gt_associated_id].get_shape() else: yshape = Y_train.sample_list[i].get_shape() if not xshape: xshape = crop_shape[:-1] if not yshape: yshape = crop_shape[:-1] if is_3d: assert len(y_upscaling) == 3 and len(xshape) == 3 upsampled_x_shape = ( xshape[0] * y_upscaling[0], xshape[1] * y_upscaling[1], xshape[2] * y_upscaling[2], ) else: assert len(y_upscaling) == 2 and len(xshape) == 2 upsampled_x_shape = ( xshape[0] * y_upscaling[0], xshape[1] * y_upscaling[1], ) if upsampled_x_shape != yshape[: len(upsampled_x_shape)]: filepath = X_train.dataset_info[X_train.sample_list[i].fid] raise ValueError( f"There is a mismatch between input image and its corresponding ground truth ({upsampled_x_shape} vs " f"{yshape}). Please check the images. Specifically, the sample that doesn't match is within " f"the file: {filepath})" ) if len(train_filter_props) > 0: save_example_dir = None if save_filtered_images and save_filtered_images_dir: save_example_dir = os.path.join(save_filtered_images_dir, "train") filter_samples_by_properties( X_train, is_3d, train_filter_props, train_filter_vals, train_filter_signs, y_dataset=Y_train, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, filter_by_entire_image=filter_by_entire_image if not random_crops_in_DA else True, norm_before_filter=norm_before_filter, norm_module=norm_module, zarr_data_information=train_zarr_data_information if train_using_zarr else None, save_filtered_images=save_filtered_images, save_filtered_images_dir=save_example_dir, save_filtered_images_num=save_filtered_images_num, ) val_using_zarr = False if create_val_from_train: print("Creating validation data from train . . .") # Create IDs based on images or samples, depending if we are working with Zarr images or not. This is required to # create the validation data x_train_files = [x.path for x in X_train.dataset_info] if len(x_train_files) == 1: print( "As only one sample was found BiaPy will assume that it is big enough to hold multiple training samples " "so the validation will be created extracting samples from it too." ) if train_using_zarr or len(x_train_files) == 1: x_train_ids = np.array(range(0, len(X_train.sample_list))) if train_mask_path and Y_train: y_train_ids = np.array(range(0, len(Y_train.sample_list))) if not multiple_raw_images and len(x_train_ids) != len(y_train_ids): raise ValueError( f"Raw image number ({len(x_train_ids)}) and ground truth file mismatch ({len(y_train_ids)}). Please check the data!" ) clean_by = "sample" else: x_train_files.sort() x_train_ids = np.array(range(0, len(x_train_files))) if train_mask_path and Y_train: y_train_ids = np.array(range(0, len(Y_train.dataset_info))) if not multiple_raw_images and len(x_train_ids) != len(y_train_ids): raise ValueError( f"Raw image number ({len(x_train_ids)}) and ground truth file mismatch ({len(y_train_ids)}). Please check the data!" ) clean_by = "image" val_path = train_path val_mask_path = train_mask_path val_zarr_data_information = train_zarr_data_information val_using_zarr = train_using_zarr if not cross_val: if train_mask_path: x_train_ids, x_val_ids, y_train_ids, y_val_ids = train_test_split( x_train_ids, y_train_ids, test_size=val_split, shuffle=shuffle_val, random_state=seed, ) else: x_train_ids, x_val_ids = train_test_split( x_train_ids, test_size=val_split, shuffle=shuffle_val, random_state=seed ) else: skf = StratifiedKFold(n_splits=cross_val_nsplits, shuffle=shuffle_val, random_state=seed) fold = 1 y_len = len(y_train_ids) if train_mask_path else len(x_train_ids) for t_index, te_index in skf.split(np.zeros(len(x_train_ids)), np.zeros(y_len)): if cross_val_fold == fold: x_train_ids, x_val_ids = x_train_ids[t_index], x_train_ids[te_index] if train_mask_path: y_train_ids, y_val_ids = y_train_ids[t_index], y_train_ids[te_index] train_index, test_index = t_index.copy(), te_index.copy() break fold += 1 if len(test_index) > 5: print("Fold number {}. Printing the first 5 ids: {}".format(fold, test_index[:5])) else: print("Fold number {}. Indexes used in cross validation: {}".format(fold, test_index)) x_val_ids = test_index.copy() # It's important to sort them in order to speed up load_images_to_dataset() process x_val_ids.sort() x_train_ids.sort() # Create validation data from train. X_val = X_train.copy() X_val.clean_dataset(x_val_ids, clean_by=clean_by) if Y_train: Y_val = Y_train.copy() Y_val.clean_dataset(x_val_ids, clean_by=clean_by) # Remove val samples from train. X_train.clean_dataset(x_train_ids, clean_by=clean_by) if Y_train: Y_train.clean_dataset(x_train_ids, clean_by=clean_by) if clean_by == "sample": print( "Raw samples chosen for training (first 10 only): {}".format(str(x_train_ids[:10]).replace("]", "...]")) ) print( "Raw samples chosen for validation (first 10 only): {}".format( str(x_val_ids[:10]).replace("]", " ...]") ) ) else: print("Raw images chosen for training: {}".format([x.path for x in X_train.dataset_info])) print("Raw images chosen for validation: {}".format([x.path for x in X_val.dataset_info])) else: if not multiple_raw_images: print("Gathering raw images for validation data . . .") # Extract a list of all validation images val_ids = next(os_walk_clean(val_path))[2] val_fids = next(os_walk_clean(val_path))[1] if len(val_ids) == 0: if len(val_fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(val_path)) # Working with Zarr if not is_3d: raise ValueError("Zarr image handle is only available for 3D problems") val_using_zarr = True assert val_zarr_data_information X_val = samples_from_zarr( list_of_data=val_fids, data_path=val_path, zarr_data_info=val_zarr_data_information, crop_shape=crop_shape, ov=val_ov, padding=val_padding, is_mask=False, is_3d=is_3d, ) else: X_val = samples_from_image_list( list_of_data=val_ids, data_path=val_path, crop=crop, crop_shape=crop_shape, ov=val_ov, padding=val_padding, norm_module=norm_module, is_mask=False, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_f=val_preprocess_f if val_shape_will_change else None, preprocess_cfg=val_preprocess_cfg if val_shape_will_change else None, ) # Extract a list of all validation gt images if val_mask_path: print("Gathering labels for validation data . . .") val_gt_ids = next(os_walk_clean(val_mask_path))[2] val_gt_fids = next(os_walk_clean(val_mask_path))[1] if len(val_gt_ids) == 0: if len(val_gt_fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(val_mask_path)) # Working with Zarr if not is_3d: raise ValueError("Zarr image handle is only available for 3D problems") assert val_zarr_data_information Y_val = samples_from_zarr( list_of_data=val_fids, data_path=val_mask_path, zarr_data_info=val_zarr_data_information, crop_shape=crop_shape, ov=val_ov, padding=val_padding, is_mask=True, is_3d=is_3d, ) else: assert len(val_gt_ids) == len(val_ids), ( "Number of validation raw images ({}) and validation labels ({}) must be the same. " "Please, check your data. Raw directory: {} . Label directory: {}.".format( len(val_ids), len(val_gt_ids), val_path, val_mask_path ) ) # Calculate shape with upsampling if is_3d: assert len(y_upscaling) == 3 and len(crop_shape) == 4 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], crop_shape[2] * y_upscaling[2], crop_shape[3], ) else: assert len(crop_shape) == 3 and len(y_upscaling) == 2 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], crop_shape[2], ) Y_val = samples_from_image_list( list_of_data=val_gt_ids, data_path=val_mask_path, crop=crop, crop_shape=real_shape, ov=val_ov, padding=val_padding, norm_module=norm_module, is_mask=True, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_f=val_preprocess_f if val_shape_will_change else None, preprocess_cfg=val_preprocess_cfg if val_shape_will_change else None, ) else: if val_mask_path is None: raise ValueError("Implementation error. Contact BiaPy team") print("Gathering raw and label images for validation data . . .") X_val, Y_val = samples_from_image_list_multiple_raw_one_gt( data_path=val_path, gt_path=val_mask_path, crop_shape=crop_shape, ov=val_ov, padding=val_padding, norm_module=norm_module, crop=crop, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_f=val_preprocess_f if val_shape_will_change else None, preprocess_cfg=val_preprocess_cfg if val_shape_will_change else None, ) # Check that the shape of all images match if val_mask_path and Y_val: print("Checking validation raw and label images' shapes . . .") for i in range(len(X_val.sample_list)): xshape = X_val.sample_list[i].get_shape() gt_associated_id = X_val.sample_list[i].get_gt_associated_id() if gt_associated_id is not None: yshape = Y_val.sample_list[gt_associated_id].get_shape() else: yshape = Y_val.sample_list[i].get_shape() if not xshape: xshape = crop_shape[:-1] if not yshape: yshape = crop_shape[:-1] if is_3d: assert len(y_upscaling) == 3 and len(xshape) == 3 upsampled_x_shape = ( xshape[0] * y_upscaling[0], xshape[1] * y_upscaling[1], xshape[2] * y_upscaling[2], ) else: assert len(y_upscaling) == 2 and len(xshape) == 2 upsampled_x_shape = ( xshape[0] * y_upscaling[0], xshape[1] * y_upscaling[1], ) if upsampled_x_shape != yshape[: len(upsampled_x_shape)]: filepath = X_val.dataset_info[X_val.sample_list[i].fid] raise ValueError( f"There is a mismatch between input image and its corresponding ground truth ({upsampled_x_shape} vs " f"{yshape}). Please check the images. Specifically, the sample that doesn't match is within " f"the file {filepath})" ) if len(val_filter_props) > 0: save_example_dir = None if save_filtered_images and save_filtered_images_dir: save_example_dir = os.path.join(save_filtered_images_dir, "val") filter_samples_by_properties( X_val, is_3d, val_filter_props, val_filter_vals, val_filter_signs, y_dataset=Y_val, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, filter_by_entire_image=filter_by_entire_image if not random_crops_in_DA else True, norm_before_filter=norm_before_filter, norm_module=norm_module, zarr_data_information=val_zarr_data_information if val_using_zarr else None, save_filtered_images=save_filtered_images, save_filtered_images_dir=save_example_dir, save_filtered_images_num=save_filtered_images_num, ) x_val_ids = np.array(range(0, len(X_val.sample_list))) if val_mask_path and Y_val: y_val_ids = np.array(range(0, len(Y_val.sample_list))) if not multiple_raw_images and len(x_val_ids) != len(y_val_ids): raise ValueError( f"Raw image number ({len(x_val_ids)}) and ground truth file mismatch ({len(y_val_ids)}). Please check the data!" ) if train_in_memory: print("* Loading train images . . .") load_images_to_dataset( dataset=X_train, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=train_preprocess_cfg, preprocess_f=train_preprocess_f, is_3d=is_3d, zarr_data_information=train_zarr_data_information if train_using_zarr else None, ) if train_mask_path and Y_train: print("* Loading train GT . . .") if is_3d: assert len(y_upscaling) == 3 and len(crop_shape) == 4 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], crop_shape[2] * y_upscaling[2], gt_channels_expected, ) else: assert len(y_upscaling) == 2 and len(crop_shape) == 3 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], gt_channels_expected, ) load_images_to_dataset( dataset=Y_train, crop_shape=real_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=train_preprocess_cfg, is_mask=is_y_mask, preprocess_f=train_preprocess_f, is_3d=is_3d, zarr_data_information=train_zarr_data_information if train_using_zarr else None, ) if val_in_memory: print("* Loading validation images . . .") load_images_to_dataset( dataset=X_val, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=val_preprocess_cfg, preprocess_f=val_preprocess_f, is_3d=is_3d, zarr_data_information=val_zarr_data_information if val_using_zarr else None, ) if val_mask_path and Y_val: print("* Loading validation GT . . .") if is_3d: assert len(y_upscaling) == 3 and len(crop_shape) == 4 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], crop_shape[2] * y_upscaling[2], gt_channels_expected, ) else: assert len(y_upscaling) == 2 and len(crop_shape) == 3 real_shape = ( crop_shape[0] * y_upscaling[0], crop_shape[1] * y_upscaling[1], gt_channels_expected, ) load_images_to_dataset( dataset=Y_val, crop_shape=real_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=val_preprocess_cfg, is_mask=is_y_mask, preprocess_f=val_preprocess_f, is_3d=is_3d, zarr_data_information=val_zarr_data_information if val_using_zarr else None, ) print("### LOAD RESULTS ###") if X_train.sample_list[0].coords == None: print( "The samples have not been cropped so they may have different shapes. Because of that only first sample's shape will be printed!" ) sample_shape = X_train.sample_list[0].get_shape() if not sample_shape: sample_shape = crop_shape X_data_shape = (len(X_train.sample_list),) + sample_shape print("*** Loaded train data shape is: {}".format(X_data_shape)) if Y_train: sample_shape = Y_train.sample_list[0].get_shape() if not sample_shape: sample_shape = crop_shape Y_data_shape = (len(Y_train.sample_list),) + sample_shape print("*** Loaded train GT shape is: {}".format(Y_data_shape)) else: Y_train = X_train.copy() sample_shape = X_val.sample_list[0].get_shape() if not sample_shape: sample_shape = crop_shape X_data_shape = (len(X_val.sample_list),) + sample_shape print("*** Loaded validation data shape is: {}".format(X_data_shape)) if Y_val: sample_shape = Y_val.sample_list[0].get_shape() if not sample_shape: sample_shape = crop_shape Y_data_shape = (len(Y_val.sample_list),) + sample_shape print("*** Loaded validation GT shape is: {}".format(Y_data_shape)) else: Y_val = X_val.copy() print("### END LOAD ###") return X_train, Y_train, X_val, Y_val
[docs] def load_and_prepare_test_data( test_path: str, test_mask_path: Optional[str], multiple_raw_images: Optional[bool] = False, test_zarr_data_information: Optional[Dict] = None, ) -> Tuple[BiaPyDataset, Optional[BiaPyDataset], List]: """ Load test data. Parameters ---------- test_path : str Path to the test data. test_mask_path : str Path to the test data masks. multiple_raw_images : bool, optional When a folder of folders for each image is expected. In each of those subfolder different versions of the same image are placed. Visit the following tutorial for a real use case and a more detailed description: `Light My Cells <https://biapy.readthedocs.io/en/latest/tutorials/image-to-image/lightmycells.html>`_. This is used when ``PROBLEM.IMAGE_TO_IMAGE.MULTIPLE_RAW_ONE_TARGET_LOADER`` is selected. test_zarr_data_information : dict, optional Additional information when using Zarr/H5 files for test. The following keys are expected: * ``"raw_path"``, str: path where the raw images reside within the zarr. * ``"gt_path"``, str: path where the mask images reside within the zarr. * ``"use_gt_path"``, str: whether the GT that should be used or not. Returns ------- X_train : list of dict Loaded train X data. Each item in the list represents a sample of the dataset. Each sample is represented as follows: * ``"filename"``, str: name of the image to extract the data sample from. * ``"dir"``, str: directory where the image resides. Y_train : list of dict, optional Loaded train Y data. Each item in the list represents a sample of the dataset. Each sample is represented as follows: * ``"train_path"``, str: name of the image to extract the data sample from. * ``"dir"``, str: directory where the image resides. test_filenames : list of str List of test filenames. """ print("### LOAD ###") sample_list = [] dataset_info = [] Y_test = None # Just read the images from test folder if not os.path.exists(test_path): raise ValueError(f"{test_path} doesn't exist") ids = next(os_walk_clean(test_path))[2] if not multiple_raw_images or len(ids) > 0: fids = next(os_walk_clean(test_path))[1] if len(ids) == 0: if len(fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(test_path)) test_filenames = fids else: test_filenames = ids for i in range(len(test_filenames)): dataset_info.append(DatasetFile(path=os.path.join(test_path, test_filenames[i]))) sample_data = DataSample(fid=i, coords=None) if test_zarr_data_information: sample_data.path_in_zarr = test_zarr_data_information["raw_path"] sample_list.append(sample_data) # Extract a list of all gt images if test_mask_path: y_dataset_info = [] y_sample_list = [] if not os.path.exists(test_mask_path): raise ValueError(f"{test_mask_path} doesn't exist") ids = next(os_walk_clean(test_mask_path))[2] fids = next(os_walk_clean(test_mask_path))[1] if len(ids) == 0: if len(fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(test_mask_path)) selected_ids = fids else: selected_ids = ids for i in range(len(selected_ids)): y_dataset_info.append(DatasetFile(path=os.path.join(test_mask_path, selected_ids[i]))) sample_data = DataSample(fid=i, coords=None) if test_zarr_data_information: if test_zarr_data_information["use_gt_path"]: sample_data.path_in_zarr = test_zarr_data_information["gt_path"] else: sample_data.path_in_zarr = test_zarr_data_information["raw_path"] y_sample_list.append(sample_data) else: test_filenames = next(os_walk_clean(test_path))[1] if len(test_filenames) == 0: raise ValueError("No folders found in dir {}".format(test_path)) for folder in test_filenames: sample_path = os.path.join(test_path, folder) ids = next(os_walk_clean(sample_path))[2] if len(ids) == 0: raise ValueError("No images found in dir {}".format(sample_path)) for i in range(len(ids)): dataset_info.append(DatasetFile(path=os.path.join(sample_path, ids[i]))) sample_list.append(DataSample(fid=i, coords=None)) # Extract a list of all training gt images if test_mask_path: y_dataset_info = [] y_sample_list = [] fids = next(os_walk_clean(test_mask_path))[1] if len(fids) == 0: raise ValueError("No folders found in dir {}".format(test_mask_path)) for folder in fids: sample_path = os.path.join(test_mask_path, folder) ids = next(os_walk_clean(sample_path))[2] if len(ids) == 0: raise ValueError("No images found in dir {}".format(sample_path)) for i in range(len(ids)): y_dataset_info.append(DatasetFile(path=os.path.join(sample_path, ids[i]))) y_sample_list.append(DataSample(fid=i, coords=None)) X_test = BiaPyDataset(dataset_info=dataset_info, sample_list=sample_list) if test_mask_path: Y_test = BiaPyDataset(dataset_info=y_dataset_info, sample_list=y_sample_list) return X_test, Y_test, test_filenames
[docs] def load_and_prepare_cls_test_data( test_path: str, norm_module: Dict, use_val_as_test: bool, expected_classes: int, crop_shape: Tuple[int, ...], is_3d: bool = True, reflect_to_complete_shape: bool = True, convert_to_rgb: bool = False, use_val_as_test_info: Optional[Dict] = None, ): """ Load test data. Parameters ---------- train_path : str Path to the training data. norm_module : Dict Information about the normalization. use_val_as_test : bool Whether to use validation data as test. expected_classes : int Expected number of classes to be loaded. crop_shape : 3D/4D int tuple Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. is_3d: bool, optional Whether the data to load is expected to be 3D or not. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. use_val_as_test_info : dict, optional Additional information to create the test set based on the validation. Used when ``use_val_as_test`` is ``True``. The expected keys of the dictionary are as follows: * ``"cross_val_samples_ids"``, list of int: ids of the validation samples (out of the cross validation). * ``"train_path"``, str: training path, as the data must be extracted from there. * ``"selected_fold``", int: fold selected in cross validation. * ``"n_splits"``, int: folds to create in cross validation. * ``"shuffle"``, bool: whether to shuffle the data or not. * ``"seed"``, int: mathematical seed. Returns ------- X_test : list of dict Loaded test data. Each item in the list represents a sample of the dataset. Each sample is represented as follows: * ``"filename"``, str: name of the image to extract the data sample from. * ``"dir"``, str: directory where the image resides. * ``"class_name"``, str: name of the class. * ``"class"``, int: represents the class (``-1`` if no ground truth provided). test_filenames : list of str List of test filenames. """ print("### LOAD ###") X_test = [] if not use_val_as_test: path_to_process = test_path else: assert use_val_as_test_info, "'use_val_as_test_info' can not be None when 'use_val_as_test' is 'True'" path_to_process = use_val_as_test_info["train_path"] X_test = samples_from_class_list( data_path=path_to_process, norm_module=norm_module, expected_classes=expected_classes, crop_shape=crop_shape, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, ) test_filenames = [X_test.dataset_info[x.fid] for x in X_test.sample_list] if use_val_as_test: # The test is the validation, and as it is only available when validation is obtained from train and when # cross validation is enabled, the test set files reside in the train folder assert use_val_as_test_info if use_val_as_test_info["cross_val_samples_ids"] is None: x_test_ids = np.array(range(0, len(X_test.sample_list))) # Split the test as it was the validation when train is not enabled skf = StratifiedKFold( n_splits=use_val_as_test_info["n_splits"], shuffle=use_val_as_test_info["shuffle"], random_state=use_val_as_test_info["seed"], ) fold = 1 A = B = np.zeros(len(x_test_ids)) for _, te_index in skf.split(A, B): if use_val_as_test_info["selected_fold"] == fold: use_val_as_test_info["cross_val_samples_ids"] = te_index.copy() break fold += 1 if len(use_val_as_test_info["cross_val_samples_ids"]) > 5: print( "Fold number {} used for test data. Printing the first 5 ids: {}".format( fold, use_val_as_test_info["cross_val_samples_ids"][:5] ) ) else: print( "Fold number {}. Indexes used in cross validation: {}".format( fold, use_val_as_test_info["cross_val_samples_ids"] ) ) if use_val_as_test_info["cross_val_samples_ids"] is not None: X_test.clean_dataset(use_val_as_test_info["cross_val_samples_ids"]) test_filenames = [test_filenames[i] for i in use_val_as_test_info["cross_val_samples_ids"]] return X_test, test_filenames
[docs] def load_data_from_dir(data_path: str, is_3d: bool = False) -> List[NDArray]: """ Create dataset samples from the given list. Parameters ---------- data_path : str Path to read the images from. is_3d : bool, optional Whether if the expected images to read are 3D or not. """ if not os.path.exists(data_path): raise ValueError(f"{data_path} folder does not exist") print(f"Loading images from {data_path} . . .") ids = next(os_walk_clean(data_path))[2] fids = next(os_walk_clean(data_path))[1] if len(ids) == 0: if len(fids) == 0: # Trying Zarr raise ValueError("No images found in dir {}".format(data_path)) else: list_of_images = fids else: list_of_images = ids all_images = [] for id_ in tqdm(list_of_images, total=len(list_of_images)): img_path = os.path.join(data_path, id_) img = read_img_as_ndarray(img_path, is_3d=is_3d) all_images.append(img) return all_images
[docs] def load_cls_data_from_dir( data_path: str, norm_module: Dict, expected_classes: int, crop_shape: Optional[Tuple[int, ...]], is_3d: bool = True, reflect_to_complete_shape: bool = True, convert_to_rgb: bool = False, preprocess_f: Optional[Callable] = None, preprocess_cfg: Optional[Dict] = None, ) -> BiaPyDataset: """ Create dataset samples from the given list following a classification workflow directory tree. Parameters ---------- data_path : str Path to read the images from. norm_module : Dict Information about the normalization. expected_classes : int Expected number of classes to be loaded. crop_shape : 3D/4D int tuple, optional Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. is_3d : bool, optional Whether if the expected images to read are 3D or not. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. preprocess_f : function, optional The preprocessing function, is necessary in case you want to apply any preprocessing. preprocess_cfg : dict, optional Configuration parameters for preprocessing, is necessary in case you want to apply any preprocessing. Returns ------- data_samples : BiaPyDataset Dataset created out of ``data_path``. """ data_samples = samples_from_class_list( data_path=data_path, norm_module=norm_module, expected_classes=expected_classes, crop_shape=crop_shape, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, ) print(f"Loading images from {data_path}") load_images_to_dataset( dataset=data_samples, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=preprocess_cfg, preprocess_f=preprocess_f, is_3d=is_3d, ) return data_samples
[docs] def load_and_prepare_train_data_cls( train_path: str, train_in_memory: bool, val_path: str, val_in_memory: bool, expected_classes: int, norm_module: Dict, crop_shape: Tuple[int, ...], cross_val: bool = False, cross_val_nsplits: int = 5, cross_val_fold: int = 1, val_split: float = 0.1, seed: int = 0, shuffle_val: bool = True, train_preprocess_f: Optional[Callable] = None, train_preprocess_cfg: Optional[Dict] = None, train_filter_props: List[List[str]] = [], train_filter_vals: List[List[float | int]] = [], train_filter_signs: List[List[str]] = [], val_preprocess_f: Optional[Callable] = None, val_preprocess_cfg: Optional[Dict] = None, val_filter_props: List[List[str]] = [], val_filter_vals: List[List[int | float]] = [], val_filter_signs: List[List[str]] = [], norm_before_filter: bool = False, reflect_to_complete_shape: bool = False, convert_to_rgb: bool = False, is_3d: bool = False, ): """ Load data to train classification methods. Parameters ---------- train_path : str Path to the training data. train_in_memory : str Whether the train data must be loaded in memory or not. val_path : str Path to the validation data. val_in_memory : str Whether the validation data must be loaded in memory or not. expected_classes : int Expected number of classes to be loaded. norm_module : Dict Information about the normalization. crop_shape : 3D/4D int tuple Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. cross_val : bool, optional Whether to use cross validation or not. cross_val_nsplits : int, optional Number of folds for the cross validation. cross_val_fold : int, optional Number of the fold to be used as validation. val_split : float, optional % of the train data used as validation (value between ``0`` and ``1``). seed : int, optional Seed value. shuffle_val : bool, optional Take random training examples to create validation data. train_preprocess_f : function, optional The train preprocessing function, is necessary in case you want to apply any preprocessing. train_preprocess_cfg : dict, optional Configuration parameters for train preprocessing, is necessary in case you want to apply any preprocessing. train_filter_props : list of lists of str Filter conditions to be applied to the train data. The three variables, ``filter_props``, ``filter_vals`` and ``filter_vals`` will compose a list of conditions to remove the samples from the list. They are list of list of conditions. For instance, the conditions can be like this: ``[['A'], ['B','C']]``. Then, if the sample satisfies the first list of conditions, only 'A' in this first case (from ['A'] list), or satisfy 'B' and 'C' (from ['B','C'] list) it will be removed. In each sublist all the conditions must be satisfied. Available properties are: [``'foreground'``, ``'mean'``, ``'min'``, ``'max'``]. Each property descrition: * ``'foreground'`` is defined as the mask foreground percentage. * ``'mean'`` is defined as the mean value. * ``'min'`` is defined as the min value. * ``'max'`` is defined as the max value. * ``'diff'`` is defined as the difference between ground truth and raw images. Require ``y_dataset`` to be provided. * ``'diff_by_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between raw image max and min. * ``'target_mean'`` is defined as the mean intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_min'`` is defined as the min intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_max'`` is defined as the max intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'diff_by_target_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between ground truth image max and min. train_filter_vals : list of int/float Represent the values of the properties listed in ``train_filter_props`` that the images need to satisfy to not be dropped. train_filter_signs : list of list of str Signs to do the comparison for train data filtering. Options: [``'gt'``, ``'ge'``, ``'lt'``, ``'le'``] that corresponds to "greather than", e.g. ">", "greather equal", e.g. ">=", "less than", e.g. "<", and "less equal" e.g. "<=" comparisons. val_preprocess_f : function, optional The validation preprocessing function, is necessary in case you want to apply any preprocessing. val_preprocess_cfg : dict, optional Configuration parameters for validation preprocessing, is necessary in case you want to apply any preprocessing. val_filter_props : list of lists of str Filter conditions to be applied to the validation data. The three variables, ``filter_props``, ``filter_vals`` and ``filter_vals`` will compose a list of conditions to remove the images from the list. They are list of list of conditions. For instance, the conditions can be like this: ``[['A'], ['B','C']]``. Then, if the sample satisfies the first list of conditions, only 'A' in this first case (from ['A'] list), or satisfy 'B' and 'C' (from ['B','C'] list) it will be removed. In each sublist all the conditions must be satisfied. Available properties are: [``'foreground'``, ``'mean'``, ``'min'``, ``'max'``]. Each property descrition: * ``'foreground'`` is defined as the mask foreground percentage. * ``'mean'`` is defined as the mean value. * ``'min'`` is defined as the min value. * ``'max'`` is defined as the max value. * ``'diff'`` is defined as the difference between ground truth and raw images. Require ``y_dataset`` to be provided. * ``'diff_by_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between raw image max and min. * ``'target_mean'`` is defined as the mean intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_min'`` is defined as the min intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_max'`` is defined as the max intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'diff_by_target_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between ground truth image max and min. val_filter_vals : list of int/float Represent the values of the properties listed in ``val_filter_props`` that the images need to satisfy to not be dropped. val_filter_signs : list of list of str Signs to do the comparison for validation data filtering. Options: [``'gt'``, ``'ge'``, ``'lt'``, ``'le'``] that corresponds to "greather than", e.g. ">", "greather equal", e.g. ">=", "less than", e.g. "<", and "less equal" e.g. "<=" comparisons. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. is_3d : bool, optional Whether if the expected images to read are 3D or not. Returns ------- X_train : list of dict Loaded train data. Each item in the list represents a sample of the dataset. Each sample is represented as follows: * ``"filename"``, str: name of the image to extract the data sample from. * ``"dir"``, str: directory where the image resides. * ``"class_name"``, str: name of the class. * ``"class"``, int: represents the class (``-1`` if no ground truth provided). * ``"img"``, ndarray (optional): image sample itself. It is of ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``. Provided when ``val_in_memory`` is ``True``. X_val : list of dict Loaded validation data. Each item in the list represents a sample of the dataset. Each sample is represented as follows: * ``"filename"``, str: name of the image to extract the data sample from. * ``"dir"``, str: directory where the image resides. * ``"class_name"``, str: name of the class. * ``"class"``, int: represents the class (``-1`` if no ground truth provided). * ``"img"``, ndarray (optional): image sample itself. It is of ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``. Provided when ``val_in_memory`` is ``True``. x_val_ids : list of int Indexes of the samples beloging to the validation. Used in cross-validation. """ print("### LOAD ###") # Check validation if val_split > 0 or cross_val: create_val_from_train = True else: create_val_from_train = False X_train, X_val = None, None X_train = samples_from_class_list( data_path=train_path, norm_module=norm_module, expected_classes=expected_classes, crop_shape=crop_shape, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, ) if len(train_filter_props) > 0: filter_samples_by_properties( X_train, is_3d, train_filter_props, train_filter_vals, train_filter_signs, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, norm_before_filter=norm_before_filter, norm_module=norm_module, ) x_train_ids = np.array(range(0, len(X_train.sample_list))) y_train_ids = np.array([x.class_num for x in X_train.dataset_info]) if create_val_from_train: val_path = train_path if not cross_val: x_train_ids, x_val_ids = train_test_split( x_train_ids, test_size=val_split, shuffle=shuffle_val, random_state=seed ) else: skf = StratifiedKFold(n_splits=cross_val_nsplits, shuffle=shuffle_val, random_state=seed) fold = 1 for t_index, te_index in skf.split(x_train_ids, y_train_ids): if cross_val_fold == fold: x_train_ids, x_val_ids = x_train_ids[t_index], x_train_ids[te_index] train_index, test_index = t_index.copy(), te_index.copy() break fold += 1 if len(test_index) > 5: print("Fold number {}. Printing the first 5 ids: {}".format(fold, test_index[:5])) else: print("Fold number {}. Indexes used in cross validation: {}".format(fold, test_index)) x_val_ids = test_index.copy() # Create validation data from train. It's important to sort them in order to speed up load_images_to_dataset() process x_val_ids.sort() x_train_ids.sort() X_val = X_train.copy() X_train.clean_dataset(x_train_ids) X_val.clean_dataset(x_val_ids) else: X_val = samples_from_class_list( data_path=val_path, expected_classes=expected_classes, crop_shape=crop_shape, is_3d=is_3d, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, norm_module=norm_module, ) if len(val_filter_props) > 0: filter_samples_by_properties( X_val, is_3d, val_filter_props, val_filter_vals, val_filter_signs, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, norm_before_filter=norm_before_filter, norm_module=norm_module, ) x_val_ids = np.array(range(0, len(X_val.sample_list))) if train_in_memory: print("* Loading train images . . .") load_images_to_dataset( dataset=X_train, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=train_preprocess_cfg, preprocess_f=train_preprocess_f, is_3d=is_3d, ) if val_in_memory: print("* Loading validation images . . .") load_images_to_dataset( dataset=X_val, crop_shape=crop_shape, reflect_to_complete_shape=reflect_to_complete_shape, convert_to_rgb=convert_to_rgb, preprocess_cfg=val_preprocess_cfg, preprocess_f=val_preprocess_f, is_3d=is_3d, ) print("### LOAD RESULTS ###") X_data_shape = (len(X_train.sample_list),) + crop_shape print("*** Loaded train data shape is: {}".format(X_data_shape)) X_data_shape = (len(X_val.sample_list),) + crop_shape print("*** Loaded validation data shape is: {}".format(X_data_shape)) print("### END LOAD ###") return X_train, X_val, x_val_ids
[docs] def samples_from_image_list( list_of_data: List[str], data_path: str, crop_shape: Tuple[int, ...], ov: Tuple[float, ...], padding: Tuple[int, ...], norm_module: Dict, crop: bool = True, is_mask: bool = False, is_3d: bool = True, reflect_to_complete_shape: bool = True, convert_to_rgb: bool = False, preprocess_f: Optional[Callable] = None, preprocess_cfg: Optional[Dict] = None, ) -> BiaPyDataset: """ Create dataset samples from the given list. This function does not load the data. Parameters ---------- list_of_data : list of str Filenames of the images to read. data_path : str Directory of the images to read. crop_shape : 3D/4D int tuple Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. ov : 2D/3D float tuple Amount of minimum overlap on x and y dimensions. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. padding : 2D/3D int tuple Size of padding to be added on each axis. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. norm_module : Dict Information about the normalization. crop : bool, optional Whether if the data needs to be cropped or not. is_mask : bool, optional Whether the data are masks. It is used to control the preprocessing of the data. is_3d: bool, optional Whether the data to load is expected to be 3D or not. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. preprocess_f : function, optional The preprocessing function, is necessary in case you want to apply any preprocessing. preprocess_cfg : dict, optional Configuration parameters for preprocessing, is necessary in case you want to apply any preprocessing. Returns ------- dataset : BiaPyDataset Dataset. """ if preprocess_f and preprocess_cfg is None: raise ValueError("'preprocess_cfg' needs to be provided with 'preprocess_f'") crop_funct = crop_3D_data_with_overlap if is_3d else crop_data_with_overlap sample_list = [] dataset_info = [] channel_expected = -1 data_range_expected = -1 for i in range(len(list_of_data)): # Read image img_path = os.path.join(data_path, list_of_data[i]) img, _ = load_img_data(img_path, is_3d=is_3d) # Apply preprocessing if preprocess_f: if is_mask: img = preprocess_f(preprocess_cfg, y_data=[img], is_2d=not is_3d, is_y_mask=is_mask)[0] else: img = preprocess_f(preprocess_cfg, x_data=[img], is_2d=not is_3d)[0] if reflect_to_complete_shape: img = pad_and_reflect(img, crop_shape, verbose=False) if crop_shape[-1] == 3 and convert_to_rgb and img.shape[-1] != 3: img = np.repeat(img, 3, axis=-1) # Channel check within dataset images if channel_expected == -1: channel_expected = img.shape[-1] if img.shape[-1] != channel_expected: raise ValueError( f"All images need to have the same number of channels and represent same information to " "ensure the deep learning model can be trained correctly. However, the current image (with " f"{channel_expected} channels) appears to have a different number of channels than the first image" f"(with {img.shape[-1]} channels) in the folder. Current image: {img_path}" ) # Channel check compared with crop_shape if not is_mask: if crop_shape[-1] != img.shape[-1]: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], img.shape[-1]) ) # Data range check if not is_mask: if data_range_expected == -1: data_range_expected = data_range(img) drange = data_range(img) if data_range_expected != drange: raise ValueError( f"All images must be within the same data range. However, the current image (with a " f"range of {drange}) appears to be in a different data range than the first image (with a range " f"of {data_range_expected}) in the folder. Current image: {img_path}" ) original_data_shape = img.shape crop_coords = None if crop and ( img.shape <= crop_shape[:-1] + (img.shape[-1],) or img.shape >= crop_shape[:-1] + (img.shape[-1],) ): crop_coords = crop_funct( np.expand_dims(img, axis=0) if not is_3d else img, crop_shape[:-1] + (img.shape[-1],), overlap=ov, padding=padding, verbose=False, load_data=False, ) tot_samples_to_insert = len(crop_coords) else: tot_samples_to_insert = 1 if is_mask: img, norm_info = normalize_mask(img, norm_module=norm_module, apply_norm=False) else: img, norm_info = normalize_image(img, norm_module=norm_module, apply_norm=False) dataset_file = DatasetFile( path=os.path.join(data_path, list_of_data[i]), shape=original_data_shape, norm_info=norm_info ) dataset_info.append(dataset_file) for j in range(tot_samples_to_insert): data_sample = DataSample( fid=i, coords=crop_coords[j] if crop_coords else None, # type: ignore ) sample_list.append(data_sample) return BiaPyDataset(dataset_info=dataset_info, sample_list=sample_list)
[docs] def samples_from_zarr( list_of_data: List[str], data_path: str, zarr_data_info: Dict, crop_shape: Tuple[int, ...], ov: Tuple[float, ...], padding: Tuple[int, ...], is_mask: bool = False, is_3d: bool = True, ) -> BiaPyDataset: """ Create dataset samples from the given list. This function does not load the data. Parameters ---------- list_of_data : list of str Filenames of the images to read. data_path : str Directory of the images to read. zarr_data_info : dict Additional information when using Zarr/H5 files for training. The following keys are expected: * ``"raw_path"``: path where the raw images reside within the zarr (used when ``multiple_data_within_zarr`` is ``True``). * ``"gt_path"``: path where the mask images reside within the zarr (used when ``multiple_data_within_zarr`` is ``True``). * ``"multiple_data_within_zarr"``: Whether if your input Zarr contains the raw images and labels together or not. * ``"input_img_axes"``: order of the axes of the images. * ``"input_mask_axes"``: order of the axes of the masks. crop_shape : 3D/4D int tuple Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. ov : 2D/3D float tuple, optional Amount of minimum overlap on x and y dimensions. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. padding : 2D/3D int tuple, optional Size of padding to be added on each axis. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. is_mask : bool, optional Whether the data are masks. It is used to control the preprocessing of the data. is_3d: bool, optional Whether the data to load is expected to be 3D or not. Returns ------- dataset : BiaPyDataset Dataset. """ # Extract a list of all training samples within the Zarr sample_list = [] dataset_info = [] channel_expected = -1 for i in range(len(list_of_data)): sample_path = os.path.join(data_path, list_of_data[i]) data_within_zarr_path = None if zarr_data_info["multiple_data_within_zarr"]: if not is_mask: data_within_zarr_path = zarr_data_info["raw_path"] else: data_within_zarr_path = zarr_data_info["gt_path"] if zarr_data_info["use_gt_path"] else None data, file = load_img_data(sample_path, is_3d=is_3d, data_within_zarr_path=data_within_zarr_path) key_to_check = "input_img_axes" if not is_mask else "input_mask_axes" if "C" in zarr_data_info[key_to_check]: pos = zarr_data_info[key_to_check].index("C") channel = data.shape[pos] if pos < len(data.shape) else 1 else: channel = 1 # Channel check within dataset images if channel_expected == -1: channel_expected = channel if channel != channel_expected: raise ValueError( f"All images need to have the same number of channels and represent same information to " "ensure the deep learning model can be trained correctly. However, the current image (with " f"{channel_expected} channels) appears to have a different number of channels than the first image" f"(with {channel} channels) in the folder. Current image: {sample_path}" ) # Channel check compared with crop_shape if not is_mask: if crop_shape[-1] != channel: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], channel) ) # Get the total patches so we can use tqdm so the user can see the time obj = extract_3D_patch_with_overlap_and_padding_yield( data, crop_shape[:-1] + (channel,), zarr_data_info["input_img_axes"] if not is_mask else zarr_data_info["input_mask_axes"], overlap=ov, padding=padding, total_ranks=1, rank=0, return_only_stats=True, load_data=False, verbose=False, ) __unnamed_iterator = iter(obj) while True: try: obj = next(__unnamed_iterator) except StopIteration: # StopIteration caught here without inspecting it break del __unnamed_iterator total_patches, _, _ = obj # type: ignore for_img_cond = extract_3D_patch_with_overlap_and_padding_yield( data, crop_shape[:-1] + (channel,), zarr_data_info["input_img_axes"] if not is_mask else zarr_data_info["input_mask_axes"], overlap=ov, padding=padding, total_ranks=1, load_data=False, rank=0, verbose=False, ) dataset_info.append( DatasetFile( path=os.path.join(data_path, list_of_data[i]), shape=data.shape, parallel_data=True, input_axes=zarr_data_info["input_img_axes"] if not is_mask else zarr_data_info["input_mask_axes"], ) ) for obj in tqdm(for_img_cond, total=total_patches, disable=not is_main_process()): # type: ignore coords, _, _, _ = obj # type: ignore # Create crop_shape from coords as the sample is not loaded to speed up the process assert isinstance(coords, PatchCoords) if is_3d: crop_shape = ( coords.z_end - coords.z_start, coords.y_end - coords.y_start, coords.x_end - coords.x_start, channel, ) else: crop_shape = ( coords.y_end - coords.y_start, coords.x_end - coords.x_start, channel, ) sample_dict = DataSample( fid=i, coords=coords, ) if data_within_zarr_path: sample_dict.path_in_zarr = data_within_zarr_path sample_list.append(sample_dict) if isinstance(file, h5py.File): file.close() return BiaPyDataset(dataset_info=dataset_info, sample_list=sample_list)
[docs] def samples_from_image_list_multiple_raw_one_gt( data_path: str, gt_path: str, crop_shape: Tuple[int, ...], ov: Tuple[float, ...], padding: Tuple[int, ...], norm_module: Dict, crop: bool = True, is_3d: bool = True, reflect_to_complete_shape: bool = True, convert_to_rgb: bool = False, preprocess_f: Optional[Callable] = None, preprocess_cfg: Optional[Dict] = None, ) -> Tuple[BiaPyDataset, BiaPyDataset]: """ Create dataset samples from the given lists. This function does not load the data. Parameters ---------- data_path : str Directory of the images to read. gt_path : str Directory to read ground truth images from. crop_shape : 3D/4D int tuple Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. ov : 2D/3D float tuple Amount of minimum overlap on x and y dimensions. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. padding : 2D/3D int tuple Size of padding to be added on each axis. Shape is ``(y, x)`` for 2D or ``(z, y, x)`` for 3D. norm_module : Dict Information about the normalization. crop : bool, optional Whether if the data needs to be cropped or not. is_3d: bool, optional Whether the data to load is expected to be 3D or not. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. preprocess_f : function, optional The preprocessing function, is necessary in case you want to apply any preprocessing. preprocess_cfg : dict, optional Configuration parameters for preprocessing, is necessary in case you want to apply any preprocessing. Returns ------- dataset : BiaPyDataset X dataset. gt_dataset : BiaPyDataset Y dataset. """ if preprocess_f and preprocess_cfg is None: raise ValueError("'preprocess_cfg' needs to be provided with 'preprocess_f'") crop_funct = crop_3D_data_with_overlap if is_3d else crop_data_with_overlap data_gt_path = next(os_walk_clean(gt_path))[1] sample_list = [] dataset_info = [] gt_sample_list = [] gt_dataset_info = [] filenames = [] if len(data_gt_path) == 0: raise ValueError("No image folder found in dir {}".format(data_gt_path)) gt_sample_channel_expected = -1 gt_sample_data_range_expected = -1 raw_sample_channel_expected = -1 raw_sample_data_range_expected = -1 cont = 0 for id_ in tqdm(data_gt_path, total=len(data_gt_path), disable=not is_main_process()): # Read image gt_id = next(os_walk_clean(os.path.join(gt_path, id_)))[2][0] gt_sample_path = os.path.join(gt_path, id_, gt_id) filenames.append(gt_sample_path) gt_sample, _ = load_img_data(gt_sample_path, is_3d=is_3d) # Apply preprocessing if preprocess_f: gt_sample = preprocess_f(preprocess_cfg, x_data=[gt_sample], is_2d=not is_3d)[0] if reflect_to_complete_shape: gt_sample = pad_and_reflect(gt_sample, crop_shape, verbose=False) if crop_shape[-1] == 3 and convert_to_rgb and gt_sample.shape[-1] != 3: gt_sample = np.repeat(gt_sample, 3, axis=-1) # Channel check within dataset images if gt_sample_channel_expected == -1: gt_sample_channel_expected = gt_sample.shape[-1] if gt_sample.shape[-1] != gt_sample_channel_expected: raise ValueError( f"All images need to have the same number of channels and represent same information to " "ensure the deep learning model can be trained correctly. However, the current image (with " f"{gt_sample_channel_expected} channels) appears to have a different number of channels than the first image" f"(with {gt_sample.shape[-1]} channels) in the folder. Current image: {gt_sample_path}" ) # Channel check compared with crop_shape if crop_shape: if crop_shape[-1] != gt_sample.shape[-1]: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], gt_sample.shape[-1]) ) # Data range check if gt_sample_data_range_expected == -1: gt_sample_data_range_expected = data_range(gt_sample) drange = data_range(gt_sample) if gt_sample_data_range_expected != drange: raise ValueError( f"All images must be within the same data range. However, the current image (with a " f"range of {drange}) appears to be in a different data range than the first image (with a range " f"of {gt_sample_data_range_expected}) in the folder. Current image: {gt_sample_path}" ) # Extract all raw images for the current gt sample associated_raw_image_dir = os.path.join(data_path, id_) if not os.path.exists(associated_raw_image_dir): raise ValueError(f"Folder {associated_raw_image_dir} with multiple raw images not found.") raw_samples = next(os_walk_clean(associated_raw_image_dir))[2] if len(raw_samples) == 0: raise ValueError("No image folder found in dir {}".format(raw_samples)) original_data_shape = gt_sample.shape crop_coords = None if crop and ( gt_sample.shape <= crop_shape[:-1] + (gt_sample.shape[-1],) or gt_sample.shape >= crop_shape[:-1] + (gt_sample.shape[-1],) ): crop_coords = crop_funct( np.expand_dims(gt_sample, axis=0) if not is_3d else gt_sample, crop_shape[:-1] + (gt_sample.shape[-1],), overlap=ov, padding=padding, verbose=False, load_data=False, ) gt_tot_samples_to_insert = len(crop_coords) else: gt_tot_samples_to_insert = 1 gt_sample, norm_info = normalize_image(gt_sample, norm_module=norm_module, apply_norm=False) data_file = DatasetFile( path=os.path.join(gt_path, id_, gt_id), shape=original_data_shape, norm_info=norm_info ) gt_dataset_info.append(data_file) for i in range(gt_tot_samples_to_insert): coords = None if crop_coords is not None: coords = crop_coords[i] assert isinstance(coords, PatchCoords) data_sample = DataSample( fid=len(gt_dataset_info) - 1, coords=coords, ) gt_sample_list.append(data_sample) # For each gt samples there are multiple raw images for raw_sample_id in raw_samples: # Read image raw_sample_path = os.path.join(associated_raw_image_dir, raw_sample_id) raw_sample, _ = load_img_data(raw_sample_path, is_3d=is_3d) # Apply preprocessing if preprocess_f: raw_sample = preprocess_f(preprocess_cfg, x_data=[raw_sample], is_2d=not is_3d)[0] if reflect_to_complete_shape: raw_sample = pad_and_reflect(raw_sample, crop_shape, verbose=False) if crop_shape[-1] == 3 and convert_to_rgb and raw_sample.shape[-1] != 3: raw_sample = np.repeat(raw_sample, 3, axis=-1) # Channel check within dataset images if raw_sample_channel_expected == -1: raw_sample_channel_expected = raw_sample.shape[-1] if raw_sample.shape[-1] != raw_sample_channel_expected: raise ValueError( f"All images need to have the same number of channels and represent same information to " "ensure the deep learning model can be trained correctly. However, the current image (with " f"{raw_sample_channel_expected} channels) appears to have a different number of channels than the first image" f"(with {raw_sample.shape[-1]} channels) in the folder. Current image: {raw_sample_path}" ) # Channel check compared with crop_shape if crop_shape: if crop_shape[-1] != raw_sample.shape[-1]: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], raw_sample.shape[-1]) ) # Data range check if raw_sample_data_range_expected == -1: raw_sample_data_range_expected = data_range(raw_sample) drange = data_range(raw_sample) if raw_sample_data_range_expected != drange: raise ValueError( f"All images must be within the same data range. However, the current image (with a " f"range of {drange}) appears to be in a different data range than the first image (with a range " f"of {raw_sample_data_range_expected}) in the folder. Current image: {gt_sample_path}" ) original_data_shape = raw_sample.shape crop_coords = None if crop and raw_sample.shape != crop_shape[:-1] + (raw_sample.shape[-1],): crop_coords = crop_funct( np.expand_dims(raw_sample, axis=0) if not is_3d else raw_sample, crop_shape[:-1] + (raw_sample.shape[-1],), overlap=ov, padding=padding, verbose=False, load_data=False, ) tot_samples_to_insert = len(crop_coords) else: tot_samples_to_insert = 1 raw_sample, norm_info = normalize_image(raw_sample, norm_module=norm_module, apply_norm=False) dataset_file = DatasetFile( path=os.path.join(associated_raw_image_dir, raw_sample_id), shape=original_data_shape, norm_info=norm_info ) dataset_info.append(dataset_file) for i in range(tot_samples_to_insert): data_sample = DataSample( fid=len(dataset_info) - 1, coords=crop_coords[i] if crop_coords else None, # type: ignore gt_associated_id=cont + i, # this extra variable is added ) sample_list.append(data_sample) cont += gt_tot_samples_to_insert return ( BiaPyDataset(dataset_info=dataset_info, sample_list=sample_list), BiaPyDataset(dataset_info=gt_dataset_info, sample_list=gt_sample_list), )
[docs] def samples_from_class_list( data_path: str, norm_module: Dict, crop_shape: Optional[Tuple[int, ...]] = None, expected_classes: int = -1, is_3d: bool = True, reflect_to_complete_shape: bool = True, convert_to_rgb: bool = False, ) -> BiaPyDataset: """ Create dataset samples from the given path taking into account that each subfolder represents a class. This function does not load the data. Parameters ---------- data_path : str Directory of the images to read. norm_module : Dict Information about the normalization. crop_shape : 3D/4D int tuple, optional Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. expected_classes : int, optional Expected number of classes to be loaded. Set to -1 if you don't expect any. is_3d: bool, optional Whether the data to load is expected to be 3D or not. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. Returns ------- sample_list : list of DataSample Samples generated out of ``data_path``. """ if expected_classes != -1: list_of_classes = next(os_walk_clean(data_path))[1] if len(list_of_classes) < 1: raise ValueError("There is no folder/class in {}".format(data_path)) if expected_classes: if expected_classes != len(list_of_classes): raise ValueError( "Found {} number of classes (folders: {}) but 'DATA.N_CLASSES' was set to {}. They must match. Aborting...".format( len(list_of_classes), list_of_classes, expected_classes ) ) else: print("Found {} classes".format(len(list_of_classes))) gt_loaded = True else: list_of_classes = [os.path.basename(data_path)] data_path = os.path.dirname(data_path) gt_loaded = False xsample_list = [] xdataset_info = [] data_file_count = 0 for c_num, class_name in enumerate(list_of_classes): class_folder = os.path.join(data_path, class_name) ids = next(os_walk_clean(class_folder))[2] if len(ids) == 0: raise ValueError("There are no images in class {}".format(class_folder)) channel_expected = -1 data_range_expected = -1 for j, id_ in enumerate(ids): # Read image img_path = os.path.join(class_folder, id_) img, _ = load_img_data(img_path, is_3d=is_3d) if reflect_to_complete_shape and crop_shape: img = pad_and_reflect(img, crop_shape, verbose=False) if crop_shape and crop_shape[-1] == 3 and convert_to_rgb and img.shape[-1] != 3: img = np.repeat(img, 3, axis=-1) # Channel check within dataset images if channel_expected == -1: channel_expected = img.shape[-1] if img.shape[-1] != channel_expected: raise ValueError( f"All images need to have the same number of channels and represent same information to " "ensure the deep learning model can be trained correctly. However, the current image (with " f"{channel_expected} channels) appears to have a different number of channels than the first image" f"(with {img.shape[-1]} channels) in the folder. Current image: {img_path}" ) # Channel check compared with crop_shape if crop_shape and crop_shape[-1] != img.shape[-1]: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], img.shape[-1]) ) # Data range check if data_range_expected == -1: data_range_expected = data_range(img) drange = data_range(img) if data_range_expected != drange: raise ValueError( f"All images must be within the same data range. However, the current image (with a " f"range of {drange}) appears to be in a different data range than the first image (with a range " f"of {data_range_expected}) in the folder. Current image: {img_path}" ) img, norm_info = normalize_image(img, norm_module=norm_module, apply_norm=False) dataset_file = DatasetFile( path=img_path, shape=img.shape, class_name=class_name, class_num=c_num if gt_loaded else -1, norm_info=norm_info ) xdataset_info.append(dataset_file) sample_dict = DataSample( fid=data_file_count, coords=None, ) xsample_list.append(sample_dict) data_file_count += 1 return BiaPyDataset(dataset_info=xdataset_info, sample_list=xsample_list)
[docs] def filter_samples_by_properties( x_dataset: BiaPyDataset, is_3d: bool, filter_props: List[List[str]], filter_vals: List[List[int | float]], filter_signs: List[List[str]], crop_shape: Tuple[int, ...], reflect_to_complete_shape: bool = False, filter_by_entire_image: bool = True, norm_before_filter: bool = False, norm_module: Optional[Dict] = None, y_dataset: Optional[BiaPyDataset] = None, zarr_data_information: Optional[Dict] = None, save_filtered_images: bool = True, save_filtered_images_dir: Optional[str] = None, save_filtered_images_num: int = 3, ): """ Filter samples from ``x_dataset`` using defined conditions. The filtering will be done using the images each sample is extracted from. However, if ``zarr_data_info`` is provided the function will assume that Zarr/h5 files are provided, so the filtering will be performed sample by sample. Parameters ---------- x_dataset : BiaPyDataset X dataset to filter samples from. is_3d: bool, optional Whether the data to load is expected to be 3D or not. filter_props : list of lists of str Filter conditions to be applied. The three variables, ``filter_props``, ``filter_vals`` and ``filter_vals`` will compose a list of conditions to remove the images from the list. They are list of list of conditions. For instance, the conditions can be like this: ``[['A'], ['B','C']]``. Then, if the sample satisfies the first list of conditions, only 'A' in this first case (from ['A'] list), or satisfy 'B' and 'C' (from ['B','C'] list) it will be removed. In each sublist all the conditions must be satisfied. Available properties are: [``'foreground'``, ``'mean'``, ``'min'``, ``'max'``, ``diff``, ``target_mean``, ``target_min``, ``target_max``]. Each property descrition: * ``'foreground'`` is defined as the mask foreground percentage. * ``'mean'`` is defined as the mean value. * ``'min'`` is defined as the min value. * ``'max'`` is defined as the max value. * ``'diff'`` is defined as the difference between ground truth and raw images. Require ``y_dataset`` to be provided. * ``'diff_by_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between raw image max and min. * ``'target_mean'`` is defined as the mean intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_min'`` is defined as the min intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_max'`` is defined as the max intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'diff_by_target_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between ground truth image max and min. filter_vals : list of int/float Represent the values of the properties listed in ``filter_props`` that the images need to satisfy to not be dropped. filter_signs :list of list of str Signs to do the comparison. Options: [``'gt'``, ``'ge'``, ``'lt'``, ``'le'``] that corresponds to "greather than", e.g. ">", "greather equal", e.g. ">=", "less than", e.g. "<", and "less equal" e.g. "<=" comparisons. crop_shape : 3D/4D int tuple Shape of the crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. reflect_to_complete_shape : bool, optional Wheter to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. filter_by_entire_image : bool, optional This decides how the filtering is done: * ``True``: apply filter image by image. * ``False``: apply filtering sample by sample. Each sample represents a patch within an image. norm_before_filter : bool, optional Whether to apply normalization before filtering. Be aware then that the values for filtering may change. norm_module : Dict Information about the normalization. y_dataset : BiaPyDataset, optional Y dataset to filter samples from. zarr_data_info : dict, optional Additional information when using Zarr/H5 files for training. The following keys are expected: * ``"raw_path"``: path where the raw images reside within the zarr (used when ``multiple_data_within_zarr`` is ``True``). * ``"gt_path"``: path where the mask images reside within the zarr (used when ``multiple_data_within_zarr`` is ``True``). * ``"multiple_data_within_zarr"``: Whether if your input Zarr contains the raw images and labels together or not. * ``"input_img_axes"``: order of the axes of the images. * ``"input_mask_axes"``: order of the axes of the masks. save_filtered_images : bool, optional Whether to save or not filtered images. save_filtered_images_dir : str, optional Directory to save filtered images. save_filtered_images_num : int, optional Number of filtered images to save. Only work when ``save_filtered_images`` is ``True``. Returns ------- new_x_filenames : list of dict ``x_dataset`` list filtered. new_y_filenames : list of dict, optional ``y_dataset`` list filtered. """ if norm_before_filter and norm_module is None: raise ValueError("'norm_module' can not be None when 'norm_before_filter' is active") if save_filtered_images: if not save_filtered_images_dir: raise ValueError("'save_filtered_images_dir' can not be None when 'save_filtered_images' is enabled") save_filtered_images_count = 0 save_not_filtered_images_count = 0 # Filter samples by properties print("Applying filtering to data samples . . .") use_Y_data = False for cond in filter_props: if ( "foreground" in cond or "diff" in cond or "diff_by_min_max_ratio" in cond or "diff_by_target_min_max_ratio" in cond or "target_mean" in cond or "target_min" in cond or "target_max" in cond ): use_Y_data = True if use_Y_data and y_dataset is None: raise ValueError("Check filtering conditions as some of them require 'y_dataset' that was not provided") using_zarr = False if zarr_data_information: using_zarr = True print("Assuming we are working with Zarr/H5 images so the filtering will be done patch by patch.") print(f"Number of samples before filtering: {len(x_dataset.sample_list)}") else: if filter_by_entire_image: images = [x.path for x in x_dataset.dataset_info] images.sort() if use_Y_data and y_dataset: masks = [x.path for x in y_dataset.dataset_info] masks.sort() print(f"Number of samples before filtering: {len(images)}") else: print(f"Number of samples before filtering: {len(x_dataset.sample_list)}") if not using_zarr and filter_by_entire_image: clean_by = "image" samples_to_maintain = [] for n, image_path in tqdm(enumerate(images), total=len(images), disable=not is_main_process()): # Load X data img, _ = load_img_data(image_path, is_3d=is_3d) # Load Y data if use_Y_data: mask, _ = load_img_data(masks[n], is_3d=is_3d) else: mask = None if norm_before_filter: assert norm_module is not None img, _ = normalize_image(img, norm_module=norm_module) if use_Y_data: assert mask is not None mask, _ = normalize_mask(mask, norm_module=norm_module) assert isinstance(mask, np.ndarray) assert isinstance(img, np.ndarray) satisfy_conds = sample_satisfy_conds( img, filter_props, filter_vals, filter_signs, mask=mask, img_ratio=float(img.max()) - float(img.min()), mask_ratio=(float(mask.max()) - float(mask.min())) if mask is not None else 0, ) if not satisfy_conds: samples_to_maintain.append(n) if ( save_filtered_images and save_filtered_images_dir and save_not_filtered_images_count < save_filtered_images_num ): save_tif( np.expand_dims(img, 0), os.path.join(save_filtered_images_dir, "not-filtered"), [os.path.basename(image_path)], verbose=False, ) save_not_filtered_images_count += 1 else: print(f"Discarding file {image_path}") if ( save_filtered_images and save_filtered_images_dir and save_filtered_images_count < save_filtered_images_num ): save_tif( np.expand_dims(img, 0), os.path.join(save_filtered_images_dir, "filtered"), [os.path.basename(image_path)], verbose=False, ) save_filtered_images_count += 1 else: img_path, mask_path = "", "" clean_by = "sample" samples_to_maintain = [] file, mfile, mask = None, None, None for n, sample in tqdm( enumerate(x_dataset.sample_list), total=len(x_dataset.sample_list), disable=not is_main_process() ): # Load X data filepath = x_dataset.dataset_info[sample.fid].path if img_path != filepath: old_img_path = img_path img_path = filepath if file and isinstance(file, h5py.File): file.close() data_within_zarr_path = ( zarr_data_information["raw_path"] if zarr_data_information and zarr_data_information["multiple_data_within_zarr"] else None ) xdata, file = load_img_data(img_path, is_3d=is_3d, data_within_zarr_path=data_within_zarr_path) if reflect_to_complete_shape and crop_shape: xdata = pad_and_reflect(xdata, crop_shape, verbose=False) # Load Y data if use_Y_data: assert y_dataset is not None filepath = y_dataset.dataset_info[sample.fid].path mask_path = filepath if mfile and isinstance(mfile, h5py.File): mfile.close() data_within_zarr_path = None if zarr_data_information and zarr_data_information["multiple_data_within_zarr"]: data_within_zarr_path = ( zarr_data_information["gt_path"] if zarr_data_information["use_gt_path"] else None ) ydata, mfile = load_img_data(mask_path, is_3d=is_3d, data_within_zarr_path=data_within_zarr_path) if reflect_to_complete_shape and crop_shape: ydata = pad_and_reflect(ydata, crop_shape, verbose=False) else: ydata, mfile = None, None if norm_before_filter: assert norm_module is not None norm_info = x_dataset.dataset_info[sample.fid].norm_info xdata, _ = normalize_image(xdata, norm_module=norm_info) if use_Y_data: assert ydata is not None and y_dataset is not None norm_info = y_dataset.dataset_info[sample.fid] ydata, _ = normalize_mask(ydata, norm_module=norm_info) if save_filtered_images and save_filtered_images_dir: if "xdata_fil_example" in locals(): save_tif( np.expand_dims(xdata_fil_example, 0), save_filtered_images_dir, [os.path.basename(old_img_path)], verbose=True, ) save_filtered_images_count += 1 if save_filtered_images_count == save_filtered_images_num: del xdata_fil_example save_filtered_images_count += 1 elif save_filtered_images_count < save_filtered_images_num: xdata_fil_example = np.zeros(xdata.shape, dtype=xdata.dtype) # type: ignore # Capture patches within image/mask coords = sample.coords if use_Y_data: assert y_dataset is not None mcoords = y_dataset.sample_list[n].coords # Prepare slices to extract the patch assert coords is not None if is_3d: xslices = ( slice(None), slice(coords.z_start, coords.z_end), slice(coords.y_start, coords.y_end), slice(coords.x_start, coords.x_end), slice(None), ) else: xslices = ( slice(None), slice(coords.y_start, coords.y_end), slice(coords.x_start, coords.x_end), slice(None), ) if zarr_data_information: xdata_ordered_slices = order_dimensions( xslices, input_order="TZYXC", output_order=zarr_data_information["input_img_axes"], default_value=0, ) else: xdata_ordered_slices = tuple([x for x in xslices if x != slice(None)]) if use_Y_data: assert mcoords is not None if is_3d: yslices = ( slice(None), slice(mcoords.z_start, mcoords.z_end), slice(mcoords.y_start, mcoords.y_end), slice(mcoords.x_start, mcoords.x_end), slice(None), ) else: yslices = ( slice(None), slice(mcoords.y_start, mcoords.y_end), slice(mcoords.x_start, mcoords.x_end), slice(None), ) if zarr_data_information: ydata_ordered_slices = order_dimensions( yslices, input_order="TZYXC", output_order=zarr_data_information["input_mask_axes"], default_value=0, ) else: ydata_ordered_slices = tuple([x for x in yslices if x != slice(None)]) img = xdata[xdata_ordered_slices] # type: ignore if use_Y_data: assert ydata is not None mask = ydata[ydata_ordered_slices] # type: ignore assert isinstance(mask, np.ndarray) assert isinstance(img, np.ndarray) img_max = float(xdata.max()) if isinstance(xdata, np.ndarray) else float(img.max()) img_min = float(xdata.min()) if isinstance(xdata, np.ndarray) else float(img.min()) img_ratio = img_max - img_min mask_ratio = None if ydata is not None and mask is not None: mask_max = float(ydata.max()) if isinstance(ydata, np.ndarray) else float(mask.max()) mask_min = float(ydata.min()) if isinstance(ydata, np.ndarray) else float(mask.min()) mask_ratio = mask_max - mask_min satisfy_conds = sample_satisfy_conds( img, filter_props, filter_vals, filter_signs, mask=mask, img_ratio=img_ratio, mask_ratio=mask_ratio, ) if not satisfy_conds: samples_to_maintain.append(n) if save_filtered_images and "xdata_fil_example" in locals(): xdata_fil_example[xdata_ordered_slices] = img if ( save_filtered_images and save_filtered_images_dir and "xdata_fil_example" in locals() and save_filtered_images_count <= save_filtered_images_num ): save_tif( np.expand_dims(xdata_fil_example, 0), save_filtered_images_dir, [os.path.basename(img_path)], verbose=True, ) del xdata_fil_example x_dataset.clean_dataset(samples_to_maintain, clean_by=clean_by) if y_dataset: y_dataset.clean_dataset(samples_to_maintain, clean_by=clean_by) number_of_samples = len(samples_to_maintain) if number_of_samples == 0: raise ValueError( "Filters set with 'DATA.TRAIN.FILTER_SAMPLES.*' variables led to discard all training samples. Aborting!" ) elif number_of_samples == 1: raise ValueError( "Filters set with 'DATA.TRAIN.FILTER_SAMPLES.*' variables led to discard all training samples but one. Aborting!" ) print(f"Number of samples after filtering: {number_of_samples}")
[docs] def sample_satisfy_conds( img: NDArray, filter_props: List[List[str]], filter_vals: List[List[float | int]], filter_signs: List[List[str]], mask: Optional[NDArray] = None, img_ratio: float = 0, mask_ratio: Optional[float] = 0, ) -> bool: """ Whether ``img`` satisfy at least one of the conditions composed by ``filter_props``, ``filter_vals``, ``filter_sings``. Parameters ---------- img : 4D/5D Numpy array Image to check if satisfy conditions. E.g. ``(z, y, x, num_classes)`` for 3D or ``(y, x, num_classes)`` for 2D. filter_props : list of lists of str Filter conditions to be applied. The three variables, ``filter_props``, ``filter_vals`` and ``filter_vals`` will compose a list of conditions to remove the images from the list. They are list of list of conditions. For instance, the conditions can be like this: ``[['A'], ['B','C']]``. Then, if the sample satisfies the first list of conditions, only 'A' in this first case (from ['A'] list), or satisfy 'B' and 'C' (from ['B','C'] list) it will be removed. In each sublist all the conditions must be satisfied. Available properties are: [``'foreground'``, ``'mean'``, ``'min'``, ``'max'``]. Each property descrition: * ``'foreground'`` is defined as the mask foreground percentage. * ``'mean'`` is defined as the mean value of the input. * ``'min'`` is defined as the min value of the input. * ``'max'`` is defined as the max value of the input. * ``'diff'`` is defined as the difference between ground truth and raw images. Require ``y_dataset`` to be provided. * ``'diff_by_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between raw image max and min. * ``'target_mean'`` is defined as the mean intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_min'`` is defined as the min intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'target_max'`` is defined as the max intensity value of the raw image targets. Require ``y_dataset`` to be provided. * ``'diff_by_target_min_max_ratio'`` is defined as the difference between ground truth and raw images multiplied by the ratio between ground truth image max and min. filter_vals : list of int/float Represent the values of the properties listed in ``filter_props`` that the images need to satisfy to not be dropped. filter_signs : list of list of str Signs to do the comparison. Options: [``'gt'``, ``'ge'``, ``'lt'``, ``'le'``] that corresponds to "greather than", e.g. ">", "greather equal", e.g. ">=", "less than", e.g. "<", and "less equal" e.g. "<=" comparisons. mask : 4D/5D Numpy array, optional Mask to check if satisfy "foreground" condition in ``filter_props``. E.g. ``(z, y, x, num_classes)`` for 3D or ``(y, x, num_classes)`` for 2D. img_ratio : float, optional Ratio of the input image. Expected to be ``(img.max - img.min)`` of the entire image. mask_ratio : float, optional Minimum value of the entire image. Expected to be ``(mask.max - mask.min)`` of the entire image. Returns ------- satisfy_conds : bool Whether if the sample satisfy one of the conditions or not. """ satisfy_conds = False # Check if the sample satisfies a condition for i, cond in enumerate(filter_props): comps = [] for j, c in enumerate(cond): if c == "foreground": assert mask is not None labels, npixels = np.unique((mask > 0).astype(np.uint8), return_counts=True) total_pixels = 1 for val in list(mask.shape): total_pixels *= val if labels[0] == 0: npixels = npixels[1:] value_to_compare = sum(npixels) / total_pixels elif c == "diff": assert mask is not None value_to_compare = np.sum(abs(img - mask)) elif c == "diff_by_min_max_ratio": assert mask is not None value_to_compare = np.sum(abs(img - mask)) * img_ratio elif c == "diff_by_target_min_max_ratio": assert mask is not None and mask_ratio is not None value_to_compare = np.sum(abs(img - mask)) * mask_ratio elif c == "min": value_to_compare = img.min() elif c == "max": value_to_compare = img.max() elif c == "mean": value_to_compare = img.mean() elif c == "target_min": assert mask is not None value_to_compare = mask.min() elif c == "target_max": assert mask is not None value_to_compare = mask.max() elif c == "target_mean": assert mask is not None value_to_compare = mask.mean() # Check each list of conditions if filter_signs[i][j] == "gt": if value_to_compare > filter_vals[i][j]: comps.append(True) else: comps.append(False) elif filter_signs[i][j] == "ge": if value_to_compare >= filter_vals[i][j]: comps.append(True) else: comps.append(False) elif filter_signs[i][j] == "lt": if value_to_compare < filter_vals[i][j]: comps.append(True) else: comps.append(False) elif filter_signs[i][j] == "le": if value_to_compare <= filter_vals[i][j]: comps.append(True) else: comps.append(False) # Check if the conditions where satified if all(comps): satisfy_conds = True break return satisfy_conds
[docs] def load_images_to_dataset( dataset: BiaPyDataset, crop_shape: Optional[Tuple[int, ...]], reflect_to_complete_shape: bool = False, convert_to_rgb: bool = False, is_mask: bool = False, is_3d: bool = False, preprocess_cfg: Optional[Dict] = None, preprocess_f: Optional[Callable] = None, zarr_data_information: Optional[Dict] = None, ): """ Load images into the ``dataset``: creating ``"img"`` key. The process done faster if the samples extracted from the same image are in continuous positions within the list. Parameters ---------- dataset : BiaPyDataset Loaded data. crop_shape : 3D/4D int tuple Shape of the expected crops. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. reflect_to_complete_shape : bool, optional Whether to increase the shape of the dimension that have less size than selected patch size padding it with 'reflect'. convert_to_rgb : bool, optional In case RGB images are expected, e.g. if ``crop_shape`` channel is 3, those images that are grayscale are converted into RGB. preprocess_cfg : dict, optional Configuration parameters for preprocessing, is necessary in case you want to apply any preprocessing. is_mask : bool, optional Whether the data are masks. It is used to control the preprocessing of the data. preprocess_f : function, optional The preprocessing function, is necessary in case you want to apply any preprocessing. is_3d: bool, optional Whether the data to load is expected to be 3D or not. zarr_data_information : dict, optional Additional information of where to find the data within the Zarr files. """ if preprocess_f and preprocess_cfg == None: raise ValueError("The preprocessing configuration ('preprocess_cfg') is missing.") channel_expected = -1 data_range_expected = -1 img_path = "" file = None for sample in tqdm(dataset.sample_list, total=len(dataset.sample_list), disable=not is_main_process()): # Read image if it is different from the last sample's filepath = dataset.dataset_info[sample.fid].path if img_path != filepath: img_path = filepath if file and isinstance(file, h5py.File): file.close() data_within_zarr_path = None if zarr_data_information and zarr_data_information["multiple_data_within_zarr"]: if not is_mask: data_within_zarr_path = zarr_data_information["raw_path"] else: data_within_zarr_path = ( zarr_data_information["gt_path"] if zarr_data_information["use_gt_path"] else None ) data, file = load_img_data(img_path, is_3d=is_3d, data_within_zarr_path=data_within_zarr_path) # Disable channel checking if it is not present. Can happen with a Zarr/H5 dataset, as in the load_img_data() # the axis were not checked to not have the data loaded in memory check_channel = True key = "input_img_axes" if not is_mask else "input_mask_axes" if zarr_data_information and "C" not in zarr_data_information[key]: check_channel = False # Channel check within dataset images if channel_expected == -1: channel_expected = data.shape[-1] if not convert_to_rgb else 3 if check_channel and not convert_to_rgb and data.shape[-1] != channel_expected: raise ValueError( f"All images need to have the same number of channels and represent same information to " "ensure the deep learning model can be trained correctly. However, the current image (with " f"{channel_expected} channels) appears to have a different number of channels than the first image" f" (with {data.shape[-1]} channels) in the folder. Current image: {img_path}" ) # Channel check compared with crop_shape if check_channel and crop_shape and not is_mask: channel_to_compare = data.shape[-1] if not convert_to_rgb else 3 if crop_shape[-1] != channel_to_compare: if not convert_to_rgb: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], channel_to_compare) ) else: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {} " "(remember that 'DATA.FORCE_RGB' was selected). Please, check the channels of the " "images!".format(crop_shape[-1], channel_to_compare) ) # Data range check if not is_mask and isinstance(data, np.ndarray): if data_range_expected == -1: data_range_expected = data_range(data) drange = data_range(data) if data_range_expected != drange: raise ValueError( f"All images must be within the same data range. However, the current image (with a " f"range of {drange}) appears to be in a different data range than the first image (with a range " f"of {data_range_expected}) in the folder. Current image: {img_path}" ) # Apply preprocessing if preprocess_f: if is_mask: data = preprocess_f(preprocess_cfg, y_data=[data], is_2d=not is_3d, is_y_mask=is_mask)[0] else: data = preprocess_f(preprocess_cfg, x_data=[data], is_2d=not is_3d)[0] # Prepare slices to extract the patch if sample.coords and sample.coords: coords = sample.coords if is_3d: xslices = ( slice(None), slice(coords.z_start, coords.z_end), slice(coords.y_start, coords.y_end), slice(coords.x_start, coords.x_end), slice(None), ) else: xslices = ( slice(None), slice(coords.y_start, coords.y_end), slice(coords.x_start, coords.x_end), slice(None), ) if zarr_data_information: data_ordered_slices = order_dimensions( xslices, input_order="TZYXC", output_order=zarr_data_information[key], default_value=0, ) else: data_ordered_slices = xslices[1:] # Extract the patch within the image img = data[data_ordered_slices] # type: ignore if zarr_data_information: img = ensure_3d_shape(img.squeeze(), path=filepath) else: img = data if crop_shape and reflect_to_complete_shape: img = pad_and_reflect(img, crop_shape, verbose=False) if crop_shape and crop_shape[-1] == 3 and convert_to_rgb and not is_mask and img.shape[-1] != 3: img = np.repeat(img, 3, axis=-1) # Insert the image sample.img = img sshape = dataset.sample_list[0].get_shape() if sshape: data_shape = (len(dataset.sample_list),) + sshape print("*** Loaded data shape is {}".format(data_shape)) else: print( "Samples of shape {} will be randomly extracted. Number of samples: {}".format( crop_shape, len(dataset.sample_list) ) )
[docs] def pad_and_reflect(img: NDArray, crop_shape: Tuple[int, ...], verbose: bool = False) -> NDArray: """ Load data from a directory. Parameters ---------- img : 3D/4D Numpy array Image to pad. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``. crop_shape : Tuple of 3/4 int, optional Shape of the subvolumes to create when cropping. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``. verbose : bool, optional Whether to output information. Returns ------- img : 3D/4D Numpy array Image padded. E.g. ``(y, x, channels)`` for 2D and ``(z, y, x, channels)`` for 3D. """ if img.ndim == 4 and len(crop_shape) < 4: raise ValueError( f"'crop_shape' needs to have 4 at least values as the input array has 4 dims. Provided crop_shape: {crop_shape}" ) if img.ndim == 3 and len(crop_shape) < 3: raise ValueError( f"'crop_shape' needs to have 3 at least values as the input array has 3 dims. Provided crop_shape: {crop_shape}" ) if img.ndim == 4: if img.shape[0] < crop_shape[0]: diff = crop_shape[0] - img.shape[0] o_shape = img.shape img = np.pad(img, ((diff, 0), (0, 0), (0, 0), (0, 0)), "reflect") if verbose: print("Reflected from {} to {}".format(o_shape, img.shape)) if img.shape[1] < crop_shape[1]: diff = crop_shape[1] - img.shape[1] o_shape = img.shape img = np.pad(img, ((0, 0), (diff, 0), (0, 0), (0, 0)), "reflect") if verbose: print("Reflected from {} to {}".format(o_shape, img.shape)) if img.shape[2] < crop_shape[2]: diff = crop_shape[2] - img.shape[2] o_shape = img.shape img = np.pad(img, ((0, 0), (0, 0), (diff, 0), (0, 0)), "reflect") if verbose: print("Reflected from {} to {}".format(o_shape, img.shape)) else: if img.shape[0] < crop_shape[0]: diff = crop_shape[0] - img.shape[0] o_shape = img.shape img = np.pad(img, ((diff, 0), (0, 0), (0, 0)), "reflect") if verbose: print("Reflected from {} to {}".format(o_shape, img.shape)) if img.shape[1] < crop_shape[1]: diff = crop_shape[1] - img.shape[1] o_shape = img.shape img = np.pad(img, ((0, 0), (diff, 0), (0, 0)), "reflect") if verbose: print("Reflected from {} to {}".format(o_shape, img.shape)) return img
[docs] def extract_patch_within_image(img: NDArray, coords: PatchCoords, is_3d=False) -> NDArray: """ Extract patch within the image. Parameters ---------- img : 3D/4D Numpy array Input image to extract the patch from. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``. coords : dict Coordinates of the crop where the following keys are expected: * ``"z_start"``: starting point of the patch in Z axis. * ``"z_end"``: end point of the patch in Z axis. * ``"y_start"``: starting point of the patch in Y axis. * ``"y_end"``: end point of the patch in Y axis. * ``"x_start"``: starting point of the patch in X axis. * ``"x_end"``: end point of the patch in X axis. is_3d : bool, optional Whether if the expected image to read is 3D or not. Returns ------- img : 3D/4D Numpy array X element. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``. """ if not is_3d: img = img[ coords.y_start : coords.y_end, coords.x_start : coords.x_end, ] else: img = img[ coords.z_start : coords.z_end, coords.y_start : coords.y_end, coords.x_start : coords.x_end, ] return img
[docs] def img_to_onehot_encoding(img: NDArray, num_classes: int = 2) -> NDArray: """ Convert image given into one-hot encode format. The opposite function is :func:`~onehot_encoding_to_img`. Parameters ---------- img : Numpy 3D/4D array Image. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``. num_classes : int, optional Number of classes to distinguish. Returns ------- one_hot_labels : Numpy 3D/4D array Data one-hot encoded. E.g. ``(y, x, num_classes)`` or ``(z, y, x, num_classes)``. """ if img.ndim == 4: shape = img.shape[:3] + (num_classes,) else: shape = img.shape[:2] + (num_classes,) encoded_image = np.zeros(shape, dtype=np.int8) for i in range(num_classes): if img.ndim == 4: encoded_image[:, :, :, i] = np.all(img.reshape((-1, 1)) == i, axis=1).reshape(shape[:3]) else: encoded_image[:, :, i] = np.all(img.reshape((-1, 1)) == i, axis=1).reshape(shape[:2]) return encoded_image
[docs] def onehot_encoding_to_img(encoded_image: NDArray) -> NDArray: """ Convert one-hot encode image into an image with jus tone channel and all the classes represented by an integer. The opposite function is :func:`~img_to_onehot_encoding`. Parameters ---------- encoded_image : Numpy 3D/4D array Image. E.g. ``(y, x, channels)`` or ``(z, y, x, channels)``. Returns ------- img : Numpy 3D/4D array Data one-hot encoded. E.g. ``(z, y, x, num_classes)``. """ if encoded_image.ndim == 4: shape = encoded_image.shape[:3] + (1,) else: shape = encoded_image.shape[:2] + (1,) img = np.zeros(shape, dtype=np.int8) for i in range(img.shape[-1]): img[encoded_image[..., i] == 1] = i return img
[docs] def load_img_data( path: str, is_3d: bool = False, data_within_zarr_path: Optional[str] = None ) -> Tuple[NDArray[Any], str]: """ Load data from a given path. Parameters ---------- path : str Path to the image to read. is_3d : bool, optional Whether if the expected image to read is 3D or not. data_within_zarr_path : str, optional Path to find the data within the Zarr file. E.g. 'volumes.labels.neuron_ids'. Returns ------- data : Zarr, H5 or Numpy 3D/4D array Data read. E.g. ``(z, y, x, channels)`` for 3D or ``(y, x, channels)`` for 2D. file : str File of the data read. Useful to close it in case it is an H5 file. """ if looks_like_hdf5(path) or any(path.endswith(x) for x in [".zarr", "n5", ".n5"]): from biapy.data.data_3D_manipulation import ( read_chunked_data, read_chunked_nested_data, ) if data_within_zarr_path: file, data = read_chunked_nested_data(path, data_within_zarr_path) else: file, data = read_chunked_data(path) else: data = read_img_as_ndarray(path, is_3d=is_3d) file = path return data, file # type: ignore
[docs] def read_img_as_ndarray(path: str, is_3d: bool = False) -> NDArray: """ Read an image from a given path. Parameters ---------- path : str Path to the image to read. is_3d : bool, optional Whether if the expected image to read is 3D or not. Returns ------- img : Numpy 3D/4D array Image read. E.g. ``(z, y, x, channels)`` for 3D or ``(y, x, channels)`` for 2D. """ try: # Read image axes_position = None if path.endswith(".npy"): img = np.load(path) elif path.endswith(".pt"): img = torch.load(path, weights_only=True, map_location="cpu").numpy() elif path.endswith(".nii.gz"): img = nib.load(path) elif looks_like_hdf5(path): img = h5py.File(path, "r") img = np.array(img[list(img)[0]]) elif path.endswith(".zarr") or path.endswith(".n5") or path.endswith("n5"): from biapy.data.data_3D_manipulation import read_chunked_data _, img = read_chunked_data(path) img = np.array(img) else: img, axes_position = imread(path) img = np.squeeze(img) if not is_3d: img = ensure_2d_shape(img, path) else: img = ensure_3d_shape(img, path, data_axes_order=axes_position) except Exception as e: raise ValueError(f"Error reading image from path {path}. Error message: {e}") return img
[docs] def imread(path: str) -> NDArray | Tuple[NDArray, Optional[str]]: """ Read an image from a given path. In the past from ``skimage.io import imread`` was used but now it is deprecated. Parameters ---------- path : str Path to the image to read. Returns ------- img : Numpy array Image read. """ if path.lower().endswith((".tiff", ".tif")): try: with tifffile.TiffFile(path) as tif: return tif.series[0].asarray(), tif.series[0].axes except: return tifffile.imread(path), None else: return imageio.imread(path), None
[docs] def imwrite(path: str, image: NDArray): """ Write ``data`` in the given ``path``. In the past from ``skimage.io import imsave`` was used but now it is deprecated. Parameters ---------- path : str Path to the image to read. image : Numpy array Image to store. """ image = np.array(image) if path.lower().endswith((".tiff", ".tif")): assert image.ndim == 6, f"Image to write needs to have 6 dimensions (axes: TZCYXS). Image shape: {image.shape}" try: tifffile.imwrite( path, image, imagej=True, metadata={"axes": "TZCYXS"}, compression="zlib", compressionargs={"level": 8}, ) except: tifffile.imwrite(path, image, imagej=True, metadata={"axes": "TZCYXS"}) else: imageio.imwrite(path, image)
[docs] def check_value( value: int | float | Tuple[int | float] | List[int | float] | NDArray, value_range: Tuple[int | float, int | float] = (0, 1), ) -> bool: """ Check whether a value or a collection of values falls within a specified range. This function supports individual values (int, float), lists or tuples of values, and NumPy arrays. If `value` is a list or tuple, all elements must fall within the specified `value_range`. For NumPy arrays, both the minimum and maximum values of the array must be within the range. Parameters ---------- value : int, float, list, tuple or np.ndarray The value or collection of values to check. value_range : tuple of (int or float), optional A (min, max) tuple specifying the inclusive range of valid values. Default is (0, 1). Returns ------- bool True if all values are within the specified range; False otherwise. """ if isinstance(value, list) or isinstance(value, tuple): for i in range(len(value)): if isinstance(value[i], np.ndarray): if value_range[0] <= np.min(value[i]) or np.max(value[i]) <= value_range[1]: return False else: if not (value_range[0] <= value[i] <= value_range[1]): return False return True else: if isinstance(value, np.ndarray): if value_range[0] <= np.min(value) and np.max(value) <= value_range[1]: return True else: if value_range[0] <= value <= value_range[1]: return True return False
[docs] def data_range(x: NDArray) -> str: """ Determine the value range of a NumPy array commonly used in image data. This function checks whether the input array falls within one of the standard intensity ranges used in image processing: [0, 1], [0, 255], or [0, 65535], corresponding to normalized float, 8-bit, or 16-bit unsigned integer images, respectively. Parameters ---------- x : np.ndarray The input array whose range is to be determined. Returns ------- str A string indicating the value range: - "01 range" for values in [0, 1] - "uint8 range" for values in [0, 255] - "uint16 range" for values in [0, 65535] - "none_range" if values fall outside these common ranges Raises ------ ValueError If the input is not a NumPy array. """ if not isinstance(x, np.ndarray): raise ValueError("Input array of type {} and not numpy array".format(type(x))) if check_value(x, (0, 1)): return "01 range" elif check_value(x, (0, 255)): return "uint8 range" elif check_value(x, (0, 65535)): return "uint16 range" else: return "none_range"
[docs] def check_masks(path: str, n_classes: int = 2, is_3d: bool = False): """ Check whether the data masks have the correct labels inspection a few random images of the given path. If the function gives no error one should assume that the masks are correct. Parameters ---------- path : str Path to the data mask. n_classes : int, optional Maximum classes that the masks must contain. is_3d : bool, optional Whether if the expected image to read is 3D or not. """ print("Checking ground truth classes in {} . . .".format(path)) ids = next(os_walk_clean(path))[2] classes_found = [] m = "" error = False for i in tqdm(range(len(ids))): if looks_like_hdf5(ids[i]) or any(ids[i].endswith(x) for x in [".zarr", "n5", ".n5"]): raise ValueError( "Mask checking with Zarr not implemented in BiaPy yet. Disable 'DATA.*.CHECK_DATA' variables to continue" ) else: img = read_img_as_ndarray(os.path.join(path, ids[i]), is_3d=is_3d) values = np.unique(img) if len(values) > n_classes: print( "Error: given mask ({}) has more classes than specified in 'DATA.N_CLASSES'. " "Values found: {}".format(os.path.join(path, ids[i]), values) ) error = True classes_found += list(values) classes_found = list(set(classes_found)) if len(classes_found) > n_classes: formated_classes = [int(c) for c in classes_found] m += ( "Number of classes found across images is greater than the value specified in 'DATA.N_CLASSES'. " f"Classes found: {formated_classes}\n" ) error = True if error: m += ( "'DATA.N_CLASSES' variable value must be set taking into account the background class. E.g. if mask has [0,1,2] " "values 'DATA.N_CLASSES' should be 3.\nCorrect the errors in the masks above to continue" ) raise ValueError(m)
[docs] def shape_mismatch_message(X_data: BiaPyDataset, Y_data: BiaPyDataset) -> str: """ Build an error message with the shape mismatch between two provided data ``X_data`` and ``Y_data``. Parameters ---------- X_data : BiaPyDataset X data. Y_data : BiaPyDataset Y data. Returns ------- mistmatch_message : str Message containing which samples mismatch. """ mistmatch_message = "" for xsample, ysample in zip(X_data.sample_list, Y_data.sample_list): xshape = xsample.get_shape() yshape = ysample.get_shape() if xshape and yshape: if xshape[:-1] != yshape[:-1]: mistmatch_message += "\n" mistmatch_message += "Raw file: '{}'\n".format(X_data.dataset_info[xsample.fid].path) mistmatch_message += "Corresponding label file: '{}'\n".format(Y_data.dataset_info[ysample.fid].path) mistmatch_message += "Raw shape: {}\n".format(xsample.get_shape()) mistmatch_message += "Label shape: {}\n".format(ysample.get_shape()) mistmatch_message += "--\n" if mistmatch_message != "": mistmatch_message = ( f"Here is a list of the pair raw and label that does not match in shape:\n{mistmatch_message}" ) return mistmatch_message
[docs] def save_tif(X: NDArray, data_dir: str, filenames: Optional[List[str]] = None, verbose: bool = True): """ Save images in the given directory. If the input file has a different dtype than np.uint8, np.uint16, np.float32 it is casted into np.float32 automatically. This is done because if not the axes are not correctly set when opening resulting images in Fiji/ImageJ. Parameters ---------- X : 4D/5D numpy array Data to save as images. The first dimension must be the number of images. E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``. data_dir : str Path to store X images. filenames : List, optional Filenames that should be used when saving each image. verbose : bool, optional To print saving information. """ if verbose: s = X.shape if not isinstance(X, list) else X[0].shape print("Saving {} data as .tif in folder: {}".format(s, data_dir)) os.makedirs(data_dir, exist_ok=True) if filenames: if len(filenames) != len(X): raise ValueError( "Filenames array and length of X have different shapes: {} vs {}".format(len(filenames), len(X)) ) if not isinstance(X, list): _dtype = X.dtype if X.dtype in [np.uint8, np.uint16, np.float32] else np.float32 ndims = X.ndim else: _dtype = X[0].dtype if X[0].dtype in [np.uint8, np.uint16, np.float32] else np.float32 ndims = X[0].ndim d = len(str(len(X))) for i in tqdm(range(len(X)), leave=False, disable=not is_main_process()): if filenames is None: f = os.path.join(data_dir, str(i).zfill(d) + ".tif") else: f = os.path.join(data_dir, os.path.splitext(filenames[i])[0] + ".tif") if ndims == 4: if not isinstance(X, list): aux = np.expand_dims(np.expand_dims(X[i], 0).transpose((0, 3, 1, 2)), -1).astype(_dtype) else: aux = np.expand_dims(np.expand_dims(X[i][0], 0).transpose((0, 3, 1, 2)), -1).astype(_dtype) else: if not isinstance(X, list): aux = np.expand_dims(X[i].transpose((0, 3, 1, 2)), -1).astype(_dtype) else: aux = np.expand_dims(X[i][0].transpose((0, 3, 1, 2)), -1).astype(_dtype) imwrite(f, np.expand_dims(aux, 0))
[docs] def save_tif_pair_discard( X: NDArray, Y: NDArray, data_dir: str, suffix: str = "", filenames: Optional[List] = None, discard: bool = True, verbose: bool = True, ): """ Save images in the given directory. Parameters ---------- X : 4D/5D numpy array Data to save as images. The first dimension must be the number of images. E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``. Y : 4D/5D numpy array Data mask to save. The first dimension must be the number of images. E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``. data_dir : str Path to store X images. suffix : str, optional Suffix to apply on output directory. filenames : List, optional Filenames that should be used when saving each image. discard : bool, optional Whether to discard image/mask pairs if the mask has no label information. verbose : bool, optional To print saving information. """ if verbose: s = X.shape if not isinstance(X, list) else X[0].shape print("Saving {} data as .tif in folder: {}".format(s, data_dir)) os.makedirs(os.path.join(data_dir, "x" + suffix), exist_ok=True) os.makedirs(os.path.join(data_dir, "y" + suffix), exist_ok=True) if filenames: if len(filenames) != len(X): raise ValueError( "Filenames array and length of X have different shapes: {} vs {}".format(len(filenames), len(X)) ) _dtype = X.dtype if X.dtype in [np.uint8, np.uint16, np.float32] else np.float32 d = len(str(len(X))) for i in tqdm(range(X.shape[0]), leave=False, disable=not is_main_process()): if len(np.unique(Y[i])) >= 2 or not discard: if filenames is None: f1 = os.path.join(data_dir, "x" + suffix, str(i).zfill(d) + ".tif") f2 = os.path.join(data_dir, "y" + suffix, str(i).zfill(d) + ".tif") else: f1 = os.path.join(data_dir, "x" + suffix, os.path.splitext(filenames[i])[0] + ".tif") f2 = os.path.join(data_dir, "y" + suffix, os.path.splitext(filenames[i])[0] + ".tif") if X.ndim == 4: aux = np.expand_dims(np.expand_dims(X[i], 0).transpose((0, 3, 1, 2)), -1).astype(_dtype) else: aux = np.expand_dims(X[i].transpose((0, 3, 1, 2)), -1).astype(_dtype) imwrite(f1, np.expand_dims(aux, 0)) if Y.ndim == 4: aux = np.expand_dims(np.expand_dims(Y[i], 0).transpose((0, 3, 1, 2)), -1).astype(_dtype) else: aux = np.expand_dims(Y[i].transpose((0, 3, 1, 2)), -1).astype(_dtype) imwrite(f2, np.expand_dims(aux, 0))
[docs] def save_npy_files(X: NDArray, data_dir: str, filenames: Optional[List[str]] = None, verbose: bool = True): """ Save images in the given directory. Parameters ---------- X : 4D/5D numpy array Data to save as images. The first dimension must be the number of images. E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``. data_dir : str Path to store X images. filenames : List, optional Filenames that should be used when saving each image. verbose : bool, optional To print saving information. """ if verbose: s = X.shape if not isinstance(X, list) else X[0].shape print("Saving {} data as .npy in folder: {}".format(s, data_dir)) os.makedirs(data_dir, exist_ok=True) if filenames: if len(filenames) != len(X): raise ValueError( "Filenames array and length of X have different shapes: {} vs {}".format(len(filenames), len(X)) ) d = len(str(len(X))) for i in tqdm(range(len(X)), leave=False, disable=not is_main_process()): if filenames is None: f = os.path.join(data_dir, str(i).zfill(d) + ".npy") else: f = os.path.join(data_dir, os.path.splitext(filenames[i])[0] + ".npy") if isinstance(X, list): np.save(f, X[i][0]) else: np.save(f, X[i])
[docs] def reduce_dtype( x: NDArray, x_min: float, x_max: float, out_min: float = 0, out_max: float = 1, out_type: str = "float32", eps: float = 1e-6, ) -> NDArray: """ Reduce the data type of the given input to the selected range. It uses the following formula: ``results = ((x - x_min)/(x_max - x_min)) * (out_max - out_min)`` Parameters ---------- x : 3D/4D Numpy array Image to reduce it's data type. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``. x_min: float ``x_min`` in the formula above. x_max: float ``x_max`` in the formula above. out_min: float, optional ``out_min`` in the formula above. out_max: float, optional ``out_max`` in the formula above. out_type : str, optional Type of the output data. eps : float, optional Epsilon to use in order to avoid zero division. Returns ------- x : 3D/4D Numpy array Data type reduced image. E.g. ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)`` in ``3D``. """ from biapy.data.norm import torch_numpy_dtype_dict if isinstance(x, np.ndarray): if not isinstance(x, np.floating): x = x.astype(np.float32) return ((np.array((x - x_min) / (x_max - x_min + eps)) * (out_max - out_min)) + out_min).astype( torch_numpy_dtype_dict[out_type][1] ) else: # Tensor considered if not torch.is_floating_point(x): x = x.to(torch.float32) return ((((x - x_min) / (x_max - x_min + eps)) * (out_max - out_min)) + out_min).to( torch_numpy_dtype_dict[out_type][0] )
# Map common interpolation modes to: # - PyTorch 'mode' string # - scikit-image 'order' int interp_mode_map = { "nearest": {"torch": "nearest", "skimage": 0}, "linear": {"torch": "linear", "skimage": 1}, # 3D only in PyTorch "bilinear": {"torch": "bilinear", "skimage": 1}, "bicubic": {"torch": "bicubic", "skimage": 3}, "trilinear": {"torch": "trilinear", "skimage": 1}, # fallback for 3D "area": {"torch": "area", "skimage": 1}, # approximate "nearest-exact": {"torch": "nearest-exact", "skimage": 0}, }
[docs] def resize(input_data, size, mode="bilinear", **kwargs): """ Resize a multi-dimensional image tensor or array to a specified size. This function resizes 2D or 3D image data in either PyTorch tensor or NumPy array format using appropriate interpolation methods. The input is expected to follow common conventions for image dimensions. Supported input formats: - PyTorch tensor of shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D data - NumPy array of shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D data Parameters ---------- input_data : torch.Tensor or np.ndarray The image data to be resized. size : tuple of int Target size for each dimension. Must match the number of dimensions in input_data. Only spatial dimensions are resized (e.g., H, W, D), batch and channel dimensions are preserved. mode : str, optional Interpolation mode to use. Must be one of the keys in `interp_mode_map`. Defaults to 'bilinear'. **kwargs : dict Additional arguments passed to `torch.nn.functional.interpolate` or `skimage.transform.resize`. Returns ------- torch.Tensor or np.ndarray The resized image data in the same format as the input. Raises ------ ValueError If the length of `size` does not match the number of dimensions in `input_data`, or if an unsupported interpolation mode is specified. TypeError If `input_data` is neither a PyTorch tensor nor a NumPy array. """ if len(size) != input_data.ndim: raise ValueError( "The size provided ({}) needs to be of the same size as the dimensions of the input_data ({})".format( size, input_data.ndim ) ) if mode not in interp_mode_map: raise ValueError(f"Unsupported interpolation mode: {mode}") # Assumed B,C,H,W (2D) or B,C,D,H,W (3D) if isinstance(input_data, torch.Tensor): interp_mode = interp_mode_map[mode]["torch"] resized = F.interpolate(input_data, size=size[2:], mode=interp_mode, **kwargs) return resized # Assumed B,H,W,C (2D) or B,D,H,W,C (3D) elif isinstance(input_data, np.ndarray): order = interp_mode_map[mode]["skimage"] return sk_resize(input_data, size, order=order, **kwargs) else: raise TypeError("Input must be a torch.Tensor or a numpy.ndarray")
[docs] def decide_dtype(num_values: int) -> np.dtype: """ Decide the smallest unsigned integer dtype that can hold the given number of values. Parameters ---------- num_values : int The number of distinct values that need to be represented. Returns ------- np.dtype The smallest unsigned integer dtype that can represent `num_values` distinct values. Possible return values are np.uint8, np.uint16, or np.uint32. Raises ------ ValueError If `num_values` is negative or exceeds the maximum representable by np.uint32. """ if num_values < 0: raise ValueError("Number of values must be non-negative.") elif num_values <= 256: return np.uint8 elif num_values <= 65536: return np.uint16 elif num_values <= 4294967296: return np.uint32 else: raise ValueError("Number of values exceeds the maximum representable by uint32.")