biapy.engine.instance_seg
Instance segmentation workflow for BiaPy.
This module defines the Instance_Segmentation_Workflow class, which implements the training, validation, and inference pipeline for instance segmentation tasks in BiaPy. It handles data preparation, model setup, metrics, predictions, post-processing, and result saving for assigning unique IDs to each object in 2D and 3D images.
- class biapy.engine.instance_seg.Instance_Segmentation_Workflow(cfg, job_identifier, device, system_dict, args, **kwargs)[source]
Bases:
Base_WorkflowInstance segmentation workflow where the goal is to assign an unique id, i.e. integer, to each object of the input image.
More details in our documentation.
- Parameters:
cfg (YACS configuration) – Running configuration.
Job_identifier (str) – Complete name of the running job.
device (Torch device) – Device used.
args (argpase class) – Arguments used in BiaPy’s call.
- define_activations_and_channels()[source]
Define the activations to be applied to the model output and the channels that the model will output.
This function must define the following variables:
- self.model_output_channelsList of int
Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels, [1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc.
- self.model_output_channel_infoList of str
Information about the output channels. A value per output head of the model must be defined.
- self.separated_class_channelbool
Whether if we should expect a separated output channel for classification.
- self.head_activationsList of str
Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined. “linear” and “ce_sigmoid” will not be applied. E.g. [“linear”] for a model with one channel, [“linear”, “sigmoid”] for a model with two channels, etc.
Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would be correct:
self.model_output_channels = [1, 3] self.model_output_channel_info = ["mask", "class"] self.separated_class_channel = True self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"]
- define_metrics()[source]
Define the metrics to be used in the instance segmentation workflow.
This function must define the following variables:
- self.train_metricsList of functions
Metrics to be calculated during model’s training.
- self.train_metric_namesList of str
Names of the metrics calculated during training.
- self.train_metric_bestList of str
To know which value should be considered as the best one. Options must be: “max” or “min”.
- self.test_metricsList of functions
Metrics to be calculated during model’s test/inference.
- self.test_metric_namesList of str
Names of the metrics calculated during test/inference.
- self.lossFunction
Loss function used during training and test.
- metric_calculation(output: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, targets: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, train: bool = True, metric_logger: MetricLogger | None = None) Dict[source]
Calculate the metrics defined in
define_metrics()function.- Parameters:
output (Torch Tensor) – Prediction of the model.
targets (Torch Tensor) – Ground truth to compare the prediction with.
train (bool, optional) – Whether to calculate train or test metrics.
metric_logger (MetricLogger, optional) – Class to be updated with the new metric(s) value(s) calculated.
- Returns:
out_metrics – Value of the metrics for the given prediction.
- Return type:
dict
- instance_seg_process(pred, filenames, out_dir, out_dir_post_proc, calculate_metrics: bool = True)[source]
Instance segmentation workflow engine for test/inference.
Process model’s prediction to prepare instance segmentation output and calculate metrics.
- Parameters:
pred (4D/5D Torch tensor) – Model predictions. E.g.
(z, y, x, channels)for both 2D and 3D.filenames (List of str) – Predicted image’s filenames.
out_dir (path) – Output directory to save the instances.
out_dir_post_proc (path) – Output directory to save the post-processed instances.
calculate_metrics (bool, optional) – Whether to calculate or not the metrics.
- synapse_seg_process(pred: ndarray[tuple[int, ...], dtype[_ScalarType_co]], filenames: List[str] | None = None, out_dir: str | None = None, out_dir_post_proc: str | None = None, calculate_metrics: bool = True, do_post_processing: bool = True) Dict[source]
Synapse segmentation workflow engine for test/inference.
Process model’s prediction to prepare synapse segmentation output and calculate metrics.
- Parameters:
pred (4D/5D Torch tensor) – Model predictions. E.g.
(z, y, x, channels)for both 2D and 3D.filenames (List of str) – Predicted image’s filenames.
out_dir (path) – Output directory to save the instances.
out_dir_post_proc (str) – Output directory to save the post-processed instances.
calculate_metrics (bool, optional) – Whether to calculate or not the metrics. Normally we disable it when doring inference per chunks as the metrics are calculated at the end on the whole image.
do_post_processing (bool) – Whether to do or not the post-processing step. Normally we disable it when doring inference per chunks as the post-processing is done at the end on the whole image.
- Returns:
A dictionary containing the predicted synapse-related points.
- Return type:
Dict[str, Any]
- calculate_synapse_det_metrics_on_points(gt_points: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | List[int], pred_points: ndarray[tuple[int, ...], dtype[_ScalarType_co]], resolution: List[float | int], filename: str, out_dir: str, point_type: str = 'pre', post_processing: bool = False) Tuple[DataFrame, DataFrame][source]
Calculate synapse detection metrics on the predicted points and save the associations between GT points and predicted ones.
- Parameters:
gt_points (np.array or list of int) – Ground truth synapse points.
pred_points (np.array) – Predicted synapse points.
resolution (list) – Image resolution.
filename (str) – Filename of the predicted image.
out_dir (str) – Output directory to save the csv files with the associations between GT points and predicted ones.
point_type (str) – Type of synaptic point to calculate the metrics on. E.g. “pre” or “post”.
post_processing (bool) – Whether the predicted points are from the post-processing step or not. Used for printing and saving the results.
- Returns:
gt_assoc (pd.DataFrame) – DataFrame with the associations between GT points and predicted ones.
fps (pd.DataFrame) – DataFrame with the false positive predicted points.
- after_merge_patches(pred)[source]
Execute steps needed after merging all predicted patches into the original image.
- Parameters:
pred (Torch Tensor) – Model prediction.
- after_one_chunk_raw_prediction(chunk_id: int, chunk: ndarray[tuple[int, ...], dtype[_ScalarType_co]], chunk_in_data: PatchCoords, added_pad: List[List[int]])[source]
Place any code that needs to be done after predicting one chunk of data in “by chunks” setting.
- Parameters:
chunk_id (int) – Chunk identifier.
chunk (NDArray) – Predicted chunk
chunk_in_data (PatchCoords) – Global coordinates of the chunk.
added_pad (List of list of ints) – Padding added to the chunk in each dimension. The order of dimensions is the same as the input image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], ….
- after_all_chunk_prediction_workflow_process()[source]
Place any code that needs to be done after predicting all patches in “by chunks” setting. This function is called on all ranks.
For
PROBLEM.INSTANCE_SEG.TYPE == "regular"andTEST.BY_CHUNKS.WORKFLOW_PROCESS.TYPE == "chunk_by_chunk"this runs five passes:Per-chunk instance labelling via
after_one_chunk_workflow_process()(base-class loop).Global-offset assignment — each chunk k adds
k * MAX_INSTANCES_PER_CHUNKto every non-zero label so that IDs are unique across the whole volume.Boundary-edge extraction — for every pair of spatially adjacent chunks we read the shared boundary face and collect pairs of IDs that co-occur, indicating the same physical instance was split by the tile boundary.
Union-Find on rank 0 to resolve connected components; the resulting remap is broadcast to all ranks.
Relabelling — every chunk is rewritten with the canonical global ID for each instance.
- after_one_chunk_workflow_process(chunks: List[ndarray[tuple[int, ...], dtype[_ScalarType_co]]], patch_in_data: List) List[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None[source]
Process a list of chunks during inference in “by chunks” setting. Each workflow should have its own implementation of this method.
- Parameters:
chunks (List[NDArray]) – List of chunks. Expected axes are:
(z, y, x, channels)for 3D and(y, x, channels)for 2D.patch_in_data (List[PatchCoords]) – Spatial coordinates of each chunk in the full volume.
- Returns:
chunks – Processed chunks.
- Return type:
Optional[List[NDArray]]
- after_all_chunk_prediction_workflow_process_master_rank()[source]
Execute steps needed after merging all predicted patches into the original image in “by chunks” setting.
- after_full_image(pred: ndarray[tuple[int, ...], dtype[_ScalarType_co]])[source]
Execute steps needed after generating the prediction by supplying the entire image to the model.
- Parameters:
pred (NDArray) – Model prediction.
- normalize_stats(image_counter)[source]
Normalize statistics.
- Parameters:
image_counter (int) – Number of images to average the metrics.
- print_stats(image_counter)[source]
Print statistics.
- Parameters:
image_counter (int) – Number of images to call
normalize_stats.
- prepare_instance_data()[source]
Create instance segmentation ground truth images to train the model based on the ground truth instances provided.
They will be saved in a separate folder in the root path of the ground truth.
- torchvision_model_call(in_img: Tensor, is_train: bool = False) Tensor | None[source]
Call a regular Pytorch model.
- Parameters:
in_img (torch.Tensor) – Input image to pass through the model.
is_train (bool, optional) – Whether if the call is during training or inference.
- Returns:
prediction – Image prediction.
- Return type:
torch.Tensor