"""
Module for 3D data manipulation utilities.
This module provides functions to process and manipulate 3D data volumes, including:
- Cropping/merging with overlap
- Padding and resizing
- Efficient loading of large 3D files
"""
from __future__ import annotations
import os
import math
from typing import Any, List, Optional, Sequence, Tuple, Union
import scipy.signal
import h5py
import zarr
try:
import z5py
_Z5PY_AVAILABLE = True
except ImportError:
z5py = None # type: ignore[assignment]
_Z5PY_AVAILABLE = False
import numpy as np
from numpy.typing import NDArray
from tqdm import tqdm
from biapy.utils.misc import is_main_process
from biapy.data.dataset import PatchCoords
ZarrOrH5File = Union[zarr.Group, zarr.Array, h5py.File]
ZarrOrH5Array = Union[zarr.Array, h5py.Dataset]
[docs]
def load_3D_efficient_files(
data_path: List[str],
input_axes: str,
crop_shape: Tuple[int, ...],
overlap: Tuple[float, ...],
padding: Tuple[int, ...],
check_channel: bool = True,
data_within_zarr_path: Optional[str] = None,
):
"""
Efficiently index 3D patches from Zarr or HDF5 image volumes for training or inference.
This function computes and returns metadata about all the 3D patches that can be
extracted from a list of multidimensional microscopy volumes, typically stored in
Zarr or HDF5 formats. Patches are extracted using overlap and padding strategies
without loading full image volumes into memory, allowing large datasets to be
preprocessed efficiently.
Parameters
----------
data_path : list of str
List of paths to Zarr or HDF5 files containing the raw 3D image volumes.
input_axes : str
Axes layout of the image data in the files. Must be one of ['TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'].
crop_shape : tuple of int
Shape of the 3D patches to be extracted, in the form (z, y, x, channels).
overlap : tuple of float
Minimum fractional overlap between neighboring patches in the z, y, and x dimensions.
Values must be in the range [0.0, 1.0).
padding : tuple of int
Number of voxels to pad along each spatial axis (z, y, x) when patching.
check_channel : bool, optional
If True, verify that the channel dimension in the `crop_shape` matches the actual
number of channels in the image volume. Default is True.
data_within_zarr_path : str, optional
Optional internal path to the dataset inside the Zarr or HDF5 file, e.g.,
'volumes/raw' or 'volumes/labels/neuron_ids'. If None, the top-level dataset is used.
Returns
-------
data_info : dict
Dictionary mapping patch index to patch metadata, with the following keys:
- "filepath": path to the source file.
- "full_shape": shape of the complete data volume.
- "patch_coords": coordinates (start and end) of the extracted patch.
data_info_total_patches : list of int
List with the number of patches extracted from each file in `data_path`.
Raises
------
ValueError
If the input crop shape is not 4D or if the channel dimension does not match.
"""
data_info = {}
data_total_patches = []
c = 0
assert len(crop_shape) == 4, f"Provided crop_shape is not a 4D tuple: {crop_shape}"
for i, filename in enumerate(data_path):
print(f"Reading Zarr/H5 file: {filename}")
if data_within_zarr_path:
file, data = read_chunked_nested_data(filename, data_within_zarr_path)
else:
file, data = read_chunked_data(filename)
# Modify crop_shape with the channel
c_index = -1
try:
c_index = input_axes.index("C")
crop_shape = crop_shape[:-1] + (data.shape[c_index],)
except:
pass
# Get the total patches so we can use tqdm so the user can see the time
obj = extract_3D_patch_with_overlap_and_padding_yield(
data,
crop_shape,
input_axes,
overlap=overlap,
padding=padding,
total_ranks=1,
rank=0,
return_only_stats=True,
verbose=True,
)
__unnamed_iterator = iter(obj)
while True:
try:
obj = next(__unnamed_iterator)
except StopIteration: # StopIteration caught here without inspecting it
break
del __unnamed_iterator
total_patches, z_vol_info, list_of_vols_in_z = obj # type: ignore
for obj in tqdm(
extract_3D_patch_with_overlap_and_padding_yield(
data,
crop_shape,
input_axes,
overlap=overlap,
padding=padding,
total_ranks=1,
rank=0,
verbose=False,
),
total=total_patches, # type: ignore
disable=not is_main_process(),
): # type: ignore
img, patch_coords, _, _, _ = obj # type: ignore
data_info[c] = {}
data_info[c]["filepath"] = filename
data_info[c]["full_shape"] = data.shape
data_info[c]["patch_coords"] = patch_coords
c += 1
assert isinstance(img, np.ndarray)
if check_channel and crop_shape[-1] != img.shape[-1]:
raise ValueError(
"Channel of the patch size given {} does not correspond with the loaded image {}. "
"Please, check the channels of the images!".format(crop_shape[-1], img.shape[-1])
)
if isinstance(file, h5py.File):
file.close()
data_total_patches.append(total_patches)
return data_info, data_total_patches
[docs]
def load_img_part_from_efficient_file(
filepath: str, patch_coords: PatchCoords, data_axes_order: str = "ZYXC", data_path: Optional[str] = None
):
"""
Load from ``filepath`` the patch determined by ``patch_coords``.
Parameters
----------
filepath : str
Path to the Zarr/H5 file to read the patch from.
patch_coords : list of PatchCoords
Coordinates of the crop.
data_axes_order : str
Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'.
data_path : str, optional
Path to find the data within the Zarr file. E.g. 'volumes.labels.neuron_ids'.
Returns
-------
img : Numpy array
Extracted patch. E.g. ``(z, y, x, channels)``.
"""
if data_path:
imgfile, img = read_chunked_nested_data(filepath, data_path)
else:
imgfile, img = read_chunked_data(filepath)
img = extract_patch_from_efficient_file(img, patch_coords, data_axes_order=data_axes_order)
if isinstance(imgfile, h5py.File):
imgfile.close()
return img
[docs]
def insert_patch_in_efficient_file(
data: zarr.Array | h5py.Dataset,
patch: NDArray,
patch_coords: PatchCoords,
data_axes_order: str = "ZYXC",
patch_axes_order: str = "ZYXC",
mode="replace",
):
"""
Insert ``patch`` in ``data`` at ``patch_coords``.
Parameters
----------
data : Zarr/H5 data
Data to insert the patch into.
patch : NDArray
Patch to insert into ``data``.
patch_coords : PatchCoords
Coordinates of the patch.
data_axes_order : str, optional
Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'.
patch_axes_order : str, optional
Order of axes of ``patch``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'.
mode : str, optional
What to do with the patch data when inserting it. Options: ["sum", "replace"]
"""
assert mode in ["add", "replace"]
# Adjust slices to calculate where to insert the predicted patch. This slice does not have into account the
# channel so any of them can be inserted
slices = (
slice(patch_coords.z_start, patch_coords.z_end),
slice(patch_coords.y_start, patch_coords.y_end),
slice(patch_coords.x_start, patch_coords.x_end),
slice(None),
)
data_ordered_slices = tuple(
order_dimensions(
slices,
input_order="ZYXC",
output_order=data_axes_order,
default_value=0,
)
)
# Adjust patch slice to transpose it before inserting intop the final data
current_order = np.array(range(len(patch.shape)))
transpose_order = order_dimensions(
current_order,
input_order=patch_axes_order,
output_order=data_axes_order,
default_value=np.nan,
)
transpose_order = [x for x in transpose_order if not np.isnan(x)] # type: ignore
# Insert the patch into the correspoding position
if mode == "replace":
data[data_ordered_slices] = patch.transpose(transpose_order) # type: ignore
else: # add
data[data_ordered_slices] += patch.transpose(transpose_order) # type: ignore
[docs]
def crop_3D_data_with_overlap(
data: NDArray,
vol_shape: Tuple[int, ...],
data_mask: Optional[NDArray] = None,
overlap: Tuple[float, ...] = (0, 0, 0),
padding: Tuple[int, ...] = (0, 0, 0),
verbose: bool = True,
median_padding: bool = False,
load_data: bool = True,
) -> Union[Tuple[NDArray, NDArray, List['PatchCoords']], Tuple[NDArray, List['PatchCoords']], List['PatchCoords']]:
"""
Crop 3D data into smaller volumes with a defined overlap. The opposite function is :func:`~merge_3D_data_with_overlap`.
Parameters
----------
data : 4D Numpy array
Data to crop. E.g. ``(z, y, x, channels)``.
vol_shape : 4D int tuple
Shape of the volumes to create. E.g. ``(z, y, x, channels)``.
data_mask : 4D Numpy array, optional
Data mask to crop. E.g. ``(z, y, x, channels)``.
overlap : Tuple of 3 floats, optional
Amount of minimum overlap on x, y and z dimensions. The values must be on range ``[0, 1)``, that is, ``0%``
or ``99%`` of overlap. E.g. ``(z, y, x)``.
padding : tuple of ints, optional
Size of padding to be added on each axis ``(z, y, x)``. E.g. ``(24, 24, 24)``.
verbose : bool, optional
To print information about the crop to be made.
median_padding : bool, optional
If ``True`` the padding value is the median value. If ``False``, the added values are zeroes.
load_data : bool, optional
Whether to create the patches or not. It saves memory in case you only need the coordiantes of the cropped patches.
Returns
-------
cropped_data : 5D Numpy array, optional
Cropped image data. E.g. ``(vol_number, z, y, x, channels)``. Returned if ``load_data`` is ``True``.
cropped_data_mask : 5D Numpy array, optional
Cropped image data masks. E.g. ``(vol_number, z, y, x, channels)``. Returned if ``load_data`` is ``True``
and ``data_mask`` is provided.
crop_coords : list of dict
Coordinates of each crop where the following keys are available:
* ``"z_start"``: starting point of the patch in Z axis.
* ``"z_end"``: end point of the patch in Z axis.
* ``"y_start"``: starting point of the patch in Y axis.
* ``"y_end"``: end point of the patch in Y axis.
* ``"x_start"``: starting point of the patch in X axis.
* ``"x_end"``: end point of the patch in X axis.
Examples
--------
::
# EXAMPLE 1
# Following the example introduced in load_and_prepare_3D_data function, the cropping of a volume with shape
# (165, 1024, 765) should be done by the following call:
X_train = np.ones((165, 768, 1024, 1))
Y_train = np.ones((165, 768, 1024, 1))
X_train, Y_train = crop_3D_data_with_overlap(X_train, (80, 80, 80, 1), data_mask=Y_train,
overlap=(0.5,0.5,0.5))
# The function will print the shape of the generated arrays. In this example:
# **** New data shape is: (2600, 80, 80, 80, 1)
A visual explanation of the process:
.. image:: ../../img/crop_3D_ov.png
:width: 80%
:align: center
Note: this image do not respect the proportions.
::
# EXAMPLE 2
# Same data crop but without overlap
X_train, Y_train = crop_3D_data_with_overlap(X_train, (80, 80, 80, 1), data_mask=Y_train, overlap=(0,0,0))
# The function will print the shape of the generated arrays. In this example:
# **** New data shape is: (390, 80, 80, 80, 1)
#
# Notice how differs the amount of subvolumes created compared to the first example
#EXAMPLE 2
#In the same way, if the addition of (64,64,64) padding is required, the call should be done as shown:
X_train, Y_train = crop_3D_data_with_overlap(
X_train, (80, 80, 80, 1), data_mask=Y_train, overlap=(0.5,0.5,0.5), padding=(64,64,64))
"""
if verbose:
print("### 3D-OV-CROP ###")
print("Cropping {} images into {} with overlapping . . .".format(data.shape, vol_shape))
print("Minimum overlap selected: {}".format(overlap))
print("Padding: {}".format(padding))
if data.ndim != 4:
raise ValueError("data expected to be 4 dimensional, given {}".format(data.shape))
if data_mask is not None:
if data_mask.ndim != 4:
raise ValueError("data_mask expected to be 4 dimensional, given {}".format(data_mask.shape))
if data.shape[:-1] != data_mask.shape[:-1]:
raise ValueError(
"data and data_mask shapes mismatch: {} vs {}".format(data.shape[:-1], data_mask.shape[:-1])
)
if len(vol_shape) != 4:
raise ValueError("vol_shape expected to be of length 4, given {}".format(vol_shape))
for i, p in enumerate(padding):
if p >= vol_shape[i] // 2:
raise ValueError(
"'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format(
vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1)
)
)
if vol_shape[0] > data.shape[0]:
raise ValueError(
"'vol_shape[0]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE' or use 'DATA.REFLECT_TO_COMPLETE_SHAPE')".format(
vol_shape[0], data.shape[0]
)
)
if vol_shape[1] > data.shape[1]:
raise ValueError(
"'vol_shape[1]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE' or use 'DATA.REFLECT_TO_COMPLETE_SHAPE')".format(
vol_shape[1], data.shape[1]
)
)
if vol_shape[2] > data.shape[2]:
raise ValueError(
"'vol_shape[2]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE' or use 'DATA.REFLECT_TO_COMPLETE_SHAPE')".format(
vol_shape[2], data.shape[2]
)
)
if (
(overlap[0] >= 1 or overlap[0] < 0)
or (overlap[1] >= 1 or overlap[1] < 0)
or (overlap[2] >= 1 or overlap[2] < 0)
):
raise ValueError("'overlap' values must be floats between range [0, 1)")
padded_data = np.pad(
data,
(
(padding[0], padding[0]),
(padding[1], padding[1]),
(padding[2], padding[2]),
(0, 0),
),
"reflect",
)
if data_mask is not None:
padded_data_mask = np.pad(
data_mask,
(
(padding[0], padding[0]),
(padding[1], padding[1]),
(padding[2], padding[2]),
(0, 0),
),
"reflect",
)
if median_padding:
padded_data[0 : padding[0], :, :, :] = np.median(data[0, :, :, :])
padded_data[padding[0] + data.shape[0] : 2 * padding[0] + data.shape[0], :, :, :] = np.median(data[-1, :, :, :])
padded_data[:, 0 : padding[1], :, :] = np.median(data[:, 0, :, :])
padded_data[:, padding[1] + data.shape[1] : 2 * padding[1] + data.shape[0], :, :] = np.median(data[:, -1, :, :])
padded_data[:, :, 0 : padding[2], :] = np.median(data[:, :, 0, :])
padded_data[:, :, padding[2] + data.shape[2] : 2 * padding[2] + data.shape[2], :] = np.median(data[:, :, -1, :])
padded_vol_shape = vol_shape
# Calculate overlapping variables
overlap_z = 1 if overlap[0] == 0 else 1 - overlap[0]
overlap_y = 1 if overlap[1] == 0 else 1 - overlap[1]
overlap_x = 1 if overlap[2] == 0 else 1 - overlap[2]
# Z
step_z = int((vol_shape[0] - padding[0] * 2) * overlap_z)
vols_per_z = math.ceil(data.shape[0] / step_z)
last_z = 0 if vols_per_z == 1 else (((vols_per_z - 1) * step_z) + vol_shape[0]) - padded_data.shape[0]
ovz_per_block = last_z // (vols_per_z - 1) if vols_per_z > 1 else 0
step_z -= ovz_per_block
last_z -= ovz_per_block * (vols_per_z - 1)
# Y
step_y = int((vol_shape[1] - padding[1] * 2) * overlap_y)
vols_per_y = math.ceil(data.shape[1] / step_y)
last_y = 0 if vols_per_y == 1 else (((vols_per_y - 1) * step_y) + vol_shape[1]) - padded_data.shape[1]
ovy_per_block = last_y // (vols_per_y - 1) if vols_per_y > 1 else 0
step_y -= ovy_per_block
last_y -= ovy_per_block * (vols_per_y - 1)
# X
step_x = int((vol_shape[2] - padding[2] * 2) * overlap_x)
vols_per_x = math.ceil(data.shape[2] / step_x)
last_x = 0 if vols_per_x == 1 else (((vols_per_x - 1) * step_x) + vol_shape[2]) - padded_data.shape[2]
ovx_per_block = last_x // (vols_per_x - 1) if vols_per_x > 1 else 0
step_x -= ovx_per_block
last_x -= ovx_per_block * (vols_per_x - 1)
# Real overlap calculation for printing
real_ov_z = ovz_per_block / (vol_shape[0] - padding[0] * 2)
real_ov_y = ovy_per_block / (vol_shape[1] - padding[1] * 2)
real_ov_x = ovx_per_block / (vol_shape[2] - padding[2] * 2)
if verbose:
print("Real overlapping (%): {}".format((real_ov_z, real_ov_y, real_ov_x)))
print(
"Real overlapping (pixels): {}".format(
(
(vol_shape[0] - padding[0] * 2) * real_ov_z,
(vol_shape[1] - padding[1] * 2) * real_ov_y,
(vol_shape[2] - padding[2] * 2) * real_ov_x,
)
)
)
print("{} patches per (z,y,x) axis".format((vols_per_z, vols_per_y, vols_per_x)))
total_vol = vols_per_z * vols_per_y * vols_per_x
if load_data:
cropped_data = np.zeros((total_vol,) + padded_vol_shape, dtype=data.dtype)
if data_mask is not None:
cropped_data_mask = np.zeros(
(total_vol,) + padded_vol_shape[:3] + (data_mask.shape[-1],),
dtype=data_mask.dtype,
)
c = 0
crop_coords = []
for z in range(vols_per_z):
for y in range(vols_per_y):
for x in range(vols_per_x):
d_z = 0 if (z * step_z + vol_shape[0]) < padded_data.shape[0] else last_z
d_y = 0 if (y * step_y + vol_shape[1]) < padded_data.shape[1] else last_y
d_x = 0 if (x * step_x + vol_shape[2]) < padded_data.shape[2] else last_x
if load_data:
cropped_data[c] = padded_data[
z * step_z - d_z : z * step_z + vol_shape[0] - d_z,
y * step_y - d_y : y * step_y + vol_shape[1] - d_y,
x * step_x - d_x : x * step_x + vol_shape[2] - d_x,
]
crop_coords.append(
PatchCoords(
z_start=z * step_z - d_z,
z_end=z * step_z + vol_shape[0] - d_z,
y_start=y * step_y - d_y,
y_end=y * step_y + vol_shape[1] - d_y,
x_start=x * step_x - d_x,
x_end=x * step_x + vol_shape[2] - d_x,
)
)
if load_data and data_mask is not None:
cropped_data_mask[c] = padded_data_mask[
z * step_z - d_z : (z * step_z) + vol_shape[0] - d_z,
y * step_y - d_y : y * step_y + vol_shape[1] - d_y,
x * step_x - d_x : x * step_x + vol_shape[2] - d_x,
]
c += 1
if verbose:
if load_data:
print("**** New data shape is: {}".format(cropped_data.shape))
print("### END 3D-OV-CROP ###")
if load_data:
if data_mask is not None:
return cropped_data, cropped_data_mask, crop_coords
else:
return cropped_data, crop_coords
else:
return crop_coords
def _get_spline_window_3D(crop_shape: Tuple[int, ...], overlap_pixels: Tuple[int, int, int], power: int = 2) -> NDArray:
"""
Generate a 3D squared spline window for smooth blending. The window is designed to have values close
to 1 in the center of the patch and smoothly taper to 0 towards the edges, with the transition controlled
by the `power` parameter.
Parameters
----------
crop_shape : tuple of int
Shape of the 3D patch for which to generate the window, in the form (z, y, x, channels).
overlap_pixels: tuple of int
The exact number of overlapping pixels in the (z, y, x) dimensions to apply the taper across.
power : int, optional
Power to control the steepness of the window. Higher values will create a sharper transition from 1 to 0.
Default is 2, which creates a smooth quadratic tapering.
Returns
-------
window : 4D Numpy array
A 3D window of shape (z, y, x, 1) that can be applied to the patch for smooth blending. The values are normalized
so that the average is 1, ensuring that the overall intensity of the patch is preserved when blended with others.
"""
def _spline_window_1D(size, ov_pixels, power=2):
wind = np.ones(size, dtype=np.float32)
if ov_pixels > 0:
ov_pixels = min(ov_pixels, size // 2)
x = np.linspace(0, 1, ov_pixels + 2)[1:-1]
taper = (x ** power) / (x ** power + (1 - x) ** power + 1e-8)
wind[:ov_pixels] = taper
wind[-ov_pixels:] = taper[::-1]
return wind
# Generate 1D splines for Z, Y, and X dimensions
wind_z = _spline_window_1D(crop_shape[0], overlap_pixels[0], power)
wind_y = _spline_window_1D(crop_shape[1], overlap_pixels[1], power)
wind_x = _spline_window_1D(crop_shape[2], overlap_pixels[2], power)
# Expand dims to perform outer product for 3D window
wind_z = wind_z[:, None, None]
wind_y = wind_y[None, :, None]
wind_x = wind_x[None, None, :]
# Broadcast multiply to get shape (z, y, x)
wind_3d = wind_z * wind_y * wind_x
# Expand to match the channel dimension of the data: (z, y, x, 1)
wind_3d = np.expand_dims(wind_3d, -1)
return wind_3d.astype(np.float32)
[docs]
def merge_3D_data_with_overlap(
data: NDArray,
orig_vol_shape: Tuple,
data_mask: Optional[NDArray] = None,
overlap: Tuple[float, ...] = (0, 0, 0),
padding: Tuple[int, ...] = (0, 0, 0),
verbose: bool = True,
) -> Union[NDArray, Tuple[NDArray, Optional[NDArray]]]:
"""
Merge 3D subvolumes in a 3D volume with a defined overlap.
The opposite function is :func:`~crop_3D_data_with_overlap`.
Parameters
----------
data : 5D Numpy array
Data to crop. E.g. ``(volume_number, z, y, x, channels)``.
orig_vol_shape : 4D int tuple
Shape of the volumes to create.
data_mask : 4D Numpy array, optional
Data mask to crop. E.g. ``(volume_number, z, y, x, channels)``.
overlap : Tuple of 3 floats, optional
Amount of minimum overlap on x, y and z dimensions. Should be the same as used in
:func:`~crop_3D_data_with_overlap`. The values must be on range ``[0, 1)``, that is, ``0%`` or ``99%`` of
overlap. E.g. ``(z, y, x)``.
padding : tuple of ints, optional
Size of padding to be added on each axis ``(z, y, x)``. E.g. ``(24, 24, 24)``.
verbose : bool, optional
To print information about the crop to be made.
Returns
-------
merged_data : 4D Numpy array
Cropped image data. E.g. ``(z, y, x, channels)``.
merged_data_mask : 5D Numpy array, optional
Cropped image data masks. E.g. ``(z, y, x, channels)``.
"""
assert data.ndim == 5, f"data expected to be 5 dimensional, given {data.shape}"
assert len(orig_vol_shape) == 4, f"orig_vol_shape expected to be 4 dimensional, given {orig_vol_shape}"
if data_mask is not None:
if data.shape[:-1] != data_mask.shape[:-1]:
raise ValueError(
"data and data_mask shapes mismatch: {} vs {}".format(data.shape[:-1], data_mask.shape[:-1])
)
if (
(overlap[0] >= 1 or overlap[0] < 0)
or (overlap[1] >= 1 or overlap[1] < 0)
or (overlap[2] >= 1 or overlap[2] < 0)
):
raise ValueError("'overlap' values must be floats between range [0, 1)")
if verbose:
print("### MERGE-3D-OV-CROP ###")
print("Merging {} images into {} with smooth blending . . .".format(data.shape, orig_vol_shape))
print("Minimum overlap selected: {}".format(overlap))
print("Padding: {}".format(padding))
# Remove the padding
pad_input_shape = data.shape
data = data[
:,
padding[0] : data.shape[1] - padding[0],
padding[1] : data.shape[2] - padding[1],
padding[2] : data.shape[3] - padding[2],
:,
]
merged_data = np.zeros((orig_vol_shape), dtype=np.float32)
if data_mask is not None:
data_mask = data_mask[
:,
padding[0] : data_mask.shape[1] - padding[0],
padding[1] : data_mask.shape[2] - padding[1],
padding[2] : data_mask.shape[3] - padding[2],
:,
]
merged_data_mask = np.zeros(orig_vol_shape[:3] + (data_mask.shape[-1],), dtype=np.float32)
# Using float32 for the weight map to accurately accumulate spline weights
weight_map_counter = np.zeros((orig_vol_shape[:-1] + (1,)), dtype=np.float32)
# Calculate overlapping variables
overlap_z = 1 if overlap[0] == 0 else 1 - overlap[0]
overlap_y = 1 if overlap[1] == 0 else 1 - overlap[1]
overlap_x = 1 if overlap[2] == 0 else 1 - overlap[2]
padded_vol_shape = [
orig_vol_shape[0] + 2 * padding[0],
orig_vol_shape[1] + 2 * padding[1],
orig_vol_shape[2] + 2 * padding[2],
]
# Z
step_z = int((pad_input_shape[1] - padding[0] * 2) * overlap_z)
vols_per_z = math.ceil(orig_vol_shape[0] / step_z)
last_z = 0 if vols_per_z == 1 else (((vols_per_z - 1) * step_z) + pad_input_shape[1]) - padded_vol_shape[0]
ovz_per_block = last_z // (vols_per_z - 1) if vols_per_z > 1 else 0
step_z -= ovz_per_block
last_z -= ovz_per_block * (vols_per_z - 1)
# Y
step_y = int((pad_input_shape[2] - padding[1] * 2) * overlap_y)
vols_per_y = math.ceil(orig_vol_shape[1] / step_y)
last_y = 0 if vols_per_y == 1 else (((vols_per_y - 1) * step_y) + pad_input_shape[2]) - padded_vol_shape[1]
ovy_per_block = last_y // (vols_per_y - 1) if vols_per_y > 1 else 0
step_y -= ovy_per_block
last_y -= ovy_per_block * (vols_per_y - 1)
# X
step_x = int((pad_input_shape[3] - padding[2] * 2) * overlap_x)
vols_per_x = math.ceil(orig_vol_shape[2] / step_x)
last_x = 0 if vols_per_x == 1 else (((vols_per_x - 1) * step_x) + pad_input_shape[3]) - padded_vol_shape[2]
ovx_per_block = last_x // (vols_per_x - 1) if vols_per_x > 1 else 0
step_x -= ovx_per_block
last_x -= ovx_per_block * (vols_per_x - 1)
# Calculate exact overlap in pixels for the dynamic window
overlap_pixels_z = (pad_input_shape[1] - padding[0] * 2) - step_z
overlap_pixels_y = (pad_input_shape[2] - padding[1] * 2) - step_y
overlap_pixels_x = (pad_input_shape[3] - padding[2] * 2) - step_x
# Generate the smooth blending window for the 3D patches
patch_shape = (data.shape[1], data.shape[2], data.shape[3])
spline_window = _get_spline_window_3D(patch_shape, (overlap_pixels_z, overlap_pixels_y, overlap_pixels_x))
c = 0
for z in range(vols_per_z):
for y in range(vols_per_y):
for x in range(vols_per_x):
d_z = 0 if (z * step_z + data.shape[1]) < orig_vol_shape[0] else last_z
d_y = 0 if (y * step_y + data.shape[2]) < orig_vol_shape[1] else last_y
d_x = 0 if (x * step_x + data.shape[3]) < orig_vol_shape[2] else last_x
z_start = z * step_z - d_z
z_end = z * step_z + data.shape[1] - d_z
y_start = y * step_y - d_y
y_end = y * step_y + data.shape[2] - d_y
x_start = x * step_x - d_x
x_end = x * step_x + data.shape[3] - d_x
# Multiply patch by 3D spline window before adding
merged_data[z_start:z_end, y_start:y_end, x_start:x_end] += (data[c] * spline_window)
if data_mask is not None:
merged_data_mask[z_start:z_end, y_start:y_end, x_start:x_end] += (data_mask[c] * spline_window)
# Accumulate the 3D weights
weight_map_counter[z_start:z_end, y_start:y_end, x_start:x_end] += spline_window
c += 1
# Normalize data by accumulated weights, preventing division by zero
merged_data = np.true_divide(merged_data, weight_map_counter + 1e-18).astype(data.dtype)
if verbose:
print("**** New data shape is: {}".format(merged_data.shape))
print("### END MERGE-3D-OV-CROP ###")
if data_mask is not None:
merged_data_mask = np.true_divide(merged_data_mask, weight_map_counter + 1e-18).astype(data_mask.dtype)
return merged_data, merged_data_mask
else:
return merged_data
[docs]
def extract_3D_patch_with_overlap_and_padding_yield(
data: zarr.Array | h5py.Dataset,
vol_shape: Tuple[int, ...],
axes_order: str,
overlap: Tuple[float, ...] = (0, 0, 0),
padding: Tuple[int, ...] = (0, 0, 0),
total_ranks: int = 1,
rank: int = 0,
return_only_stats: bool = False,
load_data: bool = True,
verbose: bool = False,
):
"""
Extract 3D patches into smaller patches with a defined overlap.
Supports multi-GPU inference by setting ``total_ranks`` and ``rank`` variables.
Each GPU will process an even number of volumes in the ``Z`` axis. If the number
of volumes is not divisible by the number of GPUs, the first GPUs will process
one more volume.
Parameters
----------
data : Zarr array or H5 dataset
Data to extract patches from. E.g. ``(z, y, x, channels)``.
vol_shape : 4D int tuple
Shape of the patches to create. E.g. ``(z, y, x, channels)``.
axes_order : str
Order of axes of ``data``. One of ['TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'].
overlap : Tuple of 3 floats, optional
Amount of minimum overlap on x, y and z dimensions. Should be the same as used
in :func:`~crop_3D_data_with_overlap`. Values must be in range ``[0, 1)``,
representing 0% to 99% overlap. E.g. ``(z, y, x)``.
padding : tuple of ints, optional
Size of padding to be added on each axis ``(z, y, x)``. E.g. ``(24, 24, 24)``.
total_ranks : int, optional
Total number of GPUs.
rank : int, optional
Rank of the current GPU.
return_only_stats : bool, optional
Whether to just return crop statistics without yielding patches. Useful for
precalculating the number of patches.
load_data: bool, optional
Whether to load data from file. Speeds up process if only patch coordinates
are needed.
verbose : bool, optional
Whether to print debugging information.
Yields
------
img : 4D Numpy array, optional
Extracted patch from ``data``. E.g. ``(z, y, x, channels)``. Only returned if
``load_data`` is ``True``.
real_patch_in_data : Tuple of tuples of ints
Coordinates where patch should be inserted in original data. E.g.
``((0, 20), (0, 8), (16, 24))`` means the patch belongs at position
[0:20,0:8,16:24] in the original data.
total_vol : int
Total number of crops to extract.
z_vol_info : dict, optional
Mapping of volume positions in original data. E.g. ``{0: [0, 20], 1: [20, 40]}``
means first volume goes at [0:20], second at [20:40].
list_of_vols_in_z : list of list of int, optional
Volumes assigned to each GPU. E.g. ``[[0, 1, 2], [3, 4]]`` means GPU 0 processes
volumes 0-2, GPU 1 processes volumes 3-4.
"""
if verbose and rank == 0:
print("### 3D-OV-CROP ###")
print(
"Cropping {} images into {} with overlapping (axis order: {}). . .".format(
data.shape, vol_shape, axes_order
)
)
print("Minimum overlap selected: {}".format(overlap))
print("Padding: {}".format(padding))
if len(vol_shape) != 4:
raise ValueError("vol_shape expected to be of length 4, given {}".format(vol_shape))
_, z_dim, c_dim, y_dim, x_dim = order_dimensions(data.shape, axes_order)
assert isinstance(z_dim, int) and isinstance(x_dim, int) and isinstance(y_dim, int) and isinstance(c_dim, int)
if vol_shape[0] > z_dim:
raise ValueError(
"'vol_shape[0]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE')".format(vol_shape[0], z_dim)
)
if vol_shape[1] > y_dim:
raise ValueError(
"'vol_shape[1]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE')".format(vol_shape[1], y_dim)
)
if vol_shape[2] > x_dim:
raise ValueError(
"'vol_shape[2]' {} greater than {} (you can reduce 'DATA.PATCH_SIZE')".format(vol_shape[2], x_dim)
)
if (
(overlap[0] >= 1 or overlap[0] < 0)
or (overlap[1] >= 1 or overlap[1] < 0)
or (overlap[2] >= 1 or overlap[2] < 0)
):
raise ValueError("'overlap' values must be floats between range [0, 1)")
if padding[0] >= vol_shape[0] // 2:
raise ValueError(
"'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format(
vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1)
)
)
if padding[1] >= vol_shape[1] // 2:
raise ValueError(
"'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format(
vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1)
)
)
if padding[2] >= vol_shape[2] // 2:
raise ValueError(
"'Padding' can not be greater than half of 'vol_shape'. Max value for the given input shape {} is {}".format(
vol_shape, ((vol_shape[0] // 2) - 1, (vol_shape[1] // 2) - 1, (vol_shape[2] // 2) - 1)
)
)
padded_data_shape = [
z_dim + padding[0] * 2,
y_dim + padding[1] * 2,
x_dim + padding[2] * 2,
c_dim,
]
# Calculate overlapping variables
overlap_z = 1 if overlap[0] == 0 else 1 - overlap[0]
overlap_y = 1 if overlap[1] == 0 else 1 - overlap[1]
overlap_x = 1 if overlap[2] == 0 else 1 - overlap[2]
# Z
step_z = int((vol_shape[0] - padding[0] * 2) * overlap_z)
vols_per_z = math.ceil(z_dim / step_z)
last_z = 0 if vols_per_z == 1 else (((vols_per_z - 1) * step_z) + vol_shape[0]) - padded_data_shape[0]
ovz_per_block = last_z // (vols_per_z - 1) if vols_per_z > 1 else 0
step_z -= ovz_per_block
last_z -= ovz_per_block * (vols_per_z - 1)
# Y
step_y = int((vol_shape[1] - padding[1] * 2) * overlap_y)
vols_per_y = math.ceil(y_dim / step_y)
last_y = 0 if vols_per_y == 1 else (((vols_per_y - 1) * step_y) + vol_shape[1]) - padded_data_shape[1]
ovy_per_block = last_y // (vols_per_y - 1) if vols_per_y > 1 else 0
step_y -= ovy_per_block
last_y -= ovy_per_block * (vols_per_y - 1)
# X
step_x = int((vol_shape[2] - padding[2] * 2) * overlap_x)
vols_per_x = math.ceil(x_dim / step_x)
last_x = 0 if vols_per_x == 1 else (((vols_per_x - 1) * step_x) + vol_shape[2]) - padded_data_shape[2]
ovx_per_block = last_x // (vols_per_x - 1) if vols_per_x > 1 else 0
step_x -= ovx_per_block
last_x -= ovx_per_block * (vols_per_x - 1)
# Real overlap calculation for printing
real_ov_z = ovz_per_block / (vol_shape[0] - padding[0] * 2)
real_ov_y = ovy_per_block / (vol_shape[1] - padding[1] * 2)
real_ov_x = ovx_per_block / (vol_shape[2] - padding[2] * 2)
if verbose and rank == 0:
print("Real overlapping (%): {}".format((real_ov_z, real_ov_y, real_ov_x)))
print(
"Real overlapping (pixels): {}".format(
(
(vol_shape[0] - padding[0] * 2) * real_ov_z,
(vol_shape[1] - padding[1] * 2) * real_ov_y,
(vol_shape[2] - padding[2] * 2) * real_ov_x,
)
)
)
print("{} patches per (z,y,x) axis".format((vols_per_z, vols_per_x, vols_per_y)))
vols_in_z = vols_per_z // total_ranks
vols_per_z_per_rank = vols_in_z
if vols_per_z % total_ranks > rank:
vols_per_z_per_rank += 1
total_vol = vols_per_z_per_rank * vols_per_y * vols_per_x
c = 0
list_of_vols_in_z = []
z_vol_info = {}
for i in range(total_ranks):
vols = (vols_per_z // total_ranks) + 1 if vols_per_z % total_ranks > i else vols_in_z
for j in range(vols):
z = c + j
real_start_z = z * step_z
real_finish_z = min(real_start_z + step_z + ovz_per_block, z_dim)
z_vol_info[z] = [real_start_z, real_finish_z]
list_of_vols_in_z.append(list(range(c, c + vols)))
c += vols
if verbose and rank == 0:
print(f"List of volume IDs to be processed by each GPU: {list_of_vols_in_z}")
print(f"Positions of each volume in Z axis: {z_vol_info}")
print(
"Rank {}: Total number of patches: {} - {} patches per (z,y,x) axis (per GPU)".format(
rank, total_vol, (vols_per_z_per_rank, vols_per_x, vols_per_y)
)
)
if return_only_stats:
yield total_vol, z_vol_info, list_of_vols_in_z
return
for _z in range(vols_per_z_per_rank):
z = list_of_vols_in_z[rank][0] + _z
for y in range(vols_per_y):
for x in range(vols_per_x):
d_z = 0 if (z * step_z + vol_shape[0]) < padded_data_shape[0] else last_z
d_y = 0 if (y * step_y + vol_shape[1]) < padded_data_shape[1] else last_y
d_x = 0 if (x * step_x + vol_shape[2]) < padded_data_shape[2] else last_x
start_z = max(0, z * step_z - d_z - padding[0])
finish_z = min(z * step_z + vol_shape[0] - d_z - padding[0], z_dim)
start_y = max(0, y * step_y - d_y - padding[1])
finish_y = min(y * step_y + vol_shape[1] - d_y - padding[1], y_dim)
start_x = max(0, x * step_x - d_x - padding[2])
finish_x = min(x * step_x + vol_shape[2] - d_x - padding[2], x_dim)
slices = [
slice(start_z, finish_z),
slice(start_y, finish_y),
slice(start_x, finish_x),
slice(None), # Channel
]
data_ordered_slices = order_dimensions(
slices, input_order="ZYXC", output_order=axes_order, default_value=0
)
real_patch_in_data = PatchCoords(
z_start=z * step_z - d_z,
z_end=(z * step_z) + vol_shape[0] - d_z - (padding[0] * 2),
y_start=y * step_y - d_y,
y_end=(y * step_y) + vol_shape[1] - d_y - (padding[1] * 2),
x_start=x * step_x - d_x,
x_end=(x * step_x) + vol_shape[2] - d_x - (padding[2] * 2),
)
if load_data:
img = data[tuple(data_ordered_slices)]
# The image should have the channel dimension at the end
current_order = np.array(range(len(img.shape)))
transpose_order = order_dimensions(
current_order, #
input_order="ZYXC",
output_order=axes_order,
default_value=np.nan,
)
# determine the transpose order
transpose_order = [x for x in transpose_order if not np.isnan(x)] # type: ignore
transpose_order = np.argsort(transpose_order) # type: ignore
transpose_order = current_order[transpose_order]
img = np.transpose(img, transpose_order)
pad_z_left = padding[0] - z * step_z - d_z if start_z <= 0 else 0
pad_z_right = (start_z + vol_shape[0]) - z_dim if start_z + vol_shape[0] > z_dim else 0
pad_y_left = padding[1] - y * step_y - d_y if start_y <= 0 else 0
pad_y_right = (start_y + vol_shape[1]) - y_dim if start_y + vol_shape[1] > y_dim else 0
pad_x_left = padding[2] - x * step_x - d_x if start_x <= 0 else 0
pad_x_right = (start_x + vol_shape[2]) - x_dim if start_x + vol_shape[2] > x_dim else 0
if img.ndim == 3:
img = np.pad(
img,
(
(pad_z_left, pad_z_right),
(pad_y_left, pad_y_right),
(pad_x_left, pad_x_right),
),
"reflect",
)
img = np.expand_dims(img, -1)
else:
img = np.pad(
img,
(
(pad_z_left, pad_z_right),
(pad_y_left, pad_y_right),
(pad_x_left, pad_x_right),
(0, 0),
),
"reflect",
)
assert (
img.shape[:-1] == vol_shape[:-1]
), f"Image shape and expected shape differ: {img.shape} vs {vol_shape}"
if rank == 0:
yield img, real_patch_in_data, total_vol, z_vol_info, list_of_vols_in_z
else:
yield img, real_patch_in_data, total_vol
else:
if rank == 0:
yield real_patch_in_data, total_vol, z_vol_info, list_of_vols_in_z
else:
yield real_patch_in_data, total_vol
[docs]
def order_dimensions(
data: Sequence[slice] | List[str | int] | Tuple[int, ...] | NDArray,
input_order: str,
output_order: str = "TZCYX",
default_value: int | float = 1,
) -> Sequence[slice] | List[str | int] | Tuple[int, ...] | NDArray:
"""
Reorder data from any input order to output order.
Parameters
----------
data : Numpy array like
data to reorder. E.g. ``(z, y, x, channels)``.
input_order : str
Order of the input data. E.g. ``ZYXC``.
output_order : str, optional
Order of the output data. E.g. ``TZCYX``.
default_value : int or float, optional
Default value to use when a dimension is not present in the input order.
Returns
-------
shape : Tuple
Reordered data. E.g. ``(t, z, channel, y, x)``.
"""
if input_order == output_order:
return data
output_data = []
for i in range(len(output_order)):
if output_order[i] in input_order:
output_data.append(data[input_order.index(output_order[i])])
else:
output_data.append(default_value)
return tuple(output_data)
[docs]
def ensure_3d_shape(
img: NDArray,
path: Optional[str] = None,
data_axes_order: Optional[str] = None,
):
"""
Read an image from a given path.
Parameters
----------
img : NDArray
Image read.
path : str, optional
Path of the image (just use to print possible errors).
data_axes_order : str, optional
Order of axes of ``data``. E.g. 'TZCYX', 'TZYXC', 'ZCYX', 'ZYXC'.
Returns
-------
img : Numpy 4D array
Image read. E.g. ``(z, y, x, channels)``.
"""
if img.ndim < 3:
if path:
m = "Read image seems to be 2D: {}. Path: {}".format(img.shape, path)
else:
m = "Read image seems to be 2D: {}".format(img.shape)
raise ValueError(m)
elif img.ndim == 5:
if img.shape[0] != 1:
# It is assumed that the image is already prepared
return img
else:
img = img[0]
# pop T in data_axes_order
if data_axes_order is not None:
T_post = data_axes_order.index("T") if "T" in data_axes_order else None
data_axes_order = data_axes_order.replace("T", "")
if "Z" not in data_axes_order:
if "C" in data_axes_order:
data_axes_order = data_axes_order.replace("C", "Z")
elif "I" in data_axes_order:
data_axes_order = data_axes_order.replace("I", "Z")
elif "Q" in data_axes_order:
data_axes_order = data_axes_order.replace("Q", "Z")
else:
if len(data_axes_order) < img.ndim and T_post is not None:
data_axes_order = data_axes_order[:T_post] + "Z" + data_axes_order[T_post:]
if any([x for x in data_axes_order if x not in "ZYXC"]):
data_axes_order = None
new_pos = list(range(img.ndim))
if img.ndim == 3:
if data_axes_order is None:
# Ensure Z axis is always in the first position
min_val = min(img.shape)
z_pos = img.shape.index(min_val)
if z_pos != 0:
new_pos = [
z_pos,
] + [x for x in range(3) if x != z_pos]
else:
# Follows the axes order provided in data_axes_order
new_pos = order_dimensions(
np.array(range(len(data_axes_order))),
input_order=data_axes_order,
output_order="ZYX",
default_value=np.nan,
)
new_pos = [x for x in new_pos if not np.isnan(x)] # type: ignore
img = img.transpose(new_pos) # type: ignore
img = np.expand_dims(img, -1)
else:
if data_axes_order is None:
# Ensure channel axis is always in the last position (assuming Z is already set)
min_val = min(img.shape)
z_pos = img.shape.index(min_val)
if z_pos != 3:
new_pos = [x for x in range(4) if x != z_pos] + [
z_pos,
]
else:
# Follows the axes order provided in data_axes_order
new_pos = order_dimensions(
np.array(range(len(data_axes_order))),
input_order=data_axes_order,
output_order="ZYXC",
default_value=np.nan,
)
new_pos = [x for x in new_pos if not np.isnan(x)] # type: ignore
img = img.transpose(new_pos) # type: ignore
return img
def _first_array_in_group(g: zarr.Group) -> zarr.Array:
"""Descend into the first array found (sorted by key) in a group."""
keys = sorted(list(g.keys()))
if not keys:
raise ValueError("Zarr group is empty (no arrays or subgroups).")
obj = g[keys[0]]
while isinstance(obj, zarr.Group):
subkeys = sorted(list(obj.keys()))
if not subkeys:
raise ValueError("Zarr group contains no arrays (only empty groups).")
obj = obj[subkeys[0]]
if not isinstance(obj, zarr.Array):
raise TypeError(f"Expected zarr.Array, found {type(obj)}")
return obj
def _first_dataset_in_h5_group(g: h5py.Group) -> h5py.Dataset:
"""Descend into the first dataset found (sorted by key) in an HDF5 group."""
keys = sorted(list(g.keys()))
if not keys:
raise ValueError("HDF5 group is empty (no datasets or subgroups).")
obj = g[keys[0]]
while isinstance(obj, h5py.Group):
subkeys = sorted(list(obj.keys()))
if not subkeys:
raise ValueError("HDF5 group contains no datasets (only empty groups).")
obj = obj[subkeys[0]]
if not isinstance(obj, h5py.Dataset):
raise TypeError(f"Expected h5py.Dataset, found {type(obj)}")
return obj
[docs]
def read_chunked_nested_data(
file: str, data_path: str = ""
) -> Tuple[ZarrOrH5File, ZarrOrH5Array]:
"""Find recursively raw and ground truth data within a H5/Zarr file.
This function automatically detects whether the input file is in HDF5 or Zarr format
and returns the appropriate file handler and dataset objects.
Parameters
----------
file : str
Path to the input file. Supported formats: .h5, .hdf5, .hdf, .n5, .zarr
data_path : str, optional
Internal path within the file where data is stored. Default: "" (root level)
Returns
-------
tuple
Returns one of:
- (zarr.Group, zarr.core.Array) for Zarr/N5 files
- (h5py.File, h5py.Dataset) for HDF5 files
Raises
------
ValueError
If the input file format is neither Zarr nor HDF5
Examples
--------
>>> file_handler, dataset = read_chunked_nested_data("data.h5")
>>> zarr_group, zarr_array = read_chunked_nested_data("data.zarr")
"""
if looks_like_hdf5(file):
return read_chunked_nested_h5(file, data_path)
elif any(file.endswith(x) for x in [".n5", "n5", ".zarr"]):
return read_chunked_nested_zarr(file, data_path)
else:
raise ValueError("Input file seems to not be either Zarr/H5/n5. Supported formats: .h5, .hdf5, .hdf, .n5, .zarr")
[docs]
def read_chunked_nested_zarr(zarrfile: str, data_path: str = "") -> Tuple[zarr.Group, zarr.Array]:
"""Find recursively raw and ground truth data within a Zarr/N5 file.
This function searches through a Zarr/N5 file hierarchy to locate array data
at the specified path. It supports nested group structures.
Parameters
----------
zarrfile : str
Path to the Zarr/N5 file. Must have .zarr or .n5 extension.
data_path : str, optional
Internal path to the dataset within the Zarr hierarchy, using dot notation
for nested groups (e.g., "group1.subgroup.data"). Default: "" (root level).
Returns
-------
tuple
A tuple containing:
- zarr.Group: The root group of the Zarr file
- zarr.core.Array: The found array data
Raises
------
ValueError
If the file extension is not .zarr or .n5
If the specified data_path is not found in the Zarr hierarchy
Examples
--------
>>> group, array = read_chunked_nested_zarr("data.zarr")
>>> subgroup, dataset = read_chunked_nested_zarr("experiment.n5", "images.channel1")
"""
if not any(zarrfile.endswith(x) for x in [".n5", "n5", ".zarr"]):
raise ValueError("Not implemented for other filetypes than Zarr")
if zarrfile.endswith(".n5") or zarrfile.endswith("n5"):
if not _Z5PY_AVAILABLE:
raise ImportError(
"z5py is required for N5 format support but is not installed. "
"Install it via conda: conda install -c conda-forge z5py"
)
fid = z5py.File(zarrfile, "r")
else:
fid = zarr.open(zarrfile, mode="r")
def find_obj(path: str, fid: zarr.Group): # type: ignore
obj = None
rpath = path.split(".")
if len(rpath) == 0:
return None
else:
if len(rpath) > 1:
groups = list(fid.group_keys())
if rpath[0] not in groups:
return None
obj = find_obj(".".join(rpath[1:]), fid[rpath[0]])
else:
try:
arrays = list(fid.array_keys())
except:
arrays = list(fid.keys())
if rpath[0] not in arrays:
return None
return fid[rpath[0]]
return obj
data = find_obj(data_path, fid)
if data is None and data_path != "":
raise ValueError(f"'{data_path}' not found in Zarr: {zarrfile}.")
return fid, data # type: ignore
[docs]
def read_chunked_nested_h5(h5file: str, data_path: str = "") -> Tuple[h5py.File, h5py.Dataset]:
"""Find recursively raw and ground truth data within an HDF5 file.
This function searches through an HDF5 file hierarchy to locate dataset objects
at the specified path. It supports nested group structures.
Parameters
----------
h5file : str
Path to the HDF5 file. Must have .h5, .hdf5, or .hdf extension.
data_path : str, optional
Internal path to the dataset within the HDF5 hierarchy, using dot notation
for nested groups (e.g., "group1/subgroup/data"). Default: "" (root level).
Returns
-------
tuple
A tuple containing:
- h5py.File: The opened HDF5 file object
- h5py.Dataset: The found dataset object
Raises
------
ValueError
If the file extension is not .h5, .hdf5, or .hdf
If the specified data_path is not found in the HDF5 hierarchy
Examples
--------
>>> file, dataset = read_chunked_nested_h5("data.h5")
>>> file, subgroup_data = read_chunked_nested_h5("experiment.hdf5", "images/channel1")
"""
if not looks_like_hdf5(h5file):
raise ValueError("Not implemented for other filetypes than H5")
fid = h5py.File(h5file, "r")
try:
if not data_path:
return fid, _first_dataset_in_h5_group(fid)
# allow both "a.b.c" and "a/b/c"
normalized = data_path.replace(".", "/").strip("/")
obj: h5py.Group | h5py.Dataset = fid
for part in normalized.split("/"):
if part not in obj:
raise ValueError(f"'{data_path}' not found in H5: {h5file}. Available keys at this level: {list(obj.keys())}")
obj = obj[part]
if isinstance(obj, h5py.Group):
return fid, _first_dataset_in_h5_group(obj)
if isinstance(obj, h5py.Dataset):
return fid, obj
raise TypeError(f"Unexpected HDF5 object type at '{data_path}': {type(obj)}")
except Exception:
fid.close()
raise
[docs]
def read_chunked_data(
filename: str,
) -> Tuple[ZarrOrH5File, ZarrOrH5Array]:
"""Read and return the first dataset found in an HDF5 or Zarr file.
This function automatically detects the file format (HDF5 or Zarr) and returns
the file handler along with the first available dataset. For Zarr files, it
prioritizes groups over arrays when multiple items exist.
Parameters
----------
filename : str
Path to the input file. Supported formats:
- HDF5: .h5, .hdf5, .hdf
- Zarr: .zarr
Returns
-------
tuple
Returns one of:
- (h5py.File, h5py.Dataset) for HDF5 files
- (zarr.Group, zarr.Array) for Zarr files
The first dataset found in the file will be returned
Raises
------
ValueError
If the file doesn't exist
If the file extension is not recognized
If the input is not a string
If no datasets are found in the file
Examples
--------
>>> file_handler, dataset = read_chunked_data("data.h5")
>>> zarr_group, zarr_array = read_chunked_data("data.zarr")
Notes
-----
For Zarr files, the function will:
1. First look for groups and return the first group found
2. If no groups exist, return the first array found
"""
if isinstance(filename, str):
if not os.path.exists(filename):
raise ValueError(f"File {filename} does not exist.")
if looks_like_hdf5(filename):
fid = h5py.File(filename, "r")
try:
data = _first_dataset_in_h5_group(fid)
except Exception:
fid.close()
raise ValueError(f"No datasets found in HDF5 file {filename}.")
elif filename.endswith(".zarr") or filename.endswith(".n5"):
try:
fid = zarr.open(filename, mode="r")
except Exception as e:
raise ValueError(f"Error opening Zarr file {filename}: {e}")
if isinstance(fid, zarr.Group):
data = _first_array_in_group(fid)
else:
data = fid
else:
raise ValueError(f"File extension {filename} not recognized")
return fid, data
else:
raise ValueError("'filename' is expected to be a str")
[docs]
def looks_like_hdf5(path: str) -> bool:
"""
Check if a given file path corresponds to an HDF5 file based on its extension.
Parameters
----------
path : str
The file path to check.
Returns
-------
bool
True if the file has an HDF5 extension, False otherwise.
"""
# robust extension handling (including ".hdf5.gz", etc.)
p = path.lower()
exts = (".h5", ".hdf5", ".hdf", ".he5")
if p.endswith(exts):
return True
# handle double extensions like ".h5.gz"
base, ext = os.path.splitext(p)
if ext in (".gz", ".bz2", ".xz", ".zip") and base.endswith(exts):
return True
return False
[docs]
def pick_chunks(shape: Tuple[int, ...], dtype: str, target_mb: float = 4.0) -> Tuple[int, ...]:
"""
Pick chunk sizes for HDF5 datasets based on the shape and data type.
Parameters
----------
shape : tuple of int
Shape of the dataset.
dtype : str
Data type of the dataset.
target_mb : float, optional
Target chunk size in megabytes. Default is 4.0 MB.
Returns
-------
tuple of int
Chunk sizes for each dimension of the dataset.
"""
itemsize = np.dtype(dtype).itemsize
target_bytes = int(target_mb * 1024 * 1024)
# start with a conservative cap per dimension (keeps metadata manageable)
chunks = [min(int(d), 256) for d in shape]
# keep channels small-ish if present
if len(shape) >= 4:
chunks[-1] = min(int(shape[-1]), 16)
def chunk_bytes() -> int:
n = 1
for c in chunks:
n *= max(1, int(c))
return n * itemsize
# shrink largest dims until under target
while chunk_bytes() > target_bytes:
# find a dim we can shrink (prefer spatial over channels)
# skip dims already at 1
candidates = [i for i, c in enumerate(chunks) if c > 1]
if not candidates:
break
# avoid shrinking channels first when possible
if len(shape) >= 4 and (len(chunks) - 1) in candidates and len(candidates) > 1:
candidates.remove(len(chunks) - 1)
# shrink the currently-largest candidate
i = max(candidates, key=lambda j: chunks[j])
chunks[i] = max(1, chunks[i] // 2)
return tuple(int(c) for c in chunks)
[docs]
def load_synapse_gt_points(
locations_path: str,
resolution_path: str,
partners_path: str,
id_path: str,
data_filename: str
) -> Dict[str, list]:
"""
Load synapse ground truth points from the given paths.
Parameters
----------
locations_path : str
Path to the synapse locations within the data file.
resolution_path : str
Path to the synapse resolution within the data file.
partners_path : str
Path to the synapse partners within the data file.
id_path : str
Path to the synapse ids within the data file.
data_filename : str
Path to the data file.
Returns
-------
gt_pre_points : list of numpy arrays
List of pre-synaptic points coordinates.
gt_post_points : list of numpy arrays
List of post-synaptic points coordinates.
gt_cleft_points : list of numpy arrays
List of synaptic cleft points coordinates.
resolution : tuple of int or float
Resolution of the synapse coordinates.
"""
file, ids = read_chunked_nested_data(data_filename, id_path)
ids = list(np.array(ids))
_, partners = read_chunked_nested_data(data_filename, partners_path)
partners = np.array(partners)
_, locations = read_chunked_nested_data(data_filename, locations_path)
locations = np.array(locations)
_, resolution = read_chunked_nested_data(data_filename, resolution_path)
try:
resolution = resolution.attrs["resolution"]
except:
raise ValueError(
"There is no 'resolution' attribute in '{}'. Add it like: data['{}'].attrs['resolution'] = (8,8,8)".format(
resolution_path, resolution_path
)
)
resolution = list(resolution)
gt_pre_points, gt_post_points = {}, {}
for i in tqdm(range(len(partners)), disable=not is_main_process()):
pre_id, post_id = partners[i]
pre_position = ids.index(pre_id)
post_position = ids.index(post_id)
pre_coord = locations[pre_position] // resolution
post_coord = locations[post_position] // resolution
if str(pre_coord) not in gt_pre_points:
gt_pre_points[str(pre_coord)] = pre_coord
if str(post_coord) not in gt_post_points:
gt_post_points[str(post_coord)] = post_coord
gt_pre_points = list(gt_pre_points.values())
gt_post_points = list(gt_post_points.values())
# For synaptic cleft points, we take the midpoint between pre and post-synaptic points
gt_cleft_points = []
for pre, post in zip(gt_pre_points, gt_post_points):
cleft_point = (pre + post) / 2
gt_cleft_points.append(cleft_point)
if isinstance(file, h5py.File):
file.close()
return {"pre": gt_pre_points, "post": gt_post_points, "cleft": gt_cleft_points, "resolution": resolution}