biapy.engine.semantic_seg

Semantic segmentation workflow for BiaPy.

This module defines the Semantic_Segmentation_Workflow class, which implements the training, validation, and inference pipeline for semantic segmentation tasks in BiaPy. It handles data preparation, model setup, metrics, predictions, post-processing, and result saving for assigning a class to each pixel in 2D and 3D images.

class biapy.engine.semantic_seg.Semantic_Segmentation_Workflow(cfg, job_identifier, device, system_dict, args, **kwargs)[source]

Bases: Base_Workflow

Semantic segmentation workflow where the goal is to assign a class to each pixel of the input image.

More details in our documentation.

Parameters:
  • cfg (YACS configuration) – Running configuration.

  • Job_identifier (str) – Complete name of the running job.

  • device (Torch device) – Device used.

  • args (argpase class) – Arguments used in BiaPy’s call.

define_activations_and_channels()[source]

Define the activations to be applied to the model output and the channels that the model will output.

This function must define the following variables:

self.model_output_channelsList of int

Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels, [1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc.

self.model_output_channel_infoList of str

Information about the output channels. A value per output head of the model must be defined.

self.separated_class_channelbool

Whether if we should expect a separated output channel for classification.

self.head_activationsList of str

Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined. “linear” and “ce_sigmoid” will not be applied. E.g. [“linear”] for a model with one channel, [“linear”, “sigmoid”] for a model with two channels, etc.

Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would be correct:

self.model_output_channels = [1, 3]
self.model_output_channel_info = ["mask", "class"]
self.separated_class_channel = True
self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"]
define_metrics()[source]

Define the metrics to be calculated during training and test/inference.

This function must define the following variables:

self.train_metricsList of functions

Metrics to be calculated during model’s training.

self.train_metric_namesList of str

Names of the metrics calculated during training.

self.train_metric_bestList of str

To know which value should be considered as the best one. Options must be: “max” or “min”.

self.test_metricsList of functions

Metrics to be calculated during model’s test/inference.

self.test_metric_namesList of str

Names of the metrics calculated during test/inference.

self.lossFunction

Loss function used during training and test.

process_test_sample()[source]

Process a sample in the inference phase.

torchvision_model_call(in_img: Tensor, is_train=False) Tensor[source]

Call a regular Pytorch model.

Parameters:
  • in_img (torch.Tensor) – Input image to pass through the model.

  • is_train (bool, optional) – Whether if the call is during training or inference.

Returns:

prediction – Image prediction.

Return type:

torch.Tensor

metric_calculation(output: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, targets: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, train: bool = True, metric_logger: MetricLogger | None = None) Dict[source]

Calculate the metrics defined in define_metrics() function.

Parameters:
  • output (Torch Tensor) – Prediction of the model.

  • targets (Torch Tensor) – Ground truth to compare the prediction with.

  • train (bool, optional) – Whether to calculate train or test metrics.

  • metric_logger (MetricLogger, optional) – Class to be updated with the new metric(s) value(s) calculated.

Returns:

out_metrics – Value of the metrics for the given prediction.

Return type:

dict

prepare_targets(targets, batch)[source]

Prepare the targets for the loss calculation.

This function is used to convert the targets to the correct format and device, ensuring they match the model’s expected input format.

Parameters:
  • targets (Torch Tensor) – Ground truth to compare the prediction with.

  • batch (Torch Tensor) – Prediction of the model. Only used in SSL workflow.

Returns:

targets – Resulting targets.

Return type:

Torch tensor

after_merge_patches(pred)[source]

Execute steps needed after merging all predicted patches into the original image.

Parameters:

pred (Torch Tensor) – Model prediction.

after_full_image(pred: ndarray[tuple[int, ...], dtype[_ScalarType_co]])[source]

Execute steps needed after generating the prediction by supplying the entire image to the model.

Parameters:

pred (NDArray) – Model prediction.

after_all_images()[source]

Execute steps needed after predicting all images.

after_one_chunk_raw_prediction(chunk_id: int, chunk: ndarray[tuple[int, ...], dtype[_ScalarType_co]], chunk_in_data: PatchCoords, added_pad: List[List[int]])[source]

Place any code that needs to be done after predicting one chunk of data in “by chunks” setting.

Parameters:
  • chunk_id (int) – Chunk identifier.

  • chunk (NDArray) – Predicted chunk

  • patch_in_data (PatchCoords) – Global coordinates of the chunk.

  • added_pad (List of list of ints) – Padding added to the chunk in each dimension. The order of dimensions is the same as the input image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], ….

after_one_chunk_workflow_process(chunks: List[ndarray[tuple[int, ...], dtype[_ScalarType_co]]], patch_in_data: List) List[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None[source]

Process a list of chunks during inference in “by chunks” setting. Each workflow should have its own implementation of this method.

Parameters:

chunks (List[NDArray]) – List of chunks. Expected axes are: (z, y, x, channels) for 3D and (y, x, channels) for 2D.

Returns:

chunks – Processed chunks.

Return type:

Optional[List[NDArray]]

after_all_chunk_prediction_workflow_process_master_rank()[source]

Place any code that needs to be done after predicting all the patches in the “by chunks” setting. This function is only called on the master rank.