biapy.engine.base_workflow

Base workflow class for BiaPy.

This module defines the Base_Workflow abstract class, which provides the main structure and utility methods for building training and inference workflows in BiaPy. It handles configuration, model preparation, data loading, training, testing, metrics, logging, and post-processing for both 2D and 3D biomedical image analysis.

class biapy.engine.base_workflow.Base_Workflow(cfg: CfgNode, job_identifier: str, device: device, system_dict: Dict[str, int], args: Namespace)[source]

Bases: object

Base workflow class. A new workflow should extend this class.

Parameters:
  • cfg (YACS configuration) – Running configuration.

  • Job_identifier (str) – Complete name of the running job.

  • device (Torch device) – Device used.

  • args (argpase class) – Arguments used in BiaPy’s call.

define_activations_and_channels()[source]

Define the activations to be applied to the model output and the channels that the model will output.

This function must define the following variables:

self.model_output_channelsList of int

Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels, [1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc.

self.model_output_channel_infoList of str

Information about the output channels. A value per output head of the model must be defined.

self.separated_class_channelbool

Whether if we should expect a separated output channel for classification.

self.head_activationsList of str

Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined. β€œlinear” and β€œce_sigmoid” will not be applied. E.g. [β€œlinear”] for a model with one channel, [β€œlinear”, β€œsigmoid”] for a model with two channels, etc.

Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would be correct:

self.model_output_channels = [1, 3]
self.model_output_channel_info = ["mask", "class"]
self.separated_class_channel = True
self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"]
define_metrics()[source]

Define the metrics to be calculated during training and test.

This function must define the following variables:

self.train_metricsList of functions

Metrics to be calculated during model’s training.

self.train_metric_namesList of str

Names of the metrics calculated during training.

self.train_metric_bestList of str

To know which value should be considered as the best one. Options must be: β€œmax” or β€œmin”.

self.test_metricsList of functions

Metrics to be calculated during model’s test/inference.

self.test_metric_namesList of str

Names of the metrics calculated during test/inference.

self.lossFunction

Loss function used during training and test.

abstract metric_calculation(output: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, targets: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, train: bool = True, metric_logger: MetricLogger | None = None) Dict[source]

Execute the calculation of metrics defined in define_metrics() function.

Parameters:
  • output (Torch Tensor) – Prediction of the model.

  • targets (Torch Tensor) – Ground truth to compare the prediction with.

  • train (bool, optional) – Whether to calculate train or test metrics.

  • metric_logger (MetricLogger, optional) – Class to be updated with the new metric(s) value(s) calculated.

Returns:

value – Value of the metric for the given prediction.

Return type:

float

prepare_targets(targets, batch)[source]

Location to perform any necessary data transformations to targets before calculating the loss.

Parameters:
  • targets (Torch Tensor) – Ground truth to compare the prediction with.

  • batch (Torch Tensor) – Prediction of the model. Only used in SSL workflow.

Returns:

targets – Resulting targets.

Return type:

Torch tensor

load_train_data()[source]

Load training and validation data.

destroy_train_data()[source]

Delete training variables to release memory.

prepare_train_generators()[source]

Build training and validation generators.

bmz_model_call(in_img, is_train=False)[source]

Call BioImage Model Zoo model.

Parameters:
  • in_img (torch.Tensor) – Input image to pass through the model.

  • is_train (bool, optional) – Whether if the call is during training or inference.

Returns:

prediction – Image prediction.

Return type:

torch.Tensor

abstract torchvision_model_call(in_img: Tensor, is_train=False) Tensor[source]

Call a regular Pytorch model.

Parameters:
  • in_img (torch.Tensor) – Input image to pass through the model.

  • is_train (bool, optional) – Whether if the call is during training or inference.

Returns:

prediction – Image prediction.

Return type:

torch.Tensor

model_call_func(in_img: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | Tensor, is_train: bool = False, apply_act: bool = True) Any[source]

Call a regular Pytorch model.

Parameters:
  • in_img (torch.Tensor) – Input image to pass through the model.

  • is_train (bool, optional) – Whether if the call is during training or inference.

  • apply_act (bool, optional) – Whether to apply activations or not.

Returns:

prediction – Image prediction.

Return type:

torch.Tensor

prepare_model()[source]

Build the model.

prepare_logging_tool()[source]

Prepare looging tool.

train()[source]

Training phase.

load_test_data()[source]

Load test data.

destroy_test_data()[source]

Delete test variable to release memory.

prepare_test_generators()[source]

Prepare test data generator.

apply_model_activations(pred: Tensor | Dict, training=False) Tensor | Dict[source]

Apply the last activation (if any) to the model’s output.

Parameters:
  • pred (Torch Tensor) – Predictions of the model.

  • training (bool, optional) – To advice the function if this is being applied during training of inference. During training, ce_sigmoid activations will NOT be applied, as torch.nn.BCEWithLogitsLoss will apply Sigmoid automatically in a way that is more stable numerically (ref).

Returns:

pred – Resulting predictions after applying last activation(s).

Return type:

Torch tensor

test()[source]

Test/Inference step.

predict_batches_in_test(x_batch: ndarray[tuple[int, ...], dtype[_ScalarType_co]], y_batch: ndarray[tuple[int, ...], dtype[_ScalarType_co]] | None, stats_name='per_crop', disable_tqdm: bool = False) ndarray[tuple[int, ...], dtype[_ScalarType_co]][source]

Predict data for the test phase.

Parameters:
  • x_batch (NDArray) – X data. Expected axes are: (num_patches, z, y, x, channels) for 3D and (num_patches, y, x, channels) for 2D.

  • y_batch (NDArray) – Y data. Expected axes are: (num_patches, z, y, x, channels) for 3D and (num_patches, y, x, channels) for 2D.

  • stats_name (str, optional) – Name of the statistics to save.

  • disable_tqdm (bool, optional) – Whether to disable tqdm or not.

Returns:

pred – Predicted batch.

Return type:

NDArray

prepare_bmz_data(img)[source]

Prepare required data for exporting a model into BMZ.

Parameters:

img (4D/5D Numpy array) – Image to save (unnormalized). The axes must be in Torch format already, i.e. (b,c,y,x) for 2D or (b,c,z,y,x) for 3D.

process_test_sample()[source]

Process a sample in the inference phase.

normalize_stats(image_counter)[source]

Normalize statistics.

Parameters:

image_counter (int) – Number of images to average the metrics.

print_stats(image_counter)[source]

Print statistics.

Parameters:

image_counter (int) – Number of images to call normalize_stats.

abstract after_merge_patches(pred)[source]

Place any code that needs to be done after merging all predicted patches into the original image.

Parameters:

pred (Torch Tensor) – Model prediction.

abstract after_full_image(pred: ndarray[tuple[int, ...], dtype[_ScalarType_co]])[source]

Place here any code that must be executed after generating the prediction by supplying the entire image to the model.

To enable this, the model should be convolutional, and the image(s) should be in a 2D format. Using 3D images as direct inputs to the model is not feasible due to their large size.

Parameters:

pred (NDArray) – Model prediction.

after_all_images()[source]

Place here any code that must be done after predicting all images.

process_test_sample_by_chunks()[source]

Process a sample in the inference phase.

A final H5/Zarr file is created in β€œTZCYX” or β€œTZYXC” order depending on DATA.TEST.INPUT_IMG_AXES_ORDER (β€˜T’ is always included).

after_one_chunk_raw_prediction(chunk_id: int, chunk: ndarray[tuple[int, ...], dtype[_ScalarType_co]], chunk_in_data: PatchCoords, added_pad: List[List[int]])[source]

Place any code that needs to be done after predicting one chunk of data in β€œby chunks” setting.

Parameters:
  • chunk_id (int) – Chunk identifier.

  • chunk (NDArray) – Predicted chunk

  • patch_in_data (PatchCoords) – Global coordinates of the chunk.

  • added_pad (List of list of ints) – Padding added to the chunk in each dimension. The order of dimensions is the same as the input image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], …].

after_one_chunk_workflow_process(chunks: List[ndarray[tuple[int, ...], dtype[_ScalarType_co]]], patch_in_data: List) List[ndarray[tuple[int, ...], dtype[_ScalarType_co]]] | None[source]

Process a list of chunks during inference in β€œby chunks” setting. Each workflow should have its own implementation of this method.

Parameters:
  • chunks (List[NDArray]) – List of chunks. Expected axes are: (z, y, x, channels) for 3D and (y, x, channels) for 2D.

  • patch_in_data (List[PatchCoords]) – Spatial coordinates of each chunk in the full volume.

Returns:

chunks – Processed chunks.

Return type:

Optional[List[NDArray]]

after_all_chunk_prediction_workflow_process()[source]

Place any code that needs to be done after predicting all patches in β€œby chunks” setting. This function is called on all ranks.

after_all_chunk_prediction_workflow_process_master_rank()[source]

Place any code that needs to be done after predicting all the patches in the β€œby chunks” setting. This function is only called on the master rank.