"""
Pre-processing utilities for image and mask data in deep learning workflows.
This module provides pre-processing functions for instance segmentation, detection mask creation, self-supervised learning data generation, semantic segmentation probability maps, and general image processing operations such as resizing, blurring, edge detection, histogram matching, and CLAHE. It supports both 2D and 3D data formats and integrates with BiaPy configuration objects for flexible data pipelines.
"""
import os
import edt
import h5py
import zarr
import numpy as np
from tqdm import tqdm
import pandas as pd
from skimage.segmentation import clear_border, find_boundaries, watershed
from scipy.ndimage import (
generic_filter,
generate_binary_structure,
grey_closing,
center_of_mass,
map_coordinates,
uniform_filter,
binary_dilation as binary_dilation_scipy
)
from skimage.morphology import disk, binary_dilation, binary_erosion, skeletonize
from skimage.measure import label, regionprops_table
from skimage.transform import resize
from skimage.feature import canny
from skimage.exposure import equalize_adapthist
from skimage.color import rgb2gray
from skimage.filters import gaussian, median
from yacs.config import CfgNode as CN
from numpy.typing import NDArray
from typing import List, Optional, Dict, Tuple, Sequence, Union
from scipy.spatial import cKDTree
from biapy.data.dataset import BiaPyDataset
from biapy.utils.util import (
seg2aff_pni,
seg_widen_border,
)
from biapy.utils.misc import is_main_process, get_rank, get_world_size, os_walk_clean
from biapy.data.data_3D_manipulation import (
load_3D_efficient_files,
load_img_part_from_efficient_file,
order_dimensions,
read_chunked_data,
read_chunked_nested_data,
looks_like_hdf5,
pick_chunks,
)
from biapy.data.data_manipulation import (
read_img_as_ndarray,
load_data_from_dir,
save_tif,
decide_dtype,
)
#########################
# INSTANCE SEGMENTATION #
#########################
[docs]
def create_instance_channels(cfg: CN, data_type: str = "train"):
"""
Create training and validation new data with appropiate channels based on ``PROBLEM.INSTANCE_SEG.DATA_CHANNELS`` for instance segmentation.
Parameters
----------
cfg : YACS CN object
Configuration.
data_type: str, optional
Wheter to create training or validation instance channels.
"""
assert data_type in ["train", "val", "test"]
tag = data_type.upper()
# Checking if the user inputted Zarr/H5 files
if getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA:
data_path = getattr(cfg.DATA, tag).PATH
try:
zarr_files = next(os_walk_clean(data_path))[1]
except StopIteration:
raise ValueError("No Zarr/N5 files found in the input path: {}".format(data_path))
try:
h5_files = next(os_walk_clean(data_path))[2]
except StopIteration:
raise ValueError("No H5 files found in the input path: {}".format(data_path))
else:
data_path = getattr(cfg.DATA, tag).GT_PATH
try:
zarr_files = next(os_walk_clean(data_path))[1]
h5_files = next(os_walk_clean(data_path))[2]
except StopIteration:
raise ValueError("No Zarr/N5 or H5 files found in the GT path: {}".format(data_path))
# Find patches info so we can iterate over them to create the instance mask
working_with_zarr_h5_files = False
if (
cfg.PROBLEM.NDIM == "3D"
and (len(zarr_files) > 0 and any(True for x in [".zarr", ".n5"] if x in zarr_files[0]))
or (len(h5_files) > 0 and looks_like_hdf5(h5_files[0]))
):
working_with_zarr_h5_files = True
# Check if the raw images and labels are within the same file
data_path = getattr(cfg.DATA, tag).GT_PATH
path_to_gt_data = None
if getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA:
data_path = getattr(cfg.DATA, tag).PATH
if cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses":
path_to_gt_data = getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_RAW_PATH
else:
path_to_gt_data = getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_GT_PATH
if len(zarr_files) > 0 and any(True for x in [".zarr", ".n5"] if x in zarr_files[0]):
print("Working with Zarr files . . .")
img_files = [os.path.join(data_path, x) for x in zarr_files]
elif len(h5_files) > 0 and looks_like_hdf5(h5_files[0]):
print("Working with H5 files . . .")
img_files = [os.path.join(data_path, x) for x in h5_files]
Y, Y_total_patches = load_3D_efficient_files(
img_files,
input_axes=getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER,
crop_shape=cfg.DATA.PATCH_SIZE,
overlap=getattr(cfg.DATA, tag).OVERLAP,
padding=getattr(cfg.DATA, tag).PADDING,
data_within_zarr_path=path_to_gt_data,
)
zarr_data_information = {
"raw_data_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_RAW_PATH,
"axes_order": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER,
"z_axe_pos": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("Z"),
"y_axe_pos": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("Y"),
"x_axe_pos": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("X"),
"id_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_ID_PATH,
"partners_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_PARTNERS_PATH,
"locations_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_LOCATIONS_PATH,
"resolution_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_RESOLUTION_PATH,
}
else:
if cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses":
raise ValueError("Synapse detection is only available for 3D Zarr/H5 data so please check your data in {}".format(data_path))
Y = next(os_walk_clean(getattr(cfg.DATA, tag).GT_PATH))[2]
del zarr_files, h5_files
print("Creating Y_{} channels . . .".format(data_type))
# Create the mask patch by patch (Zarr/H5)
if working_with_zarr_h5_files and isinstance(Y, dict):
if "D" in cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
dtype_str = "float32"
raise ValueError("Currently distance creation using Zarr by chunks is not implemented.")
else:
dtype_str = "uint8"
# For synapses the process of the channel creation is different: not patch by patch but paiting each post-synaptic
# points for each pre-synaptic point
if cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses" and len(Y) > 0:
synapse_channel_creation(
data_info=Y,
zarr_data_information=zarr_data_information,
savepath=getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR,
mode=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
channel_extra_opts=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0],
)
else: # regular instances, not synapses
mask = None
imgfile = None
last_parallel_file = None
rank = get_rank()
world_size = get_world_size()
N = len(Y)
it = range(rank, N, world_size)
for i in tqdm(it, disable=not is_main_process()):
# Extract the patch to process
patch_coords = Y[i]["patch_coords"]
img = load_img_part_from_efficient_file(
Y[i]["filepath"],
patch_coords,
data_axes_order=getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER,
data_path=getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_GT_PATH,
)
if img.ndim == 3:
img = np.expand_dims(img, -1)
# Create the instance mask
if cfg.DATA.N_CLASSES > 2:
if img.shape[-1] != 2:
raise ValueError(
"In instance segmentation, when 'DATA.N_CLASSES' are more than 2 labels need to have two channels, "
"e.g. (256,256,2), containing the instance segmentation map (first channel) and classification map (second channel)."
)
else:
class_channel = np.expand_dims(img[..., 1].copy(), -1)
else:
if img.shape[-1] != 1:
raise ValueError(
"Expected instance segmentation GT images to have a single channel containing the instance labels, "
"but got image with shape {} ({} channels). Check the image file: {}".format(img.shape, img.shape[-1], img.shape, img_path)
)
img = labels_into_channels(
img,
mode=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
channel_extra_opts=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0],
save_dir=getattr(cfg.PATHS, tag + "_INSTANCE_CHANNELS_CHECK"),
)
if cfg.DATA.N_CLASSES > 2:
img = np.concatenate([img, class_channel], axis=-1)
# Create the Zarr file where the mask will be placed
if mask is None or os.path.basename(Y[i]["filepath"]) != last_parallel_file:
last_parallel_file = os.path.basename(Y[i]["filepath"])
# Close last open H5 file
if mask and isinstance(fid_mask, h5py.File):
fid_mask.close()
if path_to_gt_data:
imgfile, data = read_chunked_nested_data(Y[i]["filepath"], path_to_gt_data)
else:
imgfile, data = read_chunked_data(Y[i]["filepath"])
fname = os.path.join(getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR, os.path.basename(Y[i]["filepath"]))
os.makedirs(getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR, exist_ok=True)
# Determine data shape
out_data_shape = np.array(data.shape)
if "C" not in getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER:
out_data_shape = tuple(out_data_shape) + (img.shape[-1],)
out_data_order = getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER + "C"
channel_pos = -1
else:
out_data_shape[getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("C")] = img.shape[-1]
out_data_shape = tuple(out_data_shape)
out_data_order = getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER
channel_pos = getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("C")
is_h5 = looks_like_hdf5(fname)
if is_h5:
fname, _ = os.path.splitext(fname)
fid_mask = h5py.File(fname + ".h5", "w")
ds_kwargs = {
"shape": out_data_shape,
"dtype": dtype_str,
"chunks": pick_chunks(out_data_shape, dtype_str, target_mb=4.0),
"compression": "gzip",
"compression_opts": 4,
"shuffle": True
}
mask = fid_mask.create_dataset("data", **ds_kwargs)
else: # Zarr file
out_data_shape = tuple(int(s) for s in out_data_shape)
mask = zarr.open(fname, mode="w", shape=out_data_shape, dtype=dtype_str, zarr_format=3)
# Close H5 file read for the data shape
if isinstance(imgfile, h5py.File):
imgfile.close()
del data, fname
slices = (
slice(patch_coords.z_start, patch_coords.z_end),
slice(patch_coords.y_start, patch_coords.y_end),
slice(patch_coords.x_start, patch_coords.x_end),
slice(0, out_data_shape[channel_pos]),
)
data_ordered_slices = tuple(
order_dimensions(
slices,
input_order="ZYXC",
output_order=out_data_order,
default_value=0,
)
)
# Adjust patch slice to transpose it before inserting intop the final data
current_order = np.array(range(len(img.shape)))
transpose_order = order_dimensions(
current_order,
input_order="ZYXC",
output_order=out_data_order,
default_value=np.nan,
)
transpose_order = [x for x in np.array(transpose_order) if not np.isnan(x)]
# Place the patch into the Zarr
mask[data_ordered_slices] = img.transpose(transpose_order)
# Close last open H5 file
if mask and isinstance(imgfile, h5py.File):
imgfile.close()
else:
rank = get_rank()
world_size = get_world_size()
N = len(Y)
it = range(rank, N, world_size)
for i in tqdm(it, disable=not is_main_process()):
img_path = os.path.join(getattr(cfg.DATA, tag).GT_PATH, Y[i])
img = read_img_as_ndarray(img_path, is_3d=not cfg.PROBLEM.NDIM == "2D")
if cfg.DATA.N_CLASSES > 2:
if img.shape[-1] != 2:
raise ValueError(
"In instance segmentation, when 'DATA.N_CLASSES' are more than 2 labels need to have two channels, "
"e.g. (256,256,2), containing the instance segmentation map (first channel) and classification map "
"(second channel)."
)
class_channel = np.expand_dims(img[..., 1].copy(), -1)
else:
if img.shape[-1] != 1:
raise ValueError(
"Expected instance segmentation GT images to have a single channel containing the instance labels, "
"but got image with shape {} ({} channels). Check the image file: {}".format(img.shape, img.shape[-1], img.shape, img_path)
)
img = labels_into_channels(
img,
mode=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
channel_extra_opts=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0],
save_dir=getattr(cfg.PATHS, tag + "_INSTANCE_CHANNELS_CHECK"),
)
if cfg.DATA.N_CLASSES > 2:
img = np.concatenate([img, class_channel], axis=-1)
save_tif(
np.expand_dims(img, 0),
data_dir=getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR,
filenames=[Y[i]],
verbose=False,
)
[docs]
def unique_labels_fast(a: np.ndarray):
"""
Find the unique labels in an integer array a in [0, K] in O(n) time and O(K) space.
Parameters
----------
a : ndarray
Input array of integers.
Returns
-------
ndarray
Array of unique labels.
"""
# a: any shape, integer dtype, all values in [0, K]
a = np.asarray(a)
K = int(a.max())
present = np.zeros(K + 1, dtype=bool)
present[a.ravel().astype(int)] = True # O(n) and very cache-friendly
return np.flatnonzero(present) # sorted because we scan 0..K
[docs]
def labels_into_channels(
instance_labels: NDArray,
mode: List[str] = ["I", "C"],
channel_extra_opts: Dict = {},
resolution: List[int|float] = [1,1,1],
save_dir: Optional[str] = None,
) -> NDArray:
"""
Convert input semantic or instance segmentation data masks into different binary channels to train an instance segmentation problem.
Parameters
----------
instance_labels : 3D/4D Numpy array
Instance labels to be used to extract the channels from. E.g. ``(200, 1000, 1000, 1)``
mode : List, optional
Operation mode. Possible values: ``C``, ``BC``, ``BCM``, ``BCD``, ``BD``, ``BCDv2``, ``Dv2``, ``BDv2`` and ``BP``.
- 'B' stands for 'Binary segmentation', containing each instance region without the contour.
- 'C' stands for 'Contour', containing each instance contour.
- 'D' stands for 'Distance', each pixel containing the distance of it to the center of the object.
- 'M' stands for 'Mask', contains the B and the C channels, i.e. the foreground mask.
Is simply achieved by binarizing input instance masks.
- 'Dv2' stands for 'Distance V2', which is an updated version of 'D' channel calculating background distance as well.
- 'P' stands for 'Points' and contains the central points of an instance (as in Detection workflow)
- 'A' stands for 'Affinities" and contains the affinity values for each dimension
channel_extra_opts : dict, optional
Additional options for each output channel (e.g., {"I": {"erosion": 1}}).
resolution : Tuple of int/float
Resolution of the data, in ``(z,y,x)`` to calibrate coordinates. E.g. ``[30,8,8]``.
save_dir : str, optional
Path to store samples of the created array just to debug it is correct.
Returns
-------
new_mask : 3D/4D Numpy array
Instance representations. The shape will be as the input ``instance_labels`` but with the amount of channels
requested. E.g. ``(200, 1000, 1000, 3)``
"""
assert len(resolution) == 3, "'resolution' must be a list of 3 int/float"
assert instance_labels.ndim in [3, 4]
c_number = 0
for ch in mode:
if ch == "R":
nrays = channel_extra_opts["R"]["nrays"]
c_number += nrays
elif ch == "A":
affs = (
len(channel_extra_opts["A"]["z_affinities"])
+len(channel_extra_opts["A"]["y_affinities"])
+len(channel_extra_opts["A"]["x_affinities"])
)
c_number += affs
elif ch in ["E_sigma", "E_seediness"]:
continue # not special channels, just extra targets for embeddings
else:
c_number += 1
if any(x for x in ["Dc", "Dn", "D", "Z", "V", "H", "R", "We", "Gv", "Gh", "Gz"] if x in mode):
dtype = np.float32
elif "Db" in mode:
dtype = np.uint8 if channel_extra_opts.get("Db", {}).get("val_type", "norm") == "discretize" else np.float32
elif "E_offset" in mode:
dtype = instance_labels.dtype
# Ensure that no floating-point dtype is used for the embeddings.
# This allows the normalization module to correctly recognize them
# as integer or binary channels, ensuring that the subsequent
# data augmentation processes the samples as intended.
if np.issubdtype(dtype, np.floating):
dtype = decide_dtype(instance_labels.max())
else:
dtype = np.uint8
new_mask = np.zeros(instance_labels.shape[:-1] + (c_number,), dtype=dtype)
vol = instance_labels[..., 0]
if np.issubdtype(vol.dtype, np.floating):
vol = vol.astype(np.uint32)
# Precompute regionprops only when needed
needs_props = False
if (
any(x in mode for x in ("Z", "V", "H", "Gv", "Gh", "Gz"))
or ("P" in mode and channel_extra_opts.get("P", {}).get("type", "") == "skeleton")
or ("Dc" in mode)
):
needs_props = True
if needs_props:
# label, bbox, centroid (you didn't have intensity stats here)
props_tbl = regionprops_table(vol, properties=("label", "bbox", "centroid"))
# Convenience view as list of labels
instances = list(props_tbl["label"])
instance_count = len(instances)
else:
instances = sorted(list(unique_labels_fast(vol)))
instance_count = len(instances)
instances = [inst for inst in instances if inst != 0] # remove background
if instance_count <= 1:
return new_mask # only background
disable_tqdm = False if instance_count >= 2000 else True
fg_mask = (vol > 0).astype(np.uint8)
bg_mask = (vol == 0).astype(np.uint8)
# Precompute horizontal/vertical/depth channels if any of H/V/Z is requested
if any(ch in mode for ch in ("Z", "V", "H")):
norm_flag = True
for ch in ("Z", "V", "H"):
if ch in channel_extra_opts and "norm" in channel_extra_opts[ch]:
norm_flag = bool(channel_extra_opts[ch]["norm"])
break
hv_channels = create_HoVe_channels(
vol,
ref_point="center",
normalize_values=norm_flag,
calc_props=props_tbl,
)
# ---------- Foreground (F) ----------
if "F" in mode:
# Check if erosion/dilation is requested as the process needs the original volume
# to make it per-instance
er_k = channel_extra_opts.get("F", {}).get("erosion", 0)
dil_k = channel_extra_opts.get("F", {}).get("dilation", 0)
mask = fg_mask.astype(np.uint8)
erode, dilate = False, False
if (isinstance(er_k, int) and er_k > 0) or (isinstance(er_k, list) and any([x for x in er_k if x > 0])):
erode = True
if (isinstance(dil_k, int) and dil_k > 0) or (isinstance(dil_k, list) and any([x for x in dil_k if x > 0])):
dilate = True
if erode or dilate:
mask = np.zeros_like(fg_mask, dtype=vol.dtype)
dil_k = [dil_k,]*mask.ndim if isinstance(dil_k, int) else dil_k
dil_k = generate_ellipse_footprint(dil_k)
er_k = [er_k,]*mask.ndim if isinstance(er_k, int) else er_k
er_k = generate_ellipse_footprint(er_k)
for lb in tqdm(instances, disable=disable_tqdm):
m = (vol == lb)
if not np.any(m):
continue
if dilate:
m = binary_dilation(m.astype(np.uint8), footprint=dil_k).astype(np.uint8)
if erode:
m = binary_erosion(m.astype(np.uint8), footprint=er_k).astype(np.uint8)
mask[m > 0] = lb
new_mask[..., mode.index("F")] = mask
# ---------- Background (B) ----------
if "B" in mode:
# Check if erosion/dilation is requested as the process needs the original volume
# to make it per-instance
er_k = channel_extra_opts.get("B", {}).get("erosion", 0)
dil_k = channel_extra_opts.get("B", {}).get("dilation", 0)
erode, dilate = False, False
mask = bg_mask.astype(np.uint8)
if (isinstance(er_k, int) and er_k > 0) or (isinstance(er_k, list) and any([x for x in er_k if x > 0])):
erode = True
if (isinstance(dil_k, int) and dil_k > 0) or (isinstance(dil_k, list) and any([x for x in dil_k if x > 0])):
dilate = True
if erode or dilate:
mask = np.zeros_like(fg_mask, dtype=np.uint8)
dil_k = [dil_k,]*mask.ndim if isinstance(dil_k, int) else dil_k
dil_k = generate_ellipse_footprint(dil_k)
er_k = [er_k,]*mask.ndim if isinstance(er_k, int) else er_k
er_k = generate_ellipse_footprint(er_k)
for lb in tqdm(instances, disable=disable_tqdm):
m = (vol == lb)
if not np.any(m):
continue
# As the background mask is going to be created using the instances,
# we need to invert the operations
if dilate:
m = binary_erosion(m.astype(np.uint8), footprint=dil_k).astype(np.uint8)
if erode:
m = binary_dilation(m.astype(np.uint8), footprint=er_k).astype(np.uint8)
mask[m > 0] = 1
new_mask[..., mode.index("B")] = mask
# ---------- P (central part) ----------
if "P" in mode:
p_opts = channel_extra_opts.get("P", {})
p_type = p_opts.get("type", "centroid")
p_dil = p_opts.get("dilation", 1)
p_ero = p_opts.get("erosion", 1)
p_out = np.zeros_like(fg_mask, dtype=np.uint8)
if p_type == "skeleton":
for i, lb in tqdm(enumerate(instances), disable=disable_tqdm):
slc = slice_from_props(props_tbl, i, vol.ndim)
sub = (vol[slc] == lb)
sk = skeletonize(sub)
p_out[slc] += sk
else:
com_list = center_of_mass(fg_mask, labels=vol, index=instances)
# Mark each centroid (guard against rounding outside bounds)
if p_out.ndim == 2:
H, W = p_out.shape
for cy, cx in com_list:
y = int(round(cy)); x = int(round(cx))
if 0 <= y < H and 0 <= x < W:
p_out[y, x] = 1
elif p_out.ndim == 3:
Z, Y, X = p_out.shape
for cz, cy, cx in com_list:
z = int(round(cz)); y = int(round(cy)); x = int(round(cx))
if 0 <= z < Z and 0 <= y < Y and 0 <= x < X:
p_out[z, y, x] = 1
else:
raise ValueError(f"Unsupported ndim {p_out.ndim} for P[type='centroid']")
# Optional dilation (in pixels / voxels)
if (isinstance(p_dil, int) and p_dil > 0) or (isinstance(p_dil, list) and any([x for x in p_dil if x > 0])):
p_dil = [p_dil,]*p_out.ndim if isinstance(p_dil, int) else p_dil
p_out = binary_dilation(p_out, footprint=generate_ellipse_footprint(p_dil)).astype(np.uint8)
# Optional erosion (in pixels / voxels)
if (isinstance(p_ero, int) and p_ero > 0) or (isinstance(p_ero, list) and any([x for x in p_ero if x > 0])):
p_ero = [p_ero,]*p_out.ndim if isinstance(p_ero, int) else p_ero
p_out = binary_erosion(p_out, footprint=generate_ellipse_footprint(p_ero)).astype(np.uint8)
# Write the channel
new_mask[..., mode.index("P")] = p_out
# ---------- C (contours) ----------
if "C" in mode:
c_mode = channel_extra_opts.get("C", {}).get("mode", "thick")
if c_mode == "dense":
# synthetic "dense" edges: dilate FG and XOR with FG to thicken borders on both sides
fg = fg_mask
if fg.ndim == 2:
rim = binary_dilation(fg, disk(1)).astype(np.uint8) ^ fg
new_mask[..., mode.index("C")] = rim
else:
out = np.zeros_like(fg)
for j in range(fg.shape[0]):
out[j] = (binary_dilation(fg[j], disk(1)).astype(np.uint8) ^ fg[j])
new_mask[..., mode.index("C")] = out
else:
# valid skimage modes: inner|outer|thick|subpixel
new_mask[..., mode.index("C")] = find_boundaries(vol, mode=c_mode).astype(np.uint8)
# ---------- Dc (distance to center/skeleton) ----------
if "Dc" in mode:
dc_type = channel_extra_opts.get("Dc", {}).get("type", "centroid")
dc_channel = np.zeros_like(vol, dtype=np.float32)
if vol.ndim == 3:
cz_tab = props_tbl['centroid-0']
cy_tab = props_tbl['centroid-1']
cx_tab = props_tbl['centroid-2']
else:
cy_tab = props_tbl['centroid-0']
cx_tab = props_tbl['centroid-1']
for i, lb in tqdm(enumerate(instances), disable=disable_tqdm):
slc = slice_from_props(props_tbl, i, vol.ndim)
sub = (vol[slc] == lb)
if not sub.any():
continue
if dc_type == "skeleton":
sk = skeletonize(sub)
dist_to_sk = edt.edt(~sk, anisotropy=resolution, parallel=-1)
dc_channel[slc][sub] = dist_to_sk[sub]
else:
if vol.ndim == 3:
cz = float(cz_tab[i]); cy = float(cy_tab[i]); cx = float(cx_tab[i])
zz, yy, xx = np.ogrid[slc[0], slc[1], slc[2]] # grids in global coords
dist = np.sqrt((zz - cz)**2 + (yy - cy)**2 + (xx - cx)**2)
dc_channel[slc][sub] = dist[sub]
else:
cy = float(cy_tab[i]); cx = float(cx_tab[i])
yy, xx = np.ogrid[slc[0], slc[1]]
dist = np.sqrt((yy - cy)**2 + (xx - cx)**2)
dc_channel[slc][sub] = dist[sub]
assert isinstance(dc_channel, np.ndarray), "Expected dc_channel to be a numpy array"
# Normalization
if channel_extra_opts.get("Dc", {}).get("norm", False):
dc_channel = norm_channel(
dc_channel,
vol,
instances,
)
new_mask[..., mode.index("Dc")] = dc_channel
# ---------- Db (distance to boundary) ----------
if "Db" in mode:
db_channel = edt.edt(vol, anisotropy=resolution, parallel=-1)
assert isinstance(db_channel, np.ndarray), "Expected db to be a numpy array"
# Normalization
val_type = channel_extra_opts.get("Db", {}).get("val_type", "norm")
if val_type in ["norm", "discretize"]:
db_channel = norm_channel(
db_channel,
vol,
instances,
)
if val_type == "discretize":
db_dis_bin_size = channel_extra_opts.get("Db", {}).get("bin_size", 0.1)
db_dis_K = int(round(1.0 / db_dis_bin_size)) # 10
db_channel = np.clip(db_channel, 0.0, 1.0)
fg = fg_mask.astype(bool)
# bin index in [0, K-1]
bin_idx = np.floor(db_channel / db_dis_bin_size).astype(np.uint8)
bin_idx = np.clip(bin_idx, 0, db_dis_K - 1)
labels = np.zeros(db_channel.shape, dtype=np.uint8) # background = 0
labels[fg] = bin_idx[fg] + 1 # foreground bins = 1..K
db_channel = labels
new_mask[..., mode.index("Db")] = db_channel
# ---------- Dn (distance to neighbor) ----------
if "Dn" in mode:
dn_opts = channel_extra_opts.get("Dn", {})
dn_norm = bool(dn_opts.get("norm", False))
power = float(dn_opts.get("decline_power", 3.0))
closing_size = int(dn_opts.get("closing_size", 3)) # neighborhood size for grey closing
dn_channel = np.zeros_like(vol, dtype=np.float32)
# Mask to remember which cell pixels belong to cells that have at least one OTHER instance
has_neighbor_px = np.zeros_like(vol, dtype=bool)
for lab in tqdm(instances, disable=disable_tqdm):
cur = (vol == lab)
if not np.any(cur):
continue
other_instances = fg_mask & (~cur)
if not np.any(other_instances):
# No other labeled object anywhere -> keep this cell at 0 (suppressed)
continue
has_neighbor_px[cur] = True
# Per paper: (selected cell ∪ background) = 1; "other cells" = 0 (distance to other cells)
fg = (bg_mask | cur)
d = edt.edt(fg, anisotropy=resolution, parallel=-1).astype(np.float32)
if dn_norm:
# Paper path: cut->normalize [0,1]->invert (1 - ..)
d_cell = d[cur].copy()
m = d_cell.max()
if m > 0:
d_cell /= m
else:
d_cell.fill(0.0)
dn_channel[cur] = 1.0 - d_cell
else:
# Store raw distances for now (we'll handle per-image norm or unnormalized inversion later)
dn_channel[cur] = d[cur]
invert_mask = has_neighbor_px # operate only on cells that actually have neighbors
# --- Unnormalized but still inverted ---
if not dn_norm:
if np.any(invert_mask):
M = float(dn_channel[invert_mask].max())
if M > 0:
dn_channel[invert_mask] = M - dn_channel[invert_mask]
# background & isolated cells remain 0
# Grayscale closing after merging & inversion
if closing_size > 0:
size = (closing_size,) * dn_channel.ndim
dn_channel = grey_closing(dn_channel, size=size)
if power is not None and power != 1.0:
if dn_norm :
# values in [0,1] already → direct power
dn_channel = np.power(dn_channel, power, dtype=np.float32)
else:
# unnormalized: temporarily normalize on the meaningful support, power, then unnormalize
if np.any(invert_mask):
M2 = float(dn_channel[invert_mask].max())
if M2 > 0:
tmp = dn_channel[invert_mask] / M2
tmp = np.power(tmp, power, dtype=np.float32)
dn_channel[invert_mask] = tmp * M2
new_mask[..., mode.index("Dn")] = dn_channel.astype(np.float32)
# ---------- D (signed distance, global) ----------
if "D" in mode:
alpha = channel_extra_opts.get("D", {}).get("alpha", 1.0)
beta = channel_extra_opts.get("D", {}).get("beta", 1.0)
# 1) Signed distance
sdist = edt.edt(fg_mask, anisotropy=resolution, parallel=-1)/alpha - edt.edt(bg_mask, anisotropy=resolution, parallel=-1)/beta
assert isinstance(sdist, np.ndarray), "Expected sdist to be a numpy array"
# 2) Map GT to [-1, 1] with tanh (COSEM-style: "Whole-cell organelle segmentation in volume electron microscopy")
tanh_on = channel_extra_opts.get("D", {}).get("norm", True)
if tanh_on:
sdist = np.tanh(sdist)
new_mask[..., mode.index("D")] = sdist
# ---------- H / V / Z (horizontal/vertical/depth channels) ----------
if "Z" in mode:
new_mask[..., mode.index("Z")] = hv_channels[...,0]
if "V" in mode:
ch_pos = 0 if new_mask[..., mode.index("V")].ndim == 2 else 1
new_mask[..., mode.index("V")] = hv_channels[...,ch_pos]
if "H" in mode:
ch_pos = 1 if new_mask[..., mode.index("H")].ndim == 2 else 2
new_mask[..., mode.index("H")] = hv_channels[...,ch_pos]
# ---------- Gv / Gh / Gz (flow-like channels) ----------
if 'Gv' in mode:
niter = channel_extra_opts.get("Gv", {}).get("niter", 200)
gtype = channel_extra_opts.get("Gv", {}).get("gradient_type", "cellpose")
Gv = np.zeros_like(vol, dtype=np.float32)
Gh = np.zeros_like(vol, dtype=np.float32)
if vol.ndim == 3:
Gz = np.zeros_like(vol, dtype=np.float32)
if gtype == "omnipose":
# 1. Calculate Distance Transform (The "Potential Field")
# Must be computed per-cell so each pixel holds its distance to its
# own cell boundary. Using the full multi-label vol as foreground
# merges touching cells into one region, giving wrong EDT values at
# shared interfaces.
dist_field = np.zeros_like(vol, dtype=np.float32)
for lb in instances:
cell_mask = (vol == lb)
dist_field[cell_mask] = edt.edt(cell_mask, anisotropy=resolution, parallel=-1)[cell_mask]
# 2. Calculate Gradients of the Distance Field
# np.gradient returns [dz, dy, dx] for 3D or [dy, dx] for 2D
# Pass voxel spacing so gradients are physically correct for anisotropic data.
if vol.ndim == 3:
grads = np.gradient(dist_field, resolution[0], resolution[1], resolution[2])
else:
grads = np.gradient(dist_field, resolution[1], resolution[2])
# 3. Normalize the flows
# Omnipose uses normalized gradients as direction vectors
mag = np.sqrt(sum(g**2 for g in grads) + 1e-6)
grads = [g / mag for g in grads]
# 4. Apply mask and save to global arrays
# Only write within foreground to leave background at 0.
fg = fg_mask > 0
if vol.ndim == 3:
Gz[fg] = grads[-3][fg]
Gv[fg] = grads[-2][fg]
Gh[fg] = grads[-1][fg]
else:
for i, lb in tqdm(enumerate(instances), disable=disable_tqdm):
slc = slice_from_props(props_tbl, i, vol.ndim)
mask = (vol[slc] == lb)
# 1. Define the center (source of the heat)
# Cellpose uses the mask pixel closest to the continuous centroid,
# not the median of sorted coordinates (which can differ for irregular shapes).
coords = np.nonzero(mask)
centroid = np.array([c.mean() for c in coords]) # continuous centroid
coord_stack = np.stack(coords, axis=1).astype(np.float32) # (N, ndim)
idx = int(np.argmin(np.sum((coord_stack - centroid) ** 2, axis=1)))
centers = tuple(int(c[idx]) for c in coords)
# 2. Iterative Diffusion Process
# We initialize the center with a value and let it diffuse strictly within the mask
heat = np.zeros_like(mask, dtype=np.float32)
heat[centers] = 1.0
# Pre-calculate diffusion normalization
# 2D (4 neighbors) -> 0.25 | 3D (6 neighbors) -> 0.1666...
diff_coeff = 1.0 / (2 * vol.ndim)
for _ in range(niter):
# Accumulate neighbor values with zero-boundary conditions.
# np.roll wraps around bounding-box edges (periodic BC), which
# is wrong; explicit slice-based shifts give Dirichlet BC (0 outside).
neighbor_sum = np.zeros_like(heat)
for axis in range(heat.ndim):
slc_src = [slice(None)] * heat.ndim
slc_dst = [slice(None)] * heat.ndim
slc_src[axis] = slice(None, -1) # value at index i
slc_dst[axis] = slice(1, None) # goes to neighbor i+1
neighbor_sum[tuple(slc_dst)] += heat[tuple(slc_src)]
slc_src[axis] = slice(1, None) # value at index i+1
slc_dst[axis] = slice(None, -1) # goes to neighbor i
neighbor_sum[tuple(slc_dst)] += heat[tuple(slc_src)]
new_heat = neighbor_sum * diff_coeff
heat[coords] = new_heat[coords]
# Keep the source hot
heat[centers] = 1.0
# 3. Calculate Gradients
# np.gradient needs at least 2 elements per axis. Skip cells that
# are only 1 pixel wide along any axis (their flow stays 0).
if any(s < 2 for s in heat.shape):
continue
# np.gradient returns a list of arrays: [d0, d1, d2] -> [dz, dy, dx] in 3D
# Pass voxel spacing so anisotropic data produces correct flow directions.
if vol.ndim == 3:
grads = np.gradient(heat, resolution[0], resolution[1], resolution[2])
else:
grads = np.gradient(heat, resolution[1], resolution[2])
# 4. Normalize the flows
mag = np.sqrt(sum(g**2 for g in grads) + 1e-6)
grads = [g / mag for g in grads]
# 5. Apply mask and save to global arrays
if vol.ndim == 3:
# grads[0]=dz, grads[1]=dy, grads[2]=dx
Gz[slc][mask] = grads[0][mask]
Gv[slc][mask] = grads[1][mask]
Gh[slc][mask] = grads[2][mask]
else:
# grads[0]=dy, grads[1]=dx
Gv[slc][mask] = grads[0][mask]
Gh[slc][mask] = grads[1][mask]
# Map back to the new_mask channels
if "Gz" in mode and Gz is not None:
new_mask[..., mode.index("Gz")] = Gz
new_mask[..., mode.index("Gv")] = Gv
new_mask[..., mode.index("Gh")] = Gh
# ---------- T (touching area) ----------
if "T" in mode:
new_mask[..., mode.index("T")] = touching_mask_nd(
vol,
connectivity=new_mask[..., mode.index("T")].ndim
)
# ---------- A (affinities) ----------
if "A" in mode:
ins_vol = vol
wb = int(channel_extra_opts["A"].get("widen_borders", 1))
if wb:
ins_vol = seg_widen_border(vol, tsz_h=wb)
k = 0
for zaff, yaff, xaff in zip(
channel_extra_opts["A"].get("z_affinities", []),
channel_extra_opts["A"].get("y_affinities", []),
channel_extra_opts["A"].get("x_affinities", []),
):
affs = seg2aff_pni(ins_vol, dz=zaff, dy=yaff, dx=xaff, dtype=dtype) # shape: (n_affs, Z, Y, X)
affs = np.transpose(affs, (1, 2, 3, 0)) # shape: (Z, Y, X, n_affs)
new_mask[..., k*3:(k+1)*3] = affs
k += 1
# ---------- R (radial distances) ----------
if "R" in mode:
r_opts = channel_extra_opts.get("R", {})
ndim = 2 if new_mask[..., mode.index("R")].ndim == 2 else 3
nrays = int(r_opts.get("nrays", 32 if ndim == 2 else 96))
rays = generate_rays(n_rays=nrays, ndim=ndim).astype(np.float32)
spacing = None if new_mask[..., mode.index("R")].ndim == 2 else resolution
new_mask[..., mode.index("R"):mode.index("R")+nrays] = radial_distances(vol, rays, spacing=spacing)
# ---------- E (Embeddings) ----------
# Here we only use E_offset as extra target for the embeddings branch
if "E_offset" in mode:
new_mask[..., mode.index("E_offset")] = vol.copy()
if "We" in mode:
mask_to_use = vol
if "F" in mode:
mask_to_use = new_mask[..., mode.index("F")]
elif "B" in mode:
mask_to_use = new_mask[..., mode.index("B")]
new_mask[..., mode.index("We")] = unet_border_weight_map(mask_to_use, w0=10.0, sigma=5.0, resolution=resolution)
# Binarize foreground channel at this point and not before because
# they need to be used in We channel creation
if "F" in mode:
new_mask[..., mode.index("F")] = new_mask[..., mode.index("F")] > 0
# ---------- M (Legacy mask used in CartoCell) ----------
if "M" in mode:
# Binary mask = F + C
f_ch = new_mask[..., mode.index("F")]
c_ch = new_mask[..., mode.index("C")]
new_mask[..., mode.index("M")] = np.clip(f_ch + c_ch, 0, 1).astype(np.uint8)
# Save examples of each channel
if save_dir:
os.makedirs(save_dir, exist_ok=True)
for j, mod in enumerate(mode):
if mod == "B":
suffix = "_background.tif"
elif mod == "F":
suffix = "_foreground.tif"
elif mod == "P":
suffix = "_central_part.tif"
elif mod == "C":
suffix = "_contour.tif"
elif mod == "H":
suffix = "_horizontal_distance.tif"
elif mod == "V":
suffix = "_vertical_distance.tif"
elif mod == "Z":
suffix = "_z_distance.tif"
elif mod == "Gv":
suffix = "_vertical_flow.tif"
elif mod == "Gh":
suffix = "_horizontal_flow.tif"
elif mod == "Gz":
suffix = "_z_flow.tif"
elif mod == "Db":
suffix = "_distance_to_border.tif"
elif mod == "Dc":
suffix = "_distance_to_center.tif"
elif mod == "Dn":
suffix = "_distance_to_neighbor.tif"
elif mod == "D":
suffix = "_distance.tif"
elif mod == "R":
suffix = "_radial_distances.tif"
elif mod == "T":
suffix = "_touching.tif"
elif mod == "A":
suffix = "_affinity.tif"
elif mod == "E_offset":
suffix = "_embedding_instances.tif"
elif mod in ["E_sigma", "E_seediness"]:
continue
elif mod == "We":
suffix = "_border_weights.tif"
elif mod == "M":
suffix = "_CartoCell_M_channel.tif"
else:
raise ValueError("Unknown channel type: {}".format(mod))
aux = new_mask[..., j]
aux = np.expand_dims(np.expand_dims(aux, -1), 0)
save_tif(aux, save_dir, filenames=["vol" + suffix[j]], verbose=False)
save_tif(
np.expand_dims(instance_labels, 0),
save_dir,
filenames=["vol_y.tif"],
verbose=False,
)
return new_mask
[docs]
def norm_channel(channel: NDArray, vol: NDArray, instances: list[int]) -> NDArray:
"""
Normalize a channel based on instance masks.
Parameters
----------
channel : NDArray
The channel to normalize (e.g. db_channel).
vol : NDArray
Instance mask volume, same shape as channel.
instances : list[int]
List of instance IDs in `vol`. Background (0) will be ignored.
Returns
-------
NDArray
Normalized channel, same shape as input.
"""
instances = [inst for inst in instances if inst != 0] # drop background
normed = np.zeros_like(channel, dtype=np.float32)
for inst in instances:
mask = (vol == inst)
if not np.any(mask):
continue
values = channel[mask]
mi, ma = values.min(), values.max()
# Avoid division by zero
if ma == mi:
normed[mask] = 0
else:
normed[mask] = (values - mi) / (ma - mi)
return normed
[docs]
def slice_from_props(props_tbl: pd.DataFrame | dict, i: int, ndim: int) -> tuple[slice, ...]:
"""
Get a slice representation from the properties table for a specific instance.
Parameters
----------
props_tbl : pd.DataFrame | dict
The properties table containing region properties.
i : int
The index of the instance in the properties table.
ndim : int
The number of dimensions (2 or 3).
Returns
-------
tuple[slice, ...]
A tuple of slice objects representing the bounding box of the instance.
"""
if ndim == 2:
y0 = int(props_tbl['bbox-0'][i])
x0 = int(props_tbl['bbox-1'][i])
y1 = int(props_tbl['bbox-2'][i])
x1 = int(props_tbl['bbox-3'][i])
return (slice(y0, y1), slice(x0, x1))
elif ndim == 3:
z0 = int(props_tbl['bbox-0'][i])
y0 = int(props_tbl['bbox-1'][i])
x0 = int(props_tbl['bbox-2'][i])
z1 = int(props_tbl['bbox-3'][i])
y1 = int(props_tbl['bbox-4'][i])
x1 = int(props_tbl['bbox-5'][i])
return (slice(z0, z1), slice(y0, y1), slice(x0, x1))
else:
raise ValueError("Only 2D or 3D volumes are supported.")
[docs]
def unet_border_weight_map(
instances: np.ndarray,
w0: float = 10.0,
sigma: float = 5.0,
apply_only_background: bool = True,
resolution: List[int|float] | None = None,
) -> np.ndarray:
"""
U-Net border-aware weight map (Ronneberger et al. 2015) for 2D or 3D labels.
Parameters
----------
instances : np.ndarray, shape (H, W) or (D, H, W), dtype int
0/`background` for background, 1..N (or any ints != background) are instance ids.
w0 : float
Border weight magnitude.
sigma : float
Spatial decay (in same units as resolution).
apply_only_background : bool
If True, apply the exponential term only on background (as in the paper).
resolution : List[int|float] | None
Voxel spacing along each axis (z,y,x) or (y,x). If None, isotropic spacing of 1 is assumed.
Returns
-------
w : np.ndarray, same shape as `instances`, dtype float32
Border weight map.
"""
if instances.ndim not in (2, 3):
raise ValueError(f"`instances` must be 2D or 3D, got shape {instances.shape}")
inst = instances.astype(np.int32, copy=False)
shp = inst.shape
# collect unique instance ids excluding background
ids = np.unique(inst)
ids = ids[ids != 0]
# Special handling when exactly one instance is present:
# treat background as a pseudo-second instance so we still emphasize the object boundary.
if ids.size == 1:
lab = ids[0]
# Distance to the (only) instance: zeros inside the instance
d_obj = edt.edt(inst != lab, anisotropy=resolution, parallel=-1).astype(np.float32, copy=False)
# Distance to background: zeros in background
d_bg = edt.edt(inst != 0, anisotropy=resolution, parallel=-1).astype(np.float32, copy=False)
denom = 2.0 * (sigma ** 2)
w_border = w0 * np.exp(-((d_obj + d_bg) ** 2) / denom, dtype=np.float32)
w_border = w_border.astype(np.float32, copy=False)
if apply_only_background:
w_border *= (inst == 0)
return w_border
# Need at least two distinct instances for the (d1 + d2) term to be meaningful
if ids.size < 2:
return np.zeros(shp, dtype=np.float32)
# Compute distance-to-each-instance via EDT on the complement of that instance
# distances[k, ...] = distance to instance ids[k]
distances = np.empty((ids.size, *shp), dtype=np.float32)
for k, lab in tqdm(enumerate(ids), total=len(ids)):
# edt computes distance to zeros -> pass mask that's zero *inside* the object
# equivalently: distance to the boundary of object `lab`
distances[k] = edt.edt(inst != lab, anisotropy=resolution, parallel=-1)
# nearest and second-nearest distances at each voxel/pixel
d1 = distances.min(axis=0)
d2 = np.partition(distances, 1, axis=0)[1]
# Border emphasis term
denom = 2.0 * (sigma ** 2)
w_border = w0 * np.exp(-((d1 + d2) ** 2) / denom, dtype=np.float32)
w_border = w_border.astype(np.float32, copy=False)
if apply_only_background:
w_border *= (inst == 0)
return w_border
[docs]
def touching_mask_nd(labels: NDArray, connectivity: int = 1) -> NDArray:
"""
Create a binary mask of touching pixels/voxels for an N-D labeled instance mask.
Parameters
----------
labels : NDArray
N-D array of instance labels (0 = background, 1..N = instances).
connectivity : int, optional
Neighborhood connectivity passed to `generate_binary_structure`.
1 = 6-neigh for 3D / 4-neigh for 2D,
2 = 18-neigh for 3D / 8-neigh for 2D,
3 = 26-neigh for 3D (if ndim==3).
Returns
-------
touch : NDArray
Binary mask with 1 where a voxel touches at least one *different* instance.
"""
# Neighborhood footprint including the center
footprint = generate_binary_structure(labels.ndim, connectivity)
def is_touching(window):
center = window[len(window)//2]
if center == 0: # background is never touching
return 0
# unique neighbor labels (including center); drop 0 and center label
uniq = np.unique(window)
return 1 if np.any((uniq != 0) & (uniq != center)) else 0
touch = generic_filter(
labels,
is_touching,
footprint=footprint,
mode='constant',
cval=0
)
return touch.astype(np.uint8)
[docs]
def generate_rays(n_rays: int, ndim: int, jitter: bool=False, seed: int=0):
"""
Unit directions in R^ndim.
- 2D: uniform angles on circle -> (R,2) [dx,dy]
- 3D: Fibonacci sphere -> (R,3) [dx,dy,dz]
Parameters
----------
n_rays : int
Number of rays to generate.
ndim : int
Dimensionality (2 or 3).
jitter : bool, optional
Whether to add jitter to 3D rays (default: False).
seed : int, optional
Random seed for jitter (default: 0).
Returns
-------
rays : (n_rays, 2) or (n_rays, 3) Numpy array
Unit vectors along which to compute distances.
"""
if ndim == 2:
a = np.linspace(0, 2*np.pi, n_rays, endpoint=False, dtype=np.float32)
return np.stack([np.cos(a), np.sin(a)], axis=1).astype(np.float32)
elif ndim == 3:
rng = np.random.default_rng(seed) if jitter else None
i = np.arange(n_rays, dtype=np.float32)
phi = (1 + np.sqrt(5.0)) / 2.0
z = 1 - 2*(i + 0.5) / n_rays
r = np.sqrt(np.maximum(0.0, 1 - z*z))
theta = 2*np.pi*i/phi
if jitter:
theta += rng.uniform(-np.pi/n_rays, np.pi/n_rays, size=n_rays)
z += rng.uniform(-1/n_rays, 1/n_rays, size=n_rays)
z = np.clip(z, -1.0, 1.0); r = np.sqrt(np.maximum(0.0, 1 - z*z))
x = r * np.cos(theta); y = r * np.sin(theta)
dirs = np.stack([x, y, z], axis=1).astype(np.float32)
dirs /= (np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12)
return dirs
else:
raise ValueError("Only 2D and 3D are supported.")
[docs]
def radial_distances(
labels: NDArray,
rays: NDArray,
max_dist: Optional[float] = None,
spacing: Optional[Sequence[float]] = None,
max_iters: int = 50,
) -> NDArray:
"""
Compute radial distances from each foreground pixel to the instance boundary along specified rays.
Parameters
----------
labels : NDArray
2D or 3D array of instance labels (0 = background, 1..N = instances).
rays : (n_rays, 2) or (n_rays, 3) Numpy array
Unit vectors along which to compute distances.
max_dist : float, optional
Maximum distance to cap at. If None, no capping is done.
spacing : sequence of float, optional
Physical spacing of the data in each dimension. If None, assumes isotropic spacing of 1.0.
max_iters : int
Maximum number of steps to march along each ray.
Returns
-------
D : NDArray
Array of shape (H, W, n_rays) or (D, H, W, n_rays) with distances in physical units.
Background pixels have distance 0 in all rays.
"""
labels = np.asarray(labels)
ndim = labels.ndim
assert rays.ndim == 2 and rays.shape[1] == ndim
spacing = np.ones(ndim, np.float32) if spacing is None else np.asarray(spacing, np.float32)
shape = labels.shape
n_rays = rays.shape[0]
# normalize rays in index space (row/col[/z] units)
rays_idx = rays.astype(np.float32)
norms = np.linalg.norm(rays_idx, axis=1, keepdims=True) + 1e-12
rays_idx /= norms
# per-ray physical step length for one unit in index space
ray_step_phys = np.linalg.norm(rays_idx * spacing, axis=1) # shape (n_rays,)
D = np.zeros(shape + (n_rays,), np.float32)
fg = np.argwhere(labels > 0)
H, W = shape[0], shape[1] if ndim == 2 else (shape[0], shape[1]) # for bounds
for (i, j, *rest) in fg:
inst_id = int(labels[i, j]) if ndim == 2 else int(labels[i, j, rest[0]])
p0 = np.array([i, j] if ndim == 2 else [i, j, rest[0]], np.float32) # pixel center reference
for k in range(n_rays):
u = rays_idx[k] # unit in index space
x = np.zeros(ndim, np.float32) # accumulated offset in index space
# march in unit steps like the ref (||u||=1)
for _ in range(max_iters * (max(shape) + 2)): # safe cap
x += u
p_samp = p0 + x
# rounded sampling
if ndim == 2:
ii = int(np.rint(p_samp[0])); jj = int(np.rint(p_samp[1]))
out = (ii < 0 or ii >= shape[0] or jj < 0 or jj >= shape[1])
changed = (not out) and (labels[ii, jj] != inst_id)
else:
ii = int(np.rint(p_samp[0])); jj = int(np.rint(p_samp[1])); kk = int(np.rint(p_samp[2]))
out = (ii < 0 or ii >= shape[0] or jj < 0 or jj >= shape[1] or kk < 0 or kk >= shape[2])
changed = (not out) and (labels[ii, jj, kk] != inst_id)
if out or changed:
max_comp = np.max(np.abs(u)) + 1e-12
t_corr = 1.0 - 0.5 / max_comp
x = x - t_corr * u # pull back along dominant axis
# distance in pixels
dist_idx = float(np.linalg.norm(x))
# convert to physical units if requested
dist = dist_idx * float(ray_step_phys[k])
if max_dist is not None and dist > max_dist:
dist = max_dist
if ndim == 2:
D[i, j, k] = dist
else:
D[i, j, rest[0], k] = dist
break
return D
[docs]
def euler_integration(flow: NDArray, coords: NDArray, n_steps: int = 200, dt: float = 1.0, suppressed: bool = True):
"""
Euler integration of flow field starting at coords.
Parameters
----------
flow : (2, H, W) or (3, D, H, W) Numpy array
Flow field (y,x) or (z,y,x).
coords : (N, 2) or (N, 3) Numpy array
Starting coordinates (y,x) or (z,y,x) in index space.
n_steps : int
Number of integration steps.
dt : float
Integration step size.
suppressed : bool
Whether to use time-suppressed integration (dt/(t+1)) or not (constant dt).
Returns
-------
pos : (N, 2) or (N, 3) Numpy array
Final positions after integration.
"""
pos = coords.astype(float).copy()
H, W = flow.shape[1:]
for t in range(n_steps):
# Interpolate flow at current positions
fy = map_coordinates(flow[0], [pos[:,0], pos[:,1]], order=1, mode='nearest')
fx = map_coordinates(flow[1], [pos[:,0], pos[:,1]], order=1, mode='nearest')
step = np.stack([fy, fx], axis=1)
# suppression factor
factor = dt / (t+1) if suppressed else dt
pos += factor * step
# keep inside bounds
pos[:,0] = np.clip(pos[:,0], 0, H-1)
pos[:,1] = np.clip(pos[:,1], 0, W-1)
return pos # final positions for clustering
def _in_bounds(p: np.ndarray, shape_zyx: Tuple[int, int, int]) -> bool:
"""
Check if a point p (z,y,x) is within the bounds of a shape (Z,Y,X).
Parameters
----------
p : np.ndarray
A point in (z,y,x) coordinates.
shape_zyx : Tuple[int, int, int]
The shape of the volume in (Z,Y,X).
Returns
-------
bool
True if p is within bounds, False otherwise.
"""
# p is (3,) z,y,x
return bool(np.all((p >= 0) & (p < np.asarray(shape_zyx))))
def _bbox_from_points(points_zyx: np.ndarray, width_zyx: np.ndarray, shape_zyx: Tuple[int, int, int]) -> List[int]:
"""
Compute a bounding box [z0,z1,y0,y1,x0,x1] that contains all points in `points_zyx` with an additional `width_zyx`
margin, while ensuring the box is within the bounds of `shape_zyx`.
Parameters
----------
points_zyx : np.ndarray
An array of shape (N, 3) containing points in (z,y,x) coordinates.
width_zyx : np.ndarray
An array of shape (3,) specifying the margin to add in each dimension (z,y,x).
shape_zyx : Tuple[int, int, int]
The shape of the volume in (Z,Y,X) to ensure the bounding box does not exceed these bounds. Returns ------- List[int] A list containing the bounding box coordinates [z0,z1,y0,y1,x0,x1].
Returns
-------
List[int] A list containing the bounding box coordinates [z0,z1,y0,y1,x0,x1].
"""
shape = np.asarray(shape_zyx)
lo = np.maximum(0, points_zyx.min(axis=0) - width_zyx)
hi = np.minimum(shape, points_zyx.max(axis=0) + width_zyx)
# make hi exclusive and ints
return [int(lo[0]), int(hi[0]), int(lo[1]), int(hi[1]), int(lo[2]), int(hi[2])]
def _make_output_array(fname: str, data_shape_zyx, channels: int, zinfo: Dict, dtype_str="float32"):
"""
Create an output array for the generated mask, either as a Zarr or HDF5 dataset depending on the file extension of `fname`.
Parameters
----------
fname : str
The filename for the output array. If it ends with .h5 or .hdf the function will create an HDF5 dataset; otherwise, it will
create a Zarr array.
data_shape_zyx : Tuple[int, int, int]
The shape of the data in (Z,Y,X) order.
channels : int
The number of channels in the output array.
zinfo : Dict
A dictionary containing metadata, including the "axes_order" key which specifies the order of axes in the output array (e.g. "ZYXC" or "ZCYX").
dtype_str : str
The data type for the output array (default: "float32").
Returns
-------
mask : Zarr array or HDF5 dataset
The created output array for the generated mask.
fid_mask : h5py.File or None
The HDF5 file object if an HDF5 dataset was created, or None if a Zarr array was created.
out_shape : Tuple[int, ...]
The shape of the output array, including channels, in the order specified by `zinfo["axes_order"]`.
out_order : str
The order of axes in the output array, as specified by `zinfo["axes_order"]` with "C" for channels if not already included.
c_pos : int or None
The position of the channel axis in the output array, or None if channels are added as the last axis.
"""
os.makedirs(os.path.dirname(fname), exist_ok=True)
out_data_shape = np.array(data_shape_zyx, dtype=int)
axes = zinfo["axes_order"]
if "C" not in axes:
out_shape = tuple(out_data_shape) + (channels,)
out_order = axes + "C"
c_pos = -1
else:
out_shape = out_data_shape.copy()
out_shape[axes.index("C")] = channels
out_shape = tuple(int(s) for s in out_shape)
out_order = axes
c_pos = axes.index("C")
is_h5 = looks_like_hdf5(fname)
fid_mask = None
if is_h5:
base, _ = os.path.splitext(fname)
fid_mask = h5py.File(base + ".h5", "w")
ds_kwargs = {
"shape": out_shape,
"dtype": dtype_str,
"chunks": pick_chunks(out_shape, dtype_str, target_mb=4.0),
"compression": "gzip",
"compression_opts": 4,
"shuffle": True,
}
mask = fid_mask.create_dataset("data", **ds_kwargs)
else:
mask = zarr.open(fname, mode="w", shape=out_shape, dtype=dtype_str, zarr_format=3)
return mask, fid_mask, out_shape, out_order, c_pos
def _clip_slices(z, y, x, hz, hy, hx, Z, Y, X):
z0, z1 = max(0, z - hz), min(Z, z + hz + 1)
y0, y1 = max(0, y - hy), min(Y, y + hy + 1)
x0, x1 = max(0, x - hx), min(X, x + hx + 1)
return slice(z0, z1), slice(y0, y1), slice(x0, x1)
def _beam_score(smoothed, z, y, x, beam_hw, score_mode="p10"):
hz, hy, hx = beam_hw
Z, Y, X = smoothed.shape
slz, sly, slx = _clip_slices(z, y, x, hz, hy, hx, Z, Y, X)
slab = smoothed[slz, sly, slx]
if slab.size == 0:
return np.inf
if score_mode == "min":
return float(slab.min())
elif score_mode == "p10":
return float(np.percentile(slab, 10))
else:
raise ValueError(f"Unknown score_mode: {score_mode}")
[docs]
def synapse_channel_creation(
data_info: Dict,
zarr_data_information: Dict,
savepath: str,
mode: List[str] = ["F_pre", "F_post"],
channel_extra_opts: Dict[str, Dict] = {},
verbose: bool = False,
):
"""
Create different channels that represent a synapse segmentation problem to train an instance segmentation problem.
This function is only prepared to read an H5/Zarr file that follows `CREMI data format <https://cremi.org/data/>`__.
Parameters
----------
data_info : dict
All patches that can be extracted from all the Zarr/H5 samples in ``data_path``.
Keys created are:
* ``"filepath"``: path to the file where the patch was extracted.
* ``"full_shape"``: shape of the data within the file where the patch was extracted.
* ``"patch_coords"``: coordinates of the data that represents the patch.
zarr_data_information : dict
Information when using Zarr/H5 files. Assumes that the H5/Zarr files contain the information according
`CREMI data format <https://cremi.org/data/>`__. The following keys are expected:
* ``"raw_data_path"``: path within the file where the raw data is stored. Reference in CREMI: ``volumes/raw``
* ``"axes_order"``: order of the axes in the file. E.g. "ZYX" or "ZCYX".
* ``"z_axe_pos"``: position of z axis of the data within the file.
* ``"y_axe_pos"``: position of y axis of the data within the file.
* ``"x_axe_pos"``: position of x axis of the data within the file.
* ``"id_path"``: path within the file where the ``ids`` are stored.
Reference in CREMI: ``annotations/ids``
* ``"partners_path"``: path within the file where ``partners`` is stored.
Reference in CREMI: ``annotations/partners``
* ``"locations_path"``: path within the file where ``locations`` is stored.
Reference in CREMI: ``annotations/locations``
* ``"resolution_path"``: path within the file where ``resolution`` is stored.
Reference in CREMI: ``["volumes/raw"].attrs["offset"]``
savepath : str
Path to save the data created.
mode : List, optional
Operation mode.
channel_extra_opts : dict, optional
Extra options for specific channels. For example, dilation for the "F_pre" and "F_post" channels. Expected keys are:
* ``"F_pre"``: options for the "F_pre" channel. Expected keys are:
* ``"dilation"``: list of 3 ints specifying the dilation in z,y,x for the "F_pre" channel (default: [1,10,10]).
* ``"F_post"``: options for the "F_post" channel. Expected keys are:
* ``"dilation"``: list of 3 ints specifying the dilation in z,y,x for the "F_post" channel (default: [1,10,10]).
* ``"H"``, ``"V"``, ``"Z"``: options for the distance channels. Expected keys are:
* ``"norm"``: whether to normalize the distance channels per instance (default: True).
verbose : bool, optional
Whether to print warnings about out-of-bounds synaptic points (default: False).
Returns
-------
new_mask : 5D Numpy array
5D array with 3 channels instead of one. E.g. ``(10, 200, 1000, 1000, 3)``
patch_offset : list of list
Pixels used on each axis to pad the patch in order to not cut some of the values in the edges.
"""
# -------------------------------------------------------
# 1. Determine Channel Count based on Mode
# -------------------------------------------------------
selected_mode = ""
if all(ch in mode for ch in ["F_pre", "F_post"]) and len(mode) == 2:
channels = 2
dtype_str = "uint8"
selected_mode = "simpsyn"
elif all(ch in mode for ch in ["F_post", "Z", "V", "H"]) and len(mode) == 4:
channels = 4
dtype_str = "float32"
norm = True if any([channel_extra_opts.get(k, {}).get("norm", True) for k in ["Z", "V", "H"]]) else False
selected_mode = "synful"
elif all(ch in mode for ch in ["F_cleft"]) and len(mode) == 1:
channels = 1
dtype_str = "uint8"
selected_mode = "cleft"
cleft_dilation = channel_extra_opts.get("F_cleft", {}).get("dilation", [1, 3, 3])
cleft_footprint = generate_ellipse_footprint(cleft_dilation)
cleft_search_dilation = channel_extra_opts.get("F_cleft", {}).get("search_dilation", [1, 5, 5])
cleft_n_samples = int(channel_extra_opts.get("F_cleft", {}).get("n_samples", 51))
cleft_t_range = channel_extra_opts.get("F_cleft", {}).get("t_range", (0.15, 0.85))
beam_halfwidth = channel_extra_opts.get("F_cleft", {}).get("beam_halfwidth", [0, 3, 3]) # [hz, hy, hx]
drop_rel = float(channel_extra_opts.get("F_cleft", {}).get("drop_rel", 0.20)) # 20% drop from baseline
drop_abs = float(channel_extra_opts.get("F_cleft", {}).get("drop_abs", 10.0)) # absolute drop in intensity units
baseline_frac = float(channel_extra_opts.get("F_cleft", {}).get("baseline_frac", 0.15)) # first 15% of path = baseline
min_persist = int(channel_extra_opts.get("F_cleft", {}).get("min_persist", 2)) # must stay dark for N steps
elif all(ch in mode for ch in ["F_post"]) and len(mode) == 1:
channels = 1
dtype_str = "uint8"
selected_mode = "F_post_only"
else:
raise ValueError(f"Unsupported mode: {mode}")
if selected_mode == "synful":
presite_dilation = channel_extra_opts.get("H", {}).get("dilation", [3, 25, 25])
else:
presite_dilation = channel_extra_opts.get("F_pre", {}).get("dilation", [1, 3, 3])
postsite_dilation = channel_extra_opts.get("F_post", {}).get("dilation", [1, 3, 3])
# footprints (keep both since you had them)
pre_footprint = generate_ellipse_footprint(presite_dilation)
post_footprint = generate_ellipse_footprint(postsite_dilation)
unique_files = []
unique_shapes = []
for i in range(len(data_info)):
if data_info[i]["filepath"] not in unique_files:
unique_files.append(data_info[i]["filepath"])
unique_shapes.append(data_info[i]["full_shape"])
rank = get_rank()
world_size = get_world_size()
it = range(rank, len(unique_files), world_size)
width = np.array([max(a,b) for a,b in zip(presite_dilation, postsite_dilation)])
print("Collecting all pre/post-synaptic points")
for idx in tqdm(it, disable=not is_main_process()):
filename, data_shape = unique_files[idx], tuple(unique_shapes[idx])
print(f"Processing file: {filename}")
# Take all the information within the dataset
files = []
file, ids = read_chunked_nested_data(filename, zarr_data_information["id_path"])
ids = list(np.array(ids))
files.append(file)
# file, types = read_chunked_nested_data(filename, cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA_TYPES_PATH)
# files.append(file)
file, partners = read_chunked_nested_data(filename, zarr_data_information["partners_path"])
partners = np.array(partners)
files.append(file)
file, locations = read_chunked_nested_data(filename, zarr_data_information["locations_path"])
locations = np.array(locations)
files.append(file)
file, resolution = read_chunked_nested_data(filename, zarr_data_information["resolution_path"])
files.append(file)
try:
resolution = resolution.attrs["resolution"]
except:
raise ValueError(
"There is no 'resolution' attribute in '{}'. Add it like: data['{}'].attrs['resolution'] = (8,8,8)".format(
zarr_data_information["resolution_path"], zarr_data_information["resolution_path"]
)
)
if selected_mode == "cleft":
data_file, raw_data = read_chunked_nested_data(filename, zarr_data_information["raw_data_path"])
# Close files
for f in files:
if isinstance(f, h5py.File):
f.close()
del files
id_to_pos = {sid: i for i, sid in enumerate(ids)}
shape_zyx = tuple(int(x) for x in data_shape)
pre_post_points = {} # pre_loc tuple -> list[post_loc tuple]
pre_seen, post_seen = set(), set()
pre_missed, post_missed = 0, 0
for i in tqdm(range(len(partners)), disable=not is_main_process()):
pre_id, post_id = partners[i]
pre_idx = id_to_pos.get(pre_id)
post_idx = id_to_pos.get(post_id)
if pre_idx is None or post_idx is None:
# inconsistent annotation; skip quietly
continue
pre_loc = (locations[pre_idx] // resolution).astype(int)
post_loc = (locations[post_idx] // resolution).astype(int)
pre_ok = _in_bounds(pre_loc, shape_zyx)
post_ok = _in_bounds(post_loc, shape_zyx)
if not pre_ok:
if verbose:
print(f"WARNING: discarding presynaptic point {pre_loc} out of shape: {shape_zyx}")
if pre_id not in pre_seen:
pre_missed += 1
if not post_ok:
if verbose:
print(f"WARNING: discarding postsynaptic point {post_loc} out of shape: {shape_zyx}")
if post_id not in post_seen:
post_missed += 1
if pre_ok and post_ok:
pre_key = tuple(pre_loc.tolist())
pre_post_points.setdefault(pre_key, []).append(tuple(post_loc.tolist()))
pre_seen.add(pre_id)
post_seen.add(post_id)
print(f"Total unique pre-synaptic points: {len(pre_seen)}")
print(f"Total unique post-synaptic points: {len(post_seen)}")
print(f"Total pre-synaptic points missed: {pre_missed}")
print(f"Total post-synaptic points missed: {post_missed}\n")
if not pre_post_points:
continue
# output file
out_fname = os.path.join(savepath, os.path.basename(filename))
mask, fid_mask, out_shape, out_order, c_pos = _make_output_array(
out_fname, shape_zyx, channels, zarr_data_information, dtype_str=dtype_str
)
if selected_mode == "synful":
_process_synapses_by_chunks_of_data(
channel_mode=mode,
selected_mode=selected_mode,
resolution=resolution,
pre_post_points=pre_post_points,
shape_zyx=shape_zyx,
mask=mask,
fid_mask=fid_mask,
out_shape=out_shape,
out_order=out_order,
c_pos=c_pos,
width=width,
pre_footprint=pre_footprint,
post_footprint=post_footprint,
norm=norm,
)
else:
if selected_mode not in ["simpsyn", "cleft", "F_post_only"]:
raise NotImplementedError(f"Mode {selected_mode} not implemented for processing by synapse pairs.")
print("Processing synapse pairs . . .")
for pre_point_global, post_sites in tqdm(pre_post_points.items(), disable=not is_main_process()):
if not post_sites:
continue
pre_point_global = np.asarray(pre_point_global, dtype=int)
post_sites_arr = np.asarray(post_sites, dtype=int)
# Define patch bounding box to load/write
bbox_points = []
bbox_points.append(pre_point_global)
bbox_points.append(post_sites_arr)
bbox_points = np.vstack([p if p.ndim == 2 else p[None, :] for p in bbox_points])
bbox = _bbox_from_points(bbox_points, width, shape_zyx)
patch_shape = (bbox[1] - bbox[0], bbox[3] - bbox[2], bbox[5] - bbox[4])
data_slices = (
slice(bbox[0], bbox[1]),
slice(bbox[2], bbox[3]),
slice(bbox[4], bbox[5]),
slice(0, out_shape[c_pos]),
)
data_slices = tuple(order_dimensions(data_slices, input_order="ZYXC", output_order=out_order, default_value=0))
pre_local = pre_point_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int)
# -------------------------------------------------------
# MODE: SimpSyn (F_pre, F_post)
# -------------------------------------------------------
if selected_mode == "simpsyn":
out_map = np.zeros(patch_shape + (channels,), dtype=np.uint8)
# Pre point
out_map[
max(0, pre_local[0] - 1) : min(pre_local[0] + 1, out_map.shape[0]),
pre_local[1],
pre_local[2],
mode.index("F_pre"),
] = 1
# Post points
for post_global in post_sites_arr:
post_local = post_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int)
if not _in_bounds(post_local, patch_shape):
raise ValueError(f"Point {post_local.tolist()} out of shape: {patch_shape}")
out_map[
max(0, post_local[0] - 1) : min(post_local[0] + 1, out_map.shape[0]),
post_local[1],
post_local[2],
mode.index("F_post"),
] = 1
# Dilate each channel
for c in range(out_map.shape[-1]):
structure = pre_footprint if c == mode.index("F_pre") else post_footprint
out_map[..., c] = binary_dilation_scipy(out_map[..., c], iterations=1, structure=structure)
# -------------------------------------------------------
# MODE: Cleft (F_cleft)
# -------------------------------------------------------
elif selected_mode == "cleft":
cleft_pos = 0
out_map = np.zeros(patch_shape + (channels,), dtype=np.uint8)
# Extract raw patch for cleft processing
raw_data_slices = (
slice(bbox[0], bbox[1]),
slice(bbox[2], bbox[3]),
slice(bbox[4], bbox[5])
)
raw_data_slices = tuple(order_dimensions(raw_data_slices, input_order="ZYX", output_order=zarr_data_information["axes_order"], default_value=0))
raw_patch = np.asarray(raw_data[raw_data_slices], dtype=np.uint8)
# 2) Pre-smooth with a local-mean filter to make "darkest" more stable
sz, sy, sx = [int(x) for x in cleft_search_dilation]
smoothed = uniform_filter(raw_patch, size=(2 * sz + 1, 2 * sy + 1, 2 * sx + 1), mode="nearest")
# 3) For each (pre, post), find darkest point along the segment and seed it
a, b = cleft_t_range
ts = np.linspace(a, b, cleft_n_samples, dtype=np.float32)
cleft_seed = np.zeros(patch_shape, dtype=np.uint8)
pre_f = pre_local.astype(np.float32)
Z, Y, X = patch_shape
beam_hw = [int(x) for x in beam_halfwidth]
a, b = cleft_t_range
ts = np.linspace(a, b, cleft_n_samples, dtype=np.float32)
for post_global in post_sites_arr:
post_local = post_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int)
if not _in_bounds(post_local, patch_shape):
continue
post_f = post_local.astype(np.float32)
# Sample points from POST -> PRE (important for "first drop")
pts = post_f[None, :] + ts[:, None] * (pre_f[None, :] - post_f[None, :])
# Compute beam score along the path
prof = np.full((len(ts),), np.inf, dtype=np.float32)
coords = np.round(pts).astype(int)
for i, (z, y, x) in enumerate(coords):
if 0 <= z < Z and 0 <= y < Y and 0 <= x < X:
prof[i] = _beam_score(smoothed, z, y, x, beam_hw, score_mode="p10")
# Baseline = robust bright-ish level near the POST side
nb = max(3, int(np.ceil(baseline_frac * len(ts))))
base_vals = prof[:nb]
base_vals = base_vals[np.isfinite(base_vals)]
if base_vals.size == 0:
# fallback: midpoint
mid = np.round((pre_f + post_f) * 0.5).astype(int)
mid = np.clip(mid, [0, 0, 0], np.array(patch_shape) - 1)
cleft_seed[tuple(mid.tolist())] = 1
continue
baseline = float(np.median(base_vals))
# Threshold for "first dark region"
# Condition: prof <= baseline - max(abs_drop, rel_drop * baseline)
thr = baseline - max(drop_abs, drop_rel * baseline)
# Find earliest index that crosses threshold and stays low for min_persist steps
hit_idx = None
for i in range(nb, len(ts)):
if prof[i] <= thr:
# persist check (hysteresis)
j_end = min(len(ts), i + min_persist)
if np.all(prof[i:j_end] <= thr):
hit_idx = i
break
if hit_idx is None:
# Fallback: choose minimum in the middle range (still safer than global min)
mid0 = nb
mid1 = len(ts)
i_min = int(np.nanargmin(prof[mid0:mid1])) + mid0
hit_idx = i_min
z, y, x = coords[hit_idx]
z, y, x = int(np.clip(z, 0, Z - 1)), int(np.clip(y, 0, Y - 1)), int(np.clip(x, 0, X - 1))
cleft_seed[z, y, x] = 1
# 4) Paint the cleft "ball" using your existing dilation footprint
out_map[..., cleft_pos] = binary_dilation_scipy(cleft_seed, iterations=1, structure=cleft_footprint).astype(np.uint8)
elif selected_mode == "F_post_only":
out_map = np.zeros(patch_shape + (channels,), dtype=np.uint8)
# Post points only
for post_global in post_sites_arr:
post_local = post_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int)
if not _in_bounds(post_local, patch_shape):
raise ValueError(f"Point {post_local.tolist()} out of shape: {patch_shape}")
out_map[
max(0, post_local[0] - 1) : min(post_local[0] + 1, out_map.shape[0]),
post_local[1],
post_local[2],
0,
] = 1
# Dilate the channel
structure = post_footprint
out_map[..., 0] = binary_dilation_scipy(out_map[..., 0], iterations=1, structure=structure)
else:
raise NotImplementedError(f"Mode {selected_mode} not implemented for processing by synapse pairs.")
# transpose into output order before writing
current_order = np.arange(out_map.ndim)
transpose_order = order_dimensions(
current_order, input_order="ZYXC", output_order=out_order, default_value=np.nan
)
transpose_order = [int(x) for x in transpose_order if not np.isnan(x)]
patch = out_map.transpose(transpose_order)
# write only where empty (background check)
target = mask[data_slices]
mask[data_slices] = target + patch * (target == 0)
if isinstance(fid_mask, h5py.File):
fid_mask.close()
if "data_file" in locals():
data_file.close()
def _process_synapses_by_chunks_of_data(
channel_mode: List[str],
selected_mode: str,
resolution: Sequence[float],
pre_post_points: Dict[Tuple[int, int, int], List[Tuple[int, int, int]]],
shape_zyx: Tuple[int, int, int],
mask: Union[h5py.Dataset, zarr.Array],
fid_mask: Optional[h5py.File],
out_shape: Tuple[int, ...],
out_order: str,
c_pos: Optional[int],
width: np.ndarray,
pre_footprint: np.ndarray,
post_footprint: np.ndarray,
norm: bool,
):
"""
Process synapses by iterating over chunks of the data, loading only the relevant synapse points for each chunk, and writing the output mask incrementally.
This is more memory efficient for large datasets with many synapses.
Parameters
----------
channel_mode : List[str]
The list of channel names in the output mask (e.g. ["F_pre", "F_post"] or ["F_post", "H", "V", "Z"]).
selected_mode : str
The selected mode of operation, either "simpsyn" for simple synapse pairs or "synful" for the full synapse representation with distance channels.
resolution : Sequence[float]
The physical resolution of the data in (Z,Y,X) order.
pre_post_points : Dict[Tuple[int, int, int], List[Tuple[int, int, int]]]
A dictionary mapping presynaptic point coordinates (z,y,x) to a list of postsynaptic point coordinates (z,y,x) that are partners of the presynaptic point.
shape_zyx : Tuple[int, int, int]
The shape of the data in (Z,Y,X) order.
mask : Union[h5py.Dataset, zarr.Array]
The output array (Zarr or HDF5 dataset) where the generated mask will be written.
fid_mask : Optional[h5py.File]
The HDF5 file object if an HDF5 dataset was created, or None if a Zarr array was created.
out_shape : Tuple[int, ...]
The shape of the output array, including channels, in the order specified by `out_order`.
out_order : str
The order of axes in the output array (e.g. "ZYXC" or "ZCYX").
c_pos : Optional[int]
The position of the channel axis in the output array, or None if channels are added as the last axis.
width : np.ndarray
An array of shape (3,) specifying the margin to add in each dimension (z,y,x) around synaptic points to ensure they are fully captured in the output mask.
pre_footprint : np.ndarray
The structuring element to use for dilating the presynaptic points.
post_footprint : np.ndarray
The structuring element to use for dilating the postsynaptic points.
norm : bool
Whether to normalize the distance channels per instance (only relevant for "synful" mode).
"""
print("Processing synapses by chunks of data . . .")
if selected_mode != "synful":
raise NotImplementedError(f"Mode {selected_mode} not implemented for processing by chunks.")
# ---------
# 1. Improved Halo Calculation
# ---------
def _footprint_radius(fp):
fp = np.asarray(fp)
return np.array([s // 2 for s in fp.shape], dtype=int)
seed_reach = np.asarray(width, dtype=int)
pre_reach = _footprint_radius(pre_footprint)
post_reach = _footprint_radius(post_footprint)
# Use the max reach across all operations.
# We add a +2 safety buffer to prevent float rounding issues in KDTree/resolution scaling
halo = np.maximum.reduce([seed_reach, pre_reach, post_reach]) + 2
# Chunks
cz, cy, cx = 64, 256, 256
# ---------------------------------------------------------
# 3. Spatial Index (Crucial Fix: Use index_halo)
# ---------------------------------------------------------
syn_items = list(pre_post_points.items())
chunk_to_syn = {}
print("Building spatial index for synapses . . .")
for syn_id, (pre_pt, post_list) in enumerate(syn_items):
if not post_list: continue
pre_pt = np.asarray(pre_pt, dtype=int)
posts = np.asarray(post_list, dtype=int)
pts = np.vstack([pre_pt[None, :], posts])
# We calculate the BBox including the halo.
# This ensures the synapse is registered in every chunk it COULD bleed into.
zmin, ymin, xmin = np.min(pts, axis=0) - halo
zmax, ymax, xmax = np.max(pts, axis=0) + halo
cz0, cz1 = max(0, zmin // cz), min((shape_zyx[0] - 1) // cz, zmax // cz)
cy0, cy1 = max(0, ymin // cy), min((shape_zyx[1] - 1) // cy, ymax // cy)
cx0, cx1 = max(0, xmin // cx), min((shape_zyx[2] - 1) // cx, xmax // cx)
for iz in range(int(cz0), int(cz1) + 1):
for iy in range(int(cy0), int(cy1) + 1):
for ix in range(int(cx0), int(cx1) + 1):
chunk_to_syn.setdefault((iz, iy, ix), []).append(syn_id)
# ---------------------------
# 4. Processing Loop
# ---------------------------
z_max, y_max, x_max = map(int, shape_zyx)
ncz, ncy, ncx = (z_max + cz - 1) // cz, (y_max + cy - 1) // cy, (x_max + cx - 1) // cx
res_array = np.asarray(resolution, dtype=np.float32)
for iz in tqdm(range(ncz), disable=not is_main_process()):
z0c, z1c = iz * cz, min((iz + 1) * cz, z_max)
for iy in range(ncy):
y0c, y1c = iy * cy, min((iy + 1) * cy, y_max)
for ix in range(ncx):
x0c, x1c = ix * cx, min((ix + 1) * cx, x_max)
syn_ids = chunk_to_syn.get((iz, iy, ix), [])
if not syn_ids: continue
# Define patch with halo
ext_bbox = [
max(0, z0c - halo[0]), min(z_max, z1c + halo[0]),
max(0, y0c - halo[1]), min(y_max, y1c + halo[1]),
max(0, x0c - halo[2]), min(x_max, x1c + halo[2])
]
patch_shape = (ext_bbox[1]-ext_bbox[0], ext_bbox[3]-ext_bbox[2], ext_bbox[5]-ext_bbox[4])
offset = np.array([ext_bbox[0], ext_bbox[2], ext_bbox[4]])
seeds = np.zeros(patch_shape, dtype=np.uint32)
mask_to_grow = np.zeros(patch_shape, dtype=np.uint8)
label_to_pre_site = {}
label_count = 1
for syn_id in syn_ids:
pre_pt, post_list = syn_items[syn_id]
pre_local = np.asarray(pre_pt) - offset
for post_pt in post_list:
post_local = np.asarray(post_pt) - offset
# We allow points slightly outside the patch because their dilation
# footprint (up to 'halo' size) will bleed back into the patch core.
is_relevant = np.all(
(post_local >= -halo) &
(post_local < np.asarray(patch_shape) + halo)
)
if is_relevant:
# 1. Calculate Z-column range for the seed
z_idx = int(post_local[0])
z0s = max(0, z_idx - seed_reach[0])
z1s = min(patch_shape[0], z_idx + seed_reach[0] + 1)
# 2. Coordinates for Y and X
yy, xx = int(post_local[1]), int(post_local[2])
# Only write if the column center (yy, xx) is inside the patch bounds
if 0 <= yy < patch_shape[1] and 0 <= xx < patch_shape[2]:
if z0s < z1s: # Ensure valid slice
col = seeds[z0s:z1s, yy, xx]
empty = (col == 0)
if empty.any():
col[empty] = label_count
seeds[z0s:z1s, yy, xx] = col
label_to_pre_site[label_count] = pre_local.tolist()
# Draw the single-pixel seed for dilation
if 0 <= z_idx < patch_shape[0]:
mask_to_grow[z_idx, yy, xx] = 1
label_count += 1
if label_count == 1: continue
# --- Dilation & Flow (Now safe because halo ensures full shapes) ---
post_site_map = binary_dilation_scipy(mask_to_grow, structure=post_footprint)
mask_grow = binary_dilation_scipy(mask_to_grow, structure=pre_footprint)
seed_coords = np.argwhere(seeds > 0)
mask_coords = np.argwhere(mask_grow > 0)
if seed_coords.size > 0 and mask_coords.size > 0:
tree = cKDTree(seed_coords * res_array)
_, nn = tree.query(mask_coords * res_array, workers=-1)
grown_labels = np.zeros_like(seeds)
grown_labels[mask_coords[:,0], mask_coords[:,1], mask_coords[:,2]] = seeds[seed_coords[nn,0], seed_coords[nn,1], seed_coords[nn,2]]
axis_order = "".join([x for x in channel_mode if x in ["H", "V", "Z"]])
out_flow = create_HoVe_channels(
grown_labels,
ref_point="presynaptic",
label_to_pre_site=label_to_pre_site,
normalize_values=norm,
resolution=resolution,
axis_order=axis_order,
)
# Assemble & Crop
stack = []
for c in channel_mode:
if c == "F_post": stack.append(post_site_map[..., None])
else:
j = axis_order.index(c)
stack.append(out_flow[..., j:j+1])
out_map = np.concatenate(stack, axis=-1)
# CROP: Extract only the core chunk from the halo-extended patch
z0p, z1p = z0c - ext_bbox[0], (z0c - ext_bbox[0]) + (z1c - z0c)
y0p, y1p = y0c - ext_bbox[2], (y0c - ext_bbox[2]) + (y1c - y0c)
x0p, x1p = x0c - ext_bbox[4], (x0c - ext_bbox[4]) + (x1c - x0c)
out_core = out_map[z0p:z1p, y0p:y1p, x0p:x1p, :]
# Transpose and Write
transpose_order = [out_order.find(ax) for ax in "ZYXC"]
patch = out_core.transpose(np.argsort(transpose_order)) # Simplified transpose
# Prepare slices for the target dataset
data_slices = []
for ax in out_order:
if ax == 'Z': data_slices.append(slice(z0c, z1c))
elif ax == 'Y': data_slices.append(slice(y0c, y1c))
elif ax == 'X': data_slices.append(slice(x0c, x1c))
elif ax == 'C': data_slices.append(slice(0, out_shape[c_pos]))
# Accumulate (prevents overwriting edges with zeros)
target = mask[tuple(data_slices)]
mask[tuple(data_slices)] = target + patch.astype(target.dtype) * (target == 0)
if isinstance(fid_mask, h5py.File):
fid_mask.close()
[docs]
def create_HoVe_channels(
data: NDArray,
ref_point: str = "center",
label_to_pre_site: Optional[Dict] = None,
normalize_values: bool = True,
calc_props: Optional[Dict] = None,
axis_order: str = "ZYX",
resolution: List[int|float] = [1,1,1],
):
"""
Obtain the horizontal and vertical distance maps for each instance.
Depth distance is also calculated if the ``data`` provided is 3D.
Parameters
----------
data : 2D/3D Numpy array
Instance mask to create horizontal/vertical/depth channels from. E.g. ``(500, 500)`` for 2D and
``(200, 1000, 1000)`` for 3D.
ref_point : str, optional
Reference point to be used to create the horizontal/vertical/depth channels. Possible values: ``center``,
``presynaptic``. Details:
- 'center': point to the centroid.
- 'presynaptic': point to the presynaptic site. To use this ``label_to_pre_site`` must be provided.
label_to_pre_site : dict, optional
Reference of the presynaptic site for each label within the provided volume (``data``).
normalize_values : bool, optional
Whether to normalize the values or not.
calc_props : dict, optional
If region properties have already been calculated, they can be provided here to avoid recalculation.
resolution : list of int or float, optional
Physical resolution of the data in each dimension. Used to scale the horizontal/vertical/depth values
to physical units if provided. Default is [1,1,1] (isotropic).
Returns
-------
new_mask : 3D/4D Numpy array
Horizontal/vertical/depth channels. E.g. ``(500, 500, 2)`` for 2D and ``(200, 1000, 1000, 3)`` for 3D.
"""
assert ref_point in ["center", "presynaptic"]
if ref_point == "presynaptic" and label_to_pre_site is None:
raise ValueError("'label_to_pre_site' must be provided when 'ref_point' is 'presynaptic'")
orig_data = data.copy() # instance ID map
dim = data.ndim
x_map = np.zeros(orig_data.shape, dtype=np.float32)
y_map = np.zeros(orig_data.shape, dtype=np.float32)
if dim == 3:
z_map = np.zeros(orig_data.shape, dtype=np.float32)
if calc_props is None:
props = regionprops_table(orig_data, properties=("label", "bbox", "centroid"))
else:
props = calc_props
for k, inst_id in tqdm(enumerate(props["label"]), total=len(props["label"]), leave=False):
inst_map = np.array(orig_data == inst_id, np.uint8)
if dim == 2:
inst_box = [props["bbox-0"][k], props["bbox-2"][k], props["bbox-1"][k], props["bbox-3"][k]]
else:
inst_box = [
props["bbox-0"][k],
props["bbox-3"][k],
props["bbox-1"][k],
props["bbox-4"][k],
props["bbox-2"][k],
props["bbox-5"][k],
]
# Extract the patch
if dim == 2:
inst_box[0] = max(0, inst_box[0] - 2)
inst_box[2] = max(0, inst_box[2] - 2)
inst_box[1] = min(inst_map.shape[0], inst_box[1] + 2)
inst_box[3] = min(inst_map.shape[1], inst_box[3] + 2)
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
else:
inst_box[0] = max(0, inst_box[0] - 2)
inst_box[2] = max(0, inst_box[2] - 2)
inst_box[4] = max(0, inst_box[4] - 2)
inst_box[1] = min(inst_map.shape[0], inst_box[1] + 2)
inst_box[3] = min(inst_map.shape[1], inst_box[3] + 2)
inst_box[5] = min(inst_map.shape[2], inst_box[5] + 2)
inst_map = inst_map[
inst_box[0] : inst_box[1],
inst_box[2] : inst_box[3],
inst_box[4] : inst_box[5],
]
if dim == 2 and (inst_map.shape[0] < 2 or inst_map.shape[1] < 2):
continue
elif dim == 3 and (inst_map.shape[0] < 2 or inst_map.shape[1] < 2 or inst_map.shape[2] < 2):
continue
# instance center of mass, rounded to nearest pixel
if ref_point == "center":
if dim == 2:
inst_com = [
props["centroid-0"][k],
props["centroid-1"][k],
]
else:
inst_com = [
props["centroid-0"][k],
props["centroid-1"][k],
props["centroid-2"][k],
]
else: # presynaptic
assert label_to_pre_site
if inst_id not in label_to_pre_site:
raise ValueError(f"Label {inst_id} not in 'label_to_pre_site'")
inst_com = label_to_pre_site[inst_id]
# Move reference point inside bbox
inst_com[0] -= inst_box[0]
inst_com[1] -= inst_box[2]
if dim == 3:
inst_com[2] -= inst_box[4]
if any(np.isnan(inst_com)):
continue
if dim == 2:
inst_com[0] = int(inst_com[0] + 0.5)
inst_com[1] = int(inst_com[1] + 0.5)
inst_y_range = np.arange(1, inst_map.shape[0] + 1)
inst_x_range = np.arange(1, inst_map.shape[1] + 1)
# shifting center of pixels grid to instance center of mass/presynaptic site
inst_x_range -= inst_com[1]
inst_y_range -= inst_com[0]
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
else:
inst_com[0] = int(inst_com[0] + 0.5)
inst_com[1] = int(inst_com[1] + 0.5)
inst_com[2] = int(inst_com[2] + 0.5)
inst_z_range = np.arange(1, inst_map.shape[0] + 1)
inst_y_range = np.arange(1, inst_map.shape[1] + 1)
inst_x_range = np.arange(1, inst_map.shape[2] + 1)
# shifting center of pixels grid to instance center of mass/presynaptic site
inst_z_range -= inst_com[0]
inst_y_range -= inst_com[1]
inst_x_range -= inst_com[2]
inst_z, inst_y, inst_x = np.meshgrid(inst_z_range, inst_y_range, inst_x_range, indexing="ij")
# remove coord outside of instance (Z)
inst_z[inst_map == 0] = 0
inst_z = inst_z.astype("float32")
# remove coord outside of instance (Y and X)
inst_y[inst_map == 0] = 0
inst_x[inst_map == 0] = 0
inst_x = inst_x.astype("float32")
inst_y = inst_y.astype("float32")
if normalize_values:
# normalize min into -1 scale
if np.min(inst_y) < 0:
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
if np.min(inst_x) < 0:
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
if dim == 3:
if np.min(inst_z) < 0:
inst_z[inst_z < 0] /= -np.amin(inst_z[inst_z < 0])
# normalize max into +1 scale
if np.max(inst_y) > 0:
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
if np.max(inst_x) > 0:
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
if dim == 3:
if np.max(inst_z) > 0:
inst_z[inst_z > 0] /= np.amax(inst_z[inst_z > 0])
else:
inst_y = inst_y * resolution[-2] # Scale Y axis
inst_x = inst_x * resolution[-1] # Scale X axis
if dim == 3:
inst_z = inst_z * resolution[0] # Scale Z axis
if dim == 2:
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
x_map_box[inst_map > 0] = inst_x[inst_map > 0]
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
y_map_box[inst_map > 0] = inst_y[inst_map > 0]
else:
z_map_box = z_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5]]
z_map_box[inst_map > 0] = inst_z[inst_map > 0]
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5]]
y_map_box[inst_map > 0] = inst_y[inst_map > 0]
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5]]
x_map_box[inst_map > 0] = inst_x[inst_map > 0]
stack = []
for x in axis_order:
if x == "Z":
stack.append(z_map)
elif x == "V":
stack.append(y_map)
elif x == "H":
stack.append(x_map)
hv_map = np.stack(stack, axis=-1)
return hv_map
#############
# DETECTION #
#############
[docs]
def create_detection_masks(cfg: CN, data_type: str = "train"):
"""
Create detection masks based on CSV files.
Parameters
----------
cfg : YACS CN object
Configuration.
data_type: str, optional
Wheter to create train, validation or test masks.
"""
assert data_type in ["train", "val", "test"]
tag = data_type.upper()
img_dir = getattr(cfg.DATA, tag).PATH
label_dir = getattr(cfg.DATA, tag).GT_PATH
out_dir = getattr(cfg.DATA, tag).DETECTION_MASK_DIR
img_ids = next(os_walk_clean(img_dir))[2]
working_with_chunked_data = False
if len(img_ids) == 0:
img_ids = next(os_walk_clean(img_dir))[1]
working_with_chunked_data = True
if len(img_ids) == 0:
raise ValueError(f"No data found in folder {img_dir}")
img_ext = "." + img_ids[0].split(".")[-1]
ids = next(os_walk_clean(label_dir))[2]
channels = 2 if cfg.DATA.N_CLASSES > 2 else 1
dtype_str = "uint8" if cfg.DATA.N_CLASSES < 255 else "uint16"
if len(img_ids) != len(ids):
raise ValueError(
"Different number of CSV files and images found ({} vs {}). "
"Please check that every image has one and only one CSV file".format(len(ids), len(img_ids))
)
if cfg.PROBLEM.NDIM == "2D":
req_columns = ["axis-0", "axis-1"] if channels == 1 else ["axis-0", "axis-1", "class"]
else:
req_columns = ["axis-0", "axis-1", "axis-2"] if channels == 1 else ["axis-0", "axis-1", "axis-2", "class"]
cpd = cfg.PROBLEM.DETECTION.CENTRAL_POINT_DILATION
ellipse_footprint = generate_ellipse_footprint(cpd)
# Distribute files by rank
rank = get_rank()
world_size = get_world_size()
it = range(rank, len(ids), world_size)
print(f"Rank {rank}: Creating {data_type} detection masks . . .")
for i in tqdm(it, disable=not is_main_process()):
img_filename = os.path.splitext(ids[i])[0] + img_ext
file_path = os.path.join(label_dir, ids[i])
if not os.path.exists(os.path.join(img_dir, img_filename)):
print("WARNING: No image found for CSV file: {}. Using the image that's in the same spot (within the CSV files list) where"
"the CSV file is in its own list of CSV files. Check if it is correct!".format(file_path)
)
img_filename = img_ids[i]
out_path = os.path.join(out_dir, img_filename)
if os.path.exists(out_path):
continue
# 1. Get shape information without loading the whole image
if img_ext not in [".zarr", ".n5"]:
img_info = read_img_as_ndarray(os.path.join(img_dir, img_filename), is_3d=not cfg.PROBLEM.NDIM == "2D")
shape = img_info.shape[:-1]
is_h5 = False
else:
img_zarr_file, img_data = read_chunked_data(os.path.join(img_dir, img_filename))
shape = img_data.shape if img_data.ndim == 3 else img_data.shape[:-1]
if isinstance(img_zarr_file, h5py.File): img_zarr_file.close()
is_h5 = looks_like_hdf5(img_filename)
# 2. Initialize the disk-backed mask (Zarr or H5) instead of np.zeros
os.makedirs(out_dir, exist_ok=True)
out_shape = shape + (channels,)
if working_with_chunked_data:
if is_h5:
fid_mask = h5py.File(out_path, "w")
ds_kwargs = {
"shape": out_shape,
"dtype": dtype_str,
"chunks": pick_chunks(out_shape, dtype_str, target_mb=4.0),
"compression": "gzip",
"compression_opts": 4,
"shuffle": True
}
mask = fid_mask.create_dataset("data", **ds_kwargs)
else:
mask = zarr.open(out_path, mode="w", shape=out_shape, dtype=dtype_str, zarr_format=3)
else:
mask = np.zeros(out_shape, dtype=dtype_str)
# 3. Process CSV points
df = pd.read_csv(file_path).dropna()
df = df.rename(columns=lambda x: x.strip())
cols_not_in_file = [x for x in req_columns if x not in df.columns]
if len(cols_not_in_file) > 0:
if len(cols_not_in_file) == 1:
m = f"'{cols_not_in_file[0]}' column is not present in CSV file: {file_path}"
else:
m = f"{cols_not_in_file} columns are not present in CSV file: {file_path}"
raise ValueError(m)
# Obtaining coords (axis-0: Z, axis-1: Y, axis-2: X)
coords = [df["axis-0"].astype(int), df["axis-1"].astype(int)]
if cfg.PROBLEM.NDIM == "3D":
coords.append(df["axis-2"].astype(int))
class_points = df["class"].astype(int) if "class" in df.columns else [1] * len(coords[0])
# 4. Paint points directly into the disk-backed array
for j in range(len(coords[0])):
c_point = class_points[j]
# --- A. Coordinate Setup ---
# Convert global coordinates to integers immediately
global_c = [int(coords[d][j]) for d in range(len(shape))]
# Skip if center point is outside array boundaries
if any(global_c[d] < 0 or global_c[d] >= shape[d] for d in range(len(shape))):
continue
# --- B. Dynamic Slicing (Handles 2D and 3D) ---
slices = []
rel_coords = []
for d in range(len(shape)):
start = max(0, global_c[d] - 1 - cpd[d])
end = min(shape[d], global_c[d] + 2 + cpd[d])
slices.append(slice(start, end))
rel_coords.append(global_c[d] - start)
target_slice = tuple(slices)
# --- C. Create the Patch (Local Dilation) ---
# Determine patch shape by looking at the slice size
local_chunk = mask[target_slice]
patch_shape = local_chunk.shape[:-1] # Exclude channel dimension
patch = np.zeros(patch_shape, dtype=np.uint8)
patch[tuple(rel_coords)] = 1
update = binary_dilation_scipy(patch, iterations=1, structure=ellipse_footprint)
# --- D. Multi-Channel Update (Disk-Backed) ---
# Channel 0: Occupancy (Binary)
# Only update where the mask is currently 0
mask_to_fill = (local_chunk[..., 0] == 0) & (update > 0)
local_chunk[mask_to_fill, 0] = 1
# Channel 1: Class/Instance ID
if channels > 1:
local_chunk[mask_to_fill, 1] = c_point
# Write the modified chunk back to the disk-backed array
mask[target_slice] = local_chunk
# Finalize
if is_h5 and working_with_chunked_data:
fid_mask.close()
elif not working_with_chunked_data:
save_tif(np.expand_dims(mask, 0), out_dir, [img_filename])
#######
# SSL #
#######
[docs]
def create_ssl_source_data_masks(cfg: CN, data_type: str = "train"):
"""
Create SSL source data.
Parameters
----------
cfg : YACS CN object
Configuration.
data_type: str, optional
Wheter to create train, validation or test source data.
"""
assert data_type in ["train", "val", "test"]
tag = data_type.upper()
img_dir = getattr(cfg.DATA, tag).PATH
out_dir = getattr(cfg.DATA, tag).SSL_SOURCE_DIR
ids = next(os_walk_clean(img_dir))[2]
add_noise = True if cfg.PROBLEM.SELF_SUPERVISED.NOISE > 0 else False
print("Creating {} SSL source. . .".format(data_type))
for i in range(len(ids)):
if not os.path.exists(os.path.join(out_dir, ids[i])):
print("Crappifying file {} to create SSL source".format(os.path.join(img_dir, ids[i])))
img = read_img_as_ndarray(os.path.join(img_dir, ids[i]), is_3d=not cfg.PROBLEM.NDIM == "2D")
img = crappify(
img,
resizing_factor=cfg.PROBLEM.SELF_SUPERVISED.RESIZING_FACTOR,
add_noise=add_noise,
noise_level=cfg.PROBLEM.SELF_SUPERVISED.NOISE,
)
save_tif(np.expand_dims(img, 0), out_dir, [ids[i]])
else:
print("Source file {} found".format(os.path.join(img_dir, ids[i])))
[docs]
def crappify(
input_img: NDArray,
resizing_factor: float,
add_noise: bool = True,
noise_level: Optional[float] = None,
Down_up: bool = True,
):
"""
Crappify input image by adding Gaussian noise and downsampling and upsampling it so the resolution gets worsen.
input_img : 4D/5D Numpy array
Data to be modified. E.g. ``(y, x, channels)`` if working with 2D images or
``(z, y, x, channels)`` if working with 3D.
resizing_factor : floats
Downsizing factor to reshape the image.
add_noise : boolean, optional
Indicating whether to add gaussian noise before applying the resizing.
noise_level: float, optional
Number between ``[0,1]`` indicating the std of the Gaussian noise N(0,std).
Down_up : bool, optional
Indicating whether to perform a final upsampling operation to obtain an image of the
same size as the original but with the corresponding loss of quality of downsizing and
upsizing.
Returns
-------
img : 4D/5D Numpy array
Train images. E.g. ``(y, x, channels)`` if working with 2D images or
``(z, y, x, channels)`` if working with 3D.
"""
if input_img.ndim == 3:
w, h, c = input_img.shape
org_sz = (w, h)
else:
d, w, h, c = input_img.shape
org_sz = (d, w, h)
new_d = int(d / np.sqrt(resizing_factor))
new_w = int(w / np.sqrt(resizing_factor))
new_h = int(h / np.sqrt(resizing_factor))
if input_img.ndim == 3:
targ_sz = (new_w, new_h)
else:
targ_sz = (new_d, new_w, new_h)
img = input_img.copy()
if add_noise:
assert noise_level
img = add_gaussian_noise(img, noise_level)
img = resize(
img,
targ_sz,
order=1,
mode="reflect",
clip=True,
preserve_range=True,
anti_aliasing=False,
)
if Down_up:
img = resize(
img,
org_sz,
order=1,
mode="reflect",
clip=True,
preserve_range=True,
anti_aliasing=False,
)
return img.astype(input_img.dtype)
[docs]
def add_gaussian_noise(image: NDArray, percentage_of_noise: float) -> NDArray:
"""
Add Gaussian noise to an input image.
Parameters
----------
image : 3D Numpy array
Image to be added Gaussian Noise with 0 mean and a certain std. E.g. ``(y, x, channels)``.
percentage_of_noise : float
percentage of the maximum value of the image that will be used as the std of the Gaussian Noise
distribution.
Returns
-------
out : 3D Numpy array
Transformed image. E.g. ``(y, x, channels)``.
"""
max_value = np.max(image)
noise_level = percentage_of_noise * max_value
noise = np.random.normal(loc=0, scale=noise_level, size=image.shape)
noisy_img = np.clip(image + noise, 0, max_value).astype(image.dtype)
return noisy_img
################
# SEMANTIC SEG #
################
[docs]
def calculate_volume_prob_map(
Y: BiaPyDataset, is_3d: bool = False, w_foreground: float = 0.94, w_background: float = 0.06, save_dir=None
) -> List[NDArray] | NDArray:
"""
Calculate the probability map of the given data.
Parameters
----------
Y : list of dict
Data to calculate the probability map from. Each item in the list represents a sample of the dataset.
Expected keys:
* ``"filename"``: name of the image to extract the data sample from.
* ``"dir"``: directory where the image resides.
* ``"img"``: image sample itself. It is a ndarrray of ``(y, x, channels)`` in ``2D`` and
``(z, y, x, channels)``in ``3D``. Provided if the user selected to load data into memory.
If ``"img"`` is provided ``"filename"`` and ``"filename"`` are not necessary, and vice versa.
w_foreground : float, optional
Weight of the foreground. This value plus ``w_background`` must be equal ``1``.
w_background : float, optional
Weight of the background. This value plus ``w_foreground`` must be equal ``1``.
save_dir : str, optional
Path to the file where the probability map will be stored.
Returns
-------
maps : NDArray or list of NDArray
Probability map(s) of all samples in ``Y.sample_list``.
"""
print("Constructing the probability map . . .")
maps = []
diff_shape = False
first_shape = None
Ylen = len(Y.sample_list)
for i in tqdm(range(Ylen), disable=not is_main_process()):
if Y.sample_list[i].img_is_loaded():
_map = Y.sample_list[i].img.copy().astype(np.float32)
else:
path = Y.dataset_info[Y.sample_list[i].fid].path
_map = read_img_as_ndarray(path, is_3d=is_3d).astype(np.float32)
for k in range(_map.shape[-1]):
if is_3d:
for j in range(_map.shape[0]):
# Remove artifacts connected to image border
_map[j, ..., k] = clear_border(_map[j, ..., k])
else:
# Remove artifacts connected to image border
_map[..., k] = clear_border(_map[..., k])
foreground_pixels = (_map[..., k] > 0).sum()
background_pixels = (_map[..., k] == 0).sum()
if foreground_pixels == 0:
_map[..., k][np.where(_map[..., k] > 0)] = 0
else:
_map[..., k][np.where(_map[..., k] > 0)] = w_foreground / foreground_pixels
if background_pixels == 0:
_map[..., k][np.where(_map[..., k] == 0)] = 0
else:
_map[..., k][np.where(_map[..., k] == 0)] = w_background / background_pixels
# Necessary to get all probs sum 1
s = _map[..., k].sum()
if s == 0:
t = 1
for x in _map[..., k].shape:
t *= x
_map[..., k].fill(1 / t)
else:
_map[..., k] = _map[..., k] / _map[..., k].sum()
if first_shape is None:
first_shape = _map.shape
if first_shape != _map.shape:
diff_shape = True
maps.append(_map)
if not diff_shape:
for i in range(len(maps)):
maps[i] = np.expand_dims(maps[i], 0)
maps = np.concatenate(maps)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
if not diff_shape:
print("Saving the probability map in {}".format(save_dir))
np.save(os.path.join(save_dir, "prob_map.npy"), maps)
return maps
else:
print(
"As the files loaded have different shapes, the probability map for each one will be stored"
" separately in {}".format(save_dir)
)
d = len(str(Ylen))
for i in range(Ylen):
f = os.path.join(save_dir, "prob_map" + str(i).zfill(d) + ".npy")
np.save(f, maps[i])
return maps
###########
# GENERAL #
###########
[docs]
def resize_images(images: List[NDArray], **kwards) -> List[NDArray]:
"""
Resize all the images using the specified parameters or default values if not provided.
Parameters
----------
images: list of Numpy arrays
The `images` parameter is the list of all input images that you want to resize.
output_shape: iterable
Size of the generated output image. E.g. `(256,256)`
(kwards): optional
`skimage.transform.resize() <https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize>`__
parameters are also allowed.
Returns
-------
resized_images: list of Numpy arrays
The resized images. The returned data will use the same data type as the given `images`.
"""
resized_images = [resize(img, **kwards).astype(img.dtype) for img in images]
return resized_images
[docs]
def apply_gaussian_blur(images: List[NDArray], **kwards) -> List[NDArray]:
"""
Apply a Gaussian blur to all images.
Parameters
----------
images: list of Numpy arrays
The input images on which the Gaussian blur will be applied.
(kwards): optional
`skimage.filters.gaussian() <https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.gaussian>`__
parameters are also allowed.
Returns
-------
blurred_images: list of Numpy arrays
A Gaussian blurred images. The returned data will use the same data type as the given `images`.
"""
def _process(image, **kwards):
im = gaussian(image, **kwards) # returns 0-1 range
if np.issubdtype(image.dtype, np.integer):
im = im * np.iinfo(image.dtype).max
im = im.astype(image.dtype)
return im
blurred_images = [_process(img, **kwards) for img in images]
return blurred_images
[docs]
def detect_edges(images: List[NDArray], **kwards) -> List[NDArray]:
"""
Detect edges in the given images using the Canny edge detection algorithm.
The function `detect_edges` takes the 2D images as input, converts it to grayscale if necessary, and applies the Canny edge detection algorithm to detect edges in the image.
Parameters
----------
images: list of Numpy arrays
The list of all input images on which the edge detection will be performed. It can be either a color image with
shape (height, width, 3) or a grayscale image with shape (height, width, 1).
(kwards): optional
`skimage.feature.canny() <https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny>`__
parameters are also allowed.
Returns
-------
edges: list of Numpy arrays
The edges of the input images. The returned Numpy arrays will be uint8, where background is black (0) and edges white (255).
The returned data will use the same structure as the given `images` (list[Numpy array] or Numpy array).
"""
def to_gray(image):
c = image.shape[-1]
if c == 3:
image = rgb2gray(image)
elif c == 1:
image = image[..., 0]
else:
raise ValueError(
f"Detect edges function does not allow given ammount of channels ({c} channels). "
"Only accepts grayscale and RGB 2D images (1 or 3 channels)."
)
return image
def set_uint8(image):
im = image.astype(np.uint8)
im = im[..., np.newaxis] # add channel dim
im = im * 255
return im
edges = [set_uint8(canny(to_gray(img), **kwards)) for img in images]
return edges
def _histogram_matching(source_imgs: List[NDArray], target_imgs: List[NDArray]) -> List[NDArray]:
"""
Apply histogram matching to a set of source images based on the mean histogram of target images.
Given a set of target images, it will obtain their mean histogram
and applies histogram matching to all images from source images.
Parameters
----------
source_imgs: list of Numpy arrays
The images of the source domain, to which the histogram matching is to be applied.
target_imgs: list of Numpy array
The target domain images, from which mean histogram will be obtained.
Returns
-------
matched_images : list of Numpy arrays
A set of source images with target's histogram
"""
# Concatenate all target images to compute the reference histogram
target_concat = np.concatenate([img.ravel() for img in target_imgs])
# Get the data type from the first image (assuming all have same dtype)
dtype = target_imgs[0].dtype
hist_mean, _ = np.histogram(target_concat, bins=np.arange(np.iinfo(dtype).max + 2)) # +2 because bins are edges
# calculate normalized quantiles
tmpl_size = np.sum(hist_mean)
tmpl_quantiles = np.cumsum(hist_mean) / tmpl_size
del target_imgs
# based on scikit implementation.
# source: https://github.com/scikit-image/scikit-image/blob/v0.18.0/skimage/exposure/histogram_matching.py#L22-L70
def _match_cumulative_cdf(source, tmpl_quantiles):
src_values, src_unique_indices, src_counts = np.unique(source.ravel(), return_inverse=True, return_counts=True)
# calculate normalized quantiles
src_size = source.size # number of pixels
src_quantiles = np.cumsum(src_counts) / src_size # normalize
interp_a_values = np.interp(src_quantiles, tmpl_quantiles, np.arange(len(tmpl_quantiles)))
return interp_a_values[src_unique_indices].reshape(source.shape)
# apply histogram matching
results = [_match_cumulative_cdf(image, tmpl_quantiles).astype(image.dtype) for image in source_imgs]
return results
[docs]
def apply_histogram_matching(images: List[NDArray], reference_path: str, is_2d: bool):
"""
Apply histogram matching to a list of images based on the histogram of reference images.
The function returns the images with their histogram matched to the histogram of the reference images, loaded from the given ``reference_path``.
Parameters
----------
images: list of Numpy arrays
The list of input images whose histogram needs to be matched to the reference histogram. It should be a
Numpy array representing the image.
reference_path: str
The reference_path is the directory path to the reference images. From reference images, we will extract
the reference histogram with which we want to match the histogram of the images. It represents
the desired distribution of pixel intensities in the output image.
is_2d: bool, optional
The value indicate if the data given in ``reference_path`` is 2D (``is_2d = True``) or 3D (``is_2d = False``).
Defaults to True.
Returns
-------
matched_images : list of Numpy arrays
The result of matching the histogram of the input images to the histogram of the reference image.
The returned data will use the same data type as the given `images`.
"""
references = load_data_from_dir(reference_path, is_3d=not is_2d)
matched_images = _histogram_matching(images, references)
return matched_images
[docs]
def apply_clahe(images: List[NDArray], **kwards) -> List[NDArray]:
"""
Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) to a list of images.
The function applies Contrast Limited Adaptive Histogram Equalization (CLAHE) to an image and returns the result.
Parameters
----------
images: list of Numpy arrays
The list of input images that you want to apply the CLAHE (Contrast Limited Adaptive Histogram Equalization)
algorithm to.
(kwards): optional
`skimage.exposure.equalize_adapthist() <https://scikit-image.org/docs/stable/api/skimage.exposure.html#skimage.exposure.equalize_adapthist>`__
parameters are also allowed.
Returns
-------
processed_images: list of Numpy arrays
The images after applying the Contrast Limited Adaptive Histogram Equalization (CLAHE) algorithm.
The returned data will use the same data type as the given `images`.
"""
def _process(image, **kwards):
im = equalize_adapthist(image, **kwards) # returns 0-1 range
if np.issubdtype(image.dtype, np.integer):
im = im * np.iinfo(image.dtype).max
im = im.astype(image.dtype)
return im
processed_images = [_process(img, **kwards) for img in images]
return processed_images
[docs]
def preprocess_data(
cfg: CN, x_data: List[NDArray] = [], y_data: List[NDArray] = [], is_2d: bool = True, is_y_mask: bool = False
) -> List[NDArray] | Tuple[List[NDArray], List[NDArray]]:
"""
Pre-process data by applying various image processing techniques.
Parameters
----------
cfg: dict
The `cfg` parameter is a configuration object that contains various settings for preprocessing the data.
It is used to control the behavior of different preprocessing techniques such as image resizing, blurring,
histogram matching, etc.
x_data: list of 3D/4D Numpy arrays, optional
The input data (images) to be preprocessed. The first dimension must be the number of images.
E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``.
In case of using a list, the format of the images remains the same. Each item in the list
corresponds to a different image.
y_data: list of 3D/4D Numpy arrays, optional
The target data that corresponds to the x_data. The first dimension must be the number of images.
E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``.
In case of using a list, the format of the images remains the same. Each item in the list
corresponds to a different image.
is_2d: bool, optional
A boolean flag indicating whether the reference data for histogram matching is 2D or not. Defaults to True.
is_y_mask: bool, optional
is_y_mask is a boolean parameter that indicates whether the y_data is a mask or not. If
it is set to True, the resize operation for y_data will use the nearest neighbor interpolation
method (order=0), otherwise it will use the interpolation method specified in the cfg.RESIZE.ORDER
parameter. Defaults to False.
Returns
-------
x_data: list of 3D/4D Numpy arrays, optional
Preprocessed data. The same structure and dimensionality of the given data will be returned.
y_data: list of 3D/4D Numpy arrays, optional
Preprocessed data. The same structure and dimensionality of the given data will be returned.
"""
if len(y_data) > 0:
if cfg.RESIZE.ENABLE:
# if y is a mask, then use nearest
y_order = 0 if is_y_mask else cfg.RESIZE.ORDER
y_data = resize_images(
y_data,
output_shape=cfg.RESIZE.OUTPUT_SHAPE,
order=y_order,
mode=cfg.RESIZE.MODE,
cval=cfg.RESIZE.CVAL,
clip=cfg.RESIZE.CLIP,
preserve_range=cfg.RESIZE.PRESERVE_RANGE,
anti_aliasing=cfg.RESIZE.ANTI_ALIASING,
)
if len(x_data) > 0:
if cfg.RESIZE.ENABLE:
x_data = resize_images(
x_data,
output_shape=cfg.RESIZE.OUTPUT_SHAPE,
order=cfg.RESIZE.ORDER,
mode=cfg.RESIZE.MODE,
cval=cfg.RESIZE.CVAL,
clip=cfg.RESIZE.CLIP,
preserve_range=cfg.RESIZE.PRESERVE_RANGE,
anti_aliasing=cfg.RESIZE.ANTI_ALIASING,
)
if cfg.GAUSSIAN_BLUR.ENABLE:
x_data = apply_gaussian_blur(
x_data,
sigma=cfg.GAUSSIAN_BLUR.SIGMA,
mode=cfg.GAUSSIAN_BLUR.MODE,
channel_axis=cfg.GAUSSIAN_BLUR.CHANNEL_AXIS,
)
if cfg.MEDIAN_BLUR.ENABLE:
x_data = apply_median_blur(
x_data,
footprint=np.ones(cfg.MEDIAN_BLUR.KERNEL_SIZE, dtype=np.uint8).tolist(),
)
if cfg.MATCH_HISTOGRAM.ENABLE:
x_data = apply_histogram_matching(
x_data,
reference_path=cfg.MATCH_HISTOGRAM.REFERENCE_PATH,
is_2d=is_2d,
)
if cfg.CLAHE.ENABLE:
x_data = apply_clahe(
x_data,
kernel_size=cfg.CLAHE.KERNEL_SIZE,
clip_limit=cfg.CLAHE.CLIP_LIMIT,
)
if cfg.CANNY.ENABLE:
x_data = detect_edges(
x_data,
low_threshold=cfg.CANNY.LOW_THRESHOLD,
high_threshold=cfg.CANNY.HIGH_THRESHOLD,
)
if len(x_data) > 0 and len(y_data) > 0:
return x_data, y_data
if len(y_data) > 0:
return y_data
else:
return x_data