Source code for biapy.data.pre_processing

"""
Pre-processing utilities for image and mask data in deep learning workflows.

This module provides pre-processing functions for instance segmentation, detection mask creation, self-supervised learning data generation, semantic segmentation probability maps, and general image processing operations such as resizing, blurring, edge detection, histogram matching, and CLAHE. It supports both 2D and 3D data formats and integrates with BiaPy configuration objects for flexible data pipelines.
"""
import os
import edt
import h5py
import zarr
import numpy as np
from tqdm import tqdm
import pandas as pd
from skimage.segmentation import clear_border, find_boundaries, watershed
from scipy.ndimage import (
    generic_filter, 
    generate_binary_structure, 
    grey_closing, 
    center_of_mass, 
    map_coordinates, 
    uniform_filter,
    binary_dilation as binary_dilation_scipy
)
from skimage.morphology import disk, binary_dilation, binary_erosion, skeletonize
from skimage.measure import label, regionprops_table
from skimage.transform import resize
from skimage.feature import canny
from skimage.exposure import equalize_adapthist
from skimage.color import rgb2gray
from skimage.filters import gaussian, median
from yacs.config import CfgNode as CN
from numpy.typing import NDArray
from typing import List, Optional, Dict, Tuple, Sequence, Union
from scipy.spatial import cKDTree

from biapy.data.dataset import BiaPyDataset
from biapy.utils.util import (
    seg2aff_pni,
    seg_widen_border,
)
from biapy.utils.misc import is_main_process, get_rank, get_world_size, os_walk_clean
from biapy.data.data_3D_manipulation import (
    load_3D_efficient_files,
    load_img_part_from_efficient_file,
    order_dimensions,
    read_chunked_data,
    read_chunked_nested_data,
    looks_like_hdf5,
    pick_chunks,
)
from biapy.data.data_manipulation import (
    read_img_as_ndarray,
    load_data_from_dir,
    save_tif,
    decide_dtype,
)

#########################
# INSTANCE SEGMENTATION #
#########################
[docs] def create_instance_channels(cfg: CN, data_type: str = "train"): """ Create training and validation new data with appropiate channels based on ``PROBLEM.INSTANCE_SEG.DATA_CHANNELS`` for instance segmentation. Parameters ---------- cfg : YACS CN object Configuration. data_type: str, optional Wheter to create training or validation instance channels. """ assert data_type in ["train", "val", "test"] tag = data_type.upper() # Checking if the user inputted Zarr/H5 files if getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA: data_path = getattr(cfg.DATA, tag).PATH try: zarr_files = next(os_walk_clean(data_path))[1] except StopIteration: raise ValueError("No Zarr/N5 files found in the input path: {}".format(data_path)) try: h5_files = next(os_walk_clean(data_path))[2] except StopIteration: raise ValueError("No H5 files found in the input path: {}".format(data_path)) else: data_path = getattr(cfg.DATA, tag).GT_PATH try: zarr_files = next(os_walk_clean(data_path))[1] h5_files = next(os_walk_clean(data_path))[2] except StopIteration: raise ValueError("No Zarr/N5 or H5 files found in the GT path: {}".format(data_path)) # Find patches info so we can iterate over them to create the instance mask working_with_zarr_h5_files = False if ( cfg.PROBLEM.NDIM == "3D" and (len(zarr_files) > 0 and any(True for x in [".zarr", ".n5"] if x in zarr_files[0])) or (len(h5_files) > 0 and looks_like_hdf5(h5_files[0])) ): working_with_zarr_h5_files = True # Check if the raw images and labels are within the same file data_path = getattr(cfg.DATA, tag).GT_PATH path_to_gt_data = None if getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA: data_path = getattr(cfg.DATA, tag).PATH if cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses": path_to_gt_data = getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_RAW_PATH else: path_to_gt_data = getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_GT_PATH if len(zarr_files) > 0 and any(True for x in [".zarr", ".n5"] if x in zarr_files[0]): print("Working with Zarr files . . .") img_files = [os.path.join(data_path, x) for x in zarr_files] elif len(h5_files) > 0 and looks_like_hdf5(h5_files[0]): print("Working with H5 files . . .") img_files = [os.path.join(data_path, x) for x in h5_files] Y, Y_total_patches = load_3D_efficient_files( img_files, input_axes=getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER, crop_shape=cfg.DATA.PATCH_SIZE, overlap=getattr(cfg.DATA, tag).OVERLAP, padding=getattr(cfg.DATA, tag).PADDING, data_within_zarr_path=path_to_gt_data, ) zarr_data_information = { "raw_data_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_RAW_PATH, "axes_order": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER, "z_axe_pos": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("Z"), "y_axe_pos": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("Y"), "x_axe_pos": getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("X"), "id_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_ID_PATH, "partners_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_PARTNERS_PATH, "locations_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_LOCATIONS_PATH, "resolution_path": getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_RESOLUTION_PATH, } else: if cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses": raise ValueError("Synapse detection is only available for 3D Zarr/H5 data so please check your data in {}".format(data_path)) Y = next(os_walk_clean(getattr(cfg.DATA, tag).GT_PATH))[2] del zarr_files, h5_files print("Creating Y_{} channels . . .".format(data_type)) # Create the mask patch by patch (Zarr/H5) if working_with_zarr_h5_files and isinstance(Y, dict): if "D" in cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS: dtype_str = "float32" raise ValueError("Currently distance creation using Zarr by chunks is not implemented.") else: dtype_str = "uint8" # For synapses the process of the channel creation is different: not patch by patch but paiting each post-synaptic # points for each pre-synaptic point if cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses" and len(Y) > 0: synapse_channel_creation( data_info=Y, zarr_data_information=zarr_data_information, savepath=getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR, mode=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, channel_extra_opts=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0], ) else: # regular instances, not synapses mask = None imgfile = None last_parallel_file = None rank = get_rank() world_size = get_world_size() N = len(Y) it = range(rank, N, world_size) for i in tqdm(it, disable=not is_main_process()): # Extract the patch to process patch_coords = Y[i]["patch_coords"] img = load_img_part_from_efficient_file( Y[i]["filepath"], patch_coords, data_axes_order=getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER, data_path=getattr(cfg.DATA, tag).INPUT_ZARR_MULTIPLE_DATA_GT_PATH, ) if img.ndim == 3: img = np.expand_dims(img, -1) # Create the instance mask if cfg.DATA.N_CLASSES > 2: if img.shape[-1] != 2: raise ValueError( "In instance segmentation, when 'DATA.N_CLASSES' are more than 2 labels need to have two channels, " "e.g. (256,256,2), containing the instance segmentation map (first channel) and classification map (second channel)." ) else: class_channel = np.expand_dims(img[..., 1].copy(), -1) else: if img.shape[-1] != 1: raise ValueError( "Expected instance segmentation GT images to have a single channel containing the instance labels, " "but got image with shape {} ({} channels). Check the image file: {}".format(img.shape, img.shape[-1], img.shape, img_path) ) img = labels_into_channels( img, mode=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, channel_extra_opts=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0], save_dir=getattr(cfg.PATHS, tag + "_INSTANCE_CHANNELS_CHECK"), ) if cfg.DATA.N_CLASSES > 2: img = np.concatenate([img, class_channel], axis=-1) # Create the Zarr file where the mask will be placed if mask is None or os.path.basename(Y[i]["filepath"]) != last_parallel_file: last_parallel_file = os.path.basename(Y[i]["filepath"]) # Close last open H5 file if mask and isinstance(fid_mask, h5py.File): fid_mask.close() if path_to_gt_data: imgfile, data = read_chunked_nested_data(Y[i]["filepath"], path_to_gt_data) else: imgfile, data = read_chunked_data(Y[i]["filepath"]) fname = os.path.join(getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR, os.path.basename(Y[i]["filepath"])) os.makedirs(getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR, exist_ok=True) # Determine data shape out_data_shape = np.array(data.shape) if "C" not in getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER: out_data_shape = tuple(out_data_shape) + (img.shape[-1],) out_data_order = getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER + "C" channel_pos = -1 else: out_data_shape[getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("C")] = img.shape[-1] out_data_shape = tuple(out_data_shape) out_data_order = getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER channel_pos = getattr(cfg.DATA, tag).INPUT_IMG_AXES_ORDER.index("C") is_h5 = looks_like_hdf5(fname) if is_h5: fname, _ = os.path.splitext(fname) fid_mask = h5py.File(fname + ".h5", "w") ds_kwargs = { "shape": out_data_shape, "dtype": dtype_str, "chunks": pick_chunks(out_data_shape, dtype_str, target_mb=4.0), "compression": "gzip", "compression_opts": 4, "shuffle": True } mask = fid_mask.create_dataset("data", **ds_kwargs) else: # Zarr file out_data_shape = tuple(int(s) for s in out_data_shape) mask = zarr.open(fname, mode="w", shape=out_data_shape, dtype=dtype_str, zarr_format=3) # Close H5 file read for the data shape if isinstance(imgfile, h5py.File): imgfile.close() del data, fname slices = ( slice(patch_coords.z_start, patch_coords.z_end), slice(patch_coords.y_start, patch_coords.y_end), slice(patch_coords.x_start, patch_coords.x_end), slice(0, out_data_shape[channel_pos]), ) data_ordered_slices = tuple( order_dimensions( slices, input_order="ZYXC", output_order=out_data_order, default_value=0, ) ) # Adjust patch slice to transpose it before inserting intop the final data current_order = np.array(range(len(img.shape))) transpose_order = order_dimensions( current_order, input_order="ZYXC", output_order=out_data_order, default_value=np.nan, ) transpose_order = [x for x in np.array(transpose_order) if not np.isnan(x)] # Place the patch into the Zarr mask[data_ordered_slices] = img.transpose(transpose_order) # Close last open H5 file if mask and isinstance(imgfile, h5py.File): imgfile.close() else: rank = get_rank() world_size = get_world_size() N = len(Y) it = range(rank, N, world_size) for i in tqdm(it, disable=not is_main_process()): img_path = os.path.join(getattr(cfg.DATA, tag).GT_PATH, Y[i]) img = read_img_as_ndarray(img_path, is_3d=not cfg.PROBLEM.NDIM == "2D") if cfg.DATA.N_CLASSES > 2: if img.shape[-1] != 2: raise ValueError( "In instance segmentation, when 'DATA.N_CLASSES' are more than 2 labels need to have two channels, " "e.g. (256,256,2), containing the instance segmentation map (first channel) and classification map " "(second channel)." ) class_channel = np.expand_dims(img[..., 1].copy(), -1) else: if img.shape[-1] != 1: raise ValueError( "Expected instance segmentation GT images to have a single channel containing the instance labels, " "but got image with shape {} ({} channels). Check the image file: {}".format(img.shape, img.shape[-1], img.shape, img_path) ) img = labels_into_channels( img, mode=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, channel_extra_opts=cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0], save_dir=getattr(cfg.PATHS, tag + "_INSTANCE_CHANNELS_CHECK"), ) if cfg.DATA.N_CLASSES > 2: img = np.concatenate([img, class_channel], axis=-1) save_tif( np.expand_dims(img, 0), data_dir=getattr(cfg.DATA, tag).INSTANCE_CHANNELS_MASK_DIR, filenames=[Y[i]], verbose=False, )
[docs] def unique_labels_fast(a: np.ndarray): """ Find the unique labels in an integer array a in [0, K] in O(n) time and O(K) space. Parameters ---------- a : ndarray Input array of integers. Returns ------- ndarray Array of unique labels. """ # a: any shape, integer dtype, all values in [0, K] a = np.asarray(a) K = int(a.max()) present = np.zeros(K + 1, dtype=bool) present[a.ravel().astype(int)] = True # O(n) and very cache-friendly return np.flatnonzero(present) # sorted because we scan 0..K
[docs] def labels_into_channels( instance_labels: NDArray, mode: List[str] = ["I", "C"], channel_extra_opts: Dict = {}, resolution: List[int|float] = [1,1,1], save_dir: Optional[str] = None, ) -> NDArray: """ Convert input semantic or instance segmentation data masks into different binary channels to train an instance segmentation problem. Parameters ---------- instance_labels : 3D/4D Numpy array Instance labels to be used to extract the channels from. E.g. ``(200, 1000, 1000, 1)`` mode : List, optional Operation mode. Possible values: ``C``, ``BC``, ``BCM``, ``BCD``, ``BD``, ``BCDv2``, ``Dv2``, ``BDv2`` and ``BP``. - 'B' stands for 'Binary segmentation', containing each instance region without the contour. - 'C' stands for 'Contour', containing each instance contour. - 'D' stands for 'Distance', each pixel containing the distance of it to the center of the object. - 'M' stands for 'Mask', contains the B and the C channels, i.e. the foreground mask. Is simply achieved by binarizing input instance masks. - 'Dv2' stands for 'Distance V2', which is an updated version of 'D' channel calculating background distance as well. - 'P' stands for 'Points' and contains the central points of an instance (as in Detection workflow) - 'A' stands for 'Affinities" and contains the affinity values for each dimension channel_extra_opts : dict, optional Additional options for each output channel (e.g., {"I": {"erosion": 1}}). resolution : Tuple of int/float Resolution of the data, in ``(z,y,x)`` to calibrate coordinates. E.g. ``[30,8,8]``. save_dir : str, optional Path to store samples of the created array just to debug it is correct. Returns ------- new_mask : 3D/4D Numpy array Instance representations. The shape will be as the input ``instance_labels`` but with the amount of channels requested. E.g. ``(200, 1000, 1000, 3)`` """ assert len(resolution) == 3, "'resolution' must be a list of 3 int/float" assert instance_labels.ndim in [3, 4] c_number = 0 for ch in mode: if ch == "R": nrays = channel_extra_opts["R"]["nrays"] c_number += nrays elif ch == "A": affs = ( len(channel_extra_opts["A"]["z_affinities"]) +len(channel_extra_opts["A"]["y_affinities"]) +len(channel_extra_opts["A"]["x_affinities"]) ) c_number += affs elif ch in ["E_sigma", "E_seediness"]: continue # not special channels, just extra targets for embeddings else: c_number += 1 if any(x for x in ["Dc", "Dn", "D", "Z", "V", "H", "R", "We", "Gv", "Gh", "Gz"] if x in mode): dtype = np.float32 elif "Db" in mode: dtype = np.uint8 if channel_extra_opts.get("Db", {}).get("val_type", "norm") == "discretize" else np.float32 elif "E_offset" in mode: dtype = instance_labels.dtype # Ensure that no floating-point dtype is used for the embeddings. # This allows the normalization module to correctly recognize them # as integer or binary channels, ensuring that the subsequent # data augmentation processes the samples as intended. if np.issubdtype(dtype, np.floating): dtype = decide_dtype(instance_labels.max()) else: dtype = np.uint8 new_mask = np.zeros(instance_labels.shape[:-1] + (c_number,), dtype=dtype) vol = instance_labels[..., 0] if np.issubdtype(vol.dtype, np.floating): vol = vol.astype(np.uint32) # Precompute regionprops only when needed needs_props = False if ( any(x in mode for x in ("Z", "V", "H", "Gv", "Gh", "Gz")) or ("P" in mode and channel_extra_opts.get("P", {}).get("type", "") == "skeleton") or ("Dc" in mode) ): needs_props = True if needs_props: # label, bbox, centroid (you didn't have intensity stats here) props_tbl = regionprops_table(vol, properties=("label", "bbox", "centroid")) # Convenience view as list of labels instances = list(props_tbl["label"]) instance_count = len(instances) else: instances = sorted(list(unique_labels_fast(vol))) instance_count = len(instances) instances = [inst for inst in instances if inst != 0] # remove background if instance_count <= 1: return new_mask # only background disable_tqdm = False if instance_count >= 2000 else True fg_mask = (vol > 0).astype(np.uint8) bg_mask = (vol == 0).astype(np.uint8) # Precompute horizontal/vertical/depth channels if any of H/V/Z is requested if any(ch in mode for ch in ("Z", "V", "H")): norm_flag = True for ch in ("Z", "V", "H"): if ch in channel_extra_opts and "norm" in channel_extra_opts[ch]: norm_flag = bool(channel_extra_opts[ch]["norm"]) break hv_channels = create_HoVe_channels( vol, ref_point="center", normalize_values=norm_flag, calc_props=props_tbl, ) # ---------- Foreground (F) ---------- if "F" in mode: # Check if erosion/dilation is requested as the process needs the original volume # to make it per-instance er_k = channel_extra_opts.get("F", {}).get("erosion", 0) dil_k = channel_extra_opts.get("F", {}).get("dilation", 0) mask = fg_mask.astype(np.uint8) erode, dilate = False, False if (isinstance(er_k, int) and er_k > 0) or (isinstance(er_k, list) and any([x for x in er_k if x > 0])): erode = True if (isinstance(dil_k, int) and dil_k > 0) or (isinstance(dil_k, list) and any([x for x in dil_k if x > 0])): dilate = True if erode or dilate: mask = np.zeros_like(fg_mask, dtype=vol.dtype) dil_k = [dil_k,]*mask.ndim if isinstance(dil_k, int) else dil_k dil_k = generate_ellipse_footprint(dil_k) er_k = [er_k,]*mask.ndim if isinstance(er_k, int) else er_k er_k = generate_ellipse_footprint(er_k) for lb in tqdm(instances, disable=disable_tqdm): m = (vol == lb) if not np.any(m): continue if dilate: m = binary_dilation(m.astype(np.uint8), footprint=dil_k).astype(np.uint8) if erode: m = binary_erosion(m.astype(np.uint8), footprint=er_k).astype(np.uint8) mask[m > 0] = lb new_mask[..., mode.index("F")] = mask # ---------- Background (B) ---------- if "B" in mode: # Check if erosion/dilation is requested as the process needs the original volume # to make it per-instance er_k = channel_extra_opts.get("B", {}).get("erosion", 0) dil_k = channel_extra_opts.get("B", {}).get("dilation", 0) erode, dilate = False, False mask = bg_mask.astype(np.uint8) if (isinstance(er_k, int) and er_k > 0) or (isinstance(er_k, list) and any([x for x in er_k if x > 0])): erode = True if (isinstance(dil_k, int) and dil_k > 0) or (isinstance(dil_k, list) and any([x for x in dil_k if x > 0])): dilate = True if erode or dilate: mask = np.zeros_like(fg_mask, dtype=np.uint8) dil_k = [dil_k,]*mask.ndim if isinstance(dil_k, int) else dil_k dil_k = generate_ellipse_footprint(dil_k) er_k = [er_k,]*mask.ndim if isinstance(er_k, int) else er_k er_k = generate_ellipse_footprint(er_k) for lb in tqdm(instances, disable=disable_tqdm): m = (vol == lb) if not np.any(m): continue # As the background mask is going to be created using the instances, # we need to invert the operations if dilate: m = binary_erosion(m.astype(np.uint8), footprint=dil_k).astype(np.uint8) if erode: m = binary_dilation(m.astype(np.uint8), footprint=er_k).astype(np.uint8) mask[m > 0] = 1 new_mask[..., mode.index("B")] = mask # ---------- P (central part) ---------- if "P" in mode: p_opts = channel_extra_opts.get("P", {}) p_type = p_opts.get("type", "centroid") p_dil = p_opts.get("dilation", 1) p_ero = p_opts.get("erosion", 1) p_out = np.zeros_like(fg_mask, dtype=np.uint8) if p_type == "skeleton": for i, lb in tqdm(enumerate(instances), disable=disable_tqdm): slc = slice_from_props(props_tbl, i, vol.ndim) sub = (vol[slc] == lb) sk = skeletonize(sub) p_out[slc] += sk else: com_list = center_of_mass(fg_mask, labels=vol, index=instances) # Mark each centroid (guard against rounding outside bounds) if p_out.ndim == 2: H, W = p_out.shape for cy, cx in com_list: y = int(round(cy)); x = int(round(cx)) if 0 <= y < H and 0 <= x < W: p_out[y, x] = 1 elif p_out.ndim == 3: Z, Y, X = p_out.shape for cz, cy, cx in com_list: z = int(round(cz)); y = int(round(cy)); x = int(round(cx)) if 0 <= z < Z and 0 <= y < Y and 0 <= x < X: p_out[z, y, x] = 1 else: raise ValueError(f"Unsupported ndim {p_out.ndim} for P[type='centroid']") # Optional dilation (in pixels / voxels) if (isinstance(p_dil, int) and p_dil > 0) or (isinstance(p_dil, list) and any([x for x in p_dil if x > 0])): p_dil = [p_dil,]*p_out.ndim if isinstance(p_dil, int) else p_dil p_out = binary_dilation(p_out, footprint=generate_ellipse_footprint(p_dil)).astype(np.uint8) # Optional erosion (in pixels / voxels) if (isinstance(p_ero, int) and p_ero > 0) or (isinstance(p_ero, list) and any([x for x in p_ero if x > 0])): p_ero = [p_ero,]*p_out.ndim if isinstance(p_ero, int) else p_ero p_out = binary_erosion(p_out, footprint=generate_ellipse_footprint(p_ero)).astype(np.uint8) # Write the channel new_mask[..., mode.index("P")] = p_out # ---------- C (contours) ---------- if "C" in mode: c_mode = channel_extra_opts.get("C", {}).get("mode", "thick") if c_mode == "dense": # synthetic "dense" edges: dilate FG and XOR with FG to thicken borders on both sides fg = fg_mask if fg.ndim == 2: rim = binary_dilation(fg, disk(1)).astype(np.uint8) ^ fg new_mask[..., mode.index("C")] = rim else: out = np.zeros_like(fg) for j in range(fg.shape[0]): out[j] = (binary_dilation(fg[j], disk(1)).astype(np.uint8) ^ fg[j]) new_mask[..., mode.index("C")] = out else: # valid skimage modes: inner|outer|thick|subpixel new_mask[..., mode.index("C")] = find_boundaries(vol, mode=c_mode).astype(np.uint8) # ---------- Dc (distance to center/skeleton) ---------- if "Dc" in mode: dc_type = channel_extra_opts.get("Dc", {}).get("type", "centroid") dc_channel = np.zeros_like(vol, dtype=np.float32) if vol.ndim == 3: cz_tab = props_tbl['centroid-0'] cy_tab = props_tbl['centroid-1'] cx_tab = props_tbl['centroid-2'] else: cy_tab = props_tbl['centroid-0'] cx_tab = props_tbl['centroid-1'] for i, lb in tqdm(enumerate(instances), disable=disable_tqdm): slc = slice_from_props(props_tbl, i, vol.ndim) sub = (vol[slc] == lb) if not sub.any(): continue if dc_type == "skeleton": sk = skeletonize(sub) dist_to_sk = edt.edt(~sk, anisotropy=resolution, parallel=-1) dc_channel[slc][sub] = dist_to_sk[sub] else: if vol.ndim == 3: cz = float(cz_tab[i]); cy = float(cy_tab[i]); cx = float(cx_tab[i]) zz, yy, xx = np.ogrid[slc[0], slc[1], slc[2]] # grids in global coords dist = np.sqrt((zz - cz)**2 + (yy - cy)**2 + (xx - cx)**2) dc_channel[slc][sub] = dist[sub] else: cy = float(cy_tab[i]); cx = float(cx_tab[i]) yy, xx = np.ogrid[slc[0], slc[1]] dist = np.sqrt((yy - cy)**2 + (xx - cx)**2) dc_channel[slc][sub] = dist[sub] assert isinstance(dc_channel, np.ndarray), "Expected dc_channel to be a numpy array" # Normalization if channel_extra_opts.get("Dc", {}).get("norm", False): dc_channel = norm_channel( dc_channel, vol, instances, ) new_mask[..., mode.index("Dc")] = dc_channel # ---------- Db (distance to boundary) ---------- if "Db" in mode: db_channel = edt.edt(vol, anisotropy=resolution, parallel=-1) assert isinstance(db_channel, np.ndarray), "Expected db to be a numpy array" # Normalization val_type = channel_extra_opts.get("Db", {}).get("val_type", "norm") if val_type in ["norm", "discretize"]: db_channel = norm_channel( db_channel, vol, instances, ) if val_type == "discretize": db_dis_bin_size = channel_extra_opts.get("Db", {}).get("bin_size", 0.1) db_dis_K = int(round(1.0 / db_dis_bin_size)) # 10 db_channel = np.clip(db_channel, 0.0, 1.0) fg = fg_mask.astype(bool) # bin index in [0, K-1] bin_idx = np.floor(db_channel / db_dis_bin_size).astype(np.uint8) bin_idx = np.clip(bin_idx, 0, db_dis_K - 1) labels = np.zeros(db_channel.shape, dtype=np.uint8) # background = 0 labels[fg] = bin_idx[fg] + 1 # foreground bins = 1..K db_channel = labels new_mask[..., mode.index("Db")] = db_channel # ---------- Dn (distance to neighbor) ---------- if "Dn" in mode: dn_opts = channel_extra_opts.get("Dn", {}) dn_norm = bool(dn_opts.get("norm", False)) power = float(dn_opts.get("decline_power", 3.0)) closing_size = int(dn_opts.get("closing_size", 3)) # neighborhood size for grey closing dn_channel = np.zeros_like(vol, dtype=np.float32) # Mask to remember which cell pixels belong to cells that have at least one OTHER instance has_neighbor_px = np.zeros_like(vol, dtype=bool) for lab in tqdm(instances, disable=disable_tqdm): cur = (vol == lab) if not np.any(cur): continue other_instances = fg_mask & (~cur) if not np.any(other_instances): # No other labeled object anywhere -> keep this cell at 0 (suppressed) continue has_neighbor_px[cur] = True # Per paper: (selected cell ∪ background) = 1; "other cells" = 0 (distance to other cells) fg = (bg_mask | cur) d = edt.edt(fg, anisotropy=resolution, parallel=-1).astype(np.float32) if dn_norm: # Paper path: cut->normalize [0,1]->invert (1 - ..) d_cell = d[cur].copy() m = d_cell.max() if m > 0: d_cell /= m else: d_cell.fill(0.0) dn_channel[cur] = 1.0 - d_cell else: # Store raw distances for now (we'll handle per-image norm or unnormalized inversion later) dn_channel[cur] = d[cur] invert_mask = has_neighbor_px # operate only on cells that actually have neighbors # --- Unnormalized but still inverted --- if not dn_norm: if np.any(invert_mask): M = float(dn_channel[invert_mask].max()) if M > 0: dn_channel[invert_mask] = M - dn_channel[invert_mask] # background & isolated cells remain 0 # Grayscale closing after merging & inversion if closing_size > 0: size = (closing_size,) * dn_channel.ndim dn_channel = grey_closing(dn_channel, size=size) if power is not None and power != 1.0: if dn_norm : # values in [0,1] already → direct power dn_channel = np.power(dn_channel, power, dtype=np.float32) else: # unnormalized: temporarily normalize on the meaningful support, power, then unnormalize if np.any(invert_mask): M2 = float(dn_channel[invert_mask].max()) if M2 > 0: tmp = dn_channel[invert_mask] / M2 tmp = np.power(tmp, power, dtype=np.float32) dn_channel[invert_mask] = tmp * M2 new_mask[..., mode.index("Dn")] = dn_channel.astype(np.float32) # ---------- D (signed distance, global) ---------- if "D" in mode: alpha = channel_extra_opts.get("D", {}).get("alpha", 1.0) beta = channel_extra_opts.get("D", {}).get("beta", 1.0) # 1) Signed distance sdist = edt.edt(fg_mask, anisotropy=resolution, parallel=-1)/alpha - edt.edt(bg_mask, anisotropy=resolution, parallel=-1)/beta assert isinstance(sdist, np.ndarray), "Expected sdist to be a numpy array" # 2) Map GT to [-1, 1] with tanh (COSEM-style: "Whole-cell organelle segmentation in volume electron microscopy") tanh_on = channel_extra_opts.get("D", {}).get("norm", True) if tanh_on: sdist = np.tanh(sdist) new_mask[..., mode.index("D")] = sdist # ---------- H / V / Z (horizontal/vertical/depth channels) ---------- if "Z" in mode: new_mask[..., mode.index("Z")] = hv_channels[...,0] if "V" in mode: ch_pos = 0 if new_mask[..., mode.index("V")].ndim == 2 else 1 new_mask[..., mode.index("V")] = hv_channels[...,ch_pos] if "H" in mode: ch_pos = 1 if new_mask[..., mode.index("H")].ndim == 2 else 2 new_mask[..., mode.index("H")] = hv_channels[...,ch_pos] # ---------- Gv / Gh / Gz (flow-like channels) ---------- if 'Gv' in mode: niter = channel_extra_opts.get("Gv", {}).get("niter", 200) gtype = channel_extra_opts.get("Gv", {}).get("gradient_type", "cellpose") Gv = np.zeros_like(vol, dtype=np.float32) Gh = np.zeros_like(vol, dtype=np.float32) if vol.ndim == 3: Gz = np.zeros_like(vol, dtype=np.float32) if gtype == "omnipose": # 1. Calculate Distance Transform (The "Potential Field") # Must be computed per-cell so each pixel holds its distance to its # own cell boundary. Using the full multi-label vol as foreground # merges touching cells into one region, giving wrong EDT values at # shared interfaces. dist_field = np.zeros_like(vol, dtype=np.float32) for lb in instances: cell_mask = (vol == lb) dist_field[cell_mask] = edt.edt(cell_mask, anisotropy=resolution, parallel=-1)[cell_mask] # 2. Calculate Gradients of the Distance Field # np.gradient returns [dz, dy, dx] for 3D or [dy, dx] for 2D # Pass voxel spacing so gradients are physically correct for anisotropic data. if vol.ndim == 3: grads = np.gradient(dist_field, resolution[0], resolution[1], resolution[2]) else: grads = np.gradient(dist_field, resolution[1], resolution[2]) # 3. Normalize the flows # Omnipose uses normalized gradients as direction vectors mag = np.sqrt(sum(g**2 for g in grads) + 1e-6) grads = [g / mag for g in grads] # 4. Apply mask and save to global arrays # Only write within foreground to leave background at 0. fg = fg_mask > 0 if vol.ndim == 3: Gz[fg] = grads[-3][fg] Gv[fg] = grads[-2][fg] Gh[fg] = grads[-1][fg] else: for i, lb in tqdm(enumerate(instances), disable=disable_tqdm): slc = slice_from_props(props_tbl, i, vol.ndim) mask = (vol[slc] == lb) # 1. Define the center (source of the heat) # Cellpose uses the mask pixel closest to the continuous centroid, # not the median of sorted coordinates (which can differ for irregular shapes). coords = np.nonzero(mask) centroid = np.array([c.mean() for c in coords]) # continuous centroid coord_stack = np.stack(coords, axis=1).astype(np.float32) # (N, ndim) idx = int(np.argmin(np.sum((coord_stack - centroid) ** 2, axis=1))) centers = tuple(int(c[idx]) for c in coords) # 2. Iterative Diffusion Process # We initialize the center with a value and let it diffuse strictly within the mask heat = np.zeros_like(mask, dtype=np.float32) heat[centers] = 1.0 # Pre-calculate diffusion normalization # 2D (4 neighbors) -> 0.25 | 3D (6 neighbors) -> 0.1666... diff_coeff = 1.0 / (2 * vol.ndim) for _ in range(niter): # Accumulate neighbor values with zero-boundary conditions. # np.roll wraps around bounding-box edges (periodic BC), which # is wrong; explicit slice-based shifts give Dirichlet BC (0 outside). neighbor_sum = np.zeros_like(heat) for axis in range(heat.ndim): slc_src = [slice(None)] * heat.ndim slc_dst = [slice(None)] * heat.ndim slc_src[axis] = slice(None, -1) # value at index i slc_dst[axis] = slice(1, None) # goes to neighbor i+1 neighbor_sum[tuple(slc_dst)] += heat[tuple(slc_src)] slc_src[axis] = slice(1, None) # value at index i+1 slc_dst[axis] = slice(None, -1) # goes to neighbor i neighbor_sum[tuple(slc_dst)] += heat[tuple(slc_src)] new_heat = neighbor_sum * diff_coeff heat[coords] = new_heat[coords] # Keep the source hot heat[centers] = 1.0 # 3. Calculate Gradients # np.gradient needs at least 2 elements per axis. Skip cells that # are only 1 pixel wide along any axis (their flow stays 0). if any(s < 2 for s in heat.shape): continue # np.gradient returns a list of arrays: [d0, d1, d2] -> [dz, dy, dx] in 3D # Pass voxel spacing so anisotropic data produces correct flow directions. if vol.ndim == 3: grads = np.gradient(heat, resolution[0], resolution[1], resolution[2]) else: grads = np.gradient(heat, resolution[1], resolution[2]) # 4. Normalize the flows mag = np.sqrt(sum(g**2 for g in grads) + 1e-6) grads = [g / mag for g in grads] # 5. Apply mask and save to global arrays if vol.ndim == 3: # grads[0]=dz, grads[1]=dy, grads[2]=dx Gz[slc][mask] = grads[0][mask] Gv[slc][mask] = grads[1][mask] Gh[slc][mask] = grads[2][mask] else: # grads[0]=dy, grads[1]=dx Gv[slc][mask] = grads[0][mask] Gh[slc][mask] = grads[1][mask] # Map back to the new_mask channels if "Gz" in mode and Gz is not None: new_mask[..., mode.index("Gz")] = Gz new_mask[..., mode.index("Gv")] = Gv new_mask[..., mode.index("Gh")] = Gh # ---------- T (touching area) ---------- if "T" in mode: new_mask[..., mode.index("T")] = touching_mask_nd( vol, connectivity=new_mask[..., mode.index("T")].ndim ) # ---------- A (affinities) ---------- if "A" in mode: ins_vol = vol wb = int(channel_extra_opts["A"].get("widen_borders", 1)) if wb: ins_vol = seg_widen_border(vol, tsz_h=wb) k = 0 for zaff, yaff, xaff in zip( channel_extra_opts["A"].get("z_affinities", []), channel_extra_opts["A"].get("y_affinities", []), channel_extra_opts["A"].get("x_affinities", []), ): affs = seg2aff_pni(ins_vol, dz=zaff, dy=yaff, dx=xaff, dtype=dtype) # shape: (n_affs, Z, Y, X) affs = np.transpose(affs, (1, 2, 3, 0)) # shape: (Z, Y, X, n_affs) new_mask[..., k*3:(k+1)*3] = affs k += 1 # ---------- R (radial distances) ---------- if "R" in mode: r_opts = channel_extra_opts.get("R", {}) ndim = 2 if new_mask[..., mode.index("R")].ndim == 2 else 3 nrays = int(r_opts.get("nrays", 32 if ndim == 2 else 96)) rays = generate_rays(n_rays=nrays, ndim=ndim).astype(np.float32) spacing = None if new_mask[..., mode.index("R")].ndim == 2 else resolution new_mask[..., mode.index("R"):mode.index("R")+nrays] = radial_distances(vol, rays, spacing=spacing) # ---------- E (Embeddings) ---------- # Here we only use E_offset as extra target for the embeddings branch if "E_offset" in mode: new_mask[..., mode.index("E_offset")] = vol.copy() if "We" in mode: mask_to_use = vol if "F" in mode: mask_to_use = new_mask[..., mode.index("F")] elif "B" in mode: mask_to_use = new_mask[..., mode.index("B")] new_mask[..., mode.index("We")] = unet_border_weight_map(mask_to_use, w0=10.0, sigma=5.0, resolution=resolution) # Binarize foreground channel at this point and not before because # they need to be used in We channel creation if "F" in mode: new_mask[..., mode.index("F")] = new_mask[..., mode.index("F")] > 0 # ---------- M (Legacy mask used in CartoCell) ---------- if "M" in mode: # Binary mask = F + C f_ch = new_mask[..., mode.index("F")] c_ch = new_mask[..., mode.index("C")] new_mask[..., mode.index("M")] = np.clip(f_ch + c_ch, 0, 1).astype(np.uint8) # Save examples of each channel if save_dir: os.makedirs(save_dir, exist_ok=True) for j, mod in enumerate(mode): if mod == "B": suffix = "_background.tif" elif mod == "F": suffix = "_foreground.tif" elif mod == "P": suffix = "_central_part.tif" elif mod == "C": suffix = "_contour.tif" elif mod == "H": suffix = "_horizontal_distance.tif" elif mod == "V": suffix = "_vertical_distance.tif" elif mod == "Z": suffix = "_z_distance.tif" elif mod == "Gv": suffix = "_vertical_flow.tif" elif mod == "Gh": suffix = "_horizontal_flow.tif" elif mod == "Gz": suffix = "_z_flow.tif" elif mod == "Db": suffix = "_distance_to_border.tif" elif mod == "Dc": suffix = "_distance_to_center.tif" elif mod == "Dn": suffix = "_distance_to_neighbor.tif" elif mod == "D": suffix = "_distance.tif" elif mod == "R": suffix = "_radial_distances.tif" elif mod == "T": suffix = "_touching.tif" elif mod == "A": suffix = "_affinity.tif" elif mod == "E_offset": suffix = "_embedding_instances.tif" elif mod in ["E_sigma", "E_seediness"]: continue elif mod == "We": suffix = "_border_weights.tif" elif mod == "M": suffix = "_CartoCell_M_channel.tif" else: raise ValueError("Unknown channel type: {}".format(mod)) aux = new_mask[..., j] aux = np.expand_dims(np.expand_dims(aux, -1), 0) save_tif(aux, save_dir, filenames=["vol" + suffix[j]], verbose=False) save_tif( np.expand_dims(instance_labels, 0), save_dir, filenames=["vol_y.tif"], verbose=False, ) return new_mask
[docs] def norm_channel(channel: NDArray, vol: NDArray, instances: list[int]) -> NDArray: """ Normalize a channel based on instance masks. Parameters ---------- channel : NDArray The channel to normalize (e.g. db_channel). vol : NDArray Instance mask volume, same shape as channel. instances : list[int] List of instance IDs in `vol`. Background (0) will be ignored. Returns ------- NDArray Normalized channel, same shape as input. """ instances = [inst for inst in instances if inst != 0] # drop background normed = np.zeros_like(channel, dtype=np.float32) for inst in instances: mask = (vol == inst) if not np.any(mask): continue values = channel[mask] mi, ma = values.min(), values.max() # Avoid division by zero if ma == mi: normed[mask] = 0 else: normed[mask] = (values - mi) / (ma - mi) return normed
[docs] def slice_from_props(props_tbl: pd.DataFrame | dict, i: int, ndim: int) -> tuple[slice, ...]: """ Get a slice representation from the properties table for a specific instance. Parameters ---------- props_tbl : pd.DataFrame | dict The properties table containing region properties. i : int The index of the instance in the properties table. ndim : int The number of dimensions (2 or 3). Returns ------- tuple[slice, ...] A tuple of slice objects representing the bounding box of the instance. """ if ndim == 2: y0 = int(props_tbl['bbox-0'][i]) x0 = int(props_tbl['bbox-1'][i]) y1 = int(props_tbl['bbox-2'][i]) x1 = int(props_tbl['bbox-3'][i]) return (slice(y0, y1), slice(x0, x1)) elif ndim == 3: z0 = int(props_tbl['bbox-0'][i]) y0 = int(props_tbl['bbox-1'][i]) x0 = int(props_tbl['bbox-2'][i]) z1 = int(props_tbl['bbox-3'][i]) y1 = int(props_tbl['bbox-4'][i]) x1 = int(props_tbl['bbox-5'][i]) return (slice(z0, z1), slice(y0, y1), slice(x0, x1)) else: raise ValueError("Only 2D or 3D volumes are supported.")
[docs] def unet_border_weight_map( instances: np.ndarray, w0: float = 10.0, sigma: float = 5.0, apply_only_background: bool = True, resolution: List[int|float] | None = None, ) -> np.ndarray: """ U-Net border-aware weight map (Ronneberger et al. 2015) for 2D or 3D labels. Parameters ---------- instances : np.ndarray, shape (H, W) or (D, H, W), dtype int 0/`background` for background, 1..N (or any ints != background) are instance ids. w0 : float Border weight magnitude. sigma : float Spatial decay (in same units as resolution). apply_only_background : bool If True, apply the exponential term only on background (as in the paper). resolution : List[int|float] | None Voxel spacing along each axis (z,y,x) or (y,x). If None, isotropic spacing of 1 is assumed. Returns ------- w : np.ndarray, same shape as `instances`, dtype float32 Border weight map. """ if instances.ndim not in (2, 3): raise ValueError(f"`instances` must be 2D or 3D, got shape {instances.shape}") inst = instances.astype(np.int32, copy=False) shp = inst.shape # collect unique instance ids excluding background ids = np.unique(inst) ids = ids[ids != 0] # Special handling when exactly one instance is present: # treat background as a pseudo-second instance so we still emphasize the object boundary. if ids.size == 1: lab = ids[0] # Distance to the (only) instance: zeros inside the instance d_obj = edt.edt(inst != lab, anisotropy=resolution, parallel=-1).astype(np.float32, copy=False) # Distance to background: zeros in background d_bg = edt.edt(inst != 0, anisotropy=resolution, parallel=-1).astype(np.float32, copy=False) denom = 2.0 * (sigma ** 2) w_border = w0 * np.exp(-((d_obj + d_bg) ** 2) / denom, dtype=np.float32) w_border = w_border.astype(np.float32, copy=False) if apply_only_background: w_border *= (inst == 0) return w_border # Need at least two distinct instances for the (d1 + d2) term to be meaningful if ids.size < 2: return np.zeros(shp, dtype=np.float32) # Compute distance-to-each-instance via EDT on the complement of that instance # distances[k, ...] = distance to instance ids[k] distances = np.empty((ids.size, *shp), dtype=np.float32) for k, lab in tqdm(enumerate(ids), total=len(ids)): # edt computes distance to zeros -> pass mask that's zero *inside* the object # equivalently: distance to the boundary of object `lab` distances[k] = edt.edt(inst != lab, anisotropy=resolution, parallel=-1) # nearest and second-nearest distances at each voxel/pixel d1 = distances.min(axis=0) d2 = np.partition(distances, 1, axis=0)[1] # Border emphasis term denom = 2.0 * (sigma ** 2) w_border = w0 * np.exp(-((d1 + d2) ** 2) / denom, dtype=np.float32) w_border = w_border.astype(np.float32, copy=False) if apply_only_background: w_border *= (inst == 0) return w_border
[docs] def touching_mask_nd(labels: NDArray, connectivity: int = 1) -> NDArray: """ Create a binary mask of touching pixels/voxels for an N-D labeled instance mask. Parameters ---------- labels : NDArray N-D array of instance labels (0 = background, 1..N = instances). connectivity : int, optional Neighborhood connectivity passed to `generate_binary_structure`. 1 = 6-neigh for 3D / 4-neigh for 2D, 2 = 18-neigh for 3D / 8-neigh for 2D, 3 = 26-neigh for 3D (if ndim==3). Returns ------- touch : NDArray Binary mask with 1 where a voxel touches at least one *different* instance. """ # Neighborhood footprint including the center footprint = generate_binary_structure(labels.ndim, connectivity) def is_touching(window): center = window[len(window)//2] if center == 0: # background is never touching return 0 # unique neighbor labels (including center); drop 0 and center label uniq = np.unique(window) return 1 if np.any((uniq != 0) & (uniq != center)) else 0 touch = generic_filter( labels, is_touching, footprint=footprint, mode='constant', cval=0 ) return touch.astype(np.uint8)
[docs] def generate_rays(n_rays: int, ndim: int, jitter: bool=False, seed: int=0): """ Unit directions in R^ndim. - 2D: uniform angles on circle -> (R,2) [dx,dy] - 3D: Fibonacci sphere -> (R,3) [dx,dy,dz] Parameters ---------- n_rays : int Number of rays to generate. ndim : int Dimensionality (2 or 3). jitter : bool, optional Whether to add jitter to 3D rays (default: False). seed : int, optional Random seed for jitter (default: 0). Returns ------- rays : (n_rays, 2) or (n_rays, 3) Numpy array Unit vectors along which to compute distances. """ if ndim == 2: a = np.linspace(0, 2*np.pi, n_rays, endpoint=False, dtype=np.float32) return np.stack([np.cos(a), np.sin(a)], axis=1).astype(np.float32) elif ndim == 3: rng = np.random.default_rng(seed) if jitter else None i = np.arange(n_rays, dtype=np.float32) phi = (1 + np.sqrt(5.0)) / 2.0 z = 1 - 2*(i + 0.5) / n_rays r = np.sqrt(np.maximum(0.0, 1 - z*z)) theta = 2*np.pi*i/phi if jitter: theta += rng.uniform(-np.pi/n_rays, np.pi/n_rays, size=n_rays) z += rng.uniform(-1/n_rays, 1/n_rays, size=n_rays) z = np.clip(z, -1.0, 1.0); r = np.sqrt(np.maximum(0.0, 1 - z*z)) x = r * np.cos(theta); y = r * np.sin(theta) dirs = np.stack([x, y, z], axis=1).astype(np.float32) dirs /= (np.linalg.norm(dirs, axis=1, keepdims=True) + 1e-12) return dirs else: raise ValueError("Only 2D and 3D are supported.")
[docs] def radial_distances( labels: NDArray, rays: NDArray, max_dist: Optional[float] = None, spacing: Optional[Sequence[float]] = None, max_iters: int = 50, ) -> NDArray: """ Compute radial distances from each foreground pixel to the instance boundary along specified rays. Parameters ---------- labels : NDArray 2D or 3D array of instance labels (0 = background, 1..N = instances). rays : (n_rays, 2) or (n_rays, 3) Numpy array Unit vectors along which to compute distances. max_dist : float, optional Maximum distance to cap at. If None, no capping is done. spacing : sequence of float, optional Physical spacing of the data in each dimension. If None, assumes isotropic spacing of 1.0. max_iters : int Maximum number of steps to march along each ray. Returns ------- D : NDArray Array of shape (H, W, n_rays) or (D, H, W, n_rays) with distances in physical units. Background pixels have distance 0 in all rays. """ labels = np.asarray(labels) ndim = labels.ndim assert rays.ndim == 2 and rays.shape[1] == ndim spacing = np.ones(ndim, np.float32) if spacing is None else np.asarray(spacing, np.float32) shape = labels.shape n_rays = rays.shape[0] # normalize rays in index space (row/col[/z] units) rays_idx = rays.astype(np.float32) norms = np.linalg.norm(rays_idx, axis=1, keepdims=True) + 1e-12 rays_idx /= norms # per-ray physical step length for one unit in index space ray_step_phys = np.linalg.norm(rays_idx * spacing, axis=1) # shape (n_rays,) D = np.zeros(shape + (n_rays,), np.float32) fg = np.argwhere(labels > 0) H, W = shape[0], shape[1] if ndim == 2 else (shape[0], shape[1]) # for bounds for (i, j, *rest) in fg: inst_id = int(labels[i, j]) if ndim == 2 else int(labels[i, j, rest[0]]) p0 = np.array([i, j] if ndim == 2 else [i, j, rest[0]], np.float32) # pixel center reference for k in range(n_rays): u = rays_idx[k] # unit in index space x = np.zeros(ndim, np.float32) # accumulated offset in index space # march in unit steps like the ref (||u||=1) for _ in range(max_iters * (max(shape) + 2)): # safe cap x += u p_samp = p0 + x # rounded sampling if ndim == 2: ii = int(np.rint(p_samp[0])); jj = int(np.rint(p_samp[1])) out = (ii < 0 or ii >= shape[0] or jj < 0 or jj >= shape[1]) changed = (not out) and (labels[ii, jj] != inst_id) else: ii = int(np.rint(p_samp[0])); jj = int(np.rint(p_samp[1])); kk = int(np.rint(p_samp[2])) out = (ii < 0 or ii >= shape[0] or jj < 0 or jj >= shape[1] or kk < 0 or kk >= shape[2]) changed = (not out) and (labels[ii, jj, kk] != inst_id) if out or changed: max_comp = np.max(np.abs(u)) + 1e-12 t_corr = 1.0 - 0.5 / max_comp x = x - t_corr * u # pull back along dominant axis # distance in pixels dist_idx = float(np.linalg.norm(x)) # convert to physical units if requested dist = dist_idx * float(ray_step_phys[k]) if max_dist is not None and dist > max_dist: dist = max_dist if ndim == 2: D[i, j, k] = dist else: D[i, j, rest[0], k] = dist break return D
[docs] def euler_integration(flow: NDArray, coords: NDArray, n_steps: int = 200, dt: float = 1.0, suppressed: bool = True): """ Euler integration of flow field starting at coords. Parameters ---------- flow : (2, H, W) or (3, D, H, W) Numpy array Flow field (y,x) or (z,y,x). coords : (N, 2) or (N, 3) Numpy array Starting coordinates (y,x) or (z,y,x) in index space. n_steps : int Number of integration steps. dt : float Integration step size. suppressed : bool Whether to use time-suppressed integration (dt/(t+1)) or not (constant dt). Returns ------- pos : (N, 2) or (N, 3) Numpy array Final positions after integration. """ pos = coords.astype(float).copy() H, W = flow.shape[1:] for t in range(n_steps): # Interpolate flow at current positions fy = map_coordinates(flow[0], [pos[:,0], pos[:,1]], order=1, mode='nearest') fx = map_coordinates(flow[1], [pos[:,0], pos[:,1]], order=1, mode='nearest') step = np.stack([fy, fx], axis=1) # suppression factor factor = dt / (t+1) if suppressed else dt pos += factor * step # keep inside bounds pos[:,0] = np.clip(pos[:,0], 0, H-1) pos[:,1] = np.clip(pos[:,1], 0, W-1) return pos # final positions for clustering
def _in_bounds(p: np.ndarray, shape_zyx: Tuple[int, int, int]) -> bool: """ Check if a point p (z,y,x) is within the bounds of a shape (Z,Y,X). Parameters ---------- p : np.ndarray A point in (z,y,x) coordinates. shape_zyx : Tuple[int, int, int] The shape of the volume in (Z,Y,X). Returns ------- bool True if p is within bounds, False otherwise. """ # p is (3,) z,y,x return bool(np.all((p >= 0) & (p < np.asarray(shape_zyx)))) def _bbox_from_points(points_zyx: np.ndarray, width_zyx: np.ndarray, shape_zyx: Tuple[int, int, int]) -> List[int]: """ Compute a bounding box [z0,z1,y0,y1,x0,x1] that contains all points in `points_zyx` with an additional `width_zyx` margin, while ensuring the box is within the bounds of `shape_zyx`. Parameters ---------- points_zyx : np.ndarray An array of shape (N, 3) containing points in (z,y,x) coordinates. width_zyx : np.ndarray An array of shape (3,) specifying the margin to add in each dimension (z,y,x). shape_zyx : Tuple[int, int, int] The shape of the volume in (Z,Y,X) to ensure the bounding box does not exceed these bounds. Returns ------- List[int] A list containing the bounding box coordinates [z0,z1,y0,y1,x0,x1]. Returns ------- List[int] A list containing the bounding box coordinates [z0,z1,y0,y1,x0,x1]. """ shape = np.asarray(shape_zyx) lo = np.maximum(0, points_zyx.min(axis=0) - width_zyx) hi = np.minimum(shape, points_zyx.max(axis=0) + width_zyx) # make hi exclusive and ints return [int(lo[0]), int(hi[0]), int(lo[1]), int(hi[1]), int(lo[2]), int(hi[2])] def _make_output_array(fname: str, data_shape_zyx, channels: int, zinfo: Dict, dtype_str="float32"): """ Create an output array for the generated mask, either as a Zarr or HDF5 dataset depending on the file extension of `fname`. Parameters ---------- fname : str The filename for the output array. If it ends with .h5 or .hdf the function will create an HDF5 dataset; otherwise, it will create a Zarr array. data_shape_zyx : Tuple[int, int, int] The shape of the data in (Z,Y,X) order. channels : int The number of channels in the output array. zinfo : Dict A dictionary containing metadata, including the "axes_order" key which specifies the order of axes in the output array (e.g. "ZYXC" or "ZCYX"). dtype_str : str The data type for the output array (default: "float32"). Returns ------- mask : Zarr array or HDF5 dataset The created output array for the generated mask. fid_mask : h5py.File or None The HDF5 file object if an HDF5 dataset was created, or None if a Zarr array was created. out_shape : Tuple[int, ...] The shape of the output array, including channels, in the order specified by `zinfo["axes_order"]`. out_order : str The order of axes in the output array, as specified by `zinfo["axes_order"]` with "C" for channels if not already included. c_pos : int or None The position of the channel axis in the output array, or None if channels are added as the last axis. """ os.makedirs(os.path.dirname(fname), exist_ok=True) out_data_shape = np.array(data_shape_zyx, dtype=int) axes = zinfo["axes_order"] if "C" not in axes: out_shape = tuple(out_data_shape) + (channels,) out_order = axes + "C" c_pos = -1 else: out_shape = out_data_shape.copy() out_shape[axes.index("C")] = channels out_shape = tuple(int(s) for s in out_shape) out_order = axes c_pos = axes.index("C") is_h5 = looks_like_hdf5(fname) fid_mask = None if is_h5: base, _ = os.path.splitext(fname) fid_mask = h5py.File(base + ".h5", "w") ds_kwargs = { "shape": out_shape, "dtype": dtype_str, "chunks": pick_chunks(out_shape, dtype_str, target_mb=4.0), "compression": "gzip", "compression_opts": 4, "shuffle": True, } mask = fid_mask.create_dataset("data", **ds_kwargs) else: mask = zarr.open(fname, mode="w", shape=out_shape, dtype=dtype_str, zarr_format=3) return mask, fid_mask, out_shape, out_order, c_pos def _clip_slices(z, y, x, hz, hy, hx, Z, Y, X): z0, z1 = max(0, z - hz), min(Z, z + hz + 1) y0, y1 = max(0, y - hy), min(Y, y + hy + 1) x0, x1 = max(0, x - hx), min(X, x + hx + 1) return slice(z0, z1), slice(y0, y1), slice(x0, x1) def _beam_score(smoothed, z, y, x, beam_hw, score_mode="p10"): hz, hy, hx = beam_hw Z, Y, X = smoothed.shape slz, sly, slx = _clip_slices(z, y, x, hz, hy, hx, Z, Y, X) slab = smoothed[slz, sly, slx] if slab.size == 0: return np.inf if score_mode == "min": return float(slab.min()) elif score_mode == "p10": return float(np.percentile(slab, 10)) else: raise ValueError(f"Unknown score_mode: {score_mode}")
[docs] def synapse_channel_creation( data_info: Dict, zarr_data_information: Dict, savepath: str, mode: List[str] = ["F_pre", "F_post"], channel_extra_opts: Dict[str, Dict] = {}, verbose: bool = False, ): """ Create different channels that represent a synapse segmentation problem to train an instance segmentation problem. This function is only prepared to read an H5/Zarr file that follows `CREMI data format <https://cremi.org/data/>`__. Parameters ---------- data_info : dict All patches that can be extracted from all the Zarr/H5 samples in ``data_path``. Keys created are: * ``"filepath"``: path to the file where the patch was extracted. * ``"full_shape"``: shape of the data within the file where the patch was extracted. * ``"patch_coords"``: coordinates of the data that represents the patch. zarr_data_information : dict Information when using Zarr/H5 files. Assumes that the H5/Zarr files contain the information according `CREMI data format <https://cremi.org/data/>`__. The following keys are expected: * ``"raw_data_path"``: path within the file where the raw data is stored. Reference in CREMI: ``volumes/raw`` * ``"axes_order"``: order of the axes in the file. E.g. "ZYX" or "ZCYX". * ``"z_axe_pos"``: position of z axis of the data within the file. * ``"y_axe_pos"``: position of y axis of the data within the file. * ``"x_axe_pos"``: position of x axis of the data within the file. * ``"id_path"``: path within the file where the ``ids`` are stored. Reference in CREMI: ``annotations/ids`` * ``"partners_path"``: path within the file where ``partners`` is stored. Reference in CREMI: ``annotations/partners`` * ``"locations_path"``: path within the file where ``locations`` is stored. Reference in CREMI: ``annotations/locations`` * ``"resolution_path"``: path within the file where ``resolution`` is stored. Reference in CREMI: ``["volumes/raw"].attrs["offset"]`` savepath : str Path to save the data created. mode : List, optional Operation mode. channel_extra_opts : dict, optional Extra options for specific channels. For example, dilation for the "F_pre" and "F_post" channels. Expected keys are: * ``"F_pre"``: options for the "F_pre" channel. Expected keys are: * ``"dilation"``: list of 3 ints specifying the dilation in z,y,x for the "F_pre" channel (default: [1,10,10]). * ``"F_post"``: options for the "F_post" channel. Expected keys are: * ``"dilation"``: list of 3 ints specifying the dilation in z,y,x for the "F_post" channel (default: [1,10,10]). * ``"H"``, ``"V"``, ``"Z"``: options for the distance channels. Expected keys are: * ``"norm"``: whether to normalize the distance channels per instance (default: True). verbose : bool, optional Whether to print warnings about out-of-bounds synaptic points (default: False). Returns ------- new_mask : 5D Numpy array 5D array with 3 channels instead of one. E.g. ``(10, 200, 1000, 1000, 3)`` patch_offset : list of list Pixels used on each axis to pad the patch in order to not cut some of the values in the edges. """ # ------------------------------------------------------- # 1. Determine Channel Count based on Mode # ------------------------------------------------------- selected_mode = "" if all(ch in mode for ch in ["F_pre", "F_post"]) and len(mode) == 2: channels = 2 dtype_str = "uint8" selected_mode = "simpsyn" elif all(ch in mode for ch in ["F_post", "Z", "V", "H"]) and len(mode) == 4: channels = 4 dtype_str = "float32" norm = True if any([channel_extra_opts.get(k, {}).get("norm", True) for k in ["Z", "V", "H"]]) else False selected_mode = "synful" elif all(ch in mode for ch in ["F_cleft"]) and len(mode) == 1: channels = 1 dtype_str = "uint8" selected_mode = "cleft" cleft_dilation = channel_extra_opts.get("F_cleft", {}).get("dilation", [1, 3, 3]) cleft_footprint = generate_ellipse_footprint(cleft_dilation) cleft_search_dilation = channel_extra_opts.get("F_cleft", {}).get("search_dilation", [1, 5, 5]) cleft_n_samples = int(channel_extra_opts.get("F_cleft", {}).get("n_samples", 51)) cleft_t_range = channel_extra_opts.get("F_cleft", {}).get("t_range", (0.15, 0.85)) beam_halfwidth = channel_extra_opts.get("F_cleft", {}).get("beam_halfwidth", [0, 3, 3]) # [hz, hy, hx] drop_rel = float(channel_extra_opts.get("F_cleft", {}).get("drop_rel", 0.20)) # 20% drop from baseline drop_abs = float(channel_extra_opts.get("F_cleft", {}).get("drop_abs", 10.0)) # absolute drop in intensity units baseline_frac = float(channel_extra_opts.get("F_cleft", {}).get("baseline_frac", 0.15)) # first 15% of path = baseline min_persist = int(channel_extra_opts.get("F_cleft", {}).get("min_persist", 2)) # must stay dark for N steps elif all(ch in mode for ch in ["F_post"]) and len(mode) == 1: channels = 1 dtype_str = "uint8" selected_mode = "F_post_only" else: raise ValueError(f"Unsupported mode: {mode}") if selected_mode == "synful": presite_dilation = channel_extra_opts.get("H", {}).get("dilation", [3, 25, 25]) else: presite_dilation = channel_extra_opts.get("F_pre", {}).get("dilation", [1, 3, 3]) postsite_dilation = channel_extra_opts.get("F_post", {}).get("dilation", [1, 3, 3]) # footprints (keep both since you had them) pre_footprint = generate_ellipse_footprint(presite_dilation) post_footprint = generate_ellipse_footprint(postsite_dilation) unique_files = [] unique_shapes = [] for i in range(len(data_info)): if data_info[i]["filepath"] not in unique_files: unique_files.append(data_info[i]["filepath"]) unique_shapes.append(data_info[i]["full_shape"]) rank = get_rank() world_size = get_world_size() it = range(rank, len(unique_files), world_size) width = np.array([max(a,b) for a,b in zip(presite_dilation, postsite_dilation)]) print("Collecting all pre/post-synaptic points") for idx in tqdm(it, disable=not is_main_process()): filename, data_shape = unique_files[idx], tuple(unique_shapes[idx]) print(f"Processing file: {filename}") # Take all the information within the dataset files = [] file, ids = read_chunked_nested_data(filename, zarr_data_information["id_path"]) ids = list(np.array(ids)) files.append(file) # file, types = read_chunked_nested_data(filename, cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA_TYPES_PATH) # files.append(file) file, partners = read_chunked_nested_data(filename, zarr_data_information["partners_path"]) partners = np.array(partners) files.append(file) file, locations = read_chunked_nested_data(filename, zarr_data_information["locations_path"]) locations = np.array(locations) files.append(file) file, resolution = read_chunked_nested_data(filename, zarr_data_information["resolution_path"]) files.append(file) try: resolution = resolution.attrs["resolution"] except: raise ValueError( "There is no 'resolution' attribute in '{}'. Add it like: data['{}'].attrs['resolution'] = (8,8,8)".format( zarr_data_information["resolution_path"], zarr_data_information["resolution_path"] ) ) if selected_mode == "cleft": data_file, raw_data = read_chunked_nested_data(filename, zarr_data_information["raw_data_path"]) # Close files for f in files: if isinstance(f, h5py.File): f.close() del files id_to_pos = {sid: i for i, sid in enumerate(ids)} shape_zyx = tuple(int(x) for x in data_shape) pre_post_points = {} # pre_loc tuple -> list[post_loc tuple] pre_seen, post_seen = set(), set() pre_missed, post_missed = 0, 0 for i in tqdm(range(len(partners)), disable=not is_main_process()): pre_id, post_id = partners[i] pre_idx = id_to_pos.get(pre_id) post_idx = id_to_pos.get(post_id) if pre_idx is None or post_idx is None: # inconsistent annotation; skip quietly continue pre_loc = (locations[pre_idx] // resolution).astype(int) post_loc = (locations[post_idx] // resolution).astype(int) pre_ok = _in_bounds(pre_loc, shape_zyx) post_ok = _in_bounds(post_loc, shape_zyx) if not pre_ok: if verbose: print(f"WARNING: discarding presynaptic point {pre_loc} out of shape: {shape_zyx}") if pre_id not in pre_seen: pre_missed += 1 if not post_ok: if verbose: print(f"WARNING: discarding postsynaptic point {post_loc} out of shape: {shape_zyx}") if post_id not in post_seen: post_missed += 1 if pre_ok and post_ok: pre_key = tuple(pre_loc.tolist()) pre_post_points.setdefault(pre_key, []).append(tuple(post_loc.tolist())) pre_seen.add(pre_id) post_seen.add(post_id) print(f"Total unique pre-synaptic points: {len(pre_seen)}") print(f"Total unique post-synaptic points: {len(post_seen)}") print(f"Total pre-synaptic points missed: {pre_missed}") print(f"Total post-synaptic points missed: {post_missed}\n") if not pre_post_points: continue # output file out_fname = os.path.join(savepath, os.path.basename(filename)) mask, fid_mask, out_shape, out_order, c_pos = _make_output_array( out_fname, shape_zyx, channels, zarr_data_information, dtype_str=dtype_str ) if selected_mode == "synful": _process_synapses_by_chunks_of_data( channel_mode=mode, selected_mode=selected_mode, resolution=resolution, pre_post_points=pre_post_points, shape_zyx=shape_zyx, mask=mask, fid_mask=fid_mask, out_shape=out_shape, out_order=out_order, c_pos=c_pos, width=width, pre_footprint=pre_footprint, post_footprint=post_footprint, norm=norm, ) else: if selected_mode not in ["simpsyn", "cleft", "F_post_only"]: raise NotImplementedError(f"Mode {selected_mode} not implemented for processing by synapse pairs.") print("Processing synapse pairs . . .") for pre_point_global, post_sites in tqdm(pre_post_points.items(), disable=not is_main_process()): if not post_sites: continue pre_point_global = np.asarray(pre_point_global, dtype=int) post_sites_arr = np.asarray(post_sites, dtype=int) # Define patch bounding box to load/write bbox_points = [] bbox_points.append(pre_point_global) bbox_points.append(post_sites_arr) bbox_points = np.vstack([p if p.ndim == 2 else p[None, :] for p in bbox_points]) bbox = _bbox_from_points(bbox_points, width, shape_zyx) patch_shape = (bbox[1] - bbox[0], bbox[3] - bbox[2], bbox[5] - bbox[4]) data_slices = ( slice(bbox[0], bbox[1]), slice(bbox[2], bbox[3]), slice(bbox[4], bbox[5]), slice(0, out_shape[c_pos]), ) data_slices = tuple(order_dimensions(data_slices, input_order="ZYXC", output_order=out_order, default_value=0)) pre_local = pre_point_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int) # ------------------------------------------------------- # MODE: SimpSyn (F_pre, F_post) # ------------------------------------------------------- if selected_mode == "simpsyn": out_map = np.zeros(patch_shape + (channels,), dtype=np.uint8) # Pre point out_map[ max(0, pre_local[0] - 1) : min(pre_local[0] + 1, out_map.shape[0]), pre_local[1], pre_local[2], mode.index("F_pre"), ] = 1 # Post points for post_global in post_sites_arr: post_local = post_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int) if not _in_bounds(post_local, patch_shape): raise ValueError(f"Point {post_local.tolist()} out of shape: {patch_shape}") out_map[ max(0, post_local[0] - 1) : min(post_local[0] + 1, out_map.shape[0]), post_local[1], post_local[2], mode.index("F_post"), ] = 1 # Dilate each channel for c in range(out_map.shape[-1]): structure = pre_footprint if c == mode.index("F_pre") else post_footprint out_map[..., c] = binary_dilation_scipy(out_map[..., c], iterations=1, structure=structure) # ------------------------------------------------------- # MODE: Cleft (F_cleft) # ------------------------------------------------------- elif selected_mode == "cleft": cleft_pos = 0 out_map = np.zeros(patch_shape + (channels,), dtype=np.uint8) # Extract raw patch for cleft processing raw_data_slices = ( slice(bbox[0], bbox[1]), slice(bbox[2], bbox[3]), slice(bbox[4], bbox[5]) ) raw_data_slices = tuple(order_dimensions(raw_data_slices, input_order="ZYX", output_order=zarr_data_information["axes_order"], default_value=0)) raw_patch = np.asarray(raw_data[raw_data_slices], dtype=np.uint8) # 2) Pre-smooth with a local-mean filter to make "darkest" more stable sz, sy, sx = [int(x) for x in cleft_search_dilation] smoothed = uniform_filter(raw_patch, size=(2 * sz + 1, 2 * sy + 1, 2 * sx + 1), mode="nearest") # 3) For each (pre, post), find darkest point along the segment and seed it a, b = cleft_t_range ts = np.linspace(a, b, cleft_n_samples, dtype=np.float32) cleft_seed = np.zeros(patch_shape, dtype=np.uint8) pre_f = pre_local.astype(np.float32) Z, Y, X = patch_shape beam_hw = [int(x) for x in beam_halfwidth] a, b = cleft_t_range ts = np.linspace(a, b, cleft_n_samples, dtype=np.float32) for post_global in post_sites_arr: post_local = post_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int) if not _in_bounds(post_local, patch_shape): continue post_f = post_local.astype(np.float32) # Sample points from POST -> PRE (important for "first drop") pts = post_f[None, :] + ts[:, None] * (pre_f[None, :] - post_f[None, :]) # Compute beam score along the path prof = np.full((len(ts),), np.inf, dtype=np.float32) coords = np.round(pts).astype(int) for i, (z, y, x) in enumerate(coords): if 0 <= z < Z and 0 <= y < Y and 0 <= x < X: prof[i] = _beam_score(smoothed, z, y, x, beam_hw, score_mode="p10") # Baseline = robust bright-ish level near the POST side nb = max(3, int(np.ceil(baseline_frac * len(ts)))) base_vals = prof[:nb] base_vals = base_vals[np.isfinite(base_vals)] if base_vals.size == 0: # fallback: midpoint mid = np.round((pre_f + post_f) * 0.5).astype(int) mid = np.clip(mid, [0, 0, 0], np.array(patch_shape) - 1) cleft_seed[tuple(mid.tolist())] = 1 continue baseline = float(np.median(base_vals)) # Threshold for "first dark region" # Condition: prof <= baseline - max(abs_drop, rel_drop * baseline) thr = baseline - max(drop_abs, drop_rel * baseline) # Find earliest index that crosses threshold and stays low for min_persist steps hit_idx = None for i in range(nb, len(ts)): if prof[i] <= thr: # persist check (hysteresis) j_end = min(len(ts), i + min_persist) if np.all(prof[i:j_end] <= thr): hit_idx = i break if hit_idx is None: # Fallback: choose minimum in the middle range (still safer than global min) mid0 = nb mid1 = len(ts) i_min = int(np.nanargmin(prof[mid0:mid1])) + mid0 hit_idx = i_min z, y, x = coords[hit_idx] z, y, x = int(np.clip(z, 0, Z - 1)), int(np.clip(y, 0, Y - 1)), int(np.clip(x, 0, X - 1)) cleft_seed[z, y, x] = 1 # 4) Paint the cleft "ball" using your existing dilation footprint out_map[..., cleft_pos] = binary_dilation_scipy(cleft_seed, iterations=1, structure=cleft_footprint).astype(np.uint8) elif selected_mode == "F_post_only": out_map = np.zeros(patch_shape + (channels,), dtype=np.uint8) # Post points only for post_global in post_sites_arr: post_local = post_global - np.asarray([bbox[0], bbox[2], bbox[4]], dtype=int) if not _in_bounds(post_local, patch_shape): raise ValueError(f"Point {post_local.tolist()} out of shape: {patch_shape}") out_map[ max(0, post_local[0] - 1) : min(post_local[0] + 1, out_map.shape[0]), post_local[1], post_local[2], 0, ] = 1 # Dilate the channel structure = post_footprint out_map[..., 0] = binary_dilation_scipy(out_map[..., 0], iterations=1, structure=structure) else: raise NotImplementedError(f"Mode {selected_mode} not implemented for processing by synapse pairs.") # transpose into output order before writing current_order = np.arange(out_map.ndim) transpose_order = order_dimensions( current_order, input_order="ZYXC", output_order=out_order, default_value=np.nan ) transpose_order = [int(x) for x in transpose_order if not np.isnan(x)] patch = out_map.transpose(transpose_order) # write only where empty (background check) target = mask[data_slices] mask[data_slices] = target + patch * (target == 0) if isinstance(fid_mask, h5py.File): fid_mask.close() if "data_file" in locals(): data_file.close()
def _process_synapses_by_chunks_of_data( channel_mode: List[str], selected_mode: str, resolution: Sequence[float], pre_post_points: Dict[Tuple[int, int, int], List[Tuple[int, int, int]]], shape_zyx: Tuple[int, int, int], mask: Union[h5py.Dataset, zarr.Array], fid_mask: Optional[h5py.File], out_shape: Tuple[int, ...], out_order: str, c_pos: Optional[int], width: np.ndarray, pre_footprint: np.ndarray, post_footprint: np.ndarray, norm: bool, ): """ Process synapses by iterating over chunks of the data, loading only the relevant synapse points for each chunk, and writing the output mask incrementally. This is more memory efficient for large datasets with many synapses. Parameters ---------- channel_mode : List[str] The list of channel names in the output mask (e.g. ["F_pre", "F_post"] or ["F_post", "H", "V", "Z"]). selected_mode : str The selected mode of operation, either "simpsyn" for simple synapse pairs or "synful" for the full synapse representation with distance channels. resolution : Sequence[float] The physical resolution of the data in (Z,Y,X) order. pre_post_points : Dict[Tuple[int, int, int], List[Tuple[int, int, int]]] A dictionary mapping presynaptic point coordinates (z,y,x) to a list of postsynaptic point coordinates (z,y,x) that are partners of the presynaptic point. shape_zyx : Tuple[int, int, int] The shape of the data in (Z,Y,X) order. mask : Union[h5py.Dataset, zarr.Array] The output array (Zarr or HDF5 dataset) where the generated mask will be written. fid_mask : Optional[h5py.File] The HDF5 file object if an HDF5 dataset was created, or None if a Zarr array was created. out_shape : Tuple[int, ...] The shape of the output array, including channels, in the order specified by `out_order`. out_order : str The order of axes in the output array (e.g. "ZYXC" or "ZCYX"). c_pos : Optional[int] The position of the channel axis in the output array, or None if channels are added as the last axis. width : np.ndarray An array of shape (3,) specifying the margin to add in each dimension (z,y,x) around synaptic points to ensure they are fully captured in the output mask. pre_footprint : np.ndarray The structuring element to use for dilating the presynaptic points. post_footprint : np.ndarray The structuring element to use for dilating the postsynaptic points. norm : bool Whether to normalize the distance channels per instance (only relevant for "synful" mode). """ print("Processing synapses by chunks of data . . .") if selected_mode != "synful": raise NotImplementedError(f"Mode {selected_mode} not implemented for processing by chunks.") # --------- # 1. Improved Halo Calculation # --------- def _footprint_radius(fp): fp = np.asarray(fp) return np.array([s // 2 for s in fp.shape], dtype=int) seed_reach = np.asarray(width, dtype=int) pre_reach = _footprint_radius(pre_footprint) post_reach = _footprint_radius(post_footprint) # Use the max reach across all operations. # We add a +2 safety buffer to prevent float rounding issues in KDTree/resolution scaling halo = np.maximum.reduce([seed_reach, pre_reach, post_reach]) + 2 # Chunks cz, cy, cx = 64, 256, 256 # --------------------------------------------------------- # 3. Spatial Index (Crucial Fix: Use index_halo) # --------------------------------------------------------- syn_items = list(pre_post_points.items()) chunk_to_syn = {} print("Building spatial index for synapses . . .") for syn_id, (pre_pt, post_list) in enumerate(syn_items): if not post_list: continue pre_pt = np.asarray(pre_pt, dtype=int) posts = np.asarray(post_list, dtype=int) pts = np.vstack([pre_pt[None, :], posts]) # We calculate the BBox including the halo. # This ensures the synapse is registered in every chunk it COULD bleed into. zmin, ymin, xmin = np.min(pts, axis=0) - halo zmax, ymax, xmax = np.max(pts, axis=0) + halo cz0, cz1 = max(0, zmin // cz), min((shape_zyx[0] - 1) // cz, zmax // cz) cy0, cy1 = max(0, ymin // cy), min((shape_zyx[1] - 1) // cy, ymax // cy) cx0, cx1 = max(0, xmin // cx), min((shape_zyx[2] - 1) // cx, xmax // cx) for iz in range(int(cz0), int(cz1) + 1): for iy in range(int(cy0), int(cy1) + 1): for ix in range(int(cx0), int(cx1) + 1): chunk_to_syn.setdefault((iz, iy, ix), []).append(syn_id) # --------------------------- # 4. Processing Loop # --------------------------- z_max, y_max, x_max = map(int, shape_zyx) ncz, ncy, ncx = (z_max + cz - 1) // cz, (y_max + cy - 1) // cy, (x_max + cx - 1) // cx res_array = np.asarray(resolution, dtype=np.float32) for iz in tqdm(range(ncz), disable=not is_main_process()): z0c, z1c = iz * cz, min((iz + 1) * cz, z_max) for iy in range(ncy): y0c, y1c = iy * cy, min((iy + 1) * cy, y_max) for ix in range(ncx): x0c, x1c = ix * cx, min((ix + 1) * cx, x_max) syn_ids = chunk_to_syn.get((iz, iy, ix), []) if not syn_ids: continue # Define patch with halo ext_bbox = [ max(0, z0c - halo[0]), min(z_max, z1c + halo[0]), max(0, y0c - halo[1]), min(y_max, y1c + halo[1]), max(0, x0c - halo[2]), min(x_max, x1c + halo[2]) ] patch_shape = (ext_bbox[1]-ext_bbox[0], ext_bbox[3]-ext_bbox[2], ext_bbox[5]-ext_bbox[4]) offset = np.array([ext_bbox[0], ext_bbox[2], ext_bbox[4]]) seeds = np.zeros(patch_shape, dtype=np.uint32) mask_to_grow = np.zeros(patch_shape, dtype=np.uint8) label_to_pre_site = {} label_count = 1 for syn_id in syn_ids: pre_pt, post_list = syn_items[syn_id] pre_local = np.asarray(pre_pt) - offset for post_pt in post_list: post_local = np.asarray(post_pt) - offset # We allow points slightly outside the patch because their dilation # footprint (up to 'halo' size) will bleed back into the patch core. is_relevant = np.all( (post_local >= -halo) & (post_local < np.asarray(patch_shape) + halo) ) if is_relevant: # 1. Calculate Z-column range for the seed z_idx = int(post_local[0]) z0s = max(0, z_idx - seed_reach[0]) z1s = min(patch_shape[0], z_idx + seed_reach[0] + 1) # 2. Coordinates for Y and X yy, xx = int(post_local[1]), int(post_local[2]) # Only write if the column center (yy, xx) is inside the patch bounds if 0 <= yy < patch_shape[1] and 0 <= xx < patch_shape[2]: if z0s < z1s: # Ensure valid slice col = seeds[z0s:z1s, yy, xx] empty = (col == 0) if empty.any(): col[empty] = label_count seeds[z0s:z1s, yy, xx] = col label_to_pre_site[label_count] = pre_local.tolist() # Draw the single-pixel seed for dilation if 0 <= z_idx < patch_shape[0]: mask_to_grow[z_idx, yy, xx] = 1 label_count += 1 if label_count == 1: continue # --- Dilation & Flow (Now safe because halo ensures full shapes) --- post_site_map = binary_dilation_scipy(mask_to_grow, structure=post_footprint) mask_grow = binary_dilation_scipy(mask_to_grow, structure=pre_footprint) seed_coords = np.argwhere(seeds > 0) mask_coords = np.argwhere(mask_grow > 0) if seed_coords.size > 0 and mask_coords.size > 0: tree = cKDTree(seed_coords * res_array) _, nn = tree.query(mask_coords * res_array, workers=-1) grown_labels = np.zeros_like(seeds) grown_labels[mask_coords[:,0], mask_coords[:,1], mask_coords[:,2]] = seeds[seed_coords[nn,0], seed_coords[nn,1], seed_coords[nn,2]] axis_order = "".join([x for x in channel_mode if x in ["H", "V", "Z"]]) out_flow = create_HoVe_channels( grown_labels, ref_point="presynaptic", label_to_pre_site=label_to_pre_site, normalize_values=norm, resolution=resolution, axis_order=axis_order, ) # Assemble & Crop stack = [] for c in channel_mode: if c == "F_post": stack.append(post_site_map[..., None]) else: j = axis_order.index(c) stack.append(out_flow[..., j:j+1]) out_map = np.concatenate(stack, axis=-1) # CROP: Extract only the core chunk from the halo-extended patch z0p, z1p = z0c - ext_bbox[0], (z0c - ext_bbox[0]) + (z1c - z0c) y0p, y1p = y0c - ext_bbox[2], (y0c - ext_bbox[2]) + (y1c - y0c) x0p, x1p = x0c - ext_bbox[4], (x0c - ext_bbox[4]) + (x1c - x0c) out_core = out_map[z0p:z1p, y0p:y1p, x0p:x1p, :] # Transpose and Write transpose_order = [out_order.find(ax) for ax in "ZYXC"] patch = out_core.transpose(np.argsort(transpose_order)) # Simplified transpose # Prepare slices for the target dataset data_slices = [] for ax in out_order: if ax == 'Z': data_slices.append(slice(z0c, z1c)) elif ax == 'Y': data_slices.append(slice(y0c, y1c)) elif ax == 'X': data_slices.append(slice(x0c, x1c)) elif ax == 'C': data_slices.append(slice(0, out_shape[c_pos])) # Accumulate (prevents overwriting edges with zeros) target = mask[tuple(data_slices)] mask[tuple(data_slices)] = target + patch.astype(target.dtype) * (target == 0) if isinstance(fid_mask, h5py.File): fid_mask.close()
[docs] def create_HoVe_channels( data: NDArray, ref_point: str = "center", label_to_pre_site: Optional[Dict] = None, normalize_values: bool = True, calc_props: Optional[Dict] = None, axis_order: str = "ZYX", resolution: List[int|float] = [1,1,1], ): """ Obtain the horizontal and vertical distance maps for each instance. Depth distance is also calculated if the ``data`` provided is 3D. Parameters ---------- data : 2D/3D Numpy array Instance mask to create horizontal/vertical/depth channels from. E.g. ``(500, 500)`` for 2D and ``(200, 1000, 1000)`` for 3D. ref_point : str, optional Reference point to be used to create the horizontal/vertical/depth channels. Possible values: ``center``, ``presynaptic``. Details: - 'center': point to the centroid. - 'presynaptic': point to the presynaptic site. To use this ``label_to_pre_site`` must be provided. label_to_pre_site : dict, optional Reference of the presynaptic site for each label within the provided volume (``data``). normalize_values : bool, optional Whether to normalize the values or not. calc_props : dict, optional If region properties have already been calculated, they can be provided here to avoid recalculation. resolution : list of int or float, optional Physical resolution of the data in each dimension. Used to scale the horizontal/vertical/depth values to physical units if provided. Default is [1,1,1] (isotropic). Returns ------- new_mask : 3D/4D Numpy array Horizontal/vertical/depth channels. E.g. ``(500, 500, 2)`` for 2D and ``(200, 1000, 1000, 3)`` for 3D. """ assert ref_point in ["center", "presynaptic"] if ref_point == "presynaptic" and label_to_pre_site is None: raise ValueError("'label_to_pre_site' must be provided when 'ref_point' is 'presynaptic'") orig_data = data.copy() # instance ID map dim = data.ndim x_map = np.zeros(orig_data.shape, dtype=np.float32) y_map = np.zeros(orig_data.shape, dtype=np.float32) if dim == 3: z_map = np.zeros(orig_data.shape, dtype=np.float32) if calc_props is None: props = regionprops_table(orig_data, properties=("label", "bbox", "centroid")) else: props = calc_props for k, inst_id in tqdm(enumerate(props["label"]), total=len(props["label"]), leave=False): inst_map = np.array(orig_data == inst_id, np.uint8) if dim == 2: inst_box = [props["bbox-0"][k], props["bbox-2"][k], props["bbox-1"][k], props["bbox-3"][k]] else: inst_box = [ props["bbox-0"][k], props["bbox-3"][k], props["bbox-1"][k], props["bbox-4"][k], props["bbox-2"][k], props["bbox-5"][k], ] # Extract the patch if dim == 2: inst_box[0] = max(0, inst_box[0] - 2) inst_box[2] = max(0, inst_box[2] - 2) inst_box[1] = min(inst_map.shape[0], inst_box[1] + 2) inst_box[3] = min(inst_map.shape[1], inst_box[3] + 2) inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] else: inst_box[0] = max(0, inst_box[0] - 2) inst_box[2] = max(0, inst_box[2] - 2) inst_box[4] = max(0, inst_box[4] - 2) inst_box[1] = min(inst_map.shape[0], inst_box[1] + 2) inst_box[3] = min(inst_map.shape[1], inst_box[3] + 2) inst_box[5] = min(inst_map.shape[2], inst_box[5] + 2) inst_map = inst_map[ inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5], ] if dim == 2 and (inst_map.shape[0] < 2 or inst_map.shape[1] < 2): continue elif dim == 3 and (inst_map.shape[0] < 2 or inst_map.shape[1] < 2 or inst_map.shape[2] < 2): continue # instance center of mass, rounded to nearest pixel if ref_point == "center": if dim == 2: inst_com = [ props["centroid-0"][k], props["centroid-1"][k], ] else: inst_com = [ props["centroid-0"][k], props["centroid-1"][k], props["centroid-2"][k], ] else: # presynaptic assert label_to_pre_site if inst_id not in label_to_pre_site: raise ValueError(f"Label {inst_id} not in 'label_to_pre_site'") inst_com = label_to_pre_site[inst_id] # Move reference point inside bbox inst_com[0] -= inst_box[0] inst_com[1] -= inst_box[2] if dim == 3: inst_com[2] -= inst_box[4] if any(np.isnan(inst_com)): continue if dim == 2: inst_com[0] = int(inst_com[0] + 0.5) inst_com[1] = int(inst_com[1] + 0.5) inst_y_range = np.arange(1, inst_map.shape[0] + 1) inst_x_range = np.arange(1, inst_map.shape[1] + 1) # shifting center of pixels grid to instance center of mass/presynaptic site inst_x_range -= inst_com[1] inst_y_range -= inst_com[0] inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) else: inst_com[0] = int(inst_com[0] + 0.5) inst_com[1] = int(inst_com[1] + 0.5) inst_com[2] = int(inst_com[2] + 0.5) inst_z_range = np.arange(1, inst_map.shape[0] + 1) inst_y_range = np.arange(1, inst_map.shape[1] + 1) inst_x_range = np.arange(1, inst_map.shape[2] + 1) # shifting center of pixels grid to instance center of mass/presynaptic site inst_z_range -= inst_com[0] inst_y_range -= inst_com[1] inst_x_range -= inst_com[2] inst_z, inst_y, inst_x = np.meshgrid(inst_z_range, inst_y_range, inst_x_range, indexing="ij") # remove coord outside of instance (Z) inst_z[inst_map == 0] = 0 inst_z = inst_z.astype("float32") # remove coord outside of instance (Y and X) inst_y[inst_map == 0] = 0 inst_x[inst_map == 0] = 0 inst_x = inst_x.astype("float32") inst_y = inst_y.astype("float32") if normalize_values: # normalize min into -1 scale if np.min(inst_y) < 0: inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) if np.min(inst_x) < 0: inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) if dim == 3: if np.min(inst_z) < 0: inst_z[inst_z < 0] /= -np.amin(inst_z[inst_z < 0]) # normalize max into +1 scale if np.max(inst_y) > 0: inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) if np.max(inst_x) > 0: inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) if dim == 3: if np.max(inst_z) > 0: inst_z[inst_z > 0] /= np.amax(inst_z[inst_z > 0]) else: inst_y = inst_y * resolution[-2] # Scale Y axis inst_x = inst_x * resolution[-1] # Scale X axis if dim == 3: inst_z = inst_z * resolution[0] # Scale Z axis if dim == 2: x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] x_map_box[inst_map > 0] = inst_x[inst_map > 0] y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] y_map_box[inst_map > 0] = inst_y[inst_map > 0] else: z_map_box = z_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5]] z_map_box[inst_map > 0] = inst_z[inst_map > 0] y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5]] y_map_box[inst_map > 0] = inst_y[inst_map > 0] x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3], inst_box[4] : inst_box[5]] x_map_box[inst_map > 0] = inst_x[inst_map > 0] stack = [] for x in axis_order: if x == "Z": stack.append(z_map) elif x == "V": stack.append(y_map) elif x == "H": stack.append(x_map) hv_map = np.stack(stack, axis=-1) return hv_map
############# # DETECTION # #############
[docs] def generate_ellipse_footprint( shape=[1, 1, 1], ) -> NDArray: """ Generate footprint of an ellipse in a n-dimensional image. Parameters ---------- shape : list, optional Shape of the hyperball with the given side lengths. Returns ------- distances : NDArray Ellipse footprint. """ center = (np.array(shape) / 2).astype(int) ranges = [ np.arange(int(center[i] - shape[i]), int(center[i] + shape[i]) + 1) if shape[i] > 0 else [center[i]] for i in range(len(center)) ] grids = np.meshgrid(*ranges, indexing="ij") # put all the dimensions at least to 1 shape = [1 if i == 0 else i for i in shape] distances = np.array([((grids[d] - center[d]) ** 2) / shape[d] ** 2 for d in range(len(center))]) distances = np.sum(distances, axis=0) <= 1 return distances.astype(bool)
[docs] def create_detection_masks(cfg: CN, data_type: str = "train"): """ Create detection masks based on CSV files. Parameters ---------- cfg : YACS CN object Configuration. data_type: str, optional Wheter to create train, validation or test masks. """ assert data_type in ["train", "val", "test"] tag = data_type.upper() img_dir = getattr(cfg.DATA, tag).PATH label_dir = getattr(cfg.DATA, tag).GT_PATH out_dir = getattr(cfg.DATA, tag).DETECTION_MASK_DIR img_ids = next(os_walk_clean(img_dir))[2] working_with_chunked_data = False if len(img_ids) == 0: img_ids = next(os_walk_clean(img_dir))[1] working_with_chunked_data = True if len(img_ids) == 0: raise ValueError(f"No data found in folder {img_dir}") img_ext = "." + img_ids[0].split(".")[-1] ids = next(os_walk_clean(label_dir))[2] channels = 2 if cfg.DATA.N_CLASSES > 2 else 1 dtype_str = "uint8" if cfg.DATA.N_CLASSES < 255 else "uint16" if len(img_ids) != len(ids): raise ValueError( "Different number of CSV files and images found ({} vs {}). " "Please check that every image has one and only one CSV file".format(len(ids), len(img_ids)) ) if cfg.PROBLEM.NDIM == "2D": req_columns = ["axis-0", "axis-1"] if channels == 1 else ["axis-0", "axis-1", "class"] else: req_columns = ["axis-0", "axis-1", "axis-2"] if channels == 1 else ["axis-0", "axis-1", "axis-2", "class"] cpd = cfg.PROBLEM.DETECTION.CENTRAL_POINT_DILATION ellipse_footprint = generate_ellipse_footprint(cpd) # Distribute files by rank rank = get_rank() world_size = get_world_size() it = range(rank, len(ids), world_size) print(f"Rank {rank}: Creating {data_type} detection masks . . .") for i in tqdm(it, disable=not is_main_process()): img_filename = os.path.splitext(ids[i])[0] + img_ext file_path = os.path.join(label_dir, ids[i]) if not os.path.exists(os.path.join(img_dir, img_filename)): print("WARNING: No image found for CSV file: {}. Using the image that's in the same spot (within the CSV files list) where" "the CSV file is in its own list of CSV files. Check if it is correct!".format(file_path) ) img_filename = img_ids[i] out_path = os.path.join(out_dir, img_filename) if os.path.exists(out_path): continue # 1. Get shape information without loading the whole image if img_ext not in [".zarr", ".n5"]: img_info = read_img_as_ndarray(os.path.join(img_dir, img_filename), is_3d=not cfg.PROBLEM.NDIM == "2D") shape = img_info.shape[:-1] is_h5 = False else: img_zarr_file, img_data = read_chunked_data(os.path.join(img_dir, img_filename)) shape = img_data.shape if img_data.ndim == 3 else img_data.shape[:-1] if isinstance(img_zarr_file, h5py.File): img_zarr_file.close() is_h5 = looks_like_hdf5(img_filename) # 2. Initialize the disk-backed mask (Zarr or H5) instead of np.zeros os.makedirs(out_dir, exist_ok=True) out_shape = shape + (channels,) if working_with_chunked_data: if is_h5: fid_mask = h5py.File(out_path, "w") ds_kwargs = { "shape": out_shape, "dtype": dtype_str, "chunks": pick_chunks(out_shape, dtype_str, target_mb=4.0), "compression": "gzip", "compression_opts": 4, "shuffle": True } mask = fid_mask.create_dataset("data", **ds_kwargs) else: mask = zarr.open(out_path, mode="w", shape=out_shape, dtype=dtype_str, zarr_format=3) else: mask = np.zeros(out_shape, dtype=dtype_str) # 3. Process CSV points df = pd.read_csv(file_path).dropna() df = df.rename(columns=lambda x: x.strip()) cols_not_in_file = [x for x in req_columns if x not in df.columns] if len(cols_not_in_file) > 0: if len(cols_not_in_file) == 1: m = f"'{cols_not_in_file[0]}' column is not present in CSV file: {file_path}" else: m = f"{cols_not_in_file} columns are not present in CSV file: {file_path}" raise ValueError(m) # Obtaining coords (axis-0: Z, axis-1: Y, axis-2: X) coords = [df["axis-0"].astype(int), df["axis-1"].astype(int)] if cfg.PROBLEM.NDIM == "3D": coords.append(df["axis-2"].astype(int)) class_points = df["class"].astype(int) if "class" in df.columns else [1] * len(coords[0]) # 4. Paint points directly into the disk-backed array for j in range(len(coords[0])): c_point = class_points[j] # --- A. Coordinate Setup --- # Convert global coordinates to integers immediately global_c = [int(coords[d][j]) for d in range(len(shape))] # Skip if center point is outside array boundaries if any(global_c[d] < 0 or global_c[d] >= shape[d] for d in range(len(shape))): continue # --- B. Dynamic Slicing (Handles 2D and 3D) --- slices = [] rel_coords = [] for d in range(len(shape)): start = max(0, global_c[d] - 1 - cpd[d]) end = min(shape[d], global_c[d] + 2 + cpd[d]) slices.append(slice(start, end)) rel_coords.append(global_c[d] - start) target_slice = tuple(slices) # --- C. Create the Patch (Local Dilation) --- # Determine patch shape by looking at the slice size local_chunk = mask[target_slice] patch_shape = local_chunk.shape[:-1] # Exclude channel dimension patch = np.zeros(patch_shape, dtype=np.uint8) patch[tuple(rel_coords)] = 1 update = binary_dilation_scipy(patch, iterations=1, structure=ellipse_footprint) # --- D. Multi-Channel Update (Disk-Backed) --- # Channel 0: Occupancy (Binary) # Only update where the mask is currently 0 mask_to_fill = (local_chunk[..., 0] == 0) & (update > 0) local_chunk[mask_to_fill, 0] = 1 # Channel 1: Class/Instance ID if channels > 1: local_chunk[mask_to_fill, 1] = c_point # Write the modified chunk back to the disk-backed array mask[target_slice] = local_chunk # Finalize if is_h5 and working_with_chunked_data: fid_mask.close() elif not working_with_chunked_data: save_tif(np.expand_dims(mask, 0), out_dir, [img_filename])
####### # SSL # #######
[docs] def create_ssl_source_data_masks(cfg: CN, data_type: str = "train"): """ Create SSL source data. Parameters ---------- cfg : YACS CN object Configuration. data_type: str, optional Wheter to create train, validation or test source data. """ assert data_type in ["train", "val", "test"] tag = data_type.upper() img_dir = getattr(cfg.DATA, tag).PATH out_dir = getattr(cfg.DATA, tag).SSL_SOURCE_DIR ids = next(os_walk_clean(img_dir))[2] add_noise = True if cfg.PROBLEM.SELF_SUPERVISED.NOISE > 0 else False print("Creating {} SSL source. . .".format(data_type)) for i in range(len(ids)): if not os.path.exists(os.path.join(out_dir, ids[i])): print("Crappifying file {} to create SSL source".format(os.path.join(img_dir, ids[i]))) img = read_img_as_ndarray(os.path.join(img_dir, ids[i]), is_3d=not cfg.PROBLEM.NDIM == "2D") img = crappify( img, resizing_factor=cfg.PROBLEM.SELF_SUPERVISED.RESIZING_FACTOR, add_noise=add_noise, noise_level=cfg.PROBLEM.SELF_SUPERVISED.NOISE, ) save_tif(np.expand_dims(img, 0), out_dir, [ids[i]]) else: print("Source file {} found".format(os.path.join(img_dir, ids[i])))
[docs] def crappify( input_img: NDArray, resizing_factor: float, add_noise: bool = True, noise_level: Optional[float] = None, Down_up: bool = True, ): """ Crappify input image by adding Gaussian noise and downsampling and upsampling it so the resolution gets worsen. input_img : 4D/5D Numpy array Data to be modified. E.g. ``(y, x, channels)`` if working with 2D images or ``(z, y, x, channels)`` if working with 3D. resizing_factor : floats Downsizing factor to reshape the image. add_noise : boolean, optional Indicating whether to add gaussian noise before applying the resizing. noise_level: float, optional Number between ``[0,1]`` indicating the std of the Gaussian noise N(0,std). Down_up : bool, optional Indicating whether to perform a final upsampling operation to obtain an image of the same size as the original but with the corresponding loss of quality of downsizing and upsizing. Returns ------- img : 4D/5D Numpy array Train images. E.g. ``(y, x, channels)`` if working with 2D images or ``(z, y, x, channels)`` if working with 3D. """ if input_img.ndim == 3: w, h, c = input_img.shape org_sz = (w, h) else: d, w, h, c = input_img.shape org_sz = (d, w, h) new_d = int(d / np.sqrt(resizing_factor)) new_w = int(w / np.sqrt(resizing_factor)) new_h = int(h / np.sqrt(resizing_factor)) if input_img.ndim == 3: targ_sz = (new_w, new_h) else: targ_sz = (new_d, new_w, new_h) img = input_img.copy() if add_noise: assert noise_level img = add_gaussian_noise(img, noise_level) img = resize( img, targ_sz, order=1, mode="reflect", clip=True, preserve_range=True, anti_aliasing=False, ) if Down_up: img = resize( img, org_sz, order=1, mode="reflect", clip=True, preserve_range=True, anti_aliasing=False, ) return img.astype(input_img.dtype)
[docs] def add_gaussian_noise(image: NDArray, percentage_of_noise: float) -> NDArray: """ Add Gaussian noise to an input image. Parameters ---------- image : 3D Numpy array Image to be added Gaussian Noise with 0 mean and a certain std. E.g. ``(y, x, channels)``. percentage_of_noise : float percentage of the maximum value of the image that will be used as the std of the Gaussian Noise distribution. Returns ------- out : 3D Numpy array Transformed image. E.g. ``(y, x, channels)``. """ max_value = np.max(image) noise_level = percentage_of_noise * max_value noise = np.random.normal(loc=0, scale=noise_level, size=image.shape) noisy_img = np.clip(image + noise, 0, max_value).astype(image.dtype) return noisy_img
################ # SEMANTIC SEG # ################
[docs] def calculate_volume_prob_map( Y: BiaPyDataset, is_3d: bool = False, w_foreground: float = 0.94, w_background: float = 0.06, save_dir=None ) -> List[NDArray] | NDArray: """ Calculate the probability map of the given data. Parameters ---------- Y : list of dict Data to calculate the probability map from. Each item in the list represents a sample of the dataset. Expected keys: * ``"filename"``: name of the image to extract the data sample from. * ``"dir"``: directory where the image resides. * ``"img"``: image sample itself. It is a ndarrray of ``(y, x, channels)`` in ``2D`` and ``(z, y, x, channels)``in ``3D``. Provided if the user selected to load data into memory. If ``"img"`` is provided ``"filename"`` and ``"filename"`` are not necessary, and vice versa. w_foreground : float, optional Weight of the foreground. This value plus ``w_background`` must be equal ``1``. w_background : float, optional Weight of the background. This value plus ``w_foreground`` must be equal ``1``. save_dir : str, optional Path to the file where the probability map will be stored. Returns ------- maps : NDArray or list of NDArray Probability map(s) of all samples in ``Y.sample_list``. """ print("Constructing the probability map . . .") maps = [] diff_shape = False first_shape = None Ylen = len(Y.sample_list) for i in tqdm(range(Ylen), disable=not is_main_process()): if Y.sample_list[i].img_is_loaded(): _map = Y.sample_list[i].img.copy().astype(np.float32) else: path = Y.dataset_info[Y.sample_list[i].fid].path _map = read_img_as_ndarray(path, is_3d=is_3d).astype(np.float32) for k in range(_map.shape[-1]): if is_3d: for j in range(_map.shape[0]): # Remove artifacts connected to image border _map[j, ..., k] = clear_border(_map[j, ..., k]) else: # Remove artifacts connected to image border _map[..., k] = clear_border(_map[..., k]) foreground_pixels = (_map[..., k] > 0).sum() background_pixels = (_map[..., k] == 0).sum() if foreground_pixels == 0: _map[..., k][np.where(_map[..., k] > 0)] = 0 else: _map[..., k][np.where(_map[..., k] > 0)] = w_foreground / foreground_pixels if background_pixels == 0: _map[..., k][np.where(_map[..., k] == 0)] = 0 else: _map[..., k][np.where(_map[..., k] == 0)] = w_background / background_pixels # Necessary to get all probs sum 1 s = _map[..., k].sum() if s == 0: t = 1 for x in _map[..., k].shape: t *= x _map[..., k].fill(1 / t) else: _map[..., k] = _map[..., k] / _map[..., k].sum() if first_shape is None: first_shape = _map.shape if first_shape != _map.shape: diff_shape = True maps.append(_map) if not diff_shape: for i in range(len(maps)): maps[i] = np.expand_dims(maps[i], 0) maps = np.concatenate(maps) if save_dir: os.makedirs(save_dir, exist_ok=True) if not diff_shape: print("Saving the probability map in {}".format(save_dir)) np.save(os.path.join(save_dir, "prob_map.npy"), maps) return maps else: print( "As the files loaded have different shapes, the probability map for each one will be stored" " separately in {}".format(save_dir) ) d = len(str(Ylen)) for i in range(Ylen): f = os.path.join(save_dir, "prob_map" + str(i).zfill(d) + ".npy") np.save(f, maps[i]) return maps
########### # GENERAL # ###########
[docs] def resize_images(images: List[NDArray], **kwards) -> List[NDArray]: """ Resize all the images using the specified parameters or default values if not provided. Parameters ---------- images: list of Numpy arrays The `images` parameter is the list of all input images that you want to resize. output_shape: iterable Size of the generated output image. E.g. `(256,256)` (kwards): optional `skimage.transform.resize() <https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize>`__ parameters are also allowed. Returns ------- resized_images: list of Numpy arrays The resized images. The returned data will use the same data type as the given `images`. """ resized_images = [resize(img, **kwards).astype(img.dtype) for img in images] return resized_images
[docs] def apply_gaussian_blur(images: List[NDArray], **kwards) -> List[NDArray]: """ Apply a Gaussian blur to all images. Parameters ---------- images: list of Numpy arrays The input images on which the Gaussian blur will be applied. (kwards): optional `skimage.filters.gaussian() <https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.gaussian>`__ parameters are also allowed. Returns ------- blurred_images: list of Numpy arrays A Gaussian blurred images. The returned data will use the same data type as the given `images`. """ def _process(image, **kwards): im = gaussian(image, **kwards) # returns 0-1 range if np.issubdtype(image.dtype, np.integer): im = im * np.iinfo(image.dtype).max im = im.astype(image.dtype) return im blurred_images = [_process(img, **kwards) for img in images] return blurred_images
[docs] def apply_median_blur(images: List[NDArray], **kwards) -> List[NDArray]: """ Apply a median blur filter to all images. Parameters ---------- image: list of Numpy arrays The input image on which the median blur operation will be applied. (kwards): optional `skimage.filters.median() <https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.median>`__ parameters are also allowed. Returns ------- blurred_images: list of Numpy arrays The median-blurred images. The returned data will use the same data type as the given `images`. """ blurred_images = [median(img, **kwards).astype(img.dtype) for img in images] return blurred_images
[docs] def detect_edges(images: List[NDArray], **kwards) -> List[NDArray]: """ Detect edges in the given images using the Canny edge detection algorithm. The function `detect_edges` takes the 2D images as input, converts it to grayscale if necessary, and applies the Canny edge detection algorithm to detect edges in the image. Parameters ---------- images: list of Numpy arrays The list of all input images on which the edge detection will be performed. It can be either a color image with shape (height, width, 3) or a grayscale image with shape (height, width, 1). (kwards): optional `skimage.feature.canny() <https://scikit-image.org/docs/stable/api/skimage.feature.html#skimage.feature.canny>`__ parameters are also allowed. Returns ------- edges: list of Numpy arrays The edges of the input images. The returned Numpy arrays will be uint8, where background is black (0) and edges white (255). The returned data will use the same structure as the given `images` (list[Numpy array] or Numpy array). """ def to_gray(image): c = image.shape[-1] if c == 3: image = rgb2gray(image) elif c == 1: image = image[..., 0] else: raise ValueError( f"Detect edges function does not allow given ammount of channels ({c} channels). " "Only accepts grayscale and RGB 2D images (1 or 3 channels)." ) return image def set_uint8(image): im = image.astype(np.uint8) im = im[..., np.newaxis] # add channel dim im = im * 255 return im edges = [set_uint8(canny(to_gray(img), **kwards)) for img in images] return edges
def _histogram_matching(source_imgs: List[NDArray], target_imgs: List[NDArray]) -> List[NDArray]: """ Apply histogram matching to a set of source images based on the mean histogram of target images. Given a set of target images, it will obtain their mean histogram and applies histogram matching to all images from source images. Parameters ---------- source_imgs: list of Numpy arrays The images of the source domain, to which the histogram matching is to be applied. target_imgs: list of Numpy array The target domain images, from which mean histogram will be obtained. Returns ------- matched_images : list of Numpy arrays A set of source images with target's histogram """ # Concatenate all target images to compute the reference histogram target_concat = np.concatenate([img.ravel() for img in target_imgs]) # Get the data type from the first image (assuming all have same dtype) dtype = target_imgs[0].dtype hist_mean, _ = np.histogram(target_concat, bins=np.arange(np.iinfo(dtype).max + 2)) # +2 because bins are edges # calculate normalized quantiles tmpl_size = np.sum(hist_mean) tmpl_quantiles = np.cumsum(hist_mean) / tmpl_size del target_imgs # based on scikit implementation. # source: https://github.com/scikit-image/scikit-image/blob/v0.18.0/skimage/exposure/histogram_matching.py#L22-L70 def _match_cumulative_cdf(source, tmpl_quantiles): src_values, src_unique_indices, src_counts = np.unique(source.ravel(), return_inverse=True, return_counts=True) # calculate normalized quantiles src_size = source.size # number of pixels src_quantiles = np.cumsum(src_counts) / src_size # normalize interp_a_values = np.interp(src_quantiles, tmpl_quantiles, np.arange(len(tmpl_quantiles))) return interp_a_values[src_unique_indices].reshape(source.shape) # apply histogram matching results = [_match_cumulative_cdf(image, tmpl_quantiles).astype(image.dtype) for image in source_imgs] return results
[docs] def apply_histogram_matching(images: List[NDArray], reference_path: str, is_2d: bool): """ Apply histogram matching to a list of images based on the histogram of reference images. The function returns the images with their histogram matched to the histogram of the reference images, loaded from the given ``reference_path``. Parameters ---------- images: list of Numpy arrays The list of input images whose histogram needs to be matched to the reference histogram. It should be a Numpy array representing the image. reference_path: str The reference_path is the directory path to the reference images. From reference images, we will extract the reference histogram with which we want to match the histogram of the images. It represents the desired distribution of pixel intensities in the output image. is_2d: bool, optional The value indicate if the data given in ``reference_path`` is 2D (``is_2d = True``) or 3D (``is_2d = False``). Defaults to True. Returns ------- matched_images : list of Numpy arrays The result of matching the histogram of the input images to the histogram of the reference image. The returned data will use the same data type as the given `images`. """ references = load_data_from_dir(reference_path, is_3d=not is_2d) matched_images = _histogram_matching(images, references) return matched_images
[docs] def apply_clahe(images: List[NDArray], **kwards) -> List[NDArray]: """ Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) to a list of images. The function applies Contrast Limited Adaptive Histogram Equalization (CLAHE) to an image and returns the result. Parameters ---------- images: list of Numpy arrays The list of input images that you want to apply the CLAHE (Contrast Limited Adaptive Histogram Equalization) algorithm to. (kwards): optional `skimage.exposure.equalize_adapthist() <https://scikit-image.org/docs/stable/api/skimage.exposure.html#skimage.exposure.equalize_adapthist>`__ parameters are also allowed. Returns ------- processed_images: list of Numpy arrays The images after applying the Contrast Limited Adaptive Histogram Equalization (CLAHE) algorithm. The returned data will use the same data type as the given `images`. """ def _process(image, **kwards): im = equalize_adapthist(image, **kwards) # returns 0-1 range if np.issubdtype(image.dtype, np.integer): im = im * np.iinfo(image.dtype).max im = im.astype(image.dtype) return im processed_images = [_process(img, **kwards) for img in images] return processed_images
[docs] def preprocess_data( cfg: CN, x_data: List[NDArray] = [], y_data: List[NDArray] = [], is_2d: bool = True, is_y_mask: bool = False ) -> List[NDArray] | Tuple[List[NDArray], List[NDArray]]: """ Pre-process data by applying various image processing techniques. Parameters ---------- cfg: dict The `cfg` parameter is a configuration object that contains various settings for preprocessing the data. It is used to control the behavior of different preprocessing techniques such as image resizing, blurring, histogram matching, etc. x_data: list of 3D/4D Numpy arrays, optional The input data (images) to be preprocessed. The first dimension must be the number of images. E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``. In case of using a list, the format of the images remains the same. Each item in the list corresponds to a different image. y_data: list of 3D/4D Numpy arrays, optional The target data that corresponds to the x_data. The first dimension must be the number of images. E.g. ``(num_of_images, y, x, channels)`` or ``(num_of_images, z, y, x, channels)``. In case of using a list, the format of the images remains the same. Each item in the list corresponds to a different image. is_2d: bool, optional A boolean flag indicating whether the reference data for histogram matching is 2D or not. Defaults to True. is_y_mask: bool, optional is_y_mask is a boolean parameter that indicates whether the y_data is a mask or not. If it is set to True, the resize operation for y_data will use the nearest neighbor interpolation method (order=0), otherwise it will use the interpolation method specified in the cfg.RESIZE.ORDER parameter. Defaults to False. Returns ------- x_data: list of 3D/4D Numpy arrays, optional Preprocessed data. The same structure and dimensionality of the given data will be returned. y_data: list of 3D/4D Numpy arrays, optional Preprocessed data. The same structure and dimensionality of the given data will be returned. """ if len(y_data) > 0: if cfg.RESIZE.ENABLE: # if y is a mask, then use nearest y_order = 0 if is_y_mask else cfg.RESIZE.ORDER y_data = resize_images( y_data, output_shape=cfg.RESIZE.OUTPUT_SHAPE, order=y_order, mode=cfg.RESIZE.MODE, cval=cfg.RESIZE.CVAL, clip=cfg.RESIZE.CLIP, preserve_range=cfg.RESIZE.PRESERVE_RANGE, anti_aliasing=cfg.RESIZE.ANTI_ALIASING, ) if len(x_data) > 0: if cfg.RESIZE.ENABLE: x_data = resize_images( x_data, output_shape=cfg.RESIZE.OUTPUT_SHAPE, order=cfg.RESIZE.ORDER, mode=cfg.RESIZE.MODE, cval=cfg.RESIZE.CVAL, clip=cfg.RESIZE.CLIP, preserve_range=cfg.RESIZE.PRESERVE_RANGE, anti_aliasing=cfg.RESIZE.ANTI_ALIASING, ) if cfg.GAUSSIAN_BLUR.ENABLE: x_data = apply_gaussian_blur( x_data, sigma=cfg.GAUSSIAN_BLUR.SIGMA, mode=cfg.GAUSSIAN_BLUR.MODE, channel_axis=cfg.GAUSSIAN_BLUR.CHANNEL_AXIS, ) if cfg.MEDIAN_BLUR.ENABLE: x_data = apply_median_blur( x_data, footprint=np.ones(cfg.MEDIAN_BLUR.KERNEL_SIZE, dtype=np.uint8).tolist(), ) if cfg.MATCH_HISTOGRAM.ENABLE: x_data = apply_histogram_matching( x_data, reference_path=cfg.MATCH_HISTOGRAM.REFERENCE_PATH, is_2d=is_2d, ) if cfg.CLAHE.ENABLE: x_data = apply_clahe( x_data, kernel_size=cfg.CLAHE.KERNEL_SIZE, clip_limit=cfg.CLAHE.CLIP_LIMIT, ) if cfg.CANNY.ENABLE: x_data = detect_edges( x_data, low_threshold=cfg.CANNY.LOW_THRESHOLD, high_threshold=cfg.CANNY.HIGH_THRESHOLD, ) if len(x_data) > 0 and len(y_data) > 0: return x_data, y_data if len(y_data) > 0: return y_data else: return x_data