"""
Instance segmentation workflow for BiaPy.
This module defines the Instance_Segmentation_Workflow class, which implements the
training, validation, and inference pipeline for instance segmentation tasks in BiaPy.
It handles data preparation, model setup, metrics, predictions, post-processing,
and result saving for assigning unique IDs to each object in 2D and 3D images.
"""
import os
import math
import torch
import h5py
import numpy as np
import pandas as pd
from tqdm import tqdm
from skimage.transform import resize
from skimage.morphology import ball, dilation
import torch.distributed as dist
from typing import Dict, Optional, List, Tuple
from numpy.typing import NDArray
from scipy.spatial import distance_matrix
from skimage.filters import threshold_otsu
from biapy.data.post_processing.post_processing import (
watershed_by_channels,
voronoi_on_mask,
measure_morphological_props_and_filter,
repare_large_blobs,
apply_binary_mask,
create_synapses_from_point_probs,
extract_points_in_predictions,
remove_close_points,
remove_close_points_by_mask,
Embedding_cluster,
apply_label_refinement,
extract_synapse_connectivity,
collect_point_type_csv_files,
extract_synful_synapses,
connect_pre_post_synapse_points_by_distance,
)
from biapy.data.post_processing.polygon_nms import stardist_instances_from_prediction
from biapy.data.post_processing.gradient_tracking import flows_to_instances
from biapy.data.pre_processing import create_instance_channels
from biapy.utils.matching import matching, wrapper_matching_dataset_lazy
from biapy.engine.metrics import (
jaccard_index,
instance_segmentation_loss,
multiple_metrics,
detection_metrics,
ContrastCELoss,
SpatialEmbLoss,
)
import zarr
from biapy.engine.base_workflow import Base_Workflow
from biapy.utils.misc import (
is_main_process,
is_dist_avail_and_initialized,
to_pytorch_format,
to_numpy_format,
MetricLogger,
os_walk_clean,
get_rank,
get_world_size,
)
from biapy.data.data_manipulation import read_img_as_ndarray, save_tif
from biapy.data.data_3D_manipulation import (
read_chunked_data,
read_chunked_nested_data,
ensure_3d_shape,
load_synapse_gt_points,
extract_patch_from_efficient_file,
insert_patch_in_efficient_file,
order_dimensions,
)
from biapy.data.dataset import PatchCoords
[docs]
class Instance_Segmentation_Workflow(Base_Workflow):
"""
Instance segmentation workflow where the goal is to assign an unique id, i.e. integer, to each object of the input image.
More details in `our documentation <https://biapy.readthedocs.io/en/latest/workflows/instance_segmentation.html>`_.
Parameters
----------
cfg : YACS configuration
Running configuration.
Job_identifier : str
Complete name of the running job.
device : Torch device
Device used.
args : argpase class
Arguments used in BiaPy's call.
"""
def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs):
"""
Initialize the Instance_Segmentation_Workflow.
Sets up configuration, device, job identifier, and initializes
workflow-specific attributes for instance segmentation tasks.
Parameters
----------
cfg : YACS configuration
Running configuration.
job_identifier : str
Complete name of the running job.
device : torch.device
Device used.
args : argparse.Namespace
Arguments used in BiaPy's call.
**kwargs : dict
Additional keyword arguments.
"""
super(Instance_Segmentation_Workflow, self).__init__(cfg, job_identifier, device, system_dict, args, **kwargs)
self.original_train_input_mask_axes_order = self.cfg.DATA.TRAIN.INPUT_MASK_AXES_ORDER
self.original_test_path, self.original_test_mask_path = self.prepare_instance_data()
# Merging the image
self.all_matching_stats_merge_patches = []
self.all_matching_stats_merge_patches_post = []
self.stats["inst_stats_merge_patches"] = None
self.stats["inst_stats_merge_patches_post"] = None
# Multi-head: instances + classification
if self.separated_class_channel:
self.all_class_stats_merge_patches = []
self.all_class_stats_merge_patches_post = []
self.stats["class_stats_merge_patches"] = None
self.stats["class_stats_merge_patches_post"] = None
# As 3D stack
self.all_matching_stats_as_3D_stack = []
self.all_matching_stats_as_3D_stack_post = []
self.stats["inst_stats_as_3D_stack"] = None
self.stats["inst_stats_as_3D_stack_post"] = None
# Multi-head: instances + classification
if self.separated_class_channel:
self.all_class_stats_as_3D_stack = []
self.all_class_stats_as_3D_stack_post = []
self.stats["class_stats_as_3D_stack"] = None
self.stats["class_stats_as_3D_stack_post"] = None
# Full image
self.all_matching_stats = []
self.all_matching_stats_post = []
self.stats["inst_stats"] = None
self.stats["inst_stats_post"] = None
# Multi-head: instances + classification
if self.separated_class_channel:
self.all_class_stats = []
self.all_class_stats_post = []
self.stats["class_stats"] = None
self.stats["class_stats_post"] = None
# From now on, no modification of the cfg will be allowed
self.cfg.freeze()
# Workflow specific training variables
self.mask_path = cfg.DATA.TRAIN.GT_PATH
self.is_y_mask = True
self.load_Y_val = True
if self.cfg.TEST.ENABLE and self.cfg.DATA.TEST.LOAD_GT:
self.test_gt_filenames = next(os_walk_clean(self.original_test_mask_path))[2]
if len(self.test_gt_filenames) == 0:
self.test_gt_filenames = next(os_walk_clean(self.original_test_mask_path))[1]
# Specific instance segmentation post-processing
self.post_processing["instance_post"] = False
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
if (
self.cfg.TEST.POST_PROCESSING.VORONOI_ON_MASK
or self.cfg.TEST.POST_PROCESSING.INSTANCE_REFINEMENT.ENABLE
or self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE
or self.cfg.TEST.POST_PROCESSING.REPARE_LARGE_BLOBS_SIZE != -1
):
self.post_processing["instance_post"] = True
else: # synapses
if (
self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.REMOVE_CLOSE_PRE_POINTS_RADIUS > 0
or self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.REMOVE_CLOSE_POST_POINTS_RADIUS > 0
):
# The "instance_post" is related to matching metrics aftwerwards, so it is more related
# to the regular instance segmentation workflow than to the synapse detection one, where
# we have specific metrics for the synapse detection performance.
self.post_processing["per_image"] = True
self.synapse_method = ""
if all(ch in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS for ch in ["F_pre", "F_post"]) and len(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS) == 2:
self.synapse_method = "simpsyn"
elif all(ch in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS for ch in ["F_post", "Z", "V", "H"]) and len(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS) == 4:
self.synapse_method = "synful"
elif all(ch in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS for ch in ["F_cleft"]) and len(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS) == 1:
self.synapse_method = "cleft"
elif all(ch in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS for ch in ["F_post"]) and len(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS) == 1:
self.synapse_method = "F_post_only"
else:
raise ValueError("Unknown synapse prediction method for the given channels. Please check the documentation for more details.")
self.instances_already_created = False
[docs]
def define_activations_and_channels(self):
"""
Define the activations to be applied to the model output and the channels that the model will output.
This function must define the following variables:
self.model_output_channels : List of int
Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels,
[1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc.
self.model_output_channel_info : List of str
Information about the output channels. A value per output head of the model must be defined.
self.separated_class_channel : bool
Whether if we should expect a separated output channel for classification.
self.head_activations : List of str
Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined.
"linear" and "ce_sigmoid" will not be applied. E.g. ["linear"] for a model with one channel, ["linear", "sigmoid"] for a
model with two channels, etc.
Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground
and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would
be correct::
self.model_output_channels = [1, 3]
self.model_output_channel_info = ["mask", "class"]
self.separated_class_channel = True
self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"]
"""
if self.cfg.PROBLEM.INSTANCE_SEG.CHANNELS_PER_HEAD_INFO != []:
set_model_output_channels = False
self.model_output_channels = []
count = 0
for head_channels in self.cfg.PROBLEM.INSTANCE_SEG.CHANNELS_PER_HEAD_INFO:
self.model_output_channels.append(head_channels)
self.model_output_channel_info.append("+".join(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS[count:count+head_channels]))
count += head_channels
else:
self.model_output_channels = [0]
self.model_output_channel_info = [""]
set_model_output_channels = True
dst = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0]
for i, channel in enumerate(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS):
if channel in ["B", "F", "P", "C", "T", "M", "F_pre", "F_post", "F_cleft"]:
self.head_activations.append("ce_sigmoid")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
elif channel in ["Dc", "Dn", "D", "Z", "V", "H"]:
if self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_LOSSES[i] not in ["mse", "l1", "mae"] or dst.get(channel, {}).get("act", "") == "sigmoid":
self.head_activations.append("ce_sigmoid")
else:
self.head_activations.append("linear")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
elif channel in ["Gv", "Gh", "Gz"]:
# Cellpose flow targets are unit vectors in [-1, 1]; tanh constrains
# predictions to the same range and stabilises MSE training.
self.head_activations.append("tanh")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
elif channel == "Db":
val_type = dst.get(channel, {}).get("val_type", "norm")
if val_type == "discretize":
for i in range(11): # Default 10 bins + background
self.head_activations.append("ce_softmax")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"_bin{}".format(i)
elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_LOSSES[i] not in ["mse", "l1", "mae"] or dst.get(channel, {}).get("act", "") == "sigmoid":
self.head_activations.append("ce_sigmoid")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
else:
self.head_activations.append("linear")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
elif channel == "D":
self.head_activations.append(dst.get("D", {}).get("act", "linear"))
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
elif channel == "A":
z_affinities = dst.get("A", {}).get("z_affinities", [1])
for i in range(len(z_affinities)):
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"z_{}".format(z_affinities[i])
self.head_activations.append("ce_sigmoid")
y_affinities = dst.get("A", {}).get("y_affinities", [1])
for i in range(len(y_affinities)):
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"y_{}".format(y_affinities[i])
self.head_activations.append("ce_sigmoid")
x_affinities = dst.get("A", {}).get("x_affinities", [1])
for i in range(len(x_affinities)):
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"x_{}".format(x_affinities[i])
self.head_activations.append("ce_sigmoid")
elif channel == "R":
for i in range(dst.get("R", {}).get("nrays", 32 if self.dims == 2 else 96)):
self.head_activations.append("linear")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"_{}".format(i)
elif channel == "E_offset":
for i in range(self.dims):
self.head_activations.append("ce_sigmoid")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"_{}".format(i)
elif channel == "E_sigma":
for i in range(self.dims):
self.head_activations.append("ce_sigmoid")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel+"_{}".format(i)
elif channel == "E_seediness":
self.head_activations.append("ce_sigmoid")
if set_model_output_channels:
self.model_output_channels[0] += 1
self.model_output_channel_info[0] += "+" + channel
elif channel == "We":
continue
else:
raise ValueError("Unknown channel: {}".format(channel))
for i in range(len(self.model_output_channel_info)):
self.model_output_channel_info[i] = self.model_output_channel_info[i].lstrip("+")
# Multi-head: instances + classification
self.gt_channels_expected = len(self.head_activations)
if self.cfg.DATA.N_CLASSES > 2:
self.head_activations += ["ce_softmax"] * self.cfg.DATA.N_CLASSES
self.model_output_channels += [self.cfg.DATA.N_CLASSES,]
self.model_output_channel_info += ["class"]
self.gt_channels_expected += 1
self.separated_class_channel = True
else:
self.separated_class_channel = False
super().define_activations_and_channels()
self.stardist_grid = (1,1)
[docs]
def define_metrics(self):
"""
Define the metrics to be used in the instance segmentation workflow.
This function must define the following variables:
self.train_metrics : List of functions
Metrics to be calculated during model's training.
self.train_metric_names : List of str
Names of the metrics calculated during training.
self.train_metric_best : List of str
To know which value should be considered as the best one. Options must be: "max" or "min".
self.test_metrics : List of functions
Metrics to be calculated during model's test/inference.
self.test_metric_names : List of str
Names of the metrics calculated during test/inference.
self.loss : Function
Loss function used during training and test.
"""
self.train_metrics = []
self.train_metric_names = []
self.train_metric_best = []
for channel in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
if channel in ["B", "F", "P", "C", "T", "A", "M"]:
m = "IoU ({} channel)".format(channel) if channel != "A" else "IoU ({} channels)".format(channel)
self.train_metric_names += [m]
self.train_metric_best += ["max"]
elif channel in ["Db", "Dc", "Dn", "D", "Z", "V", "H", "R", "Gv", "Gh", "Gz"]:
m = "L1 ({} channel)".format(channel) if channel != "R" else "L1 ({} channels)".format(channel)
self.train_metric_names += ["L1 ({} channel)".format(channel)]
self.train_metric_best += ["min"]
elif channel == "E_offset":
self.train_metric_names += ["IoU"]
self.train_metric_best += ["max"]
elif channel in ["E_sigma", "E_seediness"]:
continue # No metrics for these channels
elif channel == "We":
continue
# Extra channels for the synapse detection branch
elif channel == "F_pre":
self.train_metric_names += ["IoU (pre-sites)"]
self.train_metric_best += ["max"]
elif channel == "F_post":
self.train_metric_names += ["IoU (post-sites)"]
self.train_metric_best += ["max"]
elif channel == "F_cleft":
self.train_metric_names += ["IoU (clefts)"]
self.train_metric_best += ["max"]
else:
raise ValueError("Unknown channel: {}".format(channel))
# Multi-head: instances + classification
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular" and self.separated_class_channel:
self.train_metric_names += ["IoU (classes)"]
self.train_metric_best += ["max"]
# Used to calculate IoU with the classification results
self.jaccard_index_matching = jaccard_index(
device=self.device,
num_classes=self.cfg.DATA.N_CLASSES,
ndim=self.dims,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
)
if "E_offset" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
# No metric for the embedding representation during training as the IoU is calculated together with the loss
self.train_metrics.append("none")
else:
self.train_metrics.append(
multiple_metrics(
num_classes=self.cfg.DATA.N_CLASSES,
metric_names=self.train_metric_names,
device=self.device,
out_channels=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
channel_extra_opts = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0],
model_source=self.cfg.MODEL.SOURCE,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
ndim=self.dims,
)
)
self.test_metrics = []
self.test_metric_names = self.train_metric_names.copy()
# Multi-head: instances + classification
if self.separated_class_channel:
self.test_metric_names.append("IoU (classes)")
# Used to calculate IoU with the classification results
self.jaccard_index_matching = jaccard_index(
device=self.test_device,
num_classes=self.cfg.DATA.N_CLASSES,
ndim=self.dims,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
)
if "E_offset" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
# No metric for the embedding representation during training as the IoU is calculated together with the loss
self.test_metrics.append("none")
else:
self.test_metrics.append(
multiple_metrics(
num_classes=self.cfg.DATA.N_CLASSES,
metric_names=self.test_metric_names,
device=self.test_device,
out_channels=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
channel_extra_opts = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0],
model_source=self.cfg.MODEL.SOURCE,
ndim=self.dims,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
)
)
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "synapses":
self.test_extra_metrics = []
for x in ["pre", "post", "cleft"]:
self.test_extra_metrics.append(f"Precision ({x}-points)")
self.test_extra_metrics.append(f"Recall ({x}-points)")
self.test_extra_metrics.append(f"F1 ({x}-points)")
self.test_extra_metrics.append(f"TP ({x}-points)")
self.test_extra_metrics.append(f"FP ({x}-points)")
self.test_extra_metrics.append(f"FN ({x}-points)")
self.test_metric_names += self.test_extra_metrics
if "E_offset" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
instance_loss = SpatialEmbLoss(
patch_size=self.cfg.DATA.PATCH_SIZE,
ndims=self.dims,
anisotropy=self.resolution,
channel_weights=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNEL_WEIGHTS,
center_mode=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0].get("E_offset", {}).get("center_mode", "centroid"),
medoid_max_points=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0].get("E_offset", {}).get("medoid_max_points", 10000),
).to(self.device, non_blocking=True)
self.embedding_cluster = Embedding_cluster(
device=self.test_device,
patch_size=self.cfg.DATA.PATCH_SIZE,
ndims=self.dims,
anisotropy=self.resolution,
)
else:
instance_loss = instance_segmentation_loss(
channel_weights = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNEL_WEIGHTS,
class_rebalance_within_channels=self.cfg.PROBLEM.INSTANCE_SEG.CLASS_REBALANCE_WITHIN_CHANNELS,
separated_class_channel=self.separated_class_channel,
ndim = self.dims,
out_channels = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
losses_to_use = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_LOSSES,
channel_extra_opts = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS_EXTRA_OPTS[0],
gt_channels_expected = self.gt_channels_expected,
class_rebalance=self.cfg.LOSS.CLASS_REBALANCE,
class_weights=self.cfg.LOSS.CLASS_WEIGHTS,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
device=self.device,
)
if self.cfg.LOSS.CONTRAST.ENABLE:
self.loss = ContrastCELoss(
main_loss=instance_loss, # type: ignore
ndim=self.dims,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
)
else:
self.loss = instance_loss
super().define_metrics()
[docs]
def metric_calculation(
self,
output: NDArray | torch.Tensor,
targets: NDArray | torch.Tensor,
train: bool = True,
metric_logger: Optional[MetricLogger] = None,
) -> Dict:
"""
Calculate the metrics defined in :func:`~define_metrics` function.
Parameters
----------
output : Torch Tensor
Prediction of the model.
targets : Torch Tensor
Ground truth to compare the prediction with.
train : bool, optional
Whether to calculate train or test metrics.
metric_logger : MetricLogger, optional
Class to be updated with the new metric(s) value(s) calculated.
Returns
-------
out_metrics : dict
Value of the metrics for the given prediction.
"""
if isinstance(output, np.ndarray):
_output = to_pytorch_format(
output.copy(),
self.axes_order,
self.device if train else self.test_device,
dtype=self.loss_dtype,
)
else: # torch.Tensor
if not train:
_output = output.clone()
else:
_output = output
if not train and self.separated_class_channel and self.gt_channels_expected != _output.shape[1]:
class_idx = self.model_output_channel_info.index("class") if "class" in self.model_output_channel_info else -1
_output = torch.cat(
(
_output[:,:-self.model_output_channels[class_idx]],
torch.argmax(_output[:, -self.model_output_channels[class_idx]:], dim=1).unsqueeze(1)
), dim=1)
if isinstance(targets, np.ndarray):
_targets = to_pytorch_format(
targets.copy(),
self.axes_order,
self.device if train else self.test_device,
dtype=self.loss_dtype,
)
else: # torch.Tensor
if not train:
_targets = targets.clone()
else:
_targets = targets
out_metrics = {}
list_to_use = self.train_metrics if train else self.test_metrics
list_names_to_use = self.train_metric_names if train else self.test_metric_names
with torch.no_grad():
k = 0
for i, metric in enumerate(list_to_use):
if metric == "none":
continue
val = metric(_output, _targets)
if isinstance(val, dict):
for m in val:
if isinstance(val[m], torch.Tensor):
v = val[m].item() if not torch.isnan(val[m]) else 0
else:
v = val[m]
out_metrics[list_names_to_use[k]] = v
if metric_logger:
metric_logger.meters[list_names_to_use[k]].update(v)
k += 1
else:
if isinstance(val, torch.Tensor):
v = val.item() if not torch.isnan(val) else 0
else:
v = val
out_metrics[list_names_to_use[i]] = v
if metric_logger:
metric_logger.meters[list_names_to_use[i]].update(v)
return out_metrics
def _effective_halo(self) -> tuple:
"""Return per-axis halo sizes ``(hz, hy, hx)`` for chunk-boundary watershed.
When ``TEST.BY_CHUNKS.WORKFLOW_PROCESS.INSTANCE_SEG_HALO`` is -1 the halo
is derived automatically as ``PATCH_SIZE[axis] // 8`` independently for
each axis. A scalar non-negative value set by the user is broadcast to all
three axes.
Using per-axis values avoids over-extending the tiny Z axis (e.g. patch
Z=20 → hz=2) while keeping a generous halo in Y/X (e.g. patch 256 → h=32).
"""
configured = self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.INSTANCE_SEG_HALO
if configured != -1:
v = int(configured)
return (v, v, v)
patch_size = self.cfg.DATA.PATCH_SIZE # (Z, Y, X[, C])
hz = max(1, int(patch_size[0]) // 8)
hy = max(1, int(patch_size[1]) // 8)
hx = max(1, int(patch_size[2]) // 8)
return (hz, hy, hx)
def _create_instance_labels(self, pred: NDArray, save_dir: Optional[str] = None, verbose: bool = False):
"""
Create instance label map from raw prediction channels (no I/O, no metrics).
Parameters
----------
pred : NDArray
4-D array ``(Z, Y, X, C)``.
save_dir : str, optional
Directory for watershed debug data. ``None`` disables debug output.
Returns
-------
pred_labels : NDArray
3-D uint32 array ``(Z, Y, X)`` with unique integer label per instance.
"""
assert pred.ndim == 4, f"Expected 4D pred, got shape {pred.shape}"
if self.separated_class_channel:
pred = pred[..., :-1]
if "R" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
pred_labels, _ = stardist_instances_from_prediction(
pred[..., :self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("R")].squeeze(),
pred[..., self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("R"):].squeeze(),
prob_thresh=self.cfg.PROBLEM.INSTANCE_SEG.STARDIST.PROB_THRESH,
nms_iou_thresh=self.cfg.PROBLEM.INSTANCE_SEG.STARDIST.NMS_IOU_THRESH,
anisotropy=self.resolution[-self.dims:],
grid=self.stardist_grid,
)
elif "E_offset" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
pred_labels = self.embedding_cluster.create_instances(
pred=pred if self.dims == 3 else pred[0],
fg_thresh=self.cfg.PROBLEM.INSTANCE_SEG.EMBEDSEG.SEED_THRESH,
min_mask_sum=self.cfg.PROBLEM.INSTANCE_SEG.EMBEDSEG.MIN_MASK_SUM,
min_unclustered_sum=self.cfg.PROBLEM.INSTANCE_SEG.EMBEDSEG.MIN_UNCLUSTERED_SUM,
min_object_size=self.cfg.PROBLEM.INSTANCE_SEG.EMBEDSEG.MIN_OBJECT_SIZE,
)
if self.dims == 2:
pred_labels = np.expand_dims(pred_labels, 0)
elif any(ch in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS for ch in ("Gv", "Gh", "Gz")):
channels = list(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS)
fg_channel = next((ch for ch in channels if ch in ("F", "M", "B")), "")
_pred_in = pred if self.dims == 3 else pred[0]
pred_labels = flows_to_instances(
pred=_pred_in,
channels=channels,
flow_type=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.TYPE,
fg_channel=fg_channel,
fg_thresh=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.FG_THRESH,
flow_threshold=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.FLOW_THRESHOLD,
n_steps=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.N_STEPS,
dt=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.DT,
suppressed=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.SUPPRESSED,
min_size=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.MIN_SIZE,
max_cluster_dist=self.cfg.PROBLEM.INSTANCE_SEG.CELLPOSE.MAX_CLUSTER_DIST,
resolution=list(self.resolution[-self.dims:]),
)
if self.dims == 2:
pred_labels = np.expand_dims(pred_labels, 0)
else:
pred_labels = watershed_by_channels(
data=pred,
channels=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
seed_channels=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS,
seed_channel_ths=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_CHANNELS_THRESH,
topo_surface_channel=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.TOPOGRAPHIC_SURFACE_CHANNEL,
growth_mask_channels=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS,
growth_mask_channel_ths=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.GROWTH_MASK_CHANNELS_THRESH,
remove_before=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.DATA_REMOVE_BEFORE_MW,
thres_small_before=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.DATA_REMOVE_SMALL_OBJ_BEFORE,
seed_morph_sequence=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_MORPH_SEQUENCE,
seed_morph_radius=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.SEED_MORPH_RADIUS,
erode_and_dilate_growth_mask=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.ERODE_AND_DILATE_GROWTH_MASK,
fore_erosion_radius=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.FORE_EROSION_RADIUS,
fore_dilation_radius=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.FORE_DILATION_RADIUS,
resolution=self.resolution,
save_dir=save_dir,
watershed_by_2d_slices=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.BY_2D_SLICES,
verbose=verbose,
)
if pred_labels.ndim == 2:
pred_labels = np.expand_dims(pred_labels, 0)
return pred_labels.astype(np.uint32)
@staticmethod
def _compute_global_id_remap(edges: List[Tuple[int, int]]) -> Dict[int, int]:
"""Union-Find over boundary merge edges; returns {old_id: canonical_id}.
Only IDs that need to change are included in the returned dict.
Each component is represented by its minimum member ID, which avoids
collisions with isolated instance IDs that are not part of any edge.
"""
if not edges:
return {}
all_ids: set = set()
for a, b in edges:
all_ids.add(a)
all_ids.add(b)
parent: Dict[int, int] = {i: i for i in all_ids}
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]] # path halving
x = parent[x]
return x
for a, b in edges:
ra, rb = find(a), find(b)
if ra != rb:
parent[rb] = ra
# Minimum ID in each component becomes the canonical label
component_min: Dict[int, int] = {}
for uid in all_ids:
root = find(uid)
component_min[root] = min(component_min.get(root, uid), uid)
remap: Dict[int, int] = {}
for uid in all_ids:
new_id = component_min[find(uid)]
if new_id != uid:
remap[uid] = new_id
return remap
@staticmethod
def _apply_id_remap(patch: NDArray, remap: Dict[int, int]) -> NDArray:
"""Vectorised per-chunk ID remapping via np.unique / inverse."""
flat = patch.ravel().astype(np.uint64)
unique_ids, inv = np.unique(flat, return_inverse=True)
mapped = np.array(
[remap.get(int(uid), int(uid)) for uid in unique_ids], dtype=np.uint64
)
return mapped[inv].reshape(patch.shape)
[docs]
def instance_seg_process(self, pred, filenames, out_dir, out_dir_post_proc, calculate_metrics: bool = True):
"""
Instance segmentation workflow engine for test/inference.
Process model's prediction to prepare
instance segmentation output and calculate metrics.
Parameters
----------
pred : 4D/5D Torch tensor
Model predictions. E.g. ``(z, y, x, channels)`` for both 2D and 3D.
filenames : List of str
Predicted image's filenames.
out_dir : path
Output directory to save the instances.
out_dir_post_proc : path
Output directory to save the post-processed instances.
calculate_metrics : bool, optional
Whether to calculate or not the metrics.
"""
assert pred.ndim == 4, f"Prediction doesn't have 4 dim: {pred.shape}"
#############################
### INSTANCE SEGMENTATION ###
#############################
if not self.instances_already_created:
# Multi-head: capture class channel before _create_instance_labels strips it
if self.separated_class_channel:
class_channel = np.expand_dims(pred[..., -1], -1)
w_dir = os.path.join(self.cfg.PATHS.WATERSHED_DIR, filenames[0])
check_wa = w_dir if self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED.DATA_CHECK_MW else None
pred_labels = self._create_instance_labels(pred, save_dir=check_wa, verbose=self.cfg.TEST.VERBOSE)
# Multi-head: instances + classification
if self.separated_class_channel:
print("Adapting class channel . . .")
labels = np.unique(pred_labels)[1:]
new_class_channel = np.zeros(pred_labels.shape, dtype=pred_labels.dtype)
# Classify each instance counting the most prominent class of all the pixels that compose it
for l in labels:
instance_classes, instance_classes_count = np.unique(class_channel[pred_labels == l], return_counts=True)
# Remove background
if instance_classes[0] == 0:
instance_classes = instance_classes[1:]
instance_classes_count = instance_classes_count[1:]
if len(instance_classes) > 0:
label_selected = int(instance_classes[np.argmax(instance_classes_count)])
else: # Label by default with class 1 in case there was no class info
label_selected = 1
new_class_channel = np.where(pred_labels == l, label_selected, new_class_channel)
class_channel = new_class_channel
class_channel = class_channel.squeeze()
del new_class_channel
save_tif(
np.expand_dims(
np.concatenate(
[
np.expand_dims(pred_labels.squeeze(), -1),
np.expand_dims(class_channel, -1),
],
axis=-1,
),
0,
),
out_dir,
filenames,
verbose=self.cfg.TEST.VERBOSE,
)
else:
save_tif(
np.expand_dims(np.expand_dims(pred_labels, -1), 0),
out_dir,
filenames,
verbose=self.cfg.TEST.VERBOSE,
)
# Add extra dimension if working in 2D
if pred_labels.ndim == 2:
pred_labels = np.expand_dims(pred_labels, 0)
else:
pred_labels = pred.squeeze()
if pred_labels.ndim == 2:
pred_labels = np.expand_dims(pred_labels, 0)
results = None
results_class = None
if (
calculate_metrics
and self.cfg.TEST.MATCHING_STATS
and (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST)
):
print("Calculating matching stats . . .")
# Need to load instance labels, as Y are binary channels used for IoU calculation
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK and len(self.test_filenames) == pred_labels.shape[0]:
del self.current_sample["Y"]
_Y = np.zeros(pred_labels.shape, dtype=pred_labels.dtype)
for i in range(len(self.test_filenames)):
test_file = os.path.join(self.original_test_mask_path, self.test_filenames[i])
_Y[i] = read_img_as_ndarray(test_file, is_3d=False).squeeze()
else:
test_file = os.path.join(self.original_test_mask_path, self.test_filenames[self.f_numbers[0]])
if not os.path.exists(test_file):
print(
"WARNING: The image seems to have different name than its mask file. Using the mask file that's "
"in the same spot (within the mask files list) where the image is in its own list of images. Check if it is correct!"
)
test_file = os.path.join(
self.original_test_mask_path,
self.test_gt_filenames[self.f_numbers[0]],
)
print(f"Its respective image seems to be: {test_file}")
if test_file.endswith(".zarr") or test_file.endswith(".n5"):
if self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA:
_, _Y = read_chunked_nested_data(
test_file,
self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_GT_PATH,
)
else:
_, _Y = read_chunked_data(test_file)
_Y = np.array(_Y).squeeze()
else:
_Y = read_img_as_ndarray(test_file, is_3d=self.cfg.PROBLEM.NDIM == "3D").squeeze()
# Multi-head: instances + classification
if self.separated_class_channel:
# Channel check
error_shape = None
if self.cfg.PROBLEM.NDIM == "2D" and _Y.ndim != 3:
error_shape = (256, 256, 2)
elif self.cfg.PROBLEM.NDIM == "3D" and _Y.ndim != 4:
error_shape = (40, 256, 256, 2)
if error_shape:
raise ValueError(
f"Image {test_file} wrong dimension. In instance segmentation, when 'DATA.N_CLASSES' are "
f"more than 2 labels need to have two channels, e.g. {error_shape}, containing the instance "
"segmentation map (first channel) and classification map (second channel)."
)
# Separate instance and classification channels
_Y_classes = _Y[..., 1] # Classes
_Y = _Y[..., 0] # Instances
# Measure class IoU
class_iou = self.jaccard_index_matching(
torch.as_tensor(class_channel.squeeze().astype(np.uint8)).to(self.test_device, non_blocking=True),
torch.as_tensor(_Y_classes.squeeze().astype(np.uint8)).to(self.test_device, non_blocking=True),
)
class_iou = class_iou.item() if not torch.isnan(class_iou) else 0
print(f"Class IoU: {class_iou}")
results_class = class_iou
if _Y.ndim == 2:
_Y = np.expand_dims(_Y, 0)
# For torchvision models that resize need to rezise the images
if pred_labels.shape != _Y.shape:
pred_labels = resize(pred_labels, _Y.shape, order=0)
# Convert instances to integer
if _Y.dtype == np.float32:
_Y = _Y.astype(np.uint32)
if _Y.dtype == np.float64:
_Y = _Y.astype(np.uint64)
diff_ths_colored_img = abs(
len(self.cfg.TEST.MATCHING_STATS_THS_COLORED_IMG) - len(self.cfg.TEST.MATCHING_STATS_THS)
)
colored_img_ths = self.cfg.TEST.MATCHING_STATS_THS_COLORED_IMG + [-1] * diff_ths_colored_img
results = matching(_Y, pred_labels, thresh=self.cfg.TEST.MATCHING_STATS_THS, report_matches=True)
for i in range(len(results)):
# Extract TPs, FPs and FNs from the resulting matching data structure
r_stats = results[i]
thr = r_stats["thresh"]
# TP and FN
gt_ids = r_stats["gt_ids"][1:]
matched_pairs = r_stats["matched_pairs"]
gt_match = [x[0] for x in matched_pairs]
gt_unmatch = [x for x in gt_ids if x not in gt_match]
matched_scores = list(r_stats["matched_scores"]) + [0 for _ in gt_unmatch]
pred_match = [x[1] for x in matched_pairs] + [-1 for _ in gt_unmatch]
tag = ["TP" if score >= thr else "FN" for score in matched_scores]
# FPs
pred_ids = r_stats["pred_ids"][1:]
fp_instances = [x for x in pred_ids if x not in pred_match]
fp_instances += [pred_id for score, pred_id in zip(matched_scores, pred_match) if score < thr]
# Save csv files
df = pd.DataFrame(
zip(gt_match + gt_unmatch, pred_match, matched_scores, tag),
columns=["gt_id", "pred_id", "iou", "tag"],
)
df = df.sort_values(by=["gt_id"])
df_fp = pd.DataFrame(zip(fp_instances), columns=["pred_id"])
os.makedirs(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, exist_ok=True)
df.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS,
os.path.splitext(filenames[0])[0] + "_th_{}_gt_assoc.csv".format(thr),
),
index=False,
)
df_fp.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS,
os.path.splitext(filenames[0])[0] + "_th_{}_fp.csv".format(thr),
),
index=False,
)
del r_stats["matched_scores"]
del r_stats["matched_tps"]
del r_stats["matched_pairs"]
del r_stats["pred_ids"]
del r_stats["gt_ids"]
print("DatasetMatching: {}".format(r_stats))
if colored_img_ths[i] != -1 and colored_img_ths[i] == thr:
print("Creating the image with a summary of detected points and false positives with colors . . .")
colored_result = np.zeros(pred_labels.shape + (3,), dtype=np.uint8)
print("Painting TPs and FNs . . .")
for j in tqdm(range(len(gt_match)), disable=not is_main_process()):
color = (0, 255, 0) if tag[j] == "TP" else (255, 0, 0) # Green or red
colored_result[np.where(_Y == gt_match[j])] = color
for j in tqdm(range(len(gt_unmatch)), disable=not is_main_process()):
colored_result[np.where(_Y == gt_unmatch[j])] = (
255,
0,
0,
) # Red
print("Painting FPs . . .")
for j in tqdm(range(len(fp_instances)), disable=not is_main_process()):
colored_result[np.where(pred_labels == fp_instances[j])] = (
0,
0,
255,
) # Blue
save_tif(
np.expand_dims(colored_result, 0),
self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS,
[os.path.splitext(filenames[0])[0] + "_th_{}.tif".format(thr)],
verbose=self.cfg.TEST.VERBOSE,
)
del colored_result
###################
# Post-processing #
###################
if self.cfg.TEST.POST_PROCESSING.INSTANCE_REFINEMENT.ENABLE:
pred_labels = apply_label_refinement(
pred_labels,
is_3d=self.cfg.PROBLEM.NDIM=="3D",
operations=self.cfg.TEST.POST_PROCESSING.INSTANCE_REFINEMENT.OPERATIONS,
values=self.cfg.TEST.POST_PROCESSING.INSTANCE_REFINEMENT.VALUES,
)
if self.cfg.TEST.POST_PROCESSING.REPARE_LARGE_BLOBS_SIZE != -1:
if self.cfg.PROBLEM.NDIM == "2D":
pred_labels = pred_labels[0]
pred_labels = repare_large_blobs(pred_labels[0], self.cfg.TEST.POST_PROCESSING.REPARE_LARGE_BLOBS_SIZE)
if self.cfg.PROBLEM.NDIM == "2D":
pred_labels = np.expand_dims(pred_labels, 0)
if self.cfg.TEST.POST_PROCESSING.VORONOI_ON_MASK:
erode_size = 0
if "M" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
ch_pos = self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("M")
pred = pred[...,ch_pos]
elif "F" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
if "C" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
pred = pred[...,self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("F")] + pred[...,self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("C")]
else:
pred = pred[...,self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("F")]
elif "B" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
pred = 1 - pred[...,self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("B")]
elif "C" in self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS:
pred = pred[...,self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS.index("C")]
erode_size = 2 # As the contours are thicker we erode a little bit
pred_labels = voronoi_on_mask(
pred_labels,
pred,
th=self.cfg.TEST.POST_PROCESSING.VORONOI_TH,
verbose=self.cfg.TEST.VERBOSE,
erode_size=erode_size,
)
del pred
if (
self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.ENABLE
or self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE
):
if self.cfg.PROBLEM.NDIM == "2D":
pred_labels = pred_labels[0]
pred_labels, d_result = measure_morphological_props_and_filter(
pred_labels,
intensity_image=self.current_sample["X"].squeeze(),
resolution=self.resolution,
extra_props=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.EXTRA_PROPS,
filter_instances=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE,
properties=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.PROPS,
prop_values=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.VALUES,
comp_signs=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.SIGNS,
)
extra_properties_keys = self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.EXTRA_PROPS
if self.cfg.PROBLEM.NDIM == "2D":
pred_labels = np.expand_dims(pred_labels, 0)
# Save all instance stats
if self.cfg.PROBLEM.NDIM == "2D":
# Base properties that are always included
base_data_series = [
np.array(d_result["labels"], dtype=np.uint64),
list(d_result["centers"][:, 0]),
list(d_result["centers"][:, 1]),
d_result["npixels"],
d_result["areas"],
d_result["circularities"],
d_result["diameters"],
d_result["perimeters"],
d_result["elongations"],
d_result["comment"],
d_result["conditions"],
]
# Base column names
base_columns = [
"label",
"axis-0",
"axis-1",
"npixels",
"area",
"circularity",
"diameter",
"perimeter",
"elongation",
"comment",
"conditions",
]
else:
base_data_series = [
np.array(d_result["labels"], dtype=np.uint64),
list(d_result["centers"][:, 0]),
list(d_result["centers"][:, 1]),
list(d_result["centers"][:, 2]),
d_result["npixels"],
d_result["areas"],
d_result["sphericities"],
d_result["diameters"],
d_result["perimeters"],
d_result["comment"],
d_result["conditions"],
]
base_columns =[
"label",
"axis-0",
"axis-1",
"axis-2",
"npixels",
"volume",
"sphericity",
"diameter",
"perimeter (surface area)",
"comment",
"conditions",
]
extra_properties_keys = [key for key in extra_properties_keys if key not in base_columns and key in d_result]
extra_data_series = [d_result[key] for key in extra_properties_keys if key in d_result]
all_data_series = base_data_series + extra_data_series
all_columns = base_columns + extra_properties_keys
df = pd.DataFrame(
zip(*all_data_series),
columns=all_columns,
)
df = df.sort_values(by=["label"])
df.to_csv(
os.path.join(out_dir, os.path.splitext(filenames[0])[0] + "_full_stats.csv"),
index=False,
)
# Save only remain instances stats
df = df[df["comment"].str.contains("Strange") == False]
os.makedirs(out_dir_post_proc, exist_ok=True)
df.to_csv(
os.path.join(
out_dir_post_proc,
os.path.splitext(filenames[0])[0] + "_filtered_stats.csv",
),
index=False,
)
del df
results_post_proc = None
results_class_post_proc = None
if self.post_processing["instance_post"]:
if self.cfg.PROBLEM.NDIM == "2D":
pred_labels = pred_labels[0]
# Multi-head: instances + classification
if self.separated_class_channel:
class_channel = np.where(pred_labels > 0, class_channel, 0) # Adapt changes to post-processed pred_labels
save_tif(
np.expand_dims(
np.concatenate(
[
np.expand_dims(pred_labels, -1),
np.expand_dims(class_channel, -1),
],
axis=-1,
),
0,
),
out_dir_post_proc,
filenames,
verbose=self.cfg.TEST.VERBOSE,
)
else:
save_tif(
np.expand_dims(np.expand_dims(pred_labels, -1), 0),
out_dir_post_proc,
filenames,
verbose=self.cfg.TEST.VERBOSE,
)
if (
calculate_metrics
and self.cfg.TEST.MATCHING_STATS
and (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST)
):
# Multi-head: instances + classification
if self.separated_class_channel:
# Measure class IoU
class_iou = self.jaccard_index_matching(
torch.as_tensor(class_channel.squeeze().astype(np.int32)),
torch.as_tensor(_Y_classes.squeeze().astype(np.int32)),
)
class_iou = class_iou.item() if not torch.isnan(class_iou) else 0
print(f"Class IoU (post-processing): {class_iou}")
results_class_post_proc = class_iou
if self.cfg.PROBLEM.NDIM == "2D":
pred_labels = np.expand_dims(pred_labels, 0)
print("Calculating matching stats after post-processing . . .")
results_post_proc = matching(
_Y,
pred_labels,
thresh=self.cfg.TEST.MATCHING_STATS_THS,
report_matches=True,
)
for i in range(len(results_post_proc)):
# Extract TPs, FPs and FNs from the resulting matching data structure
r_stats = results_post_proc[i]
thr = r_stats["thresh"]
# TP and FN
gt_ids = r_stats["gt_ids"][1:]
matched_pairs = r_stats["matched_pairs"]
gt_match = [x[0] for x in matched_pairs]
gt_unmatch = [x for x in gt_ids if x not in gt_match]
matched_scores = list(r_stats["matched_scores"]) + [0 for _ in gt_unmatch]
pred_match = [x[1] for x in matched_pairs] + [-1 for _ in gt_unmatch]
tag = ["TP" if score >= thr else "FN" for score in matched_scores]
# FPs
pred_ids = r_stats["pred_ids"][1:]
fp_instances = [x for x in pred_ids if x not in pred_match]
fp_instances += [pred_id for score, pred_id in zip(matched_scores, pred_match) if score < thr]
# Save csv files
df = pd.DataFrame(
zip(gt_match + gt_unmatch, pred_match, matched_scores, tag),
columns=["gt_id", "pred_id", "iou", "tag"],
)
df = df.sort_values(by=["gt_id"])
df_fp = pd.DataFrame(zip(fp_instances), columns=["pred_id"])
os.makedirs(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, exist_ok=True)
df.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS,
os.path.splitext(filenames[0])[0] + "_post-proc_th_{}_gt_assoc.csv".format(thr),
),
index=False,
)
df_fp.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS,
os.path.splitext(filenames[0])[0] + "_post-proc_th_{}_fp.csv".format(thr),
),
index=False,
)
del r_stats["matched_scores"]
del r_stats["matched_tps"]
del r_stats["matched_pairs"]
del r_stats["pred_ids"]
del r_stats["gt_ids"]
print("DatasetMatching: {}".format(r_stats))
if colored_img_ths[i] != -1 and colored_img_ths[i] == thr:
print(
"Creating the image with a summary of detected points and false positives with colors . . ."
)
colored_result = np.zeros(pred_labels.shape + (3,), dtype=np.uint8)
print("Painting TPs and FNs . . .")
for j in tqdm(range(len(gt_match)), disable=not is_main_process()):
color = (0, 255, 0) if tag[j] == "TP" else (255, 0, 0) # Green or red
colored_result[np.where(_Y == gt_match[j])] = color
for j in tqdm(range(len(gt_unmatch)), disable=not is_main_process()):
colored_result[np.where(_Y == gt_unmatch[j])] = (
255,
0,
0,
) # Red
print("Painting FPs . . .")
for j in tqdm(range(len(fp_instances)), disable=not is_main_process()):
colored_result[np.where(pred_labels == fp_instances[j])] = (
0,
0,
255,
) # Blue
save_tif(
np.expand_dims(colored_result, 0),
self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS,
[os.path.splitext(filenames[0])[0] + "_post-proc_th_{}.tif".format(thr)],
verbose=self.cfg.TEST.VERBOSE,
)
del colored_result
return results, results_post_proc, results_class, results_class_post_proc
[docs]
def synapse_seg_process(
self,
pred: NDArray,
filenames: Optional[List[str]] = None,
out_dir: Optional[str] = None,
out_dir_post_proc: Optional[str] = None,
calculate_metrics: bool = True,
do_post_processing: bool = True,
) -> Dict:
"""
Synapse segmentation workflow engine for test/inference.
Process model's prediction to prepare
synapse segmentation output and calculate metrics.
Parameters
----------
pred : 4D/5D Torch tensor
Model predictions. E.g. ``(z, y, x, channels)`` for both 2D and 3D.
filenames : List of str
Predicted image's filenames.
out_dir : path
Output directory to save the instances.
out_dir_post_proc : str
Output directory to save the post-processed instances.
calculate_metrics : bool, optional
Whether to calculate or not the metrics. Normally we disable it when doring inference per chunks
as the metrics are calculated at the end on the whole image.
do_post_processing : bool
Whether to do or not the post-processing step. Normally we disable it when doring inference per chunks
as the post-processing is done at the end on the whole image.
Returns
-------
Dict[str, Any]
A dictionary containing the predicted synapse-related points.
"""
assert pred.ndim == 4, f"Prediction doesn't have 4 dim: {pred.shape}"
#############################
### INSTANCE SEGMENTATION ###
#############################
threshold_abs = []
for c in range(pred.shape[-1]):
if self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.TH_TYPE == "auto":
threshold_abs.append(threshold_otsu(pred[..., c]))
else: # "manual", "relative_by_patch", "relative"
threshold_abs.append(self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.MIN_TH_TO_BE_PEAK)
points_available = {}
if self.synapse_method == "synful":
pre_points_df, pre_points, post_points_df, post_points = extract_synful_synapses(
data=pred,
channels=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
threshold_abs=0.2,
min_distance=1,
cluster_distance=5.0,
out_dir=out_dir,
verbose=self.cfg.TEST.VERBOSE,
)
points_available["pre"] = {"points": pre_points, "df": pre_points_df}
points_available["post"] = {"points": post_points, "df": post_points_df}
elif self.synapse_method == "simpsyn":
pre_points_df, pre_points, post_points_df, post_points = create_synapses_from_point_probs(
data=pred,
channels=self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
point_creation_func=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.POINT_CREATION_FUNCTION,
min_th_to_be_peak=threshold_abs,
min_distance=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.PEAK_LOCAL_MAX_MIN_DISTANCE,
min_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_MIN_SIGMA,
max_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_MAX_SIGMA,
num_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_NUM_SIGMA,
exclude_border=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.EXCLUDE_BORDER,
relative_th_value=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.TH_TYPE in ["relative", "relative_by_patch"],
out_dir=out_dir,
filenames = filenames,
)
points_available["pre"] = {"points": pre_points, "df": pre_points_df}
points_available["post"] = {"points": post_points, "df": post_points_df}
elif self.synapse_method == "cleft":
cleft_points_df, cleft_points = extract_points_in_predictions(
data=pred[...,0],
point_type="cleft",
point_creation_func=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.POINT_CREATION_FUNCTION,
min_th_to_be_peak=threshold_abs[0],
min_distance=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.PEAK_LOCAL_MAX_MIN_DISTANCE,
min_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_MIN_SIGMA,
max_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_MAX_SIGMA,
num_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_NUM_SIGMA,
exclude_border=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.EXCLUDE_BORDER,
relative_th_value=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.TH_TYPE in ["relative", "relative_by_patch"],
out_dir=out_dir,
filenames = filenames,
verbose=self.cfg.TEST.VERBOSE,
)
points_available["cleft"] = {"points": cleft_points, "df": cleft_points_df}
elif self.synapse_method == "F_post_only":
post_points_df, post_points = extract_points_in_predictions(
data=pred[...,0],
point_type="post",
point_creation_func=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.POINT_CREATION_FUNCTION,
min_th_to_be_peak=threshold_abs[0],
min_distance=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.PEAK_LOCAL_MAX_MIN_DISTANCE,
min_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_MIN_SIGMA,
max_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_MAX_SIGMA,
num_sigma=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.BLOB_LOG_NUM_SIGMA,
exclude_border=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.EXCLUDE_BORDER,
relative_th_value=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.TH_TYPE in ["relative", "relative_by_patch"],
out_dir=out_dir,
filenames = filenames,
verbose=self.cfg.TEST.VERBOSE,
)
points_available["post"] = {"points": post_points, "df": post_points_df}
else:
raise ValueError(f"Synapse method {self.synapse_method} not recognized.")
if calculate_metrics and self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
print("Calculating synapse detection stats . . .")
gt_info = load_synapse_gt_points(
locations_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_LOCATIONS_PATH,
resolution_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_RESOLUTION_PATH,
partners_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_PARTNERS_PATH,
id_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_ID_PATH,
data_filename = os.path.join(self.current_sample["X_dir"], self.current_sample["X_filename"]),
)
assert out_dir is not None, "Output directory must be provided to save the synapse detection metrics results."
# Calculate detection metrics for each type of points if they are available
for key in points_available:
if key not in ["pre", "post", "cleft"]:
raise ValueError(f"Unknown point type {key} found in points_available. Expected 'pre', 'post' or 'cleft'.")
points_available[key]["gt"] = gt_info[key]
assert "points" in points_available[key], f"Points not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
assert "gt" in points_available[key], f"GT not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
points_available[key]["gt_assoc"], points_available[key]["fps"] = self.calculate_synapse_det_metrics_on_points(
points_available[key]["gt"],
points_available[key]["points"],
gt_info["resolution"],
self.current_sample["X_filename"],
out_dir,
point_type=key
)
###################
# Post-processing #
###################
if do_post_processing and self.post_processing["per_image"]:
print("TODO: post-processing")
return points_available
[docs]
def calculate_synapse_det_metrics_on_points(self,
gt_points: NDArray | List[int],
pred_points: NDArray,
resolution: List[int | float],
filename: str,
out_dir: str,
point_type: str ="pre",
post_processing: bool=False
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Calculate synapse detection metrics on the predicted points and save the associations between GT points and predicted ones.
Parameters
----------
gt_points : np.array or list of int
Ground truth synapse points.
pred_points : np.array
Predicted synapse points.
resolution : list
Image resolution.
filename : str
Filename of the predicted image.
out_dir : str
Output directory to save the csv files with the associations between GT points and predicted ones.
point_type : str
Type of synaptic point to calculate the metrics on. E.g. "pre" or "post".
post_processing : bool
Whether the predicted points are from the post-processing step or not. Used for printing and saving the results.
Returns
-------
gt_assoc : pd.DataFrame
DataFrame with the associations between GT points and predicted ones.
fps : pd.DataFrame
DataFrame with the false positive predicted points.
"""
d_metrics, gt_assoc, fps = detection_metrics(
true_points=gt_points,
pred_points=pred_points,
true_classes=None,
pred_classes=[],
tolerance=self.cfg.TEST.DET_TOLERANCE,
resolution=resolution,
bbox_to_consider=[],
verbose=True,
)
point_metrics = [x for x in self.test_extra_metrics if point_type in str(x).lower()]
stat_key = "merge_patches" if not post_processing else "merge_patches_post"
print("Synapse detection ({} points) metrics{}: {}".format(point_type, " (post-processing)" if post_processing else "", d_metrics))
for n, item in enumerate(d_metrics.items()):
metric = point_metrics[n]
if str(metric).lower() not in self.stats[stat_key]:
self.stats[stat_key][str(metric.lower())] = 0
self.stats[stat_key][str(metric).lower()] += item[1]
self.current_sample_metrics[str(metric).lower() + f" ({point_type} points{(', post-processing' if post_processing else '')})"] = item[1]
# Save csv files with the associations between GT points and predicted ones
gt_assoc.to_csv(
os.path.join(
out_dir,
filename+f"_pred_{point_type}_locations_gt_assoc.csv",
),
index=False,
)
fps.to_csv(
os.path.join(
out_dir,
filename+f"_pred_{point_type}_locations_fp.csv",
),
index=False,
)
return gt_assoc, fps
[docs]
def process_test_sample(self):
"""Process a sample in the inference phase."""
if self.cfg.MODEL.SOURCE != "torchvision":
self.instances_already_created = False
super().process_test_sample()
else:
# Skip processing image
if "discard" in self.current_sample and self.current_sample["discard"]:
return True
self.instances_already_created = True
##################
### FULL IMAGE ###
##################
# Make the prediction
pred = self.model_call_func(self.current_sample["X"])
pred = to_numpy_format(pred, self.axes_order_back)
del self.current_sample["X"]
# In Torchvision the output is a collection of bboxes so there is nothing else to do here
if self.cfg.MODEL.SOURCE == "torchvision":
return
if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE:
reflected_orig_shape = (1,) + self.current_sample["reflected_orig_shape"]
if reflected_orig_shape != pred.shape:
if self.cfg.PROBLEM.NDIM == "2D":
pred = pred[:, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :]
else:
pred = pred[
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
-reflected_orig_shape[3] :,
]
if self.cfg.TEST.POST_PROCESSING.APPLY_MASK:
pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS)
self.after_full_image(pred)
[docs]
def after_merge_patches(self, pred):
"""
Execute steps needed after merging all predicted patches into the original image.
Parameters
----------
pred : Torch Tensor
Model prediction.
"""
if pred.shape[0] == 1:
if self.cfg.PROBLEM.NDIM == "3D":
pred = pred[0]
if not self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
r, r_post, rcls, rcls_post = self.instance_seg_process(
pred,
[self.current_sample["X_filename"]],
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_INSTANCES,
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING,
)
if r:
self.all_matching_stats_merge_patches.append(r)
for i, r_per_th in enumerate(r):
prefix = str(r_per_th['thresh']) + " TH " if 'thresh' in r_per_th else f"TH{i} "
for mkey in ['fp','tp','fn','precision','recall','accuracy','f1','n_true','n_pred','mean_true_score','mean_matched_score','panoptic_quality']:
if mkey in r_per_th:
self.current_sample_metrics[prefix + mkey] = r_per_th[mkey]
if r_post:
self.all_matching_stats_merge_patches_post.append(r_post)
for i, r_per_th in enumerate(r_post):
prefix = str(r_per_th['thresh']) + " TH (post)" if 'thresh' in r_per_th else f"TH{i} (post) "
for mkey in ['fp','tp','fn','precision','recall','accuracy','f1','n_true','n_pred','mean_true_score','mean_matched_score','panoptic_quality']:
if mkey in r_per_th:
self.current_sample_metrics[prefix + mkey] = r_per_th[mkey]
if rcls:
self.all_class_stats_merge_patches.append(rcls)
self.current_sample_metrics["class iou"] = rcls
if rcls_post:
self.all_class_stats_merge_patches_post.append(rcls_post)
self.current_sample_metrics["class iou (post)"] = rcls_post
else: # synapses
self.synapse_seg_process(
pred,
[self.current_sample["X_filename"]],
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_INSTANCES,
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING,
)
else:
raise NotImplementedError
[docs]
def after_one_chunk_raw_prediction(
self, chunk_id: int, chunk: NDArray, chunk_in_data: PatchCoords, added_pad: List[List[int]]
):
"""
Place any code that needs to be done after predicting one chunk of data in "by chunks" setting.
Parameters
----------
chunk_id: int
Chunk identifier.
chunk : NDArray
Predicted chunk
chunk_in_data : PatchCoords
Global coordinates of the chunk.
added_pad: List of list of ints
Padding added to the chunk in each dimension. The order of dimensions is the same as the input
image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], ....
"""
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
pass
# Important to maintain calculate_metrics=False in the future call here
# pre_points_df, post_points_df = self.instance_seg_process(chunk, filenames, out_dir, out_dir_post_proc, calculate_metrics=False)
else: # synapses
if self.synapse_method == "synful":
return
# "simpsyn", "cleft" or "F_post_only"
points_available = self.synapse_seg_process(chunk, calculate_metrics=False, do_post_processing=False)
_filename, _ = os.path.splitext(os.path.basename(self.current_sample["X_filename"]))
npatches = len(str(len(self.test_generator)))
for key in points_available:
assert key in ["pre", "post", "cleft"], f"Unknown point type {key} found in points_available. Expected 'pre', 'post' or 'cleft'."
assert "df" in points_available[key], f"'df' key not found for {key} in points_available. Found keys: {points_available[key].keys()}"
point_df = points_available[key]["df"]
# Remove possible points in the padded area
point_df = point_df[point_df["axis-0"] < chunk.shape[0] - added_pad[0][1]]
point_df = point_df[point_df["axis-1"] < chunk.shape[1] - added_pad[1][1]]
point_df = point_df[point_df["axis-2"] < chunk.shape[2] - added_pad[2][1]]
point_df["axis-0"] = point_df["axis-0"] - added_pad[0][0]
point_df["axis-1"] = point_df["axis-1"] - added_pad[1][0]
point_df["axis-2"] = point_df["axis-2"] - added_pad[2][0]
point_df = point_df[point_df["axis-0"] >= 0]
point_df = point_df[point_df["axis-1"] >= 0]
point_df = point_df[point_df["axis-2"] >= 0]
# Add the chunk shift to the detected coordinates so they represent global coords
point_df["axis-0"] = point_df["axis-0"] + chunk_in_data.z_start
point_df["axis-1"] = point_df["axis-1"] + chunk_in_data.y_start
point_df["axis-2"] = point_df["axis-2"] + chunk_in_data.x_start
# Save the csv file
os.makedirs(self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK, exist_ok=True)
point_df.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK,
_filename + "_patch" + str(chunk_id).zfill(npatches) + "_" + key + "_points.csv",
),
index=False,
)
[docs]
def after_all_chunk_prediction_workflow_process(self):
"""
Place any code that needs to be done after predicting all patches in "by chunks" setting.
This function is called on all ranks.
For ``PROBLEM.INSTANCE_SEG.TYPE == "regular"`` and
``TEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk"`` this runs five passes:
A. Per-chunk instance labelling via :meth:`after_one_chunk_workflow_process` (base-class loop).
B. Global-offset assignment — each chunk *k* adds ``k * MAX_INSTANCES_PER_CHUNK`` to every
non-zero label so that IDs are unique across the whole volume.
C. Boundary-edge extraction — for every pair of spatially adjacent chunks we read the
shared boundary face and collect pairs of IDs that co-occur, indicating the same
physical instance was split by the tile boundary.
D. Union-Find on rank 0 to resolve connected components; the resulting remap is broadcast
to all ranks.
E. Relabelling — every chunk is rewritten with the canonical global ID for each instance.
"""
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
if self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk":
# Instances go to PER_IMAGE_INSTANCES (same as the non-chunked pipeline).
self.test_chunked_workflow_process_vars["out_dir"] = self.cfg.PATHS.RESULT_DIR.PER_IMAGE_INSTANCES
# Use uint64 to hold large globally-offset IDs.
self.test_chunked_workflow_process_vars["dtype_str"] = "uint64"
phases = self.cfg.TEST.BY_CHUNKS.PHASES
run_pass_a = "instance_creation" in phases
run_merging = "instance_merging" in phases
if not run_pass_a and not run_merging:
return
save_out_tif = self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF
self._halo = self._effective_halo()
print(f"[Rank {get_rank()} ({os.getpid()})] Effective halo: {self._halo}")
if run_pass_a:
# Temporarily suppress TIF creation inside the base-class pass.
# The base class would save it after Pass A with un-merged local IDs,
# which is wrong. We create the final TIF ourselves after Pass E.
if save_out_tif:
self.cfg.defrost()
self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF = False
self.cfg.freeze()
# ---- Pass A: per-chunk instance labelling (base-class loop, all ranks) ----
print(f"[Rank {get_rank()} ({os.getpid()})] Pass A: creating per-chunk instance labels . . .")
super().after_all_chunk_prediction_workflow_process()
if save_out_tif:
self.cfg.defrost()
self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF = True
self.cfg.freeze()
# Retrieve grid parameters from the generator
from biapy.data.generators.chunked_workflow_process_generator import (
chunked_workflow_process_generator as _CWP,
)
tgen: _CWP = self.test_generator.dataset # type: ignore
zarr_path = tgen._shared_zarr_path()
axes_order = tgen.out_data_order
z_dim, y_dim, x_dim = tgen.z_dim, tgen.y_dim, tgen.x_dim
step_z, step_y, step_x = tgen.step_z, tgen.step_y, tgen.step_x
vols_per_z = tgen.vols_per_z
vols_per_y = tgen.vols_per_y
vols_per_x = tgen.vols_per_x
else:
# ---- Skip Pass A: derive grid params from config + existing label Zarr ----
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass A skipped ('instance_creation' not in PHASES), "
"reusing existing instance label Zarr."
)
if "C" not in self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER:
axes_order = self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER + "C"
else:
axes_order = self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER
base_filename = os.path.splitext(self.current_sample["X_filename"])[0]
zarr_path = os.path.join(
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_INSTANCES, base_filename + ".zarr"
)
if not os.path.exists(zarr_path):
raise FileNotFoundError(
f"Pass A was skipped but the instance Zarr was not found: {zarr_path}"
)
_existing = zarr.open(zarr_path, mode="r", zarr_format=3)
_, z_dim, _, y_dim, x_dim = order_dimensions(_existing.shape, axes_order)
assert isinstance(z_dim, int) and isinstance(y_dim, int) and isinstance(x_dim, int)
patch_size = self.cfg.DATA.PATCH_SIZE
step_z = int(patch_size[0])
step_y = int(patch_size[1])
step_x = int(patch_size[2])
vols_per_z = math.ceil(z_dim / step_z)
vols_per_y = math.ceil(y_dim / step_y)
vols_per_x = math.ceil(x_dim / step_x)
if not run_merging:
return
total_chunks = vols_per_z * vols_per_y * vols_per_x
rank = get_rank()
world_size = get_world_size()
# Helper: linear chunk index → PatchCoords
def _chunk_coords(linear_idx: int) -> PatchCoords:
zi, yi, xi = np.unravel_index(
linear_idx, (vols_per_z, vols_per_y, vols_per_x)
)
z0 = int(zi) * step_z; z1 = min(z0 + step_z, z_dim)
y0 = int(yi) * step_y; y1 = min(y0 + step_y, y_dim)
x0 = int(xi) * step_x; x1 = min(x0 + step_x, x_dim)
return PatchCoords(z_start=z0, z_end=z1, y_start=y0, y_end=y1,
x_start=x0, x_end=x1)
my_chunk_indices = list(range(rank, total_chunks, world_size))
zarr_data = zarr.open(zarr_path, mode="r+", zarr_format=3)
# ---- Pass B1: collect per-chunk max ID (all ranks, disjoint) ----
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass B: collecting per-chunk max IDs from "
f"{len(my_chunk_indices)}/{total_chunks} chunks . . ."
)
local_max_ids: Dict[int, int] = {}
for idx in my_chunk_indices:
coords = _chunk_coords(idx)
patch = extract_patch_from_efficient_file(zarr_data, coords, axes_order)
local_max_ids[idx] = int(patch.max())
# Gather max IDs from all ranks so every rank can compute prefix sums.
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
gathered_max: List = [None] * world_size
dist.all_gather_object(gathered_max, local_max_ids)
all_max_ids: Dict[int, int] = {}
for d in gathered_max:
if d:
all_max_ids.update(d)
else:
all_max_ids = local_max_ids
# Prefix-sum offsets: offset[k] = sum of max IDs of chunks 0 … k-1.
# This guarantees non-overlapping ID ranges without any fixed constant.
chunk_offsets: List[int] = [0] * total_chunks
for k in range(1, total_chunks):
chunk_offsets[k] = chunk_offsets[k - 1] + all_max_ids.get(k - 1, 0)
if self.cfg.TEST.VERBOSE:
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass B: chunk offsets computed "
f"(max global ID will be {chunk_offsets[-1] + all_max_ids.get(total_chunks - 1, 0)})"
)
# ---- Pass B2: apply offsets (all ranks, disjoint) ----
for i, idx in enumerate(my_chunk_indices):
offset = np.uint64(chunk_offsets[idx])
if offset == 0:
if self.cfg.TEST.VERBOSE:
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass B: chunk {i+1}/{len(my_chunk_indices)} "
f"(linear idx {idx}) no offset needed (first chunk)"
)
continue
coords = _chunk_coords(idx)
patch = extract_patch_from_efficient_file(zarr_data, coords, axes_order)
nonzero = patch > 0
patch[nonzero] = patch[nonzero].astype(np.uint64) + offset
insert_patch_in_efficient_file(
zarr_data, patch.astype(np.uint64), coords,
axes_order, "ZYXC", mode="replace",
)
if self.cfg.TEST.VERBOSE:
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass B: chunk {i+1}/{len(my_chunk_indices)} "
f"(linear idx {idx}, coords {coords}) offset by {offset}"
)
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
print(f"[Rank {get_rank()} ({os.getpid()})] Pass B done. Waiting for all ranks . . .")
dist.barrier()
# ---- Pass C: extract boundary edges (all ranks, disjoint set of faces) ----
all_boundaries: List[Tuple[int, int, int, str]] = []
for zi in range(vols_per_z - 1):
for yi in range(vols_per_y):
for xi in range(vols_per_x):
all_boundaries.append((zi, yi, xi, "z"))
for zi in range(vols_per_z):
for yi in range(vols_per_y - 1):
for xi in range(vols_per_x):
all_boundaries.append((zi, yi, xi, "y"))
for zi in range(vols_per_z):
for yi in range(vols_per_y):
for xi in range(vols_per_x - 1):
all_boundaries.append((zi, yi, xi, "x"))
my_boundaries = all_boundaries[rank::world_size]
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass C: extracting boundary edges from "
f"{len(my_boundaries)}/{len(all_boundaries)} boundary faces . . ."
)
# Strip-based IoU boundary matching.
#
# For each boundary between adjacent chunks A and B we read the last
# H voxels of A and the first H voxels of B along the split axis.
# Using H > 1 gives a more reliable IoU estimate than a single face,
# especially for cells that taper near the boundary.
#
# We use IoU = cnt / (size_a + size_b - cnt) rather than
# cnt / min(size_a, size_b). The min-based metric gives 1.0 whenever
# a small cell's face is entirely inside a large unrelated cell's face,
# causing a chain of spurious merges. IoU is near 1.0 only when BOTH
# cells have nearly the same cross-section (true split cell) and drops
# to near 0 for a small cell inside a large cell's footprint.
# Pass C compares exactly the two voxel-thin face slices that are
# physically adjacent across each chunk boundary. Multi-slice strips
# are wrong here: a cell of width W < strip would appear at local
# indices [strip-W .. strip-1] in face_a and [0 .. W-1] in face_b;
# the AND mask at the same local index would yield zero overlap and
# the cell would never be merged. Using a single face slice from
# each side means both patches cover the same cross-sectional
# positions (they differ only in the boundary axis by one voxel) so
# the comparison is always well-defined.
from collections import Counter
local_edges: set = set()
for b in my_boundaries:
zi, yi, xi, direction = b
z0_a = zi * step_z; y0_a = yi * step_y; x0_a = xi * step_x
z1_a = min(z0_a + step_z, z_dim)
y1_a = min(y0_a + step_y, y_dim)
x1_a = min(x0_a + step_x, x_dim)
if direction == "z":
z0_b = (zi + 1) * step_z
face_a_coords = PatchCoords(
z_start=z1_a - 1, z_end=z1_a,
y_start=y0_a, y_end=y1_a, x_start=x0_a, x_end=x1_a,
)
face_b_coords = PatchCoords(
z_start=z0_b, z_end=z0_b + 1,
y_start=y0_a, y_end=y1_a, x_start=x0_a, x_end=x1_a,
)
elif direction == "y":
y0_b = (yi + 1) * step_y
face_a_coords = PatchCoords(
z_start=z0_a, z_end=z1_a,
y_start=y1_a - 1, y_end=y1_a, x_start=x0_a, x_end=x1_a,
)
face_b_coords = PatchCoords(
z_start=z0_a, z_end=z1_a,
y_start=y0_b, y_end=y0_b + 1, x_start=x0_a, x_end=x1_a,
)
else: # x
x0_b = (xi + 1) * step_x
face_a_coords = PatchCoords(
z_start=z0_a, z_end=z1_a,
y_start=y0_a, y_end=y1_a, x_start=x1_a - 1, x_end=x1_a,
)
face_b_coords = PatchCoords(
z_start=z0_a, z_end=z1_a,
y_start=y0_a, y_end=y1_a, x_start=x0_b, x_end=x0_b + 1,
)
face_a = extract_patch_from_efficient_file(zarr_data, face_a_coords, axes_order)
face_b = extract_patch_from_efficient_file(zarr_data, face_b_coords, axes_order)
# Both face slices have the same cross-sectional shape (they are
# one-voxel thick along the boundary axis); flatten to 1-D so
# index i corresponds to the same spatial position in both.
flat_a = face_a.ravel().astype(np.int64)
flat_b = face_b.ravel().astype(np.int64)
if flat_a.size != flat_b.size:
continue
mask = (flat_a > 0) & (flat_b > 0)
if not mask.any():
continue
# Count co-occurring pixels per (a_id, b_id) pair and per-instance
# face sizes, then apply the normalised-overlap threshold.
overlap_counts: Dict[Tuple[int,int], int] = Counter()
size_a: Dict[int, int] = Counter()
size_b: Dict[int, int] = Counter()
for a_id, b_id in zip(flat_a[mask], flat_b[mask]):
overlap_counts[(int(a_id), int(b_id))] += 1
for a_id in flat_a[flat_a > 0]:
size_a[int(a_id)] += 1
for b_id in flat_b[flat_b > 0]:
size_b[int(b_id)] += 1
thresh = self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.INSTANCE_SEG_MERGE_IOU_TH
for (a_id, b_id), cnt in overlap_counts.items():
union = size_a[a_id] + size_b[b_id] - cnt
if union > 0 and cnt / union > thresh:
local_edges.add((int(min(a_id, b_id)), int(max(a_id, b_id))))
local_edge_list: List[Tuple[int, int]] = list(local_edges)
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass C done: found {len(local_edge_list)} local merge edges."
)
# ---- Gather edges from all ranks ----
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
gathered: List = [None] * world_size
dist.all_gather_object(gathered, local_edge_list)
all_edges: List[Tuple[int, int]] = [
e for sublist in gathered if sublist for e in sublist
]
else:
all_edges = local_edge_list
# ---- Pass D: Union-Find on rank 0, then broadcast ----
if is_main_process():
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass D: computing global ID remap "
f"from {len(all_edges)} total merge edges . . ."
)
remap = self._compute_global_id_remap(all_edges)
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass D done: {len(remap)} IDs will be remapped."
)
else:
remap = None
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
remap_container: List = [remap]
dist.broadcast_object_list(remap_container, src=0)
remap = remap_container[0]
if remap is None:
remap = {}
# ---- Pass E: relabelling (all ranks, disjoint chunks) ----
if remap:
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass E: relabelling "
f"{len(my_chunk_indices)}/{total_chunks} chunks . . ."
)
for i, idx in enumerate(my_chunk_indices):
coords = _chunk_coords(idx)
patch = extract_patch_from_efficient_file(zarr_data, coords, axes_order)
patch_relabeled = self._apply_id_remap(patch, remap)
insert_patch_in_efficient_file(
zarr_data, patch_relabeled, coords,
axes_order, "ZYXC", mode="replace",
)
if self.cfg.TEST.VERBOSE:
print(
f"[Rank {get_rank()} ({os.getpid()})] Pass E: chunk {i+1}/{len(my_chunk_indices)} "
f"(linear idx {idx}) relabelled."
)
else:
print(f"[Rank {get_rank()} ({os.getpid()})] Pass E: no cross-boundary merges needed, skipping relabelling.")
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
print(f"[Rank {get_rank()} ({os.getpid()})] Pass E done. Waiting for all ranks . . .")
dist.barrier()
print(f"[Rank {get_rank()} ({os.getpid()})] Chunk-by-chunk instance merging complete. Result: {zarr_path}")
# Save TIF of the final merged labels if requested.
if save_out_tif:
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
dist.barrier()
if is_main_process():
if not run_pass_a:
# tgen was not created when skipping Pass A — save directly.
data = np.array(zarr.open(zarr_path, mode="r", zarr_format=3))
data = ensure_3d_shape(data)
out_filename = os.path.splitext(os.path.basename(zarr_path))[0] + ".tif"
save_tif(
np.expand_dims(data, 0),
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_INSTANCES,
[out_filename],
verbose=True,
)
else:
tgen.save_parallel_data_as_tif()
if self.cfg.SYSTEM.NUM_GPUS > 1 and is_dist_avail_and_initialized():
dist.barrier()
else:
pass
else: # synapses
pass
[docs]
def after_one_chunk_workflow_process(self, chunks: List[NDArray], patch_in_data: List) -> Optional[List[NDArray]]:
"""
Process a list of chunks during inference in "by chunks" setting. Each workflow should have
its own implementation of this method.
Parameters
----------
chunks : List[NDArray]
List of chunks. Expected axes are: ``(z, y, x, channels)`` for 3D and
``(y, x, channels)`` for 2D.
patch_in_data : List[PatchCoords]
Spatial coordinates of each chunk in the full volume.
Returns
-------
chunks : Optional[List[NDArray]]
Processed chunks.
"""
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
chunk_by_chunk = (self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk")
result = []
for chunk, coords in zip(chunks, patch_in_data):
if chunk_by_chunk:
# Run watershed on a halo-extended region so seeds near chunk
# boundaries have context from neighbouring chunks. Only the
# inner result is kept — no side zarr is needed.
# This follows torch-em's predict_with_halo pattern: compute
# with halo, write only the inner block.
tgen = self.test_generator.dataset
hz, hy, hx = self._halo
z_dim, y_dim, x_dim = tgen.z_dim, tgen.y_dim, tgen.x_dim
z0_h = max(0, coords.z_start - hz)
z1_h = min(z_dim, coords.z_end + hz)
y0_h = max(0, coords.y_start - hy)
y1_h = min(y_dim, coords.y_end + hy)
x0_h = max(0, coords.x_start - hx)
x1_h = min(x_dim, coords.x_end + hx)
halo_raw = extract_patch_from_efficient_file(
tgen.X_parallel_data,
PatchCoords(z_start=z0_h, z_end=z1_h, y_start=y0_h, y_end=y1_h,
x_start=x0_h, x_end=x1_h),
tgen.out_data_order,
) # (Z+2H, Y+2H, X+2H, C)
labels_halo = self._create_instance_labels(halo_raw)
# Crop to inner region only
dz0 = coords.z_start - z0_h; dz1 = dz0 + (coords.z_end - coords.z_start)
dy0 = coords.y_start - y0_h; dy1 = dy0 + (coords.y_end - coords.y_start)
dx0 = coords.x_start - x0_h; dx1 = dx0 + (coords.x_end - coords.x_start)
labels = labels_halo[dz0:dz1, dy0:dy1, dx0:dx1]
else:
labels = self._create_instance_labels(chunk)
result.append(np.expand_dims(labels, -1)) # (Z, Y, X, 1) uint32
return result
else:
raise NotImplementedError
[docs]
def after_all_chunk_prediction_workflow_process_master_rank(self):
"""Execute steps needed after merging all predicted patches into the original image in "by chunks" setting."""
assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list)
filename = os.path.basename(self.current_sample["X_filename"])
points_available = {}
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
if self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk":
pass
else:
fpath = os.path.join(
self.cfg.PATHS.RESULT_DIR.PER_IMAGE, os.path.splitext(filename)[0] + ".zarr"
)
# Load H5/Zarr and convert it into numpy array
pred_file, pred = read_chunked_data(fpath)
pred = np.squeeze(np.array(pred, dtype=self.dtype))
if isinstance(pred_file, h5py.File):
pred_file.close()
pred = ensure_3d_shape(pred, fpath)
self.after_merge_patches(np.expand_dims(pred, 0))
else:
if self.synapse_method == "synful":
if self.cfg.TEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk":
raise NotImplementedError
else:
# Load H5/Zarr and convert it into numpy array
fpath = os.path.join(
self.cfg.PATHS.RESULT_DIR.PER_IMAGE, os.path.splitext(filename)[0] + ".zarr"
)
pred_file, pred = read_chunked_data(fpath)
pred = np.squeeze(np.array(pred, dtype=self.dtype))
if isinstance(pred_file, h5py.File):
pred_file.close()
pred = ensure_3d_shape(pred, fpath)
self.after_merge_patches(np.expand_dims(pred, 0))
print("TODO: synful support")
return
elif self.synapse_method in ["F_post_only", "cleft"]:
p_type = "post" if self.synapse_method == "F_post_only" else "cleft"
point_info = collect_point_type_csv_files(
filename=os.path.splitext(filename)[0],
point_type=p_type,
csv_dir=self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK,
min_th_to_be_peak=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.MIN_TH_TO_BE_PEAK,
th_type=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.TH_TYPE,
)
points = []
if isinstance(point_info["df"], pd.DataFrame) and len(point_info["df"]) > 0:
for coord in zip(point_info["df"]["axis-0"], point_info["df"]["axis-1"], point_info["df"]["axis-2"]):
points.append(list(coord))
points_available[p_type] = {
"points": points,
"df": point_info["df"]
}
elif self.synapse_method == "simpsyn":
pre_points_df, pre_points, pre_th_global, post_points_df, post_points, post_th_global = extract_synapse_connectivity(
filename=os.path.splitext(filename)[0],
reuse_predictions=self.cfg.TEST.REUSE_PREDICTIONS,
csv_dir=self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK,
min_th_to_be_peak=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.MIN_TH_TO_BE_PEAK,
th_type=self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.TH_TYPE,
verbose=self.cfg.TEST.VERBOSE,
)
points_available["pre"] = {"points": pre_points, "df": pre_points_df}
points_available["post"] = {"points": post_points, "df": post_points_df}
# Calculate synapse detection metrics if GT is available
if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
print("Calculating synapse detection stats . . .")
gt_info = load_synapse_gt_points(
locations_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_LOCATIONS_PATH,
resolution_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_RESOLUTION_PATH,
partners_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_PARTNERS_PATH,
id_path = self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA_ID_PATH,
data_filename = os.path.join(self.current_sample["X_dir"], filename),
)
# Calculate detection metrics for each type of points if they are available
for key in points_available:
if key not in ["pre", "post", "cleft"]:
raise ValueError(f"Unknown point type {key} found in points_available. Expected 'pre', 'post' or 'cleft'.")
points_available[key]["gt"] = gt_info[key]
assert "points" in points_available[key], f"Points not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
assert "gt" in points_available[key], f"GT not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
points_available[key]["gt_assoc"], points_available[key]["fps"] = self.calculate_synapse_det_metrics_on_points(
points_available[key]["gt"],
points_available[key]["points"],
gt_info["resolution"],
filename,
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK,
point_type=key
)
# Post-processing: remove points that are too close to each other based on a radius threshold defined in the config.
for key in points_available:
if key not in ["pre", "post", "cleft"]:
raise ValueError(f"Unknown point type {key} found in points_available. Expected 'pre', 'post' or 'cleft'.")
assert "points" in points_available[key], f"Points not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
assert "df" in points_available[key], f"Dataframe not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
if self.post_processing["per_image"]:
if self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.REMOVE_CLOSE_PRE_POINTS_RADIUS > 0:
if key == "pre" :
pradius = self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.REMOVE_CLOSE_PRE_POINTS_RADIUS
th_global = pre_th_global
else:
pradius = self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.REMOVE_CLOSE_POST_POINTS_RADIUS
th_global = post_th_global
if self.cfg.PROBLEM.INSTANCE_SEG.SYNAPSES.REMOVE_CLOSE_POINTS_RADIUS_BY_MASK:
# Load H5/Zarr and convert it into numpy array
fpath = os.path.join(
self.cfg.PATHS.RESULT_DIR.PER_IMAGE, os.path.splitext(filename)[0] + ".zarr"
)
pred_file, pred = read_chunked_data(fpath)
points_available[key]["points"], pre_dropped_pos = remove_close_points_by_mask( # type: ignore
points=points_available[key]["points"],
radius=pradius,
raw_predictions=pred,
bin_th=th_global,
resolution=gt_info["resolution"],
channel_to_look_into=1, # post channel
ndim=self.dims,
return_drops=True,
)
if isinstance(pred_file, h5py.File):
pred_file.close()
else:
points_available[key]["points"], pre_dropped_pos = remove_close_points( # type: ignore
points_available[key]["points"],
pradius,
gt_info["resolution"],
ndim=self.dims,
return_drops=True,
)
points_available[key]["df"].drop(points_available[key]["df"].index[pre_dropped_pos], inplace=True) # type: ignore
os.makedirs(self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING, exist_ok=True)
points_available[key]["df"].to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING,
str(filename)+"_pred_"+key+"_locations.csv",
),
index=False,
)
# After removing close points, calculate again the detection metrics to see the effect of this post-processing step on the metrics.
if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
points_available[key]["gt_assoc"], points_available[key]["fps"] = self.calculate_synapse_det_metrics_on_points(
points_available[key]["gt"],
points_available[key]["points"],
gt_info["resolution"],
filename,
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING,
point_type=key,
post_processing=True
)
# If both pre and post points are available and the post-processing step was done we need to connect again the points to create the
# synapse connectivity.
if self.post_processing["per_image"] and "pre" in points_available and "post" in points_available:
connect_pre_post_synapse_points_by_distance(
pre_points_df=points_available["pre"]["df"],
pre_points=points_available["pre"]["points"],
post_points_df=points_available["post"]["df"],
post_points=points_available["post"]["points"],
out_dir=self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING,
)
if self.cfg.TEST.BY_CHUNKS.SAVE_OUT_TIF:
print("Preparing prediction and GT tiffs as auxiliary images for checking the output. . .")
sshape = list(self.current_sample["X"].shape)
channels = len(points_available) if len(points_available) > 0 else 1
assert len(sshape) >= 3
if len(sshape) == 3:
sshape += [channels]
else:
sshape[-1] = channels
# Create a tif with the predicted points, coloring them based on their ID in the dataframe (if available) and dilating them to make them more visible.
# We create one channel per type of point (pre, post, cleft).
aux_tif = np.zeros(sshape, dtype=np.uint16)
for i, key in enumerate(points_available):
point_df = points_available[key]["df"]
# Paint points
if point_df is not None:
point_ids = point_df[f"{key}_id"].to_list()
assert len(points_available[key]["points"]) == len(point_ids)
for j, cor in enumerate(points_available[key]["points"]):
z, y, x = int(cor[0]), int(cor[1]), int(cor[2])
aux_tif[z, y, x, i] = point_ids[j]
aux_tif[z, y, x, i] = point_ids[j]
# Dilate points to make them more visible in the tif
aux_tif[..., i] = dilation(aux_tif[..., i], ball(3))
out_dir = (
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING
if self.post_processing["per_image"]
else self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK
)
save_tif(
np.expand_dims(aux_tif, 0),
out_dir,
[str(filename) + "_points.tif"],
verbose=self.cfg.TEST.VERBOSE,
)
# Create another tif with the GT points, coloring them based on their ID in the dataframe (if available) and dilating them to make them more visible.
aux_tif = np.zeros(sshape, dtype=np.uint16)
for i, key in enumerate(points_available):
points = points_available[key]["gt"]
if len(points) > 0:
for j, coord in enumerate(points):
z, y, x = int(coord[0])-1, int(coord[1])-1, int(coord[2])-1
aux_tif[z, y, x, i] = j+1
aux_tif[..., i] = dilation(aux_tif[..., i], ball(3))
save_tif(
np.expand_dims(aux_tif, 0),
out_dir,
[str(filename) + "_gt_ids.tif"],
verbose=self.cfg.TEST.VERBOSE,
)
# Create another tif with the predicted points colored in green if they are TP and in red if they are FP, and with the GT points colored in blue.
# This is useful to visually check the quality of the predictions and the errors. We do this only if GT is available, otherwise we don't know which
# predicted points are TP or FP. We create one image per type of point (pre, post, cleft).
if (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST):
for i, key in enumerate(points_available):
if key not in ["pre", "post", "cleft"]:
raise ValueError(f"Unknown point type {key} found in points_available. Expected 'pre', 'post' or 'cleft'.")
assert "gt" in points_available[key], f"GT not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
assert "gt_assoc" in points_available[key], f"GT association not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
assert "fps" in points_available[key], f"FPs not found for key {key} in points_available. Found keys: {points_available[key].keys()}"
aux_tif = np.zeros(sshape[:-1] + [3,], dtype=np.uint8)
print(f"Creating the image with a summary of detected points and false positives with colors ({key}-points) . . .")
print(f"Painting TPs and FNs ({key}-points) . . .")
for j, cor in tqdm(enumerate(points_available[key]["gt"]), total=len(points_available[key]["gt"])):
z, y, x = int(cor[0])-1, int(cor[1])-1, int(cor[2])-1
tag = points_available[key]["gt_assoc"][points_available[key]["gt_assoc"]["gt_id"]==j+1]["tag"].iloc[0]
color = (0, 255, 0) if tag == "TP" else (255, 0, 0) # Green or red
try:
aux_tif[z, y, x] = color
except:
pass
print(f"Painting FPs ({key}-points) . . .")
for index, row in tqdm(points_available[key]["fps"].iterrows(), total=len(points_available[key]["fps"])):
z,y,x = int(row['axis-0']), int(row['axis-1']), int(row['axis-2'])
try:
aux_tif[z, y, x] = (0,0,255) # Blue
except:
pass
print(f"Dilating points ({key}-points) . . .")
for c in range(aux_tif.shape[-1]):
aux_tif[..., c] = dilation(aux_tif[..., c], ball(3))
save_tif(
np.expand_dims(aux_tif, 0),
out_dir,
[str(filename) + f"_{key}_point_assoc.tif"],
verbose=self.cfg.TEST.VERBOSE,
)
[docs]
def after_full_image(self, pred: NDArray):
"""
Execute steps needed after generating the prediction by supplying the entire image to the model.
Parameters
----------
pred : NDArray
Model prediction.
"""
if pred.shape[0] == 1:
if self.cfg.PROBLEM.NDIM == "3D":
pred = pred[0]
if not self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
r, r_post, rcls, rcls_post = self.instance_seg_process(
pred,
[self.current_sample["X_filename"]],
self.cfg.PATHS.RESULT_DIR.FULL_IMAGE_INSTANCES,
self.cfg.PATHS.RESULT_DIR.FULL_IMAGE_POST_PROCESSING,
)
if r:
self.all_matching_stats.append(r)
if r_post:
self.all_matching_stats_post.append(r_post)
if rcls:
self.all_class_stats.append(rcls)
if rcls_post:
self.all_class_stats_post.append(rcls_post)
else: # synapses
self.synapse_seg_process(
pred,
[self.current_sample["X_filename"]],
self.cfg.PATHS.RESULT_DIR.FULL_IMAGE_INSTANCES,
self.cfg.PATHS.RESULT_DIR.FULL_IMAGE_POST_PROCESSING,
)
else:
raise NotImplementedError
[docs]
def after_all_images(self):
"""Execute steps needed after predicting all images."""
super().after_all_images()
assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list)
if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK:
print("Analysing all images as a 3D stack . . .")
self.all_pred = np.concatenate(self.all_pred)
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
r, r_post, rcls, rcls_post = self.instance_seg_process(
self.all_pred,
["3D_stack_instances.tif"],
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK,
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_POST_PROCESSING,
)
if r:
self.all_matching_stats_as_3D_stack.append(r)
if r_post:
self.all_matching_stats_as_3D_stack_post.append(r_post)
if rcls:
self.all_class_stats_as_3D_stack.append(rcls)
if rcls_post:
self.all_class_stats_as_3D_stack_post.append(rcls_post)
else: # synapses
self.synapse_seg_process(
self.all_pred,
["3D_stack_instances.tif"],
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK,
self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_POST_PROCESSING,
)
[docs]
def normalize_stats(self, image_counter):
"""
Normalize statistics.
Parameters
----------
image_counter : int
Number of images to average the metrics.
"""
super().normalize_stats(image_counter)
if self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
if self.cfg.TEST.MATCHING_STATS:
# Merge patches
if len(self.all_matching_stats_merge_patches) > 0:
self.stats["inst_stats_merge_patches"] = wrapper_matching_dataset_lazy(
self.all_matching_stats_merge_patches,
self.cfg.TEST.MATCHING_STATS_THS,
)
# As 3D stack
if len(self.all_matching_stats_as_3D_stack) > 0:
self.stats["inst_stats_as_3D_stack"] = wrapper_matching_dataset_lazy(
self.all_matching_stats_as_3D_stack,
self.cfg.TEST.MATCHING_STATS_THS,
)
# Full image
if len(self.all_matching_stats) > 0:
self.stats["inst_stats"] = wrapper_matching_dataset_lazy(
self.all_matching_stats, self.cfg.TEST.MATCHING_STATS_THS
)
if self.post_processing["instance_post"]:
# Merge patches
if len(self.all_matching_stats_merge_patches_post) > 0:
self.stats["inst_stats_merge_patches_post"] = wrapper_matching_dataset_lazy(
self.all_matching_stats_merge_patches_post,
self.cfg.TEST.MATCHING_STATS_THS,
)
# As 3D stack
if len(self.all_matching_stats_as_3D_stack_post) > 0:
self.stats["inst_stats_as_3D_stack_post"] = wrapper_matching_dataset_lazy(
self.all_matching_stats_as_3D_stack_post,
self.cfg.TEST.MATCHING_STATS_THS,
)
# Full image
if len(self.all_matching_stats_post) > 0:
self.stats["inst_stats_post"] = wrapper_matching_dataset_lazy(
self.all_matching_stats_post,
self.cfg.TEST.MATCHING_STATS_THS,
)
# Multi-head: instances + classification
if self.separated_class_channel:
# Merge patches
if len(self.all_class_stats_merge_patches) > 0:
self.stats["class_stats_merge_patches"] = np.mean(self.all_class_stats_merge_patches)
# As 3D stack
if len(self.all_class_stats_as_3D_stack) > 0:
self.stats["class_stats_as_3D_stack"] = np.mean(self.all_class_stats_as_3D_stack)
# Full image
if len(self.all_class_stats) > 0:
self.stats["class_stats"] = np.mean(self.all_class_stats)
if self.post_processing["instance_post"]:
# Merge patches
if len(self.all_class_stats_merge_patches_post) > 0:
self.stats["class_stats_merge_patches_post"] = np.mean(
self.all_class_stats_merge_patches_post
)
# As 3D stack
if len(self.all_class_stats_as_3D_stack_post) > 0:
self.stats["class_stats_as_3D_stack_post"] = np.mean(self.all_class_stats_as_3D_stack_post)
# Full image
if len(self.all_class_stats_post) > 0:
self.stats["class_stats_post"] = np.mean(self.all_class_stats_post)
[docs]
def print_stats(self, image_counter):
"""
Print statistics.
Parameters
----------
image_counter : int
Number of images to call ``normalize_stats``.
"""
if self.cfg.MODEL.SOURCE != "torchvision":
super().print_stats(image_counter)
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
print("Instance segmentation specific metrics:")
if self.cfg.TEST.MATCHING_STATS and (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST):
for i in range(len(self.cfg.TEST.MATCHING_STATS_THS)):
if self.cfg.PROBLEM.INSTANCE_SEG.TYPE == "regular":
print("IoU TH={}".format(self.cfg.TEST.MATCHING_STATS_THS[i]))
# Merge patches
if self.stats["inst_stats_merge_patches"]:
print(" Merge patches:")
print(f" {self.stats['inst_stats_merge_patches'][i]}")
# As 3D stack
if self.stats["inst_stats_as_3D_stack"]:
print(" As 3D stack:")
print(f" {self.stats['inst_stats_as_3D_stack'][i]}")
# Full image
if self.stats["inst_stats"]:
print(" Full image:")
print(f" {self.stats['inst_stats'][i]}")
if self.post_processing["instance_post"]:
print("IoU (post-processing) TH={}".format(self.cfg.TEST.MATCHING_STATS_THS[i]))
# Merge patches
if self.stats["inst_stats_merge_patches_post"]:
print(" Merge patches (post-processing):")
print(f" {self.stats['inst_stats_merge_patches_post'][i]}")
# As 3D stack
if self.stats["inst_stats_as_3D_stack_post"]:
print(" As 3D stack (post-processing):")
print(f" {self.stats['inst_stats_as_3D_stack_post'][i]}")
# Full image
if self.stats["inst_stats_post"]:
print(" Full image (post-processing):")
print(f" {self.stats['inst_stats_post'][i]}")
# Multi-head: instances + classification
if self.separated_class_channel:
# Merge patches
if self.stats["class_stats_merge_patches"]:
print(f" Merge patches classification IoU: {self.stats['class_stats_merge_patches']}")
# As 3D stack
if self.stats["class_stats_as_3D_stack"]:
print(f" As 3D stack classification IoU: {self.stats['class_stats_as_3D_stack']}")
# Full image
if self.stats["class_stats"]:
print(f" Full image classification IoU: {self.stats['class_stats']}")
if self.post_processing["instance_post"]:
# Merge patches
if self.stats["class_stats_merge_patches_post"]:
print(
f" Merge patches classification IoU (post-processing): {self.stats['class_stats_merge_patches_post']}"
)
# As 3D stack
if self.stats["class_stats_as_3D_stack_post"]:
print(
f" As 3D stack classification IoU (post-processing): {self.stats['class_stats_as_3D_stack_post']}"
)
# Full image
if self.stats["class_stats_post"]:
print(
f" Full image classification IoU (post-processing): {self.stats['class_stats_post']}"
)
[docs]
def prepare_instance_data(self):
"""
Create instance segmentation ground truth images to train the model based on the ground truth instances provided.
They will be saved in a separate folder in the root path of the ground truth.
"""
original_test_path, original_test_mask_path = None, None
train_channel_mask_dir = self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_MASK_DIR
val_channel_mask_dir = self.cfg.DATA.VAL.INSTANCE_CHANNELS_MASK_DIR
test_channel_mask_dir = self.cfg.DATA.TEST.INSTANCE_CHANNELS_MASK_DIR
if not self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA:
test_instance_mask_dir = self.cfg.DATA.TEST.GT_PATH
else:
test_instance_mask_dir = self.cfg.DATA.TEST.PATH
opts = []
print("###########################")
print("# PREPARE INSTANCE DATA #")
print("###########################")
# Create selected channels for train data
if self.cfg.TRAIN.ENABLE or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
if not os.path.isdir(train_channel_mask_dir):
# Barrier need as some of the threads may check the existence of the folder after it is created
if is_dist_avail_and_initialized():
dist.barrier()
print(
"You select to create {} channels from given instance labels and no file is detected in {} . "
"So let's prepare the data. This process will be done just once!".format(
self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, train_channel_mask_dir
)
)
create_instance_channels(self.cfg)
# Change the value of DATA.TRAIN.INPUT_MASK_AXES_ORDER as we have created the instance mask and maybe the user doesn't
# know the data order that is created.
if self.cfg.DATA.TRAIN.INPUT_ZARR_MULTIPLE_DATA:
out_data_order = self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER
if "C" not in self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER:
out_data_order += "C"
print(
"DATA.TRAIN.INPUT_MASK_AXES_ORDER changed from {} to {}".format(
self.cfg.DATA.TRAIN.INPUT_MASK_AXES_ORDER, out_data_order
)
)
opts.extend([f"DATA.TRAIN.INPUT_MASK_AXES_ORDER", out_data_order])
# Create selected channels for val data
if self.cfg.TRAIN.ENABLE and not self.cfg.DATA.VAL.FROM_TRAIN:
if not os.path.isdir(val_channel_mask_dir):
# Barrier need as some of the threads may check the existence of the folder after it is created
if is_dist_avail_and_initialized():
dist.barrier()
print(
"You select to create {} channels from given instance labels and no file is detected in {} . "
"So let's prepare the data. This process will be done just once!".format(
self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, val_channel_mask_dir
)
)
create_instance_channels(self.cfg, data_type="val")
# Change the value of DATA.VAL.INPUT_MASK_AXES_ORDER as we have created the instance mask and maybe the user doesn't
# know the data order that is created.
if self.cfg.DATA.VAL.INPUT_ZARR_MULTIPLE_DATA:
out_data_order = self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER
if "C" not in self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER:
out_data_order += "C"
print(
"DATA.VAL.INPUT_MASK_AXES_ORDER changed from {} to {}".format(
self.cfg.DATA.VAL.INPUT_MASK_AXES_ORDER, out_data_order
)
)
opts.extend([f"DATA.VAL.INPUT_MASK_AXES_ORDER", out_data_order])
# Create selected channels for test data once
if self.cfg.TEST.ENABLE and not self.cfg.DATA.TEST.USE_VAL_AS_TEST and self.cfg.DATA.TEST.LOAD_GT:
if not os.path.isdir(test_channel_mask_dir):
# Barrier need as some of the threads may check the existence of the folder after it is created
if is_dist_avail_and_initialized():
dist.barrier()
print(
"You select to create {} channels from given instance labels and no file is detected in {} . "
"So let's prepare the data. This process will be done just once!".format(
self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS,
test_channel_mask_dir,
)
)
create_instance_channels(self.cfg, data_type="test")
# Change the value of DATA.TEST.INPUT_MASK_AXES_ORDER as we have created the instance mask and maybe the user doesn't
# know the data order that is created.
if self.cfg.DATA.TEST.INPUT_ZARR_MULTIPLE_DATA:
out_data_order = self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER
if "C" not in self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER:
out_data_order += "C"
print(
"DATA.TEST.INPUT_MASK_AXES_ORDER changed from {} to {}".format(
self.cfg.DATA.TEST.INPUT_MASK_AXES_ORDER, out_data_order
)
)
opts.extend([f"DATA.TEST.INPUT_MASK_AXES_ORDER", out_data_order])
if is_dist_avail_and_initialized():
dist.barrier()
if self.cfg.TRAIN.ENABLE:
if self.cfg.DATA.TRAIN.GT_PATH != train_channel_mask_dir:
print(
"DATA.TRAIN.GT_PATH changed from {} to {}".format(
self.cfg.DATA.TRAIN.GT_PATH, train_channel_mask_dir
)
)
opts.extend(
[
"DATA.TRAIN.GT_PATH",
train_channel_mask_dir,
]
)
if not self.cfg.DATA.VAL.FROM_TRAIN:
if self.cfg.DATA.VAL.GT_PATH != val_channel_mask_dir:
print("DATA.VAL.GT_PATH changed from {} to {}".format(self.cfg.DATA.VAL.GT_PATH, val_channel_mask_dir))
opts.extend(
[
"DATA.VAL.GT_PATH",
val_channel_mask_dir,
]
)
if self.cfg.TEST.ENABLE and not self.cfg.DATA.TEST.USE_VAL_AS_TEST:
if self.cfg.DATA.TEST.LOAD_GT:
if self.cfg.DATA.TEST.GT_PATH != test_channel_mask_dir:
print(
"DATA.TEST.GT_PATH changed from {} to {}".format(
self.cfg.DATA.TEST.GT_PATH, test_channel_mask_dir
)
)
opts.extend(["DATA.TEST.GT_PATH", test_channel_mask_dir])
original_test_path = self.cfg.DATA.TEST.PATH
original_test_mask_path = test_instance_mask_dir
self.cfg.merge_from_list(opts)
return original_test_path, original_test_mask_path
[docs]
def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None:
"""
Call a regular Pytorch model.
Parameters
----------
in_img : torch.Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : torch.Tensor
Image prediction.
"""
assert self.torchvision_preprocessing and self.model
filename, file_extension = os.path.splitext(self.current_sample["X_filename"])
# Convert first to 0-255 range if uint16
if in_img.dtype == torch.float32:
if torch.max(in_img) > 1:
in_img = (self.torchvision_norm.apply_image_norm(in_img)[0] * 255).to(torch.uint8) # type: ignore
in_img = in_img.to(torch.uint8)
# Apply TorchVision pre-processing
in_img = self.torchvision_preprocessing(in_img)
pred = self.model(in_img)
masks = pred[0]["masks"].cpu().numpy().transpose(0, 2, 3, 1)
if masks.shape[0] != 0:
masks = np.argmax(pred[0]["masks"].cpu().numpy().transpose(0, 2, 3, 1), axis=0)
else:
masks = torch.ones(
(1,) + pred[0]["masks"].cpu().numpy().transpose(0, 2, 3, 1).shape[1:],
dtype=torch.uint8,
)
if not is_train and masks.shape[0] != 0:
# Extract each output from MaskRCNN
bboxes = pred[0]["boxes"].cpu().numpy().astype(np.uint16)
labels = pred[0]["labels"].cpu().numpy()
scores = pred[0]["scores"].cpu().numpy()
# Save all info in a csv file
df = pd.DataFrame(
zip(
labels,
scores,
bboxes[:, 0],
bboxes[:, 1],
bboxes[:, 2],
bboxes[:, 3],
),
columns=["label", "scores", "x1", "y1", "x2", "y2"],
)
df = df.sort_values(by=["label"])
df.to_csv(
os.path.join(self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, filename + ".csv"),
index=False,
)
# Save masks
save_tif(
np.expand_dims(masks, 0),
self.cfg.PATHS.RESULT_DIR.FULL_IMAGE,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
return None