"""
Detection workflow for BiaPy.
This module defines the Detection_Workflow class, which implements the
training, validation, and inference pipeline for object detection tasks in BiaPy.
It handles data preparation, model setup, metrics, predictions, post-processing,
and result saving for localization of objects in 2D and 3D images.
"""
import os
import torch
import torch.distributed as dist
import numpy as np
import pandas as pd
from skimage.feature import peak_local_max, blob_log
from skimage.morphology import disk, dilation
from typing import Dict, Optional, List
from numpy.typing import NDArray
from skimage.filters import threshold_otsu
from biapy.data.post_processing.post_processing import (
remove_close_points,
detection_watershed,
measure_morphological_props_and_filter,
)
from biapy.utils.misc import (
is_main_process,
is_dist_avail_and_initialized,
to_pytorch_format,
MetricLogger,
os_walk_clean
)
from biapy.engine.metrics import (
detection_metrics,
multiple_metrics,
DiceCELoss,
DiceLoss,
detection_loss,
)
from biapy.data.pre_processing import create_detection_masks
from biapy.engine.base_workflow import Base_Workflow
from biapy.data.data_3D_manipulation import order_dimensions, looks_like_hdf5
from biapy.data.data_manipulation import save_tif, decide_dtype
from biapy.data.dataset import PatchCoords
[docs]
class Detection_Workflow(Base_Workflow):
"""
Detection workflow where the goal is to localize objects in the input image, not requiring a pixel-level class.
More details in `our documentation <https://biapy.readthedocs.io/en/latest/workflows/detection.html>`_.
Parameters
----------
cfg : YACS configuration
Running configuration.
Job_identifier : str
Complete name of the running job.
device : Torch device
Device used.
args : argpase class
Arguments used in BiaPy's call.
"""
def __init__(self, cfg, job_identifier, device, system_dict, args, **kwargs):
"""
Initialize the Detection_Workflow.
Sets up configuration, device, job identifier, and initializes
workflow-specific attributes for detection tasks.
Parameters
----------
cfg : YACS configuration
Running configuration.
job_identifier : str
Complete name of the running job.
device : torch.device
Device used.
args : argparse.Namespace
Arguments used in BiaPy's call.
**kwargs : dict
Additional keyword arguments.
"""
super(Detection_Workflow, self).__init__(cfg, job_identifier, device, system_dict, args, **kwargs)
self.original_test_mask_path = self.prepare_detection_data()
if self.use_gt:
self.csv_files = next(os_walk_clean(self.original_test_mask_path))[2]
# From now on, no modification of the cfg will be allowed
self.cfg.freeze()
# Workflow specific training variables
self.mask_path = cfg.DATA.TRAIN.GT_PATH
self.is_y_mask = True
self.load_Y_val = True
# Workflow specific test variables
if self.cfg.TEST.POST_PROCESSING.DET_WATERSHED or self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS:
self.post_processing["detection_post"] = True
else:
self.post_processing["detection_post"] = False
[docs]
def define_activations_and_channels(self):
"""
Define the activations to be applied to the model output and the channels that the model will output.
This function must define the following variables:
self.model_output_channels : List of int
Number of channels for each output head of the model. E.g. [3] for a model with one head outputting 3 channels,
[1, 5] for a model with two heads outputting 1 and 5 channels respectively, etc.
self.model_output_channel_info : List of str
Information about the output channels. A value per output head of the model must be defined.
self.separated_class_channel : bool
Whether if we should expect a separated output channel for classification.
self.head_activations : List of str
Activations to be applied to the model output. A value per output channel (not output head) of the model must be defined.
"linear" and "ce_sigmoid" will not be applied. E.g. ["linear"] for a model with one channel, ["linear", "sigmoid"] for a
model with two channels, etc.
Example of a correct definition of the function for a model with two output heads: 1) the first one will be predicting foreground
and contours; 2) the second one will classify into 3 classes the predicted objects. In this case the following definition would
be correct::
self.model_output_channels = [1, 3]
self.model_output_channel_info = ["mask", "class"]
self.separated_class_channel = True
self.head_activations = ["ce_sigmoid", "ce_sigmoid", "ce_softmax", "ce_softmax", "ce_softmax"]
"""
# Multi-head: points + classification
if self.cfg.DATA.N_CLASSES > 2:
self.head_activations = ["ce_sigmoid"] + ["ce_softmax"] * self.cfg.DATA.N_CLASSES
self.model_output_channels = [1, self.cfg.DATA.N_CLASSES]
self.model_output_channel_info = ["points", "class"]
self.separated_class_channel = True
else:
self.head_activations = ["ce_sigmoid"]
self.model_output_channel_info = ["points"]
self.model_output_channels = [1]
self.separated_class_channel = False
self.gt_channels_expected = self.model_output_channels[0]
super().define_activations_and_channels()
[docs]
def define_metrics(self):
"""
Define the metrics to be calculated during training and test/inference phases.
This function must define the following variables:
self.train_metrics : List of functions
Metrics to be calculated during model's training.
self.train_metric_names : List of str
Names of the metrics calculated during training.
self.train_metric_best : List of str
To know which value should be considered as the best one. Options must be: "max" or "min".
self.test_metrics : List of functions
Metrics to be calculated during model's test/inference.
self.test_metric_names : List of str
Names of the metrics calculated during test/inference.
self.loss : Function
Loss function used during training and test.
"""
self.train_metrics = []
self.train_metric_names = []
self.train_metric_best = []
for metric in list(set(self.cfg.TRAIN.METRICS)):
if metric in ["iou", "jaccard_index"]:
self.train_metric_names.append("IoU")
self.train_metric_best.append("max")
# Multi-head: detection + classification
if self.separated_class_channel:
self.train_metric_names.append("IoU (classes)")
self.train_metric_best += ["max"]
self.train_metrics.append(
multiple_metrics(
num_classes=self.cfg.DATA.N_CLASSES,
metric_names=self.train_metric_names,
device=self.device,
model_source=self.cfg.MODEL.SOURCE,
ndim=self.dims,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
)
)
self.test_metrics = []
self.test_metric_names = []
for metric in list(set(self.cfg.TEST.METRICS)):
if metric in ["iou", "jaccard_index"]:
self.test_metric_names.append("IoU")
# Multi-head: detection + classification
if self.separated_class_channel:
self.test_metric_names.append("IoU (classes)")
self.test_metrics.append(
multiple_metrics(
num_classes=self.cfg.DATA.N_CLASSES,
metric_names=self.test_metric_names,
device=self.test_device,
model_source=self.cfg.MODEL.SOURCE,
ndim=self.dims,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
)
)
# Workflow specific metrics calculated in a different way than calling metric_calculation(). These metrics are
# always calculated
self.test_extra_metrics = ["Precision", "Recall", "F1", "TP", "FP", "FN"]
if self.separated_class_channel:
self.test_extra_metrics += ["Precision (class)", "Recall (class)", "F1 (class)", "TP (class)", "FN (class)"]
self.test_metric_names += self.test_extra_metrics
self.loss = detection_loss(
ndim=self.dims,
class_rebalance_within_channels=self.cfg.PROBLEM.DETECTION.CLASS_REBALANCE_WITHIN_CHANNELS,
separated_class_channel=self.separated_class_channel,
channel_weights = self.cfg.PROBLEM.DETECTION.DATA_CHANNEL_WEIGHTS,
class_rebalance=self.cfg.LOSS.CLASS_REBALANCE,
class_weights=self.cfg.LOSS.CLASS_WEIGHTS,
ignore_index=self.cfg.LOSS.IGNORE_INDEX,
device=self.device,
)
super().define_metrics()
[docs]
def metric_calculation(
self,
output: NDArray | torch.Tensor,
targets: NDArray | torch.Tensor,
train: bool = True,
metric_logger: Optional[MetricLogger] = None,
) -> Dict:
"""
Calculate the metrics defined in :func:`~define_metrics` function.
Parameters
----------
output : Torch Tensor
Prediction of the model.
targets : Torch Tensor
Ground truth to compare the prediction with.
train : bool, optional
Whether to calculate train or test metrics.
metric_logger : MetricLogger, optional
Class to be updated with the new metric(s) value(s) calculated.
Returns
-------
out_metrics : dict
Value of the metrics for the given prediction.
"""
if isinstance(output, np.ndarray):
_output = to_pytorch_format(
output.copy(),
self.axes_order,
self.device if train else self.test_device,
dtype=self.loss_dtype,
)
else: # torch.Tensor
if not train:
_output = output.clone()
else:
_output = output
if not train and self.separated_class_channel and self.gt_channels_expected != _output.shape[1]:
class_idx = self.model_output_channel_info.index("class") if "class" in self.model_output_channel_info else -1
_output = torch.cat(
(
_output[:,:-self.model_output_channels[class_idx]],
torch.argmax(_output[:, -self.model_output_channels[class_idx]:], dim=1).unsqueeze(1)
), dim=1)
if isinstance(targets, np.ndarray):
_targets = to_pytorch_format(
targets.copy(),
self.axes_order,
self.device if train else self.test_device,
dtype=self.loss_dtype,
)
else: # torch.Tensor
if not train:
_targets = targets.clone()
else:
_targets = targets
out_metrics = {}
list_to_use = self.train_metrics if train else self.test_metrics
list_names_to_use = self.train_metric_names if train else self.test_metric_names
with torch.no_grad():
k = 0
for i, metric in enumerate(list_to_use):
val = metric(_output, _targets)
if isinstance(val, dict):
for m in val:
v = val[m].item() if not torch.isnan(val[m]) else 0
out_metrics[list_names_to_use[k]] = v
if metric_logger:
metric_logger.meters[list_names_to_use[k]].update(v)
k += 1
else:
val = val.item() if not torch.isnan(val) else 0 # type: ignore
out_metrics[list_names_to_use[i]] = val
if metric_logger:
metric_logger.meters[list_names_to_use[i]].update(val)
return out_metrics
[docs]
def detection_process(
self,
pred: NDArray,
inference_type: str = "full_image",
patch_pos: Optional[PatchCoords] = None,
):
"""
Process model's prediction to prepare detection output and calculate metrics (detection workflow engine for test/inference).
Parameters
----------
pred : 4D Torch tensor
Model predictions. E.g. ``(z, y, x, channels)`` for both 2D and 3D.
inference_type : str, optional
Type of inference. Options: ["per_crop", "merge_patches", "as_3D_stack", "full_image"].
patch_pos : PatchCoords, optional
Position of the patch to analize. By setting this the function will take only into account the GT points
corresponding to the patch at hand.
"""
assert inference_type in ["per_crop", "merge_patches", "as_3D_stack", "full_image"]
assert pred.ndim == 4, f"Prediction doesn't have 4 dim: {pred.shape}"
# Multi-head: points + classification
if self.separated_class_channel:
class_channel = np.expand_dims(pred[..., -1], -1)
pred = pred[..., :-1]
pred_shape = pred.shape
if self.cfg.TEST.VERBOSE and not self.cfg.TEST.BY_CHUNKS.ENABLE:
print("Capturing the local maxima ")
# Find points
if self.cfg.TEST.DET_TH_TYPE == "auto":
threshold_abs = threshold_otsu(pred[..., 0])
else: # manual
threshold_abs = self.cfg.TEST.DET_MIN_TH_TO_BE_PEAK
if self.cfg.TEST.DET_POINT_CREATION_FUNCTION == "peak_local_max":
pred_points = peak_local_max(
pred[..., 0].astype(np.float32),
min_distance=self.cfg.TEST.DET_PEAK_LOCAL_MAX_MIN_DISTANCE,
threshold_abs=threshold_abs,
exclude_border=self.cfg.TEST.DET_EXCLUDE_BORDER,
)
else:
pred_points = blob_log(
pred[..., 0] * 255,
min_sigma=self.cfg.TEST.DET_BLOB_LOG_MIN_SIGMA,
max_sigma=self.cfg.TEST.DET_BLOB_LOG_MAX_SIGMA,
num_sigma=self.cfg.TEST.DET_BLOB_LOG_NUM_SIGMA,
threshold=threshold_abs,
exclude_border=self.cfg.TEST.DET_EXCLUDE_BORDER,
)
pred_points = pred_points[:, :3].astype(int) # Remove sigma
# Remove close points per class as post-processing method
out_dir = self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK
if self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS and not self.cfg.TEST.BY_CHUNKS.ENABLE:
out_dir = self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING
pred_points = remove_close_points(
pred_points,
self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS_RADIUS,
self.resolution,
ndim=self.dims,
)
assert isinstance(pred_points, list)
# Decide the class for each point
pred_points_classes = []
if self.separated_class_channel:
for point in pred_points:
if self.dims == 3:
point_area = class_channel[
max(0, point[0] - 1) : min(pred.shape[0], point[0] + 1),
max(0, point[1] - 2) : min(pred.shape[1], point[1] + 2),
max(0, point[2] - 2) : min(pred.shape[2], point[2] + 2),
]
else:
point_area = class_channel[
max(0, point[0] - 2) : min(pred.shape[0], point[0] + 2),
max(0, point[1] - 2) : min(pred.shape[1], point[1] + 2),
]
instance_classes, instance_classes_count = np.unique(point_area, return_counts=True)
# Remove background
if instance_classes[0] == 0:
instance_classes = instance_classes[1:]
instance_classes_count = instance_classes_count[1:]
if len(instance_classes) > 0:
label_selected = int(instance_classes[np.argmax(instance_classes_count)])
else: # Label by default with class 1 in case there was no class info
label_selected = 1
pred_points_classes.append(label_selected)
else:
pred_points_classes = [0] * len(pred_points)
# Create a file with detected point and other image with predictions ids (if GT given)
if not self.cfg.TEST.BY_CHUNKS.ENABLE:
file_ext = os.path.splitext(self.current_sample["X_filename"])[1]
if self.cfg.TEST.VERBOSE:
print("Creating the images with detected points . . .")
dtype = decide_dtype(len(pred_points)+1)
points_pred_mask = np.zeros(pred.shape[:-1], dtype=dtype)
if len(pred_points) > 0:
# Paint the points
for n, coord in enumerate(pred_points):
z, y, x = coord
points_pred_mask[z, y, x] = n + 1
# Dilate and save the detected point image
for i in range(points_pred_mask.shape[0]):
points_pred_mask[i] = dilation(points_pred_mask[i], disk(3))
if self.separated_class_channel:
class_channel = np.zeros(points_pred_mask.shape, dtype=dtype)
for n in range(len(pred_points)):
class_channel = np.where(points_pred_mask == n + 1, pred_points_classes[n], class_channel)
points_pred_mask = np.concatenate(
[
np.expand_dims(points_pred_mask, -1),
np.expand_dims(class_channel, -1),
],
axis=-1,
)
else:
points_pred_mask = np.expand_dims(points_pred_mask, -1)
save_tif(
np.expand_dims(points_pred_mask, 0),
out_dir,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
if self.separated_class_channel:
points_pred_mask = points_pred_mask[..., 0]
else:
points_pred_mask = points_pred_mask.squeeze()
# Detection watershed
if self.cfg.TEST.POST_PROCESSING.DET_WATERSHED:
data_filename = os.path.join(self.cfg.DATA.TEST.PATH, self.current_sample["X_filename"])
w_dir = os.path.join(self.cfg.PATHS.WATERSHED_DIR, self.current_sample["X_filename"])
check_wa = w_dir if self.cfg.PROBLEM.DETECTION.DATA_CHECK_MW else None
assert isinstance(pred_points, list)
points_pred_mask = detection_watershed(
points_pred_mask,
pred_points,
data_filename,
self.cfg.TEST.POST_PROCESSING.DET_WATERSHED_FIRST_DILATION,
save_dir=check_wa,
ndim=self.dims,
donuts_classes=self.cfg.TEST.POST_PROCESSING.DET_WATERSHED_DONUTS_CLASSES,
donuts_patch=self.cfg.TEST.POST_PROCESSING.DET_WATERSHED_DONUTS_PATCH,
donuts_nucleus_diameter=self.cfg.TEST.POST_PROCESSING.DET_WATERSHED_DONUTS_NUCLEUS_DIAMETER,
)
# Instance filtering by properties
points_pred_mask, d_result = measure_morphological_props_and_filter(
points_pred_mask,
intensity_image=self.current_sample["X"][0],
resolution=self.resolution,
extra_props=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.EXTRA_PROPS,
filter_instances=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE,
properties=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.PROPS,
prop_values=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.VALUES,
comp_signs=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.SIGNS,
)
save_tif(
np.expand_dims(np.expand_dims(points_pred_mask, 0), -1),
self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
del points_pred_mask
# Save coords in a couple of csv files
df = None
if len(pred_points) > 0:
aux = np.array(pred_points)
if self.cfg.PROBLEM.NDIM == "3D":
prob = pred[aux[:, 0], aux[:, 1], aux[:, 2], 0]
if self.cfg.TEST.POST_PROCESSING.DET_WATERSHED:
df = pd.DataFrame(
zip(
d_result["labels"],
list(aux[:, 0]),
list(aux[:, 1]),
list(aux[:, 2]),
list(prob),
list(pred_points_classes),
d_result["npixels"],
d_result["areas"],
d_result["sphericities"],
d_result["diameters"],
d_result["perimeters"],
d_result["comment"],
d_result["conditions"],
),
columns=[
"pred_id",
"axis-0",
"axis-1",
"axis-2",
"probability",
"class",
"npixels",
"volume",
"sphericity",
"diameter",
"perimeter (surface area)",
"comment",
"conditions",
],
)
else:
labels = np.array(range(1, len(pred_points) + 1))
df = pd.DataFrame(
zip(
labels,
list(aux[:, 0]),
list(aux[:, 1]),
list(aux[:, 2]),
list(prob),
list(pred_points_classes),
),
columns=[
"pred_id",
"axis-0",
"axis-1",
"axis-2",
"probability",
"class",
],
)
else:
prob = pred[aux[:, 0], aux[:, 1], 0]
if self.cfg.TEST.POST_PROCESSING.DET_WATERSHED:
df = pd.DataFrame(
zip(
d_result["labels"],
list(aux[:, 0]),
list(aux[:, 1]),
list(prob),
list(pred_points_classes),
d_result["npixels"],
d_result["areas"],
d_result["circularities"],
d_result["diameters"],
d_result["perimeters"],
d_result["elongations"],
d_result["comment"],
d_result["conditions"],
),
columns=[
"pred_id",
"axis-0",
"axis-1",
"probability",
"class",
"npixels",
"area",
"circularity",
"diameter",
"perimeter",
"elongation",
"comment",
"conditions",
],
)
else:
labels = np.array(range(1, len(pred_points) + 1))
df = pd.DataFrame(
zip(
labels,
list(aux[:, 0]),
list(aux[:, 1]),
list(prob),
list(pred_points_classes),
),
columns=["pred_id", "axis-0", "axis-1", "probability", "class"],
)
df = df.sort_values(by=["pred_id"])
del aux
if not self.separated_class_channel:
df = df.drop(columns=["class"])
if not self.cfg.TEST.BY_CHUNKS.ENABLE:
# Save just the points and their probabilities
os.makedirs(out_dir, exist_ok=True)
df.to_csv(
os.path.join(
out_dir,
os.path.splitext(self.current_sample["X_filename"])[0] + "_points.csv",
),
index=False,
)
# Calculate detection metrics
if self.use_gt and not self.cfg.TEST.BY_CHUNKS.ENABLE:
all_channel_d_metrics = [0, 0, 0, 0, 0, 0]
if self.separated_class_channel:
all_channel_d_metrics += [0, 0, 0, 0, 0]
# Read the GT coordinates from the CSV file
csv_filename = os.path.join(
self.original_test_mask_path, os.path.splitext(self.current_sample["X_filename"])[0] + ".csv"
)
if not os.path.exists(csv_filename):
if self.cfg.TEST.VERBOSE:
print(
"WARNING: The CSV file seems to have different name than image. Using the CSV file "
"with the same position as the CSV in the directory. Check if it is correct!"
)
csv_filename = os.path.join(self.original_test_mask_path, self.csv_files[self.f_numbers[0]])
if self.cfg.TEST.VERBOSE:
print("Its respective CSV file seems to be: {}".format(csv_filename))
if self.cfg.TEST.VERBOSE:
print("Reading GT data from: {}".format(csv_filename))
df_gt = pd.read_csv(csv_filename)
df_gt = df_gt.rename(columns=lambda x: x.strip())
zcoords = df_gt["axis-0"].tolist()
ycoords = df_gt["axis-1"].tolist()
if self.cfg.PROBLEM.NDIM == "3D":
xcoords = df_gt["axis-2"].tolist()
gt_coordinates = [[z, y, x] for z, y, x in zip(zcoords, ycoords, xcoords)]
else:
gt_coordinates = [[0, y, x] for y, x in zip(zcoords, ycoords)]
if self.cfg.DATA.N_CLASSES > 2:
if "class" not in df_gt:
raise ValueError("DATA.N_CLASSES > 2 but no class specified in the CSV file")
gt_points_classes = None
if self.separated_class_channel:
if "class" not in df_gt:
raise ValueError("'class' column not present in the CSV file")
gt_points_classes = df_gt["class"].tolist()
# Take only into account the GT points corresponding to the patch at hand
if patch_pos:
patch_gt_coordinates = []
for j, cor in enumerate(gt_coordinates):
z, y, x = cor
z, y, x = int(z), int(y), int(x)
if (
patch_pos.z_start <= z < patch_pos.z_end
and patch_pos.y_start <= y < patch_pos.y_end
and patch_pos.x_start <= x < patch_pos.x_end
):
z = z - patch_pos.z_start
y = y - patch_pos.y_start
x = x - patch_pos.x_start
patch_gt_coordinates.append([z, y, x])
if z >= pred_shape[0] or y >= pred_shape[1] or x >= pred_shape[2]:
raise ValueError(f"Point [{z},{y},{x}] outside image with shape {pred_shape}")
gt_coordinates = patch_gt_coordinates.copy()
roi_to_consider = []
if self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX:
if self.cfg.PROBLEM.NDIM == "2D":
roi_to_consider = [
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[0],
max(
pred_shape[0] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[0],
0,
),
],
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[1],
max(
pred_shape[1] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[1],
0,
),
],
]
else:
roi_to_consider = [
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[0],
max(
pred_shape[0] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[0],
0,
),
],
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[1],
max(
pred_shape[1] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[1],
0,
),
],
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[2],
max(
pred_shape[2] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[2],
0,
),
],
]
# Calculate detection metrics
fp, gt_assoc = None, None
if len(pred_points) > 0:
d_metrics, gt_assoc, fp = detection_metrics(
gt_coordinates,
pred_points,
true_classes=gt_points_classes,
pred_classes=pred_points_classes,
tolerance=self.cfg.TEST.DET_TOLERANCE,
resolution=self.resolution,
bbox_to_consider=roi_to_consider,
verbose=self.cfg.TEST.VERBOSE,
)
if self.cfg.TEST.VERBOSE:
print("Detection metrics: {}".format(d_metrics))
all_channel_d_metrics[0] += d_metrics["Precision"]
all_channel_d_metrics[1] += d_metrics["Recall"]
all_channel_d_metrics[2] += d_metrics["F1"]
all_channel_d_metrics[3] += d_metrics["TP"]
all_channel_d_metrics[4] += d_metrics["FP"]
all_channel_d_metrics[5] += d_metrics["FN"]
if self.separated_class_channel:
all_channel_d_metrics[6] += d_metrics["Precision (class)"]
all_channel_d_metrics[7] += d_metrics["Recall (class)"]
all_channel_d_metrics[8] += d_metrics["F1 (class)"]
all_channel_d_metrics[9] += d_metrics["TP (class)"]
all_channel_d_metrics[10] += d_metrics["FN (class)"]
# Save csv files with the associations between GT points and predicted ones
if gt_assoc is not None:
gt_assoc_orig = gt_assoc.copy()
if fp is not None:
fp_orig = fp.copy()
if self.cfg.PROBLEM.NDIM == "2D":
if gt_assoc is not None:
gt_assoc = gt_assoc.drop(columns=["axis-0"])
gt_assoc = gt_assoc.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
if fp is not None:
fp = fp.drop(columns=["axis-0"])
fp = fp.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
if gt_assoc is not None:
os.makedirs(self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS, exist_ok=True)
gt_assoc.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS,
os.path.splitext(self.current_sample["X_filename"])[0] + "_gt_assoc.csv",
),
index=False,
)
if fp is not None:
os.makedirs(self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS, exist_ok=True)
fp.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS,
os.path.splitext(self.current_sample["X_filename"])[0] + "_fp.csv",
),
index=False,
)
if gt_assoc is not None:
gt_assoc = gt_assoc_orig
if fp is not None:
fp = fp_orig
else:
if self.cfg.TEST.VERBOSE:
print("No point found to calculate the metrics!")
if not self.cfg.TEST.BY_CHUNKS.ENABLE:
for n, metric in enumerate(self.test_extra_metrics):
if str(metric).lower() not in self.stats[inference_type]:
self.stats[inference_type][str(metric.lower())] = 0
self.stats[inference_type][str(metric).lower()] += all_channel_d_metrics[n]
self.current_sample_metrics[str(metric).lower()] = all_channel_d_metrics[n]
if self.cfg.TEST.VERBOSE:
if len(gt_coordinates) == 0:
print("No points found in GT!")
print("Creating the image with a summary of detected points and false positives with colors . . .")
dtype = decide_dtype(len(gt_coordinates)+1)
points_pred_mask_color = np.zeros(pred_shape[:-1] + (3,), dtype=dtype)
# TP and FN
gt_id_img = np.zeros(pred_shape[:-1], dtype=dtype)
for j, cor in enumerate(gt_coordinates):
z, y, x = cor
z, y, x = int(z), int(y), int(x)
if z >= pred_shape[0] or y >= pred_shape[1] or x >= pred_shape[2]:
print(f"WARNING: GT point [{z},{y},{x}] outside image with shape {pred_shape}. Skipping it in the summary image.")
continue
if gt_assoc is not None:
if gt_assoc[gt_assoc["gt_id"] == j + 1]["tag"].iloc[0] == "TP":
points_pred_mask_color[z, y, x] = (0, 255, 0) # Green
elif gt_assoc[gt_assoc["gt_id"] == j + 1]["tag"].iloc[0] == "NC":
points_pred_mask_color[z, y, x] = (150, 150, 150) # Gray
else:
points_pred_mask_color[z, y, x] = (255, 0, 0) # Red
else:
points_pred_mask_color[z, y, x] = (255, 0, 0) # Red
gt_id_img[z, y, x] = j + 1
# Dilate and save the GT ids for the current class
for i in range(gt_id_img.shape[0]):
gt_id_img[i] = dilation(gt_id_img[i], disk(3))
save_tif(
np.expand_dims(np.expand_dims(gt_id_img, 0), -1),
self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS,
[os.path.splitext(self.current_sample["X_filename"])[0] + "_gt_ids" + file_ext],
verbose=self.cfg.TEST.VERBOSE,
)
# FP
if fp is not None:
for cor in zip(
fp["axis-0"].tolist(),
fp["axis-1"].tolist(),
fp["axis-2"].tolist(),
):
z, y, x = cor
z, y, x = int(z), int(y), int(x)
points_pred_mask_color[z, y, x] = (0, 0, 255) # Blue
# Dilate and save the predicted points for the current class
for i in range(points_pred_mask_color.shape[0]):
for j in range(points_pred_mask_color.shape[-1]):
points_pred_mask_color[i, ..., j] = dilation(points_pred_mask_color[i, ..., j], disk(3))
save_tif(
np.expand_dims(points_pred_mask_color, 0),
self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS,
[self.current_sample["X_filename"]],
verbose=self.cfg.TEST.VERBOSE,
)
return df
[docs]
def after_merge_patches(self, pred):
"""
Excute steps needed after merging all predicted patches into the original image.
Parameters
----------
pred : Torch Tensor
Model prediction.
"""
if pred.shape[0] == 1:
if self.cfg.PROBLEM.NDIM == "3D":
pred = pred[0]
self.detection_process(
pred,
inference_type="merge_patches",
)
else:
raise NotImplementedError
[docs]
def after_one_chunk_raw_prediction(
self, chunk_id: int, chunk: NDArray, chunk_in_data: PatchCoords, added_pad: List[List[int]]
):
"""
Place any code that needs to be done after predicting one chunk of data in "by chunks" setting.
Parameters
----------
chunk_id: int
Chunk identifier.
chunk : NDArray
Predicted chunk
chunk_in_data : PatchCoords
Global coordinates of the chunk.
added_pad: List of list of ints
Padding added to the chunk in each dimension. The order of dimensions is the same as the input
image, and the order of the list is: [[pad_before_dim1, pad_after_dim1], [pad_before_dim2, pad_after_dim2], ....
"""
df_patch = self.detection_process(chunk, patch_pos=chunk_in_data)
if df_patch is not None and len(df_patch) > 0:
# Remove possible points in the padded area
df_patch = df_patch[df_patch["axis-0"] < chunk.shape[0] - added_pad[0][1]]
df_patch = df_patch[df_patch["axis-1"] < chunk.shape[1] - added_pad[1][1]]
df_patch = df_patch[df_patch["axis-2"] < chunk.shape[2] - added_pad[2][1]]
df_patch["axis-0"] = df_patch["axis-0"] - added_pad[0][0]
df_patch["axis-1"] = df_patch["axis-1"] - added_pad[1][0]
df_patch["axis-2"] = df_patch["axis-2"] - added_pad[2][0]
df_patch = df_patch[df_patch["axis-0"] >= 0]
df_patch = df_patch[df_patch["axis-1"] >= 0]
df_patch = df_patch[df_patch["axis-2"] >= 0]
# Add the chunk shift to the detected coordinates so they represent global coords
df_patch["axis-0"] = df_patch["axis-0"] + chunk_in_data.z_start
df_patch["axis-1"] = df_patch["axis-1"] + chunk_in_data.y_start
df_patch["axis-2"] = df_patch["axis-2"] + chunk_in_data.x_start
# Save the csv file
output_dir = (
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING
if self.post_processing["detection_post"]
else self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK
)
os.makedirs(output_dir, exist_ok=True)
_filename, _ = os.path.splitext(os.path.basename(self.current_sample["X_filename"]))
df_patch.to_csv(
os.path.join(
output_dir,
_filename + "_patch" + str(chunk_id).zfill(len(str(len(self.test_generator)))) + "_points.csv",
),
index=False,
)
[docs]
def after_one_chunk_workflow_process(self, chunks: List[NDArray], patch_in_data: List) -> Optional[List[NDArray]]:
"""
Process a list of chunks during inference in "by chunks" setting. Each workflow should have
its own implementation of this method.
Parameters
----------
chunks : List[NDArray]
List of chunks. Expected axes are: ``(z, y, x, channels)`` for 3D and
``(y, x, channels)`` for 2D.
Returns
-------
chunks : Optional[List[NDArray]]
Processed chunks.
"""
pass
[docs]
def after_all_chunk_prediction_workflow_process(self):
"""
Place any code that needs to be done after predicting all patches in "by chunks" setting.
This function is called on all ranks.
"""
pass
[docs]
def after_all_chunk_prediction_workflow_process_master_rank(self):
"""Excute stepes needed after predicting all the patches, one by one, in the "by chunks" setting."""
assert isinstance(self.all_pred, list)
filename, _ = os.path.splitext(self.current_sample["X_filename"])
input_dir = (
self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING
if self.post_processing["detection_post"]
else self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK
)
try:
all_pred_files = next(os_walk_clean(input_dir))[2]
except:
all_pred_files = []
all_pred_files = [x for x in all_pred_files if filename + "_patch" in x]
all_pred_files = [x for x in all_pred_files if "_points.csv" in x and "all_points.csv" not in x]
if len(all_pred_files) > 0:
point_counter = 0
for pred_file in all_pred_files:
pred_file_path = os.path.join(input_dir, pred_file)
pred_df = pd.read_csv(pred_file_path, index_col=False)
pred_df["pred_id"] = pred_df["pred_id"] + point_counter
point_counter += len(pred_df)
self.all_pred.append(pred_df)
if len(self.all_pred) > 0:
df = pd.concat(self.all_pred, ignore_index=True)
# Take point coords
pred_coordinates = []
if df is None:
print("No points created, skipping evaluation . . .")
return
coordz = df["axis-0"].tolist()
coordy = df["axis-1"].tolist()
coordx = df["axis-2"].tolist()
for z, y, x in zip(coordz, coordy, coordx):
pred_coordinates.append([z, y, x])
# Apply post-processing of removing points
out_dir = self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK
if self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS:
out_dir = self.cfg.PATHS.RESULT_DIR.DET_LOCAL_MAX_COORDS_CHECK_POST_PROCESSING
pred_coordinates, dropped_pos = remove_close_points( # type: ignore
pred_coordinates,
self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS_RADIUS,
self.resolution,
ndim=self.dims,
return_drops=True,
)
# Remove points from dataframe
df = df.drop(dropped_pos)
t_dim, z_dim, y_dim, x_dim, c_dim = order_dimensions(
self.cfg.DATA.PREPROCESS.ZOOM.ZOOM_FACTOR,
input_order=self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER,
output_order="TZYXC",
default_value=1,
)
df["axis-0"] = df["axis-0"] / z_dim # type: ignore
df["axis-1"] = df["axis-1"] / y_dim # type: ignore
df["axis-2"] = df["axis-2"] / x_dim # type: ignore
os.makedirs(out_dir, exist_ok=True)
df.to_csv(
os.path.join(
out_dir,
filename + "_all_points.csv",
),
index=False,
)
# Calculate metrics with all the points
if self.use_gt:
print("Calculating detection metrics with all the points found . . .")
# Read the GT coordinates from the CSV file
csv_filename = os.path.join(self.original_test_mask_path, os.path.splitext(filename[0])[0] + ".csv")
if not os.path.exists(csv_filename):
if self.cfg.TEST.VERBOSE:
print(
"WARNING: The CSV file seems to have different name than image. Using the CSV file "
"with the same position as the CSV in the directory. Check if it is correct!"
)
csv_filename = os.path.join(self.original_test_mask_path, self.csv_files[self.f_numbers[0]])
if self.cfg.TEST.VERBOSE:
print("Its respective CSV file seems to be: {}".format(csv_filename))
if self.cfg.TEST.VERBOSE:
print("Reading GT data from: {}".format(csv_filename))
df_gt = pd.read_csv(csv_filename)
df_gt = df_gt.rename(columns=lambda x: x.strip())
gt_coordinates = [
[z, y, x]
for z, y, x in zip(
df_gt["axis-0"].tolist(),
df_gt["axis-1"].tolist(),
df_gt["axis-2"].tolist(),
)
]
# Measure metrics
roi_to_consider = []
if self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX:
roi_to_consider = [
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[0],
max(
self.parallel_data_shape[0] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[0],
0,
),
],
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[1],
max(
self.parallel_data_shape[1] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[1],
0,
),
],
[
self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[2],
max(
self.parallel_data_shape[2] - self.cfg.TEST.DET_IGNORE_POINTS_OUTSIDE_BOX[2],
0,
),
],
]
d_metrics, gt_assoc, fp = detection_metrics(
gt_coordinates,
pred_coordinates,
tolerance=self.cfg.TEST.DET_TOLERANCE,
resolution=self.resolution,
bbox_to_consider=roi_to_consider,
verbose=self.cfg.TEST.VERBOSE,
)
print("Detection metrics: {}".format(d_metrics))
if gt_assoc is not None:
os.makedirs(self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS, exist_ok=True)
gt_assoc.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS,
filename + "_gt_assoc.csv",
),
index=False,
)
if fp is not None:
os.makedirs(self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS, exist_ok=True)
fp.to_csv(
os.path.join(
self.cfg.PATHS.RESULT_DIR.DET_ASSOC_POINTS,
filename + "_fp.csv",
),
index=False,
)
for metric in self.test_extra_metrics:
if str(metric).lower() not in self.stats["merge_patches"]:
self.stats["merge_patches"][str(metric).lower()] = 0
self.stats["merge_patches"][str(metric).lower()] += d_metrics[str(metric)]
self.current_sample_metrics[str(metric).lower()] = d_metrics[str(metric)]
else:
print("No points created for the given sample")
[docs]
def process_test_sample(self):
"""Process a sample in the test/inference phase."""
if self.cfg.MODEL.SOURCE != "torchvision":
super().process_test_sample()
else:
# Skip processing image
if "discard" in self.current_sample and self.current_sample["discard"]:
return True
##################
### FULL IMAGE ###
##################
# Make the prediction
pred = self.model_call_func(self.current_sample["X"])
del self.current_sample["X"]
# In Torchvision the output is a collection of bboxes so there is nothing else to do here
if self.cfg.MODEL.SOURCE == "torchvision" and pred is None:
return
if self.cfg.DATA.REFLECT_TO_COMPLETE_SHAPE:
reflected_orig_shape = (1,) + self.current_sample["reflected_orig_shape"]
if reflected_orig_shape != pred.shape:
if self.cfg.PROBLEM.NDIM == "2D":
pred = pred[:, -reflected_orig_shape[1] :, -reflected_orig_shape[2] :]
else:
pred = pred[
:,
-reflected_orig_shape[1] :,
-reflected_orig_shape[2] :,
-reflected_orig_shape[3] :,
]
[docs]
def torchvision_model_call(self, in_img: torch.Tensor, is_train: bool = False) -> torch.Tensor | None:
"""
Call a regular Pytorch model.
Parameters
----------
in_img : torch.Tensor
Input image to pass through the model.
is_train : bool, optional
Whether if the call is during training or inference.
Returns
-------
prediction : torch.Tensor
Image prediction.
"""
assert self.model and self.torchvision_preprocessing
filename, file_extension = os.path.splitext(self.current_sample["X_filename"])
# Convert first to 0-255 range if uint16
if in_img.dtype == torch.float32:
if torch.max(in_img) > 1:
in_img = (self.torchvision_norm.apply_image_norm(in_img)[0] * 255).to(torch.uint8) # type: ignore
in_img = in_img.to(torch.uint8)
# Apply TorchVision pre-processing
in_img = self.torchvision_preprocessing(in_img)
pred = self.model(in_img)
bboxes = pred[0]["boxes"].cpu().numpy()
if not is_train and len(bboxes) != 0:
# Extract each output from prediction
labels = pred[0]["labels"].cpu().numpy()
scores = pred[0]["scores"].cpu().numpy()
# Save all info in a csv file
df = pd.DataFrame(
zip(
labels,
scores,
bboxes[:, 0],
bboxes[:, 1],
bboxes[:, 2],
bboxes[:, 3],
),
columns=["label", "scores", "x1", "y1", "x2", "y2"],
)
df = df.sort_values(by=["label"])
df.to_csv(
os.path.join(self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, filename + ".csv"),
index=False,
)
return None
[docs]
def after_full_image(self, pred: NDArray):
"""
Excute steps due after generating the prediction by supplying the entire image to the model.
Parameters
----------
pred : NDArray
Model prediction.
"""
if pred.shape[0] == 1:
if self.cfg.PROBLEM.NDIM == "3D":
pred = pred[0]
self.detection_process(
pred,
inference_type="full_image",
)
else:
raise NotImplementedError
[docs]
def after_all_images(self):
"""Execute steps that must be done after predicting all images."""
super().after_all_images()
[docs]
def prepare_detection_data(self) -> str:
"""
Create detection ground truth images to train the model based on the ground truth coordinates provided.
They will be saved in a separate folder in the root path of the ground truth.
"""
original_test_mask_path = self.cfg.DATA.TEST.GT_PATH
create_mask = False
print("############################")
print("# PREPARE DETECTION DATA #")
print("############################")
# Create selected channels for train data
if self.cfg.TRAIN.ENABLE or self.cfg.DATA.TEST.USE_VAL_AS_TEST:
create_mask = False
if not os.path.isdir(self.cfg.DATA.TRAIN.DETECTION_MASK_DIR):
print(
"You select to create detection masks from given .csv files but no file is detected in {}. "
"So let's prepare the data. Notice that, if you do not modify 'DATA.TRAIN.DETECTION_MASK_DIR' "
"path, this process will be done just once!".format(self.cfg.DATA.TRAIN.DETECTION_MASK_DIR)
)
create_mask = True
else:
if len(next(os_walk_clean(self.cfg.DATA.TRAIN.DETECTION_MASK_DIR))[2]) != len(
next(os_walk_clean(self.cfg.DATA.TRAIN.GT_PATH))[2]
) and len(next(os_walk_clean(self.cfg.DATA.TRAIN.DETECTION_MASK_DIR))[1]) != len(
next(os_walk_clean(self.cfg.DATA.TRAIN.GT_PATH))[2]
):
print(
"Different number of files found in {} and {}. Trying to create the the rest again".format(
self.cfg.DATA.TRAIN.GT_PATH,
self.cfg.DATA.TRAIN.DETECTION_MASK_DIR,
)
)
create_mask = True
if create_mask:
create_detection_masks(self.cfg)
# Create selected channels for val data
if self.cfg.TRAIN.ENABLE and not self.cfg.DATA.VAL.FROM_TRAIN:
create_mask = False
if not os.path.isdir(self.cfg.DATA.VAL.DETECTION_MASK_DIR):
print(
"You select to create detection masks from given .csv files but no file is detected in {}. "
"So let's prepare the data. Notice that, if you do not modify 'DATA.VAL.DETECTION_MASK_DIR' "
"path, this process will be done just once!".format(self.cfg.DATA.VAL.DETECTION_MASK_DIR)
)
create_mask = True
else:
if len(next(os_walk_clean(self.cfg.DATA.VAL.DETECTION_MASK_DIR))[2]) != len(
next(os_walk_clean(self.cfg.DATA.VAL.GT_PATH))[2]
) and len(next(os_walk_clean(self.cfg.DATA.VAL.DETECTION_MASK_DIR))[1]) != len(
next(os_walk_clean(self.cfg.DATA.VAL.GT_PATH))[2]
):
print(
"Different number of files found in {} and {}. Trying to create the the rest again".format(
self.cfg.DATA.VAL.GT_PATH,
self.cfg.DATA.VAL.DETECTION_MASK_DIR,
)
)
create_mask = True
if create_mask:
create_detection_masks(self.cfg, data_type="val")
# Create selected channels for test data once
if self.cfg.TEST.ENABLE and self.cfg.DATA.TEST.LOAD_GT and not self.cfg.DATA.TEST.USE_VAL_AS_TEST:
create_mask = False
if not os.path.isdir(self.cfg.DATA.TEST.DETECTION_MASK_DIR):
print(
"You select to create detection masks from given .csv files but no file is detected in {}. "
"So let's prepare the data. Notice that, if you do not modify 'DATA.TEST.DETECTION_MASK_DIR' "
"path, this process will be done just once!".format(self.cfg.DATA.TEST.DETECTION_MASK_DIR)
)
create_mask = True
else:
if len(next(os_walk_clean(self.cfg.DATA.TEST.DETECTION_MASK_DIR))[2]) != len(
next(os_walk_clean(self.cfg.DATA.TEST.GT_PATH))[2]
) and len(next(os_walk_clean(self.cfg.DATA.TEST.DETECTION_MASK_DIR))[1]) != len(
next(os_walk_clean(self.cfg.DATA.TEST.GT_PATH))[2]
):
print(
"Different number of files found in {} and {}. Trying to create the the rest again".format(
self.cfg.DATA.TEST.GT_PATH,
self.cfg.DATA.TEST.DETECTION_MASK_DIR,
)
)
create_mask = True
if create_mask:
create_detection_masks(self.cfg, data_type="test")
if is_dist_avail_and_initialized():
dist.barrier()
opts = []
if self.cfg.TRAIN.ENABLE:
print(
"DATA.TRAIN.GT_PATH changed from {} to {}".format(
self.cfg.DATA.TRAIN.GT_PATH, self.cfg.DATA.TRAIN.DETECTION_MASK_DIR
)
)
opts.extend(["DATA.TRAIN.GT_PATH", self.cfg.DATA.TRAIN.DETECTION_MASK_DIR])
out_data_order = self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER
if "C" not in self.cfg.DATA.TRAIN.INPUT_IMG_AXES_ORDER:
out_data_order += "C"
print(
"DATA.TRAIN.INPUT_MASK_AXES_ORDER changed from {} to {}".format(
self.cfg.DATA.TRAIN.INPUT_MASK_AXES_ORDER, out_data_order
)
)
opts.extend([f"DATA.TRAIN.INPUT_MASK_AXES_ORDER", out_data_order])
if not self.cfg.DATA.VAL.FROM_TRAIN:
print(
"DATA.VAL.GT_PATH changed from {} to {}".format(
self.cfg.DATA.VAL.GT_PATH, self.cfg.DATA.VAL.DETECTION_MASK_DIR
)
)
opts.extend(["DATA.VAL.GT_PATH", self.cfg.DATA.VAL.DETECTION_MASK_DIR])
out_data_order = self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER
if "C" not in self.cfg.DATA.VAL.INPUT_IMG_AXES_ORDER:
out_data_order += "C"
print(
"DATA.VAL.INPUT_MASK_AXES_ORDER changed from {} to {}".format(
self.cfg.DATA.VAL.INPUT_MASK_AXES_ORDER, out_data_order
)
)
opts.extend([f"DATA.VAL.INPUT_MASK_AXES_ORDER", out_data_order])
if self.cfg.TEST.ENABLE and self.cfg.DATA.TEST.LOAD_GT:
print(
"DATA.TEST.GT_PATH changed from {} to {}".format(
self.cfg.DATA.TEST.GT_PATH, self.cfg.DATA.TEST.DETECTION_MASK_DIR
)
)
opts.extend(["DATA.TEST.GT_PATH", self.cfg.DATA.TEST.DETECTION_MASK_DIR])
out_data_order = self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER
if "C" not in self.cfg.DATA.TEST.INPUT_IMG_AXES_ORDER:
out_data_order += "C"
print(
"DATA.TEST.INPUT_MASK_AXES_ORDER changed from {} to {}".format(
self.cfg.DATA.TEST.INPUT_MASK_AXES_ORDER, out_data_order
)
)
opts.extend([f"DATA.TEST.INPUT_MASK_AXES_ORDER", out_data_order])
self.cfg.merge_from_list(opts)
return original_test_mask_path