Source code for biapy.engine.classification

"""
Classification workflow for BiaPy.

This module defines the Classification_Workflow class, which implements the
training, validation, and inference pipeline for image classification tasks in BiaPy.
It handles data loading, model setup, metrics, predictions, and result saving for
single-label classification problems.
"""
import os
import torch
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from torchmetrics import Accuracy
from typing import Dict, Optional
from numpy.typing import NDArray

from biapy.engine.base_workflow import Base_Workflow
from biapy.data.pre_processing import preprocess_data
from biapy.data.data_manipulation import load_and_prepare_train_data_cls, load_and_prepare_cls_test_data
from biapy.utils.misc import is_main_process, MetricLogger
from biapy.engine.metrics import loss_encapsulation


[docs] class Classification_Workflow(Base_Workflow): """ Classification workflow where the goal of this workflow is to assing a label to the input image. More details in `our documentation <https://biapy.readthedocs.io/en/latest/workflows/classification.html>`_. Parameters ---------- cfg : YACS configuration Running configuration. Job_identifier : str Complete name of the running job. device : Torch device Device used. args : argpase class Arguments used in BiaPy's call. """ def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs): """ Initialize the Classification_Workflow. Sets up configuration, device, job identifier, and initializes workflow-specific attributes for classification tasks. Parameters ---------- cfg : YACS configuration Running configuration. job_identifier : str Complete name of the running job. device : torch.device Device used. args : argparse.Namespace Arguments used in BiaPy's call. **kwargs : dict Additional keyword arguments. """ super(Classification_Workflow, self).__init__(cfg, job_identifier, device, system_dict, args, **kwargs) self.all_pred = [] if self.cfg.DATA.TEST.LOAD_GT: self.all_gt = [] self.test_filenames = None self.class_names = None # From now on, no modification of the cfg will be allowed self.cfg.freeze() # Workflow specific training variables self.mask_path = cfg.DATA.TRAIN.GT_PATH self.is_y_mask = False self.load_Y_val = True
[docs] def define_activations_and_channels(self): """ Define the activations to be applied to the model output and the channels that the model will output. This function must define the following variables: self.model_output_channels : List of int Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels, [1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc. self.model_output_channel_info : List of str Information about the output channels. A value per output head of the model must be defined. self.separated_class_channel : bool Whether if we should expect a separated output channel for classification. self.head_activations : List of str Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined. "linear" and "ce_sigmoid" will not be applied. E.g. ["linear"] for a model with one channel, ["linear", "sigmoid"] for a model with two channels, etc. Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would be correct:: self.model_output_channels = [1, 3] self.model_output_channel_info = ["mask", "class"] self.separated_class_channel = True self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"] """ self.model_output_channels = [self.cfg.DATA.N_CLASSES] self.gt_channels_expected = self.cfg.DATA.N_CLASSES self.separated_class_channel = False self.head_activations = ["ce_softmax"]* self.model_output_channels[0] self.model_output_channel_info = ["pred{}".format(i) for i in range(len(self.model_output_channels))] super().define_activations_and_channels()
[docs] def define_metrics(self): """ Define the metrics to be used during training and test/inference. This function must define the following variables: self.train_metrics : List of functions Metrics to be calculated during model's training. self.train_metric_names : List of str Names of the metrics calculated during training. self.train_metric_best : List of str To know which value should be considered as the best one. Options must be: "max" or "min". self.test_metrics : List of functions Metrics to be calculated during model's test/inference. self.test_metric_names : List of str Names of the metrics calculated during test/inference. self.loss : Function Loss function used during training and test. """ self.train_metrics = [] self.train_metric_names = [] self.train_metric_best = [] for metric in list(set(self.cfg.TRAIN.METRICS)): if metric == "accuracy": self.train_metrics.append( Accuracy(task="multiclass", num_classes=self.cfg.DATA.N_CLASSES).to(self.device), ) self.train_metric_names.append("Accuracy") self.train_metric_best.append("max") elif metric == "top-5-accuracy" and self.cfg.DATA.N_CLASSES > 5: self.train_metrics.append( Accuracy(task="multiclass", num_classes=self.cfg.DATA.N_CLASSES, top_k=5).to(self.device), ) self.train_metric_names.append("Top 5 accuracy") self.train_metric_best.append("max") self.test_metrics = [] self.test_metric_names = [] for metric in list(set(self.cfg.TEST.METRICS)): if metric == "accuracy": self.test_metrics.append( accuracy_score, ) self.test_metric_names.append("Accuracy") self.test_metrics.append(confusion_matrix) self.test_metric_names.append("Confusion matrix") if self.cfg.LOSS.TYPE == "CE": self.loss = loss_encapsulation(torch.nn.CrossEntropyLoss()) super().define_metrics()
[docs] def metric_calculation( self, output: NDArray | torch.Tensor, targets: NDArray | torch.Tensor, train: bool = True, metric_logger: Optional[MetricLogger] = None, ) -> Dict: """ Execute the calculation of metrics defined in :func:`~define_metrics` function. Parameters ---------- output : Torch Tensor/List of ints Prediction of the model. targets : Torch Tensor/List of ints Ground truth to compare the prediction with. train : bool, optional Whether to calculate train or test metrics. metric_logger : MetricLogger, optional Class to be updated with the new metric(s) value(s) calculated. Returns ------- out_metrics : dict Value of the metrics for the given prediction. """ out_metrics = {} list_to_use = self.train_metrics if train else self.test_metrics list_names_to_use = self.train_metric_names if train else self.test_metric_names with torch.no_grad(): for i, metric in enumerate(list_to_use): if isinstance(output, dict): output = output["pred"] val = metric(output, targets) if torch.is_tensor(val): val = val.item() if not torch.isnan(val) else 0 out_metrics[list_names_to_use[i]] = val if metric_logger: metric_logger.meters[list_names_to_use[i]].update(val) return out_metrics
[docs] def prepare_targets(self, targets, batch): """ Perform any necessary data transformations to ``targets`` before calculating the loss. Parameters ---------- targets : Torch Tensor Ground truth to compare the prediction with. batch : Torch Tensor Prediction of the model. Not used here. Returns ------- targets : Torch tensor Resulting targets. """ return targets.to(self.device, non_blocking=True)
[docs] def load_train_data(self): """Load training and validation data.""" ( self.X_train, self.X_val, self.cross_val_samples_ids, ) = load_and_prepare_train_data_cls( train_path=self.cfg.DATA.TRAIN.PATH, train_in_memory=self.cfg.DATA.TRAIN.IN_MEMORY, val_path=self.cfg.DATA.VAL.PATH, val_in_memory=self.cfg.DATA.VAL.IN_MEMORY, expected_classes=self.cfg.DATA.N_CLASSES, cross_val=self.cfg.DATA.VAL.CROSS_VAL, cross_val_nsplits=self.cfg.DATA.VAL.CROSS_VAL_NFOLD, cross_val_fold=self.cfg.DATA.VAL.CROSS_VAL_FOLD, val_split=self.cfg.DATA.VAL.SPLIT_TRAIN if self.cfg.DATA.VAL.FROM_TRAIN else 0.0, seed=self.cfg.SYSTEM.SEED, shuffle_val=self.cfg.DATA.VAL.RANDOM, train_preprocess_f=preprocess_data if self.cfg.DATA.PREPROCESS.TRAIN else None, train_preprocess_cfg=self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.TRAIN else None, train_filter_props=( self.cfg.DATA.TRAIN.FILTER_SAMPLES.PROPS if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else [] ), train_filter_vals=( self.cfg.DATA.TRAIN.FILTER_SAMPLES.VALUES if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else [] ), train_filter_signs=( self.cfg.DATA.TRAIN.FILTER_SAMPLES.SIGNS if self.cfg.DATA.TRAIN.FILTER_SAMPLES.ENABLE else [] ), val_preprocess_f=preprocess_data if self.cfg.DATA.PREPROCESS.VAL else None, val_preprocess_cfg=self.cfg.DATA.PREPROCESS if self.cfg.DATA.PREPROCESS.VAL else None, val_filter_props=( self.cfg.DATA.VAL.FILTER_SAMPLES.PROPS if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else [] ), val_filter_vals=( self.cfg.DATA.VAL.FILTER_SAMPLES.VALUES if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else [] ), val_filter_signs=( self.cfg.DATA.VAL.FILTER_SAMPLES.SIGNS if self.cfg.DATA.VAL.FILTER_SAMPLES.ENABLE else [] ), crop_shape=self.cfg.DATA.PATCH_SIZE, reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE, convert_to_rgb=self.cfg.DATA.FORCE_RGB, is_3d=(self.cfg.PROBLEM.NDIM == "3D"), norm_before_filter=self.cfg.DATA.TRAIN.FILTER_SAMPLES.NORM_BEFORE, norm_module=self.norm_module, ) self.Y_train, self.Y_val = None, None
[docs] def load_test_data(self): """Load test data.""" if self.cfg.TEST.ENABLE: print("######################") print("# LOAD TEST DATA #") print("######################") use_val_as_test_info = None if self.cfg.DATA.TEST.USE_VAL_AS_TEST: use_val_as_test_info = { "cross_val_samples_ids": self.cross_val_samples_ids, "train_path": self.cfg.DATA.TRAIN.PATH, "selected_fold": self.cfg.DATA.VAL.CROSS_VAL_FOLD, "n_splits": self.cfg.DATA.VAL.CROSS_VAL_NFOLD, "shuffle": self.cfg.DATA.VAL.RANDOM, "seed": self.cfg.SYSTEM.SEED, } self.Y_test = None ( self.X_test, self.test_filenames, ) = load_and_prepare_cls_test_data( test_path=self.cfg.DATA.TEST.PATH, norm_module=self.norm_module, use_val_as_test=self.cfg.DATA.TEST.USE_VAL_AS_TEST, expected_classes=self.cfg.DATA.N_CLASSES if self.use_gt else 1, crop_shape=self.cfg.DATA.PATCH_SIZE, is_3d=(self.cfg.PROBLEM.NDIM == "3D"), reflect_to_complete_shape=self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE, convert_to_rgb=self.cfg.DATA.FORCE_RGB, use_val_as_test_info=use_val_as_test_info, )
[docs] def process_test_sample(self): """Process a sample in the inference phase.""" assert isinstance(self.all_pred, list) and isinstance(self.all_gt, list) # Skip processing image if "discard" in self.current_sample and self.current_sample["discard"]: return True # Predict each patch l = int(math.ceil(self.current_sample["X"].shape[0] / self.cfg.TRAIN.BATCH_SIZE)) for k in tqdm(range(l), leave=False, disable=not is_main_process()): top = ( (k + 1) * self.cfg.TRAIN.BATCH_SIZE if (k + 1) * self.cfg.TRAIN.BATCH_SIZE < self.current_sample["X"].shape[0] else self.current_sample["X"].shape[0] ) p = self.model_call_func(self.current_sample["X"][k * self.cfg.TRAIN.BATCH_SIZE : top]) if isinstance(p, dict): p = p["pred"] p = p.cpu().numpy() p = np.argmax(p, axis=1) self.all_pred.append(p) if self.current_sample["Y"] is not None and self.all_gt is not None: self.all_gt.append(self.current_sample["Y"])
[docs] def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None: """ Call a regular Pytorch model. Parameters ---------- in_img : torch.Tensors Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. Returns ------- prediction : torch.Tensor Image prediction. """ # Convert first to 0-255 range if uint16 if in_img.dtype == torch.float32: if torch.max(in_img) > 255: in_img = (self.torchvision_norm.apply_image_norm(in_img)[0] * 255).to(torch.uint8) # type: ignore in_img = in_img.to(torch.uint8) # Apply TorchVision pre-processing assert self.torchvision_preprocessing and self.model in_img = self.torchvision_preprocessing(in_img) return self.model(in_img)
[docs] def after_all_images(self): """Execute steps that are needed after predicting all images.""" self.all_pred = np.array(self.all_pred).squeeze() if self.cfg.DATA.TEST.LOAD_GT and self.all_gt is not None: self.all_gt = np.array(self.all_gt).squeeze() # Save predictions in a csv file assert self.test_filenames is not None, "Test filenames must be defined before saving predictions." t_filename = [x.path for x in self.test_filenames] df = pd.DataFrame(t_filename, columns=["filename"]) df["class"] = self.all_pred f = os.path.join(self.cfg.PATHS.RESULT_DIR.PATH, "predictions.csv") os.makedirs(self.cfg.PATHS.RESULT_DIR.PATH, exist_ok=True) df.to_csv(f, index=False, header=True) # Calculate the metrics if self.cfg.DATA.TEST.LOAD_GT and self.all_gt is not None: metric_values = self.metric_calculation( self.all_pred, self.all_gt, # type: ignore train=False, ) for metric in metric_values: self.stats["full_image"][str(metric).lower()] = metric_values[metric] self.current_sample_metrics[str(metric).lower()] = metric_values[metric]
[docs] def print_stats(self, image_counter): """ Print statistics. Parameters ---------- image_counter : int Number of images to call ``normalize_stats``. """ if len(self.stats["full_image"]) > 0: for metric in self.test_metric_names: if metric.lower() in self.stats["full_image"]: if metric == "Confusion matrix": print("Confusion matrix: ") print(self.stats["full_image"][metric.lower()]) if self.class_names: display_labels = [ "Category {} ({})".format(i, self.class_names[i]) for i in range(self.cfg.DATA.N_CLASSES) ] else: display_labels = ["Category {}".format(i) for i in range(self.cfg.DATA.N_CLASSES)] print("\n" + classification_report(self.all_gt, self.all_pred, target_names=display_labels)) # type: ignore else: print( "Test {}: {}".format( metric, self.stats["full_image"][metric.lower()], ) )
[docs] def after_merge_patches(self, pred): """ Execute steps that are needed after merging all predicted patches into the original image. Parameters ---------- pred : Torch Tensor Model prediction. """ pass
[docs] def after_full_image(self, pred: NDArray): """ Execute steps that are needed after generating the prediction by supplying the entire image to the model. Parameters ---------- pred : NDArray Model prediction. """ pass
[docs] def after_all_chunk_prediction_workflow_process(self): """ Place any code that needs to be done after predicting all patches in "by chunks" setting. This function is called on all ranks. """ pass
[docs] def after_all_chunk_prediction_workflow_process_master_rank(self): """ Place any code that needs to be done after predicting all patches in "by chunks" setting, but only on the master rank. This function is called only on the master rank. """ pass