Source code for biapy.data.data_3D_manipulation

"""
Module for 3D data manipulation utilities.

This module provides functions to process and manipulate 3D data volumes, including:

- Cropping/merging with overlap
- Padding and resizing
- Efficient loading of large 3D files
"""
from __future__ import annotations

import os
import math
from typing import Any, List, Optional, Sequence, Tuple, Union
import scipy.signal
import h5py
import zarr
try:
    import z5py
    _Z5PY_AVAILABLE = True
except ImportError:
    z5py = None  # type: ignore[assignment]
    _Z5PY_AVAILABLE = False
import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm
from biapy.utils.misc import is_main_process
from biapy.data.dataset import PatchCoords

ZarrOrH5File = Union[zarr.Group, zarr.Array, h5py.File]
ZarrOrH5Array = Union[zarr.Array, h5py.Dataset]

[docs] def load_3D_efficient_files( data_path: List[str], input_axes: str, crop_shape: Tuple[int, ...], overlap: Tuple[float, ...], padding: Tuple[int, ...], check_channel: bool = True, data_within_zarr_path: Optional[str] = None, ): """ Efficiently index 3D patches from Zarr or HDF5 image volumes for training or inference. This function computes and returns metadata about all the 3D patches that can be extracted from a list of multidimensional microscopy volumes, typically stored in Zarr or HDF5 formats. Patches are extracted using overlap and padding strategies without loading full image volumes into memory, allowing large datasets to be preprocessed efficiently. Parameters ---------- data_path : list of str List of paths to Zarr or HDF5 files containing the raw 3D image volumes. input_axes : str Axes layout of the image data in the files. Must be one of ['TZCYX', 'TZYXC', 'ZCYX', 'ZYXC']. crop_shape : tuple of int Shape of the 3D patches to be extracted, in the form (z, y, x, channels). overlap : tuple of float Minimum fractional overlap between neighboring patches in the z, y, and x dimensions. Values must be in the range [0.0, 1.0). padding : tuple of int Number of voxels to pad along each spatial axis (z, y, x) when patching. check_channel : bool, optional If True, verify that the channel dimension in the `crop_shape` matches the actual number of channels in the image volume. Default is True. data_within_zarr_path : str, optional Optional internal path to the dataset inside the Zarr or HDF5 file, e.g., 'volumes/raw' or 'volumes/labels/neuron_ids'. If None, the top-level dataset is used. Returns ------- data_info : dict Dictionary mapping patch index to patch metadata, with the following keys: - "filepath": path to the source file. - "full_shape": shape of the complete data volume. - "patch_coords": coordinates (start and end) of the extracted patch. data_info_total_patches : list of int List with the number of patches extracted from each file in `data_path`. Raises ------ ValueError If the input crop shape is not 4D or if the channel dimension does not match. """ data_info = {} data_total_patches = [] c = 0 assert len(crop_shape) == 4, f"Provided crop_shape is not a 4D tuple: {crop_shape}" for i, filename in enumerate(data_path): print(f"Reading Zarr/H5 file: {filename}") if data_within_zarr_path: file, data = read_chunked_nested_data(filename, data_within_zarr_path) else: file, data = read_chunked_data(filename) # Modify crop_shape with the channel c_index = -1 try: c_index = input_axes.index("C") crop_shape = crop_shape[:-1] + (data.shape[c_index],) except: pass # Get the total patches so we can use tqdm so the user can see the time obj = extract_3D_patch_with_overlap_and_padding_yield( data, crop_shape, input_axes, overlap=overlap, padding=padding, total_ranks=1, rank=0, return_only_stats=True, verbose=True, ) __unnamed_iterator = iter(obj) while True: try: obj = next(__unnamed_iterator) except StopIteration: # StopIteration caught here without inspecting it break del __unnamed_iterator total_patches, z_vol_info, list_of_vols_in_z = obj # type: ignore for obj in tqdm( extract_3D_patch_with_overlap_and_padding_yield( data, crop_shape, input_axes, overlap=overlap, padding=padding, total_ranks=1, rank=0, verbose=False, ), total=total_patches, # type: ignore disable=not is_main_process(), ): # type: ignore img, patch_coords, _, _, _ = obj # type: ignore data_info[c] = {} data_info[c]["filepath"] = filename data_info[c]["full_shape"] = data.shape data_info[c]["patch_coords"] = patch_coords c += 1 assert isinstance(img, np.ndarray) if check_channel and crop_shape[-1] != img.shape[-1]: raise ValueError( "Channel of the patch size given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(crop_shape[-1], img.shape[-1]) ) if isinstance(file, h5py.File): file.close() data_total_patches.append(total_patches) return data_info, data_total_patches
[docs] def load_img_part_from_efficient_file( filepath: str, patch_coords: PatchCoords, data_axes_order: str = "ZYXC", data_path: Optional[str] = None ): """ Load from ``filepath`` the patch determined by ``patch_coords``. Parameters ---------- filepath : str Path to the Zarr/H5 file to read the patch from. patch_coords : list of PatchCoords Coordinates of the crop. data_axes_order : str Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'. data_path : str, optional Path to find the data within the Zarr file. E.g. 'volumes.labels.neuron_ids'. Returns ------- img : Numpy array Extracted patch. E.g. ``(z, y, x, channels)``. """ if data_path: imgfile, img = read_chunked_nested_data(filepath, data_path) else: imgfile, img = read_chunked_data(filepath) img = extract_patch_from_efficient_file(img, patch_coords, data_axes_order=data_axes_order) if isinstance(imgfile, h5py.File): imgfile.close() return img
[docs] def extract_patch_from_efficient_file( data: zarr.Array | h5py.Dataset, patch_coords: PatchCoords, data_axes_order: str = "ZYXC", ) -> NDArray: """ Load from ``filepath`` the patch determined by ``patch_coords``. Parameters ---------- data : Zarr/H5 data Data to extract the patch from. patch_coords : PatchCoords Coordinates of the crop. data_axes_order : str Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'. Returns ------- img : Numpy array Extracted patch. E.g. ``(z, y, x, channels)``. """ pcoords = np.array( [ [patch_coords.z_start, patch_coords.z_end], [patch_coords.y_start, patch_coords.y_end], [patch_coords.x_start, patch_coords.x_end], ] ) # Prepare slices to extract the patch slices = [] for j in range(len(pcoords)): if isinstance(pcoords[j], int): # +1 to prevent 0 length axes that can not be removed with np.squeeze later slices.append(slice(0, pcoords[j] + 1)) else: slices.append(slice(pcoords[j][0], pcoords[j][1])) slices.append(slice(None)) # Channel # Convert slices into Zarr axis position data_ordered_slices = order_dimensions( tuple(slices), input_order="ZYXC", output_order=data_axes_order, default_value=0 ) # Extract patch try: img = np.squeeze(np.array(data[data_ordered_slices])) except: raise ValueError(f"Read data axes ({data.shape}) do not match the expected axis order ({data_axes_order})") # Try to correct the axes if img.ndim != len(data_axes_order): empty_axes = [ (patch_coords.z_end - patch_coords.z_start) <= 1, (patch_coords.y_end - patch_coords.y_start) <= 1, (patch_coords.x_end - patch_coords.x_start) <= 1, ] axes_to_add = [i for i, empty in enumerate(empty_axes) if empty] if axes_to_add: img = np.expand_dims(img, axis=tuple(axes_to_add)) img = ensure_3d_shape(img, data_axes_order=data_axes_order) return img
[docs] def insert_patch_in_efficient_file( data: zarr.Array | h5py.Dataset, patch: NDArray, patch_coords: PatchCoords, data_axes_order: str = "ZYXC", patch_axes_order: str = "ZYXC", mode="replace", ): """ Insert ``patch`` in ``data`` at ``patch_coords``. Parameters ---------- data : Zarr/H5 data Data to insert the patch into. patch : NDArray Patch to insert into ``data``. patch_coords : PatchCoords Coordinates of the patch. data_axes_order : str, optional Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'. patch_axes_order : str, optional Order of axes of ``patch``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'. mode : str, optional What to do with the patch data when inserting it. Options: ["sum", "replace"] """ assert mode in ["add", "replace"] # Adjust slices to calculate where to insert the predicted patch. This slice does not have into account the # channel so any of them can be inserted slices = ( slice(patch_coords.z_start, patch_coords.z_end), slice(patch_coords.y_start, patch_coords.y_end), slice(patch_coords.x_start, patch_coords.x_end), slice(None), ) data_ordered_slices = tuple( order_dimensions( slices, input_order="ZYXC", output_order=data_axes_order, default_value=0, ) ) # Adjust patch slice to transpose it before inserting intop the final data current_order = np.array(range(len(patch.shape))) transpose_order = order_dimensions( current_order, input_order=patch_axes_order, output_order=data_axes_order, default_value=np.nan, ) transpose_order = [x for x in transpose_order if not np.isnan(x)] # type: ignore # Insert the patch into the correspoding position if mode == "replace": data[data_ordered_slices] = patch.transpose(transpose_order) # type: ignore else: # add data[data_ordered_slices] += patch.transpose(transpose_order) # type: ignore
[docs] def crop_3D_data_with_overlap( data: NDArray, vol_shape: Tuple[int, ...], data_mask: Optional[NDArray] = None, overlap: Tuple[float, ...] = (0, 0, 0), padding: Tuple[int, ...] = (0, 0, 0), verbose: bool = True, median_padding: bool = False, load_data: bool = True, ) -> Union[Tuple[NDArray, NDArray, List['PatchCoords']], Tuple[NDArray, List['PatchCoords']], List['PatchCoords']]: """ Crop 3D data into smaller volumes with a defined overlap. The opposite function is :func:`~merge_3D_data_with_overlap`. Parameters ---------- data : 4D Numpy array Data to crop. E.g. ``(z, y, x, channels)``. vol_shape : 4D int tuple Shape of the volumes to create. E.g. ``(z, y, x, channels)``. data_mask : 4D Numpy array, optional Data mask to crop. E.g. ``(z, y, x, channels)``. overlap : Tuple of 3 floats, optional Amount of minimum overlap on x, y and z dimensions. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. E.g. ``(z, y, x)``. padding : tuple of ints, optional Size of padding to be added on each axis ``(z, y, x)``. E.g. ``(24, 24, 24)``. verbose : bool, optional To print information about the crop to be made. median_padding : bool, optional If ``True`` the padding value is the median value. If ``False``, the added values are zeroes. load_data : bool, optional Whether to create the patches or not. It saves memory in case you only need the coordiantes of the cropped patches. Returns ------- cropped_data : 5D Numpy array, optional Cropped image data. E.g. ``(vol_number, z, y, x, channels)``. Returned if ``load_data`` is ``True``. cropped_data_mask : 5D Numpy array, optional Cropped image data masks. E.g. ``(vol_number, z, y, x, channels)``. Returned if ``load_data`` is ``True`` and ``data_mask`` is provided. crop_coords : list of dict Coordinates of each crop where the following keys are available: * ``"z_start"``: starting point of the patch in Z axis. * ``"z_end"``: end point of the patch in Z axis. * ``"y_start"``: starting point of the patch in Y axis. * ``"y_end"``: end point of the patch in Y axis. * ``"x_start"``: starting point of the patch in X axis. * ``"x_end"``: end point of the patch in X axis. Examples -------- :: # EXAMPLE 1 # Following the example introduced in load_and_prepare_3D_data function, the cropping of a volume with shape # (165, 1024, 765) should be done by the following call: X_train = np.ones((165, 768, 1024, 1)) Y_train = np.ones((165, 768, 1024, 1)) X_train, Y_train = crop_3D_data_with_overlap(X_train, (80, 80, 80, 1), data_mask=Y_train, overlap=(0.5,0.5,0.5)) # The function will print the shape of the generated arrays. In this example: # **** New data shape is: (2600, 80, 80, 80, 1) A visual explanation of the process: .. image:: ../../img/crop_3D_ov.png :width: 80% :align: center Note: this image do not respect the proportions. :: # EXAMPLE 2 # Same data crop but without overlap X_train, Y_train = crop_3D_data_with_overlap(X_train, (80, 80, 80, 1), data_mask=Y_train, overlap=(0,0,0)) # The function will print the shape of the generated arrays. In this example: # **** New data shape is: (390, 80, 80, 80, 1) # # Notice how differs the amount of subvolumes created compared to the first example #EXAMPLE 2 #In the same way, if the addition of (64,64,64) padding is required, the call should be done as shown: X_train, Y_train = crop_3D_data_with_overlap( X_train, (80, 80, 80, 1), data_mask=Y_train, overlap=(0.5,0.5,0.5), padding=(64,64,64)) """ if verbose: print("### 3D-OV-CROP ###") print("Cropping {} images into {} with overlapping . . .".format(data.shape, vol_shape)) print("Minimum overlap selected: {}".format(overlap)) print("Padding: {}".format(padding)) if data.ndim != 4: raise ValueError("data expected to be 4 dimensional, given {}".format(data.shape)) if data_mask is not None: if data_mask.ndim != 4: raise ValueError("data_mask expected to be 4 dimensional, given {}".format(data_mask.shape)) if data.shape[:-1] != data_mask.shape[:-1]: raise ValueError( "data and data_mask shapes mismatch: {} vs {}".format(data.shape[:-1], data_mask.shape[:-1]) ) if len(vol_shape) != 4: raise ValueError("vol_shape expected to be of length 4, given {}".format(vol_shape)) for i, p in enumerate(padding): if p >= vol_shape[i] // 2: raise ValueError( "'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format( vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1) ) ) if vol_shape[0] > data.shape[0]: raise ValueError( "'vol_shape[0]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE' or use 'DATA.REFLECT_TO_COMPLETE_SHAPE')".format( vol_shape[0], data.shape[0] ) ) if vol_shape[1] > data.shape[1]: raise ValueError( "'vol_shape[1]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE' or use 'DATA.REFLECT_TO_COMPLETE_SHAPE')".format( vol_shape[1], data.shape[1] ) ) if vol_shape[2] > data.shape[2]: raise ValueError( "'vol_shape[2]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE' or use 'DATA.REFLECT_TO_COMPLETE_SHAPE')".format( vol_shape[2], data.shape[2] ) ) if ( (overlap[0] >= 1 or overlap[0] < 0) or (overlap[1] >= 1 or overlap[1] < 0) or (overlap[2] >= 1 or overlap[2] < 0) ): raise ValueError("'overlap' values must be floats between range [0, 1)") padded_data = np.pad( data, ( (padding[0], padding[0]), (padding[1], padding[1]), (padding[2], padding[2]), (0, 0), ), "reflect", ) if data_mask is not None: padded_data_mask = np.pad( data_mask, ( (padding[0], padding[0]), (padding[1], padding[1]), (padding[2], padding[2]), (0, 0), ), "reflect", ) if median_padding: padded_data[0 : padding[0], :, :, :] = np.median(data[0, :, :, :]) padded_data[padding[0] + data.shape[0] : 2 * padding[0] + data.shape[0], :, :, :] = np.median(data[-1, :, :, :]) padded_data[:, 0 : padding[1], :, :] = np.median(data[:, 0, :, :]) padded_data[:, padding[1] + data.shape[1] : 2 * padding[1] + data.shape[0], :, :] = np.median(data[:, -1, :, :]) padded_data[:, :, 0 : padding[2], :] = np.median(data[:, :, 0, :]) padded_data[:, :, padding[2] + data.shape[2] : 2 * padding[2] + data.shape[2], :] = np.median(data[:, :, -1, :]) padded_vol_shape = vol_shape # Calculate overlapping variables overlap_z = 1 if overlap[0] == 0 else 1 - overlap[0] overlap_y = 1 if overlap[1] == 0 else 1 - overlap[1] overlap_x = 1 if overlap[2] == 0 else 1 - overlap[2] # Z step_z = int((vol_shape[0] - padding[0] * 2) * overlap_z) vols_per_z = math.ceil(data.shape[0] / step_z) last_z = 0 if vols_per_z == 1 else (((vols_per_z - 1) * step_z) + vol_shape[0]) - padded_data.shape[0] ovz_per_block = last_z // (vols_per_z - 1) if vols_per_z > 1 else 0 step_z -= ovz_per_block last_z -= ovz_per_block * (vols_per_z - 1) # Y step_y = int((vol_shape[1] - padding[1] * 2) * overlap_y) vols_per_y = math.ceil(data.shape[1] / step_y) last_y = 0 if vols_per_y == 1 else (((vols_per_y - 1) * step_y) + vol_shape[1]) - padded_data.shape[1] ovy_per_block = last_y // (vols_per_y - 1) if vols_per_y > 1 else 0 step_y -= ovy_per_block last_y -= ovy_per_block * (vols_per_y - 1) # X step_x = int((vol_shape[2] - padding[2] * 2) * overlap_x) vols_per_x = math.ceil(data.shape[2] / step_x) last_x = 0 if vols_per_x == 1 else (((vols_per_x - 1) * step_x) + vol_shape[2]) - padded_data.shape[2] ovx_per_block = last_x // (vols_per_x - 1) if vols_per_x > 1 else 0 step_x -= ovx_per_block last_x -= ovx_per_block * (vols_per_x - 1) # Real overlap calculation for printing real_ov_z = ovz_per_block / (vol_shape[0] - padding[0] * 2) real_ov_y = ovy_per_block / (vol_shape[1] - padding[1] * 2) real_ov_x = ovx_per_block / (vol_shape[2] - padding[2] * 2) if verbose: print("Real overlapping (%): {}".format((real_ov_z, real_ov_y, real_ov_x))) print( "Real overlapping (pixels): {}".format( ( (vol_shape[0] - padding[0] * 2) * real_ov_z, (vol_shape[1] - padding[1] * 2) * real_ov_y, (vol_shape[2] - padding[2] * 2) * real_ov_x, ) ) ) print("{} patches per (z,y,x) axis".format((vols_per_z, vols_per_y, vols_per_x))) total_vol = vols_per_z * vols_per_y * vols_per_x if load_data: cropped_data = np.zeros((total_vol,) + padded_vol_shape, dtype=data.dtype) if data_mask is not None: cropped_data_mask = np.zeros( (total_vol,) + padded_vol_shape[:3] + (data_mask.shape[-1],), dtype=data_mask.dtype, ) c = 0 crop_coords = [] for z in range(vols_per_z): for y in range(vols_per_y): for x in range(vols_per_x): d_z = 0 if (z * step_z + vol_shape[0]) < padded_data.shape[0] else last_z d_y = 0 if (y * step_y + vol_shape[1]) < padded_data.shape[1] else last_y d_x = 0 if (x * step_x + vol_shape[2]) < padded_data.shape[2] else last_x if load_data: cropped_data[c] = padded_data[ z * step_z - d_z : z * step_z + vol_shape[0] - d_z, y * step_y - d_y : y * step_y + vol_shape[1] - d_y, x * step_x - d_x : x * step_x + vol_shape[2] - d_x, ] crop_coords.append( PatchCoords( z_start=z * step_z - d_z, z_end=z * step_z + vol_shape[0] - d_z, y_start=y * step_y - d_y, y_end=y * step_y + vol_shape[1] - d_y, x_start=x * step_x - d_x, x_end=x * step_x + vol_shape[2] - d_x, ) ) if load_data and data_mask is not None: cropped_data_mask[c] = padded_data_mask[ z * step_z - d_z : (z * step_z) + vol_shape[0] - d_z, y * step_y - d_y : y * step_y + vol_shape[1] - d_y, x * step_x - d_x : x * step_x + vol_shape[2] - d_x, ] c += 1 if verbose: if load_data: print("**** New data shape is: {}".format(cropped_data.shape)) print("### END 3D-OV-CROP ###") if load_data: if data_mask is not None: return cropped_data, cropped_data_mask, crop_coords else: return cropped_data, crop_coords else: return crop_coords
def _get_spline_window_3D(crop_shape: Tuple[int, ...], overlap_pixels: Tuple[int, int, int], power: int = 2) -> NDArray: """ Generate a 3D squared spline window for smooth blending. The window is designed to have values close to 1 in the center of the patch and smoothly taper to 0 towards the edges, with the transition controlled by the `power` parameter. Parameters ---------- crop_shape : tuple of int Shape of the 3D patch for which to generate the window, in the form (z, y, x, channels). overlap_pixels: tuple of int The exact number of overlapping pixels in the (z, y, x) dimensions to apply the taper across. power : int, optional Power to control the steepness of the window. Higher values will create a sharper transition from 1 to 0. Default is 2, which creates a smooth quadratic tapering. Returns ------- window : 4D Numpy array A 3D window of shape (z, y, x, 1) that can be applied to the patch for smooth blending. The values are normalized so that the average is 1, ensuring that the overall intensity of the patch is preserved when blended with others. """ def _spline_window_1D(size, ov_pixels, power=2): wind = np.ones(size, dtype=np.float32) if ov_pixels > 0: ov_pixels = min(ov_pixels, size // 2) x = np.linspace(0, 1, ov_pixels + 2)[1:-1] taper = (x ** power) / (x ** power + (1 - x) ** power + 1e-8) wind[:ov_pixels] = taper wind[-ov_pixels:] = taper[::-1] return wind # Generate 1D splines for Z, Y, and X dimensions wind_z = _spline_window_1D(crop_shape[0], overlap_pixels[0], power) wind_y = _spline_window_1D(crop_shape[1], overlap_pixels[1], power) wind_x = _spline_window_1D(crop_shape[2], overlap_pixels[2], power) # Expand dims to perform outer product for 3D window wind_z = wind_z[:, None, None] wind_y = wind_y[None, :, None] wind_x = wind_x[None, None, :] # Broadcast multiply to get shape (z, y, x) wind_3d = wind_z * wind_y * wind_x # Expand to match the channel dimension of the data: (z, y, x, 1) wind_3d = np.expand_dims(wind_3d, -1) return wind_3d.astype(np.float32)
[docs] def merge_3D_data_with_overlap( data: NDArray, orig_vol_shape: Tuple, data_mask: Optional[NDArray] = None, overlap: Tuple[float, ...] = (0, 0, 0), padding: Tuple[int, ...] = (0, 0, 0), verbose: bool = True, ) -> Union[NDArray, Tuple[NDArray, Optional[NDArray]]]: """ Merge 3D subvolumes in a 3D volume with a defined overlap. The opposite function is :func:`~crop_3D_data_with_overlap`. Parameters ---------- data : 5D Numpy array Data to crop. E.g. ``(volume_number, z, y, x, channels)``. orig_vol_shape : 4D int tuple Shape of the volumes to create. data_mask : 4D Numpy array, optional Data mask to crop. E.g. ``(volume_number, z, y, x, channels)``. overlap : Tuple of 3 floats, optional Amount of minimum overlap on x, y and z dimensions. Should be the same as used in :func:`~crop_3D_data_with_overlap`. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of overlap. E.g. ``(z, y, x)``. padding : tuple of ints, optional Size of padding to be added on each axis ``(z, y, x)``. E.g. ``(24, 24, 24)``. verbose : bool, optional To print information about the crop to be made. Returns ------- merged_data : 4D Numpy array Cropped image data. E.g. ``(z, y, x, channels)``. merged_data_mask : 5D Numpy array, optional Cropped image data masks. E.g. ``(z, y, x, channels)``. """ assert data.ndim == 5, f"data expected to be 5 dimensional, given {data.shape}" assert len(orig_vol_shape) == 4, f"orig_vol_shape expected to be 4 dimensional, given {orig_vol_shape}" if data_mask is not None: if data.shape[:-1] != data_mask.shape[:-1]: raise ValueError( "data and data_mask shapes mismatch: {} vs {}".format(data.shape[:-1], data_mask.shape[:-1]) ) if ( (overlap[0] >= 1 or overlap[0] < 0) or (overlap[1] >= 1 or overlap[1] < 0) or (overlap[2] >= 1 or overlap[2] < 0) ): raise ValueError("'overlap' values must be floats between range [0, 1)") if verbose: print("### MERGE-3D-OV-CROP ###") print("Merging {} images into {} with smooth blending . . .".format(data.shape, orig_vol_shape)) print("Minimum overlap selected: {}".format(overlap)) print("Padding: {}".format(padding)) # Remove the padding pad_input_shape = data.shape data = data[ :, padding[0] : data.shape[1] - padding[0], padding[1] : data.shape[2] - padding[1], padding[2] : data.shape[3] - padding[2], :, ] merged_data = np.zeros((orig_vol_shape), dtype=np.float32) if data_mask is not None: data_mask = data_mask[ :, padding[0] : data_mask.shape[1] - padding[0], padding[1] : data_mask.shape[2] - padding[1], padding[2] : data_mask.shape[3] - padding[2], :, ] merged_data_mask = np.zeros(orig_vol_shape[:3] + (data_mask.shape[-1],), dtype=np.float32) # Using float32 for the weight map to accurately accumulate spline weights weight_map_counter = np.zeros((orig_vol_shape[:-1] + (1,)), dtype=np.float32) # Calculate overlapping variables overlap_z = 1 if overlap[0] == 0 else 1 - overlap[0] overlap_y = 1 if overlap[1] == 0 else 1 - overlap[1] overlap_x = 1 if overlap[2] == 0 else 1 - overlap[2] padded_vol_shape = [ orig_vol_shape[0] + 2 * padding[0], orig_vol_shape[1] + 2 * padding[1], orig_vol_shape[2] + 2 * padding[2], ] # Z step_z = int((pad_input_shape[1] - padding[0] * 2) * overlap_z) vols_per_z = math.ceil(orig_vol_shape[0] / step_z) last_z = 0 if vols_per_z == 1 else (((vols_per_z - 1) * step_z) + pad_input_shape[1]) - padded_vol_shape[0] ovz_per_block = last_z // (vols_per_z - 1) if vols_per_z > 1 else 0 step_z -= ovz_per_block last_z -= ovz_per_block * (vols_per_z - 1) # Y step_y = int((pad_input_shape[2] - padding[1] * 2) * overlap_y) vols_per_y = math.ceil(orig_vol_shape[1] / step_y) last_y = 0 if vols_per_y == 1 else (((vols_per_y - 1) * step_y) + pad_input_shape[2]) - padded_vol_shape[1] ovy_per_block = last_y // (vols_per_y - 1) if vols_per_y > 1 else 0 step_y -= ovy_per_block last_y -= ovy_per_block * (vols_per_y - 1) # X step_x = int((pad_input_shape[3] - padding[2] * 2) * overlap_x) vols_per_x = math.ceil(orig_vol_shape[2] / step_x) last_x = 0 if vols_per_x == 1 else (((vols_per_x - 1) * step_x) + pad_input_shape[3]) - padded_vol_shape[2] ovx_per_block = last_x // (vols_per_x - 1) if vols_per_x > 1 else 0 step_x -= ovx_per_block last_x -= ovx_per_block * (vols_per_x - 1) # Calculate exact overlap in pixels for the dynamic window overlap_pixels_z = (pad_input_shape[1] - padding[0] * 2) - step_z overlap_pixels_y = (pad_input_shape[2] - padding[1] * 2) - step_y overlap_pixels_x = (pad_input_shape[3] - padding[2] * 2) - step_x # Generate the smooth blending window for the 3D patches patch_shape = (data.shape[1], data.shape[2], data.shape[3]) spline_window = _get_spline_window_3D(patch_shape, (overlap_pixels_z, overlap_pixels_y, overlap_pixels_x)) c = 0 for z in range(vols_per_z): for y in range(vols_per_y): for x in range(vols_per_x): d_z = 0 if (z * step_z + data.shape[1]) < orig_vol_shape[0] else last_z d_y = 0 if (y * step_y + data.shape[2]) < orig_vol_shape[1] else last_y d_x = 0 if (x * step_x + data.shape[3]) < orig_vol_shape[2] else last_x z_start = z * step_z - d_z z_end = z * step_z + data.shape[1] - d_z y_start = y * step_y - d_y y_end = y * step_y + data.shape[2] - d_y x_start = x * step_x - d_x x_end = x * step_x + data.shape[3] - d_x # Multiply patch by 3D spline window before adding merged_data[z_start:z_end, y_start:y_end, x_start:x_end] += (data[c] * spline_window) if data_mask is not None: merged_data_mask[z_start:z_end, y_start:y_end, x_start:x_end] += (data_mask[c] * spline_window) # Accumulate the 3D weights weight_map_counter[z_start:z_end, y_start:y_end, x_start:x_end] += spline_window c += 1 # Normalize data by accumulated weights, preventing division by zero merged_data = np.true_divide(merged_data, weight_map_counter + 1e-18).astype(data.dtype) if verbose: print("**** New data shape is: {}".format(merged_data.shape)) print("### END MERGE-3D-OV-CROP ###") if data_mask is not None: merged_data_mask = np.true_divide(merged_data_mask, weight_map_counter + 1e-18).astype(data_mask.dtype) return merged_data, merged_data_mask else: return merged_data
[docs] def extract_3D_patch_with_overlap_and_padding_yield( data: zarr.Array | h5py.Dataset, vol_shape: Tuple[int, ...], axes_order: str, overlap: Tuple[float, ...] = (0, 0, 0), padding: Tuple[int, ...] = (0, 0, 0), total_ranks: int = 1, rank: int = 0, return_only_stats: bool = False, load_data: bool = True, verbose: bool = False, ): """ Extract 3D patches into smaller patches with a defined overlap. Supports multi-GPU inference by setting ``total_ranks`` and ``rank`` variables. Each GPU will process an even number of volumes in the ``Z`` axis. If the number of volumes is not divisible by the number of GPUs, the first GPUs will process one more volume. Parameters ---------- data : Zarr array or H5 dataset Data to extract patches from. E.g. ``(z, y, x, channels)``. vol_shape : 4D int tuple Shape of the patches to create. E.g. ``(z, y, x, channels)``. axes_order : str Order of axes of ``data``. One of ['TZCYX', 'TZYXC', 'ZCYX', 'ZYXC']. overlap : Tuple of 3 floats, optional Amount of minimum overlap on x, y and z dimensions. Should be the same as used in :func:`~crop_3D_data_with_overlap`. Values must be in range ``[0, 1)``, representing 0% to 99% overlap. E.g. ``(z, y, x)``. padding : tuple of ints, optional Size of padding to be added on each axis ``(z, y, x)``. E.g. ``(24, 24, 24)``. total_ranks : int, optional Total number of GPUs. rank : int, optional Rank of the current GPU. return_only_stats : bool, optional Whether to just return crop statistics without yielding patches. Useful for precalculating the number of patches. load_data: bool, optional Whether to load data from file. Speeds up process if only patch coordinates are needed. verbose : bool, optional Whether to print debugging information. Yields ------ img : 4D Numpy array, optional Extracted patch from ``data``. E.g. ``(z, y, x, channels)``. Only returned if ``load_data`` is ``True``. real_patch_in_data : Tuple of tuples of ints Coordinates where patch should be inserted in original data. E.g. ``((0, 20), (0, 8), (16, 24))`` means the patch belongs at position [0:20,0:8,16:24] in the original data. total_vol : int Total number of crops to extract. z_vol_info : dict, optional Mapping of volume positions in original data. E.g. ``{0: [0, 20], 1: [20, 40]}`` means first volume goes at [0:20], second at [20:40]. list_of_vols_in_z : list of list of int, optional Volumes assigned to each GPU. E.g. ``[[0, 1, 2], [3, 4]]`` means GPU 0 processes volumes 0-2, GPU 1 processes volumes 3-4. """ if verbose and rank == 0: print("### 3D-OV-CROP ###") print( "Cropping {} images into {} with overlapping (axis order: {}). . .".format( data.shape, vol_shape, axes_order ) ) print("Minimum overlap selected: {}".format(overlap)) print("Padding: {}".format(padding)) if len(vol_shape) != 4: raise ValueError("vol_shape expected to be of length 4, given {}".format(vol_shape)) _, z_dim, c_dim, y_dim, x_dim = order_dimensions(data.shape, axes_order) assert isinstance(z_dim, int) and isinstance(x_dim, int) and isinstance(y_dim, int) and isinstance(c_dim, int) if vol_shape[0] > z_dim: raise ValueError( "'vol_shape[0]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE')".format(vol_shape[0], z_dim) ) if vol_shape[1] > y_dim: raise ValueError( "'vol_shape[1]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE')".format(vol_shape[1], y_dim) ) if vol_shape[2] > x_dim: raise ValueError( "'vol_shape[2]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE')".format(vol_shape[2], x_dim) ) if ( (overlap[0] >= 1 or overlap[0] < 0) or (overlap[1] >= 1 or overlap[1] < 0) or (overlap[2] >= 1 or overlap[2] < 0) ): raise ValueError("'overlap' values must be floats between range [0, 1)") if padding[0] >= vol_shape[0] // 2: raise ValueError( "'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format( vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1) ) ) if padding[1] >= vol_shape[1] // 2: raise ValueError( "'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format( vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1) ) ) if padding[2] >= vol_shape[2] // 2: raise ValueError( "'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format( vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1) ) ) padded_data_shape = [ z_dim + padding[0] * 2, y_dim + padding[1] * 2, x_dim + padding[2] * 2, c_dim, ] # Calculate overlapping variables overlap_z = 1 if overlap[0] == 0 else 1 - overlap[0] overlap_y = 1 if overlap[1] == 0 else 1 - overlap[1] overlap_x = 1 if overlap[2] == 0 else 1 - overlap[2] # Z step_z = int((vol_shape[0] - padding[0] * 2) * overlap_z) vols_per_z = math.ceil(z_dim / step_z) last_z = 0 if vols_per_z == 1 else (((vols_per_z - 1) * step_z) + vol_shape[0]) - padded_data_shape[0] ovz_per_block = last_z // (vols_per_z - 1) if vols_per_z > 1 else 0 step_z -= ovz_per_block last_z -= ovz_per_block * (vols_per_z - 1) # Y step_y = int((vol_shape[1] - padding[1] * 2) * overlap_y) vols_per_y = math.ceil(y_dim / step_y) last_y = 0 if vols_per_y == 1 else (((vols_per_y - 1) * step_y) + vol_shape[1]) - padded_data_shape[1] ovy_per_block = last_y // (vols_per_y - 1) if vols_per_y > 1 else 0 step_y -= ovy_per_block last_y -= ovy_per_block * (vols_per_y - 1) # X step_x = int((vol_shape[2] - padding[2] * 2) * overlap_x) vols_per_x = math.ceil(x_dim / step_x) last_x = 0 if vols_per_x == 1 else (((vols_per_x - 1) * step_x) + vol_shape[2]) - padded_data_shape[2] ovx_per_block = last_x // (vols_per_x - 1) if vols_per_x > 1 else 0 step_x -= ovx_per_block last_x -= ovx_per_block * (vols_per_x - 1) # Real overlap calculation for printing real_ov_z = ovz_per_block / (vol_shape[0] - padding[0] * 2) real_ov_y = ovy_per_block / (vol_shape[1] - padding[1] * 2) real_ov_x = ovx_per_block / (vol_shape[2] - padding[2] * 2) if verbose and rank == 0: print("Real overlapping (%): {}".format((real_ov_z, real_ov_y, real_ov_x))) print( "Real overlapping (pixels): {}".format( ( (vol_shape[0] - padding[0] * 2) * real_ov_z, (vol_shape[1] - padding[1] * 2) * real_ov_y, (vol_shape[2] - padding[2] * 2) * real_ov_x, ) ) ) print("{} patches per (z,y,x) axis".format((vols_per_z, vols_per_x, vols_per_y))) vols_in_z = vols_per_z // total_ranks vols_per_z_per_rank = vols_in_z if vols_per_z % total_ranks > rank: vols_per_z_per_rank += 1 total_vol = vols_per_z_per_rank * vols_per_y * vols_per_x c = 0 list_of_vols_in_z = [] z_vol_info = {} for i in range(total_ranks): vols = (vols_per_z // total_ranks) + 1 if vols_per_z % total_ranks > i else vols_in_z for j in range(vols): z = c + j real_start_z = z * step_z real_finish_z = min(real_start_z + step_z + ovz_per_block, z_dim) z_vol_info[z] = [real_start_z, real_finish_z] list_of_vols_in_z.append(list(range(c, c + vols))) c += vols if verbose and rank == 0: print(f"List of volume IDs to be processed by each GPU: {list_of_vols_in_z}") print(f"Positions of each volume in Z axis: {z_vol_info}") print( "Rank {}: Total number of patches: {} - {} patches per (z,y,x) axis (per GPU)".format( rank, total_vol, (vols_per_z_per_rank, vols_per_x, vols_per_y) ) ) if return_only_stats: yield total_vol, z_vol_info, list_of_vols_in_z return for _z in range(vols_per_z_per_rank): z = list_of_vols_in_z[rank][0] + _z for y in range(vols_per_y): for x in range(vols_per_x): d_z = 0 if (z * step_z + vol_shape[0]) < padded_data_shape[0] else last_z d_y = 0 if (y * step_y + vol_shape[1]) < padded_data_shape[1] else last_y d_x = 0 if (x * step_x + vol_shape[2]) < padded_data_shape[2] else last_x start_z = max(0, z * step_z - d_z - padding[0]) finish_z = min(z * step_z + vol_shape[0] - d_z - padding[0], z_dim) start_y = max(0, y * step_y - d_y - padding[1]) finish_y = min(y * step_y + vol_shape[1] - d_y - padding[1], y_dim) start_x = max(0, x * step_x - d_x - padding[2]) finish_x = min(x * step_x + vol_shape[2] - d_x - padding[2], x_dim) slices = [ slice(start_z, finish_z), slice(start_y, finish_y), slice(start_x, finish_x), slice(None), # Channel ] data_ordered_slices = order_dimensions( slices, input_order="ZYXC", output_order=axes_order, default_value=0 ) real_patch_in_data = PatchCoords( z_start=z * step_z - d_z, z_end=(z * step_z) + vol_shape[0] - d_z - (padding[0] * 2), y_start=y * step_y - d_y, y_end=(y * step_y) + vol_shape[1] - d_y - (padding[1] * 2), x_start=x * step_x - d_x, x_end=(x * step_x) + vol_shape[2] - d_x - (padding[2] * 2), ) if load_data: img = data[tuple(data_ordered_slices)] # The image should have the channel dimension at the end current_order = np.array(range(len(img.shape))) transpose_order = order_dimensions( current_order, # input_order="ZYXC", output_order=axes_order, default_value=np.nan, ) # determine the transpose order transpose_order = [x for x in transpose_order if not np.isnan(x)] # type: ignore transpose_order = np.argsort(transpose_order) # type: ignore transpose_order = current_order[transpose_order] img = np.transpose(img, transpose_order) pad_z_left = padding[0] - z * step_z - d_z if start_z <= 0 else 0 pad_z_right = (start_z + vol_shape[0]) - z_dim if start_z + vol_shape[0] > z_dim else 0 pad_y_left = padding[1] - y * step_y - d_y if start_y <= 0 else 0 pad_y_right = (start_y + vol_shape[1]) - y_dim if start_y + vol_shape[1] > y_dim else 0 pad_x_left = padding[2] - x * step_x - d_x if start_x <= 0 else 0 pad_x_right = (start_x + vol_shape[2]) - x_dim if start_x + vol_shape[2] > x_dim else 0 if img.ndim == 3: img = np.pad( img, ( (pad_z_left, pad_z_right), (pad_y_left, pad_y_right), (pad_x_left, pad_x_right), ), "reflect", ) img = np.expand_dims(img, -1) else: img = np.pad( img, ( (pad_z_left, pad_z_right), (pad_y_left, pad_y_right), (pad_x_left, pad_x_right), (0, 0), ), "reflect", ) assert ( img.shape[:-1] == vol_shape[:-1] ), f"Image shape and expected shape differ: {img.shape} vs {vol_shape}" if rank == 0: yield img, real_patch_in_data, total_vol, z_vol_info, list_of_vols_in_z else: yield img, real_patch_in_data, total_vol else: if rank == 0: yield real_patch_in_data, total_vol, z_vol_info, list_of_vols_in_z else: yield real_patch_in_data, total_vol
[docs] def order_dimensions( data: Sequence[slice] | List[str | int] | Tuple[int, ...] | NDArray, input_order: str, output_order: str = "TZCYX", default_value: int | float = 1, ) -> Sequence[slice] | List[str | int] | Tuple[int, ...] | NDArray: """ Reorder data from any input order to output order. Parameters ---------- data : Numpy array like data to reorder. E.g. ``(z, y, x, channels)``. input_order : str Order of the input data. E.g. ``ZYXC``. output_order : str, optional Order of the output data. E.g. ``TZCYX``. default_value : int or float, optional Default value to use when a dimension is not present in the input order. Returns ------- shape : Tuple Reordered data. E.g. ``(t, z, channel, y, x)``. """ if input_order == output_order: return data output_data = [] for i in range(len(output_order)): if output_order[i] in input_order: output_data.append(data[input_order.index(output_order[i])]) else: output_data.append(default_value) return tuple(output_data)
[docs] def ensure_3d_shape( img: NDArray, path: Optional[str] = None, data_axes_order: Optional[str] = None, ): """ Read an image from a given path. Parameters ---------- img : NDArray Image read. path : str, optional Path of the image (just use to print possible errors). data_axes_order : str, optional Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'. Returns ------- img : Numpy 4D array Image read. E.g. ``(z, y, x, channels)``. """ if img.ndim < 3: if path: m = "Read image seems to be 2D: {}. Path: {}".format(img.shape, path) else: m = "Read image seems to be 2D: {}".format(img.shape) raise ValueError(m) elif img.ndim == 5: if img.shape[0] != 1: # It is assumed that the image is already prepared return img else: img = img[0] # pop T in data_axes_order if data_axes_order is not None: T_post = data_axes_order.index("T") if "T" in data_axes_order else None data_axes_order = data_axes_order.replace("T", "") if "Z" not in data_axes_order: if "C" in data_axes_order: data_axes_order = data_axes_order.replace("C", "Z") elif "I" in data_axes_order: data_axes_order = data_axes_order.replace("I", "Z") elif "Q" in data_axes_order: data_axes_order = data_axes_order.replace("Q", "Z") else: if len(data_axes_order) < img.ndim and T_post is not None: data_axes_order = data_axes_order[:T_post] + "Z" + data_axes_order[T_post:] if any([x for x in data_axes_order if x not in "ZYXC"]): data_axes_order = None new_pos = list(range(img.ndim)) if img.ndim == 3: if data_axes_order is None: # Ensure Z axis is always in the first position min_val = min(img.shape) z_pos = img.shape.index(min_val) if z_pos != 0: new_pos = [ z_pos, ] + [x for x in range(3) if x != z_pos] else: # Follows the axes order provided in data_axes_order new_pos = order_dimensions( np.array(range(len(data_axes_order))), input_order=data_axes_order, output_order="ZYX", default_value=np.nan, ) new_pos = [x for x in new_pos if not np.isnan(x)] # type: ignore img = img.transpose(new_pos) # type: ignore img = np.expand_dims(img, -1) else: if data_axes_order is None: # Ensure channel axis is always in the last position (assuming Z is already set) min_val = min(img.shape) z_pos = img.shape.index(min_val) if z_pos != 3: new_pos = [x for x in range(4) if x != z_pos] + [ z_pos, ] else: # Follows the axes order provided in data_axes_order new_pos = order_dimensions( np.array(range(len(data_axes_order))), input_order=data_axes_order, output_order="ZYXC", default_value=np.nan, ) new_pos = [x for x in new_pos if not np.isnan(x)] # type: ignore img = img.transpose(new_pos) # type: ignore return img
def _first_array_in_group(g: zarr.Group) -> zarr.Array: """Descend into the first array found (sorted by key) in a group.""" keys = sorted(list(g.keys())) if not keys: raise ValueError("Zarr group is empty (no arrays or subgroups).") obj = g[keys[0]] while isinstance(obj, zarr.Group): subkeys = sorted(list(obj.keys())) if not subkeys: raise ValueError("Zarr group contains no arrays (only empty groups).") obj = obj[subkeys[0]] if not isinstance(obj, zarr.Array): raise TypeError(f"Expected zarr.Array, found {type(obj)}") return obj def _first_dataset_in_h5_group(g: h5py.Group) -> h5py.Dataset: """Descend into the first dataset found (sorted by key) in an HDF5 group.""" keys = sorted(list(g.keys())) if not keys: raise ValueError("HDF5 group is empty (no datasets or subgroups).") obj = g[keys[0]] while isinstance(obj, h5py.Group): subkeys = sorted(list(obj.keys())) if not subkeys: raise ValueError("HDF5 group contains no datasets (only empty groups).") obj = obj[subkeys[0]] if not isinstance(obj, h5py.Dataset): raise TypeError(f"Expected h5py.Dataset, found {type(obj)}") return obj
[docs] def read_chunked_nested_data( file: str, data_path: str = "" ) -> Tuple[ZarrOrH5File, ZarrOrH5Array]: """Find recursively raw and ground truth data within a H5/Zarr file. This function automatically detects whether the input file is in HDF5 or Zarr format and returns the appropriate file handler and dataset objects. Parameters ---------- file : str Path to the input file. Supported formats: .h5, .hdf5, .hdf, .n5, .zarr data_path : str, optional Internal path within the file where data is stored. Default: "" (root level) Returns ------- tuple Returns one of: - (zarr.Group, zarr.core.Array) for Zarr/N5 files - (h5py.File, h5py.Dataset) for HDF5 files Raises ------ ValueError If the input file format is neither Zarr nor HDF5 Examples -------- >>> file_handler, dataset = read_chunked_nested_data("data.h5") >>> zarr_group, zarr_array = read_chunked_nested_data("data.zarr") """ if looks_like_hdf5(file): return read_chunked_nested_h5(file, data_path) elif any(file.endswith(x) for x in [".n5", "n5", ".zarr"]): return read_chunked_nested_zarr(file, data_path) else: raise ValueError("Input file seems to not be either Zarr/H5/n5. Supported formats: .h5, .hdf5, .hdf, .n5, .zarr")
[docs] def read_chunked_nested_zarr(zarrfile: str, data_path: str = "") -> Tuple[zarr.Group, zarr.Array]: """Find recursively raw and ground truth data within a Zarr/N5 file. This function searches through a Zarr/N5 file hierarchy to locate array data at the specified path. It supports nested group structures. Parameters ---------- zarrfile : str Path to the Zarr/N5 file. Must have .zarr or .n5 extension. data_path : str, optional Internal path to the dataset within the Zarr hierarchy, using dot notation for nested groups (e.g., "group1.subgroup.data"). Default: "" (root level). Returns ------- tuple A tuple containing: - zarr.Group: The root group of the Zarr file - zarr.core.Array: The found array data Raises ------ ValueError If the file extension is not .zarr or .n5 If the specified data_path is not found in the Zarr hierarchy Examples -------- >>> group, array = read_chunked_nested_zarr("data.zarr") >>> subgroup, dataset = read_chunked_nested_zarr("experiment.n5", "images.channel1") """ if not any(zarrfile.endswith(x) for x in [".n5", "n5", ".zarr"]): raise ValueError("Not implemented for other filetypes than Zarr") if zarrfile.endswith(".n5") or zarrfile.endswith("n5"): if not _Z5PY_AVAILABLE: raise ImportError( "z5py is required for N5 format support but is not installed. " "Install it via conda: conda install -c conda-forge z5py" ) fid = z5py.File(zarrfile, "r") else: fid = zarr.open(zarrfile, mode="r") def find_obj(path: str, fid: zarr.Group): # type: ignore obj = None rpath = path.split(".") if len(rpath) == 0: return None else: if len(rpath) > 1: groups = list(fid.group_keys()) if rpath[0] not in groups: return None obj = find_obj(".".join(rpath[1:]), fid[rpath[0]]) else: try: arrays = list(fid.array_keys()) except: arrays = list(fid.keys()) if rpath[0] not in arrays: return None return fid[rpath[0]] return obj data = find_obj(data_path, fid) if data is None and data_path != "": raise ValueError(f"'{data_path}' not found in Zarr: {zarrfile}.") return fid, data # type: ignore
[docs] def read_chunked_nested_h5(h5file: str, data_path: str = "") -> Tuple[h5py.File, h5py.Dataset]: """Find recursively raw and ground truth data within an HDF5 file. This function searches through an HDF5 file hierarchy to locate dataset objects at the specified path. It supports nested group structures. Parameters ---------- h5file : str Path to the HDF5 file. Must have .h5, .hdf5, or .hdf extension. data_path : str, optional Internal path to the dataset within the HDF5 hierarchy, using dot notation for nested groups (e.g., "group1/subgroup/data"). Default: "" (root level). Returns ------- tuple A tuple containing: - h5py.File: The opened HDF5 file object - h5py.Dataset: The found dataset object Raises ------ ValueError If the file extension is not .h5, .hdf5, or .hdf If the specified data_path is not found in the HDF5 hierarchy Examples -------- >>> file, dataset = read_chunked_nested_h5("data.h5") >>> file, subgroup_data = read_chunked_nested_h5("experiment.hdf5", "images/channel1") """ if not looks_like_hdf5(h5file): raise ValueError("Not implemented for other filetypes than H5") fid = h5py.File(h5file, "r") try: if not data_path: return fid, _first_dataset_in_h5_group(fid) # allow both "a.b.c" and "a/b/c" normalized = data_path.replace(".", "/").strip("/") obj: h5py.Group | h5py.Dataset = fid for part in normalized.split("/"): if part not in obj: raise ValueError(f"'{data_path}' not found in H5: {h5file}. Available keys at this level: {list(obj.keys())}") obj = obj[part] if isinstance(obj, h5py.Group): return fid, _first_dataset_in_h5_group(obj) if isinstance(obj, h5py.Dataset): return fid, obj raise TypeError(f"Unexpected HDF5 object type at '{data_path}': {type(obj)}") except Exception: fid.close() raise
[docs] def read_chunked_data( filename: str, ) -> Tuple[ZarrOrH5File, ZarrOrH5Array]: """Read and return the first dataset found in an HDF5 or Zarr file. This function automatically detects the file format (HDF5 or Zarr) and returns the file handler along with the first available dataset. For Zarr files, it prioritizes groups over arrays when multiple items exist. Parameters ---------- filename : str Path to the input file. Supported formats: - HDF5: .h5, .hdf5, .hdf - Zarr: .zarr Returns ------- tuple Returns one of: - (h5py.File, h5py.Dataset) for HDF5 files - (zarr.Group, zarr.Array) for Zarr files The first dataset found in the file will be returned Raises ------ ValueError If the file doesn't exist If the file extension is not recognized If the input is not a string If no datasets are found in the file Examples -------- >>> file_handler, dataset = read_chunked_data("data.h5") >>> zarr_group, zarr_array = read_chunked_data("data.zarr") Notes ----- For Zarr files, the function will: 1. First look for groups and return the first group found 2. If no groups exist, return the first array found """ if isinstance(filename, str): if not os.path.exists(filename): raise ValueError(f"File {filename} does not exist.") if looks_like_hdf5(filename): fid = h5py.File(filename, "r") try: data = _first_dataset_in_h5_group(fid) except Exception: fid.close() raise ValueError(f"No datasets found in HDF5 file {filename}.") elif filename.endswith(".zarr") or filename.endswith(".n5"): try: fid = zarr.open(filename, mode="r") except Exception as e: raise ValueError(f"Error opening Zarr file {filename}: {e}") if isinstance(fid, zarr.Group): data = _first_array_in_group(fid) else: data = fid else: raise ValueError(f"File extension {filename} not recognized") return fid, data else: raise ValueError("'filename' is expected to be a str")
[docs] def looks_like_hdf5(path: str) -> bool: """ Check if a given file path corresponds to an HDF5 file based on its extension. Parameters ---------- path : str The file path to check. Returns ------- bool True if the file has an HDF5 extension, False otherwise. """ # robust extension handling (including ".hdf5.gz", etc.) p = path.lower() exts = (".h5", ".hdf5", ".hdf", ".he5") if p.endswith(exts): return True # handle double extensions like ".h5.gz" base, ext = os.path.splitext(p) if ext in (".gz", ".bz2", ".xz", ".zip") and base.endswith(exts): return True return False
[docs] def pick_chunks(shape: Tuple[int, ...], dtype: str, target_mb: float = 4.0) -> Tuple[int, ...]: """ Pick chunk sizes for HDF5 datasets based on the shape and data type. Parameters ---------- shape : tuple of int Shape of the dataset. dtype : str Data type of the dataset. target_mb : float, optional Target chunk size in megabytes. Default is 4.0 MB. Returns ------- tuple of int Chunk sizes for each dimension of the dataset. """ itemsize = np.dtype(dtype).itemsize target_bytes = int(target_mb * 1024 * 1024) # start with a conservative cap per dimension (keeps metadata manageable) chunks = [min(int(d), 256) for d in shape] # keep channels small-ish if present if len(shape) >= 4: chunks[-1] = min(int(shape[-1]), 16) def chunk_bytes() -> int: n = 1 for c in chunks: n *= max(1, int(c)) return n * itemsize # shrink largest dims until under target while chunk_bytes() > target_bytes: # find a dim we can shrink (prefer spatial over channels) # skip dims already at 1 candidates = [i for i, c in enumerate(chunks) if c > 1] if not candidates: break # avoid shrinking channels first when possible if len(shape) >= 4 and (len(chunks) - 1) in candidates and len(candidates) > 1: candidates.remove(len(chunks) - 1) # shrink the currently-largest candidate i = max(candidates, key=lambda j: chunks[j]) chunks[i] = max(1, chunks[i] // 2) return tuple(int(c) for c in chunks)
[docs] def load_synapse_gt_points( locations_path: str, resolution_path: str, partners_path: str, id_path: str, data_filename: str ) -> Dict[str, list]: """ Load synapse ground truth points from the given paths. Parameters ---------- locations_path : str Path to the synapse locations within the data file. resolution_path : str Path to the synapse resolution within the data file. partners_path : str Path to the synapse partners within the data file. id_path : str Path to the synapse ids within the data file. data_filename : str Path to the data file. Returns ------- gt_pre_points : list of numpy arrays List of pre-synaptic points coordinates. gt_post_points : list of numpy arrays List of post-synaptic points coordinates. gt_cleft_points : list of numpy arrays List of synaptic cleft points coordinates. resolution : tuple of int or float Resolution of the synapse coordinates. """ file, ids = read_chunked_nested_data(data_filename, id_path) ids = list(np.array(ids)) _, partners = read_chunked_nested_data(data_filename, partners_path) partners = np.array(partners) _, locations = read_chunked_nested_data(data_filename, locations_path) locations = np.array(locations) _, resolution = read_chunked_nested_data(data_filename, resolution_path) try: resolution = resolution.attrs["resolution"] except: raise ValueError( "There is no 'resolution' attribute in '{}'. Add it like: data['{}'].attrs['resolution'] = (8,8,8)".format( resolution_path, resolution_path ) ) resolution = list(resolution) gt_pre_points, gt_post_points = {}, {} for i in tqdm(range(len(partners)), disable=not is_main_process()): pre_id, post_id = partners[i] pre_position = ids.index(pre_id) post_position = ids.index(post_id) pre_coord = locations[pre_position] // resolution post_coord = locations[post_position] // resolution if str(pre_coord) not in gt_pre_points: gt_pre_points[str(pre_coord)] = pre_coord if str(post_coord) not in gt_post_points: gt_post_points[str(post_coord)] = post_coord gt_pre_points = list(gt_pre_points.values()) gt_post_points = list(gt_post_points.values()) # For synaptic cleft points, we take the midpoint between pre and post-synaptic points gt_cleft_points = [] for pre, post in zip(gt_pre_points, gt_post_points): cleft_point = (pre + post) / 2 gt_cleft_points.append(cleft_point) if isinstance(file, h5py.File): file.close() return {"pre": gt_pre_points, "post": gt_post_points, "cleft": gt_cleft_points, "resolution": resolution}