Source code for biapy.engine.instance_seg

import os
import torch
import numpy as np
import pandas as pd
from skimage.io import imread
from tqdm import tqdm
from skimage.segmentation import clear_border
from skimage.transform import resize
import torch.distributed as dist

from biapy.data.post_processing.post_processing import (watershed_by_channels, voronoi_on_mask, 
    measure_morphological_props_and_filter, repare_large_blobs, apply_binary_mask)
from biapy.data.pre_processing import create_instance_channels, create_test_instance_channels, norm_range01
from biapy.utils.util import save_tif
from biapy.utils.matching import matching, wrapper_matching_dataset_lazy
from biapy.engine.metrics import jaccard_index, instance_segmentation_loss, instance_metrics
from biapy.engine.base_workflow import Base_Workflow
from biapy.utils.misc import is_main_process, is_dist_avail_and_initialized


[docs]class Instance_Segmentation_Workflow(Base_Workflow): """ Instance segmentation workflow where the goal is to assign an unique id, i.e. integer, to each object of the input image. More details in `our documentation <https://biapy.readthedocs.io/en/latest/workflows/instance_segmentation.html>`_. Parameters ---------- cfg : YACS configuration Running configuration. Job_identifier : str Complete name of the running job. device : Torch device Device used. args : argpase class Arguments used in BiaPy's call. """ def __init__(self, cfg, job_identifier, device, args, **kwargs): super(Instance_Segmentation_Workflow, self).__init__(cfg, job_identifier, device, args, **kwargs) self.original_test_path, self.original_test_mask_path = self.prepare_instance_data() # Merging the image self.all_matching_stats_merge_patches = [] self.all_matching_stats_merge_patches_post = [] self.stats['inst_stats_merge_patches'] = None self.stats['inst_stats_merge_patches_post'] = None # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: self.all_class_stats_merge_patches = [] self.all_class_stats_merge_patches_post = [] self.stats['class_stats_merge_patches'] = None self.stats['class_stats_merge_patches_post'] = None # As 3D stack self.all_matching_stats_as_3D_stack = [] self.all_matching_stats_as_3D_stack_post = [] self.stats['inst_stats_as_3D_stack'] = None self.stats['inst_stats_as_3D_stack_post'] = None # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: self.all_class_stats_as_3D_stack = [] self.all_class_stats_as_3D_stack_post = [] self.stats['class_stats_as_3D_stack'] = None self.stats['class_stats_as_3D_stack_post'] = None # Full image self.all_matching_stats = [] self.all_matching_stats_post = [] self.stats['inst_stats'] = None self.stats['inst_stats_post'] = None # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: self.all_class_stats = [] self.all_class_stats_post = [] self.stats['class_stats'] = None self.stats['class_stats_post'] = None self.instance_ths = {} self.instance_ths['TYPE'] = self.cfg.PROBLEM.INSTANCE_SEG.DATA_MW_TH_TYPE self.instance_ths['TH_BINARY_MASK'] = self.cfg.PROBLEM.INSTANCE_SEG.DATA_MW_TH_BINARY_MASK self.instance_ths['TH_CONTOUR'] = self.cfg.PROBLEM.INSTANCE_SEG.DATA_MW_TH_CONTOUR self.instance_ths['TH_FOREGROUND'] = self.cfg.PROBLEM.INSTANCE_SEG.DATA_MW_TH_FOREGROUND self.instance_ths['TH_DISTANCE'] = self.cfg.PROBLEM.INSTANCE_SEG.DATA_MW_TH_DISTANCE self.instance_ths['TH_POINTS'] = self.cfg.PROBLEM.INSTANCE_SEG.DATA_MW_TH_POINTS # From now on, no modification of the cfg will be allowed self.cfg.freeze() # Activations for each output channel: # channel number : 'activation' self.activations = {} if self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "Dv2": self.activations = {'0': 'Linear'} elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS in ["BC", "BP", "BCM"]: self.activations = {':': 'CE_Sigmoid'} elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS in ["BDv2", "BD"]: self.activations = {'0': 'CE_Sigmoid', '1': 'Linear'} elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS in ["BCD", "BCDv2"]: self.activations = {'0': 'CE_Sigmoid', '1': 'CE_Sigmoid', '2': 'Linear'} # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: self.activations = [self.activations, {'0': 'Linear'}] else: self.activations = [self.activations] # Workflow specific training variables self.mask_path = cfg.DATA.TRAIN.GT_PATH self.load_Y_val = True # Specific instance segmentation post-processing if self.cfg.TEST.POST_PROCESSING.VORONOI_ON_MASK or \ self.cfg.TEST.POST_PROCESSING.CLEAR_BORDER or\ self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE or \ self.cfg.TEST.POST_PROCESSING.REPARE_LARGE_BLOBS_SIZE != -1: self.post_processing['instance_post'] = True else: self.post_processing['instance_post'] = False self.instances_already_created = False
[docs] def define_metrics(self): """ Definition of self.metrics, self.metric_names and self.loss variables. """ self.metrics = [] self.metric_names = [] self.loss = instance_segmentation_loss( self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNEL_WEIGHTS, self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, self.cfg.PROBLEM.INSTANCE_SEG.DISTANCE_CHANNEL_MASK, self.cfg.MODEL.N_CLASSES ) if self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "BC": self.metric_names = ["jaccard_index_B", "jaccard_index_C"] elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "BCM": self.metric_names = ["jaccard_index_B", "jaccard_index_C", "jaccard_index_M"] elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "BP": self.metric_names = ["jaccard_index_B", "jaccard_index_P"] elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "BD": self.metric_names = ["jaccard_index_B", "L1_distance_channel"] elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS in ["BCD", 'BCDv2']: self.metric_names = ["jaccard_index_B", "jaccard_index_C", "L1_distance_channel"] elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "BDv2": self.metric_names = ["jaccard_index_B", "L1_distance_channel"] elif self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS == "Dv2": self.metric_names = ["L1_distance_channel"] # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: self.metric_names.append("jaccard_index_classes") # Used to calculate IoU with the classification results self.jaccard_index_matching = jaccard_index(device="cpu", num_classes=self.cfg.MODEL.N_CLASSES) self.metrics = instance_metrics(num_classes=self.cfg.MODEL.N_CLASSES, metric_names=self.metric_names, device=self.device, torchvision_models=True if self.cfg.MODEL.SOURCE == "torchvision" else False)
[docs] def metric_calculation(self, output, targets, metric_logger=None): """ Execution of the metrics defined in :func:`~define_metrics` function. Parameters ---------- output : Torch Tensor Prediction of the model. targets : Torch Tensor Ground truth to compare the prediction with. metric_logger : MetricLogger, optional Class to be updated with the new metric(s) value(s) calculated. Returns ------- value : float Value of the metric for the given prediction. """ with torch.no_grad(): out = self.metrics(output, targets) first_val = None for key, value in out.items(): value = value.item() if not torch.isnan(value) else 0 if first_val is None and "jaccard_index" in key: first_val = value if metric_logger is not None: metric_logger.meters[key].update(value) if first_val is None: first_val = 0 return first_val
[docs] def instance_seg_process(self, pred, filenames, out_dir, out_dir_post_proc, resolution): """ Instance segmentation workflow engine for test/inference. Process model's prediction to prepare instance segmentation output and calculate metrics. Parameters ---------- pred : 4D/5D Torch tensor Model predictions. E.g. ``(z, y, x, channels)`` for both 2D and 3D. filenames : List of str Predicted image's filenames. out_dir : path Output directory to save the instances. out_dir_post_proc : path Output directory to save the post-processed instances. """ assert pred.ndim == 4, "Prediction doesn't have 4 dim: {pred.shape}" ############################# ### INSTANCE SEGMENTATION ### ############################# if not self.instances_already_created: # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: class_channel = np.expand_dims(pred[...,-1], -1) pred = pred[...,:-1] print("Creating instances with watershed . . .") w_dir = os.path.join(self.cfg.PATHS.WATERSHED_DIR, filenames[0]) check_wa = w_dir if self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHECK_MW else None w_pred = watershed_by_channels(pred, self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, ths=self.instance_ths, remove_before=self.cfg.PROBLEM.INSTANCE_SEG.DATA_REMOVE_BEFORE_MW, thres_small_before=self.cfg.PROBLEM.INSTANCE_SEG.DATA_REMOVE_SMALL_OBJ_BEFORE, seed_morph_sequence=self.cfg.PROBLEM.INSTANCE_SEG.SEED_MORPH_SEQUENCE, seed_morph_radius=self.cfg.PROBLEM.INSTANCE_SEG.SEED_MORPH_RADIUS, erode_and_dilate_foreground=self.cfg.PROBLEM.INSTANCE_SEG.ERODE_AND_DILATE_FOREGROUND, fore_erosion_radius=self.cfg.PROBLEM.INSTANCE_SEG.FORE_EROSION_RADIUS, fore_dilation_radius=self.cfg.PROBLEM.INSTANCE_SEG.FORE_DILATION_RADIUS, rmv_close_points=self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS, remove_close_points_radius=self.cfg.TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS_RADIUS[0], resolution=resolution, save_dir=check_wa, watershed_by_2d_slices=self.cfg.PROBLEM.INSTANCE_SEG.WATERSHED_BY_2D_SLICES) # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: print("Adapting class channel . . .") labels = np.unique(w_pred)[1:] new_class_channel = np.zeros(w_pred.shape, dtype=w_pred.dtype) # Classify each instance counting the most prominent class of all the pixels that compose it for l in labels: instance_classes, instance_classes_count = np.unique(class_channel[w_pred == l], return_counts=True) # Remove background if instance_classes[0] == 0: instance_classes = instance_classes[1:] instance_classes_count = instance_classes_count[1:] if len(instance_classes) > 0: label_selected = int(instance_classes[np.argmax(instance_classes_count)]) else: # Label by default with class 1 in case there was no class info label_selected = 1 new_class_channel = np.where(w_pred == l, label_selected, new_class_channel) class_channel = new_class_channel class_channel = class_channel.squeeze() del new_class_channel save_tif(np.expand_dims(np.concatenate([np.expand_dims(w_pred.squeeze(),-1), np.expand_dims(class_channel,-1)],axis=-1),0), out_dir, filenames, verbose=self.cfg.TEST.VERBOSE) else: save_tif(np.expand_dims(np.expand_dims(w_pred,-1),0), out_dir, filenames, verbose=self.cfg.TEST.VERBOSE) # Add extra dimension if working in 2D if w_pred.ndim == 2: w_pred = np.expand_dims(w_pred,0) else: w_pred = pred.squeeze() if w_pred.ndim == 2: w_pred = np.expand_dims(w_pred,0) results = None results_class = None if self.cfg.TEST.MATCHING_STATS and (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST): print("Calculating matching stats . . .") # Need to load instance labels, as Y are binary channels used for IoU calculation if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK and len(self.test_filenames) == w_pred.shape[0]: del self._Y _Y = np.zeros(w_pred.shape, dtype=w_pred.dtype) for i in range(len(self.test_filenames)): test_file = os.path.join(self.original_test_mask_path, self.test_filenames[i]) _Y[i] = imread(test_file).squeeze() else: test_file = os.path.join(self.original_test_mask_path, self.test_filenames[self.f_numbers[0]]) _Y = imread(test_file).squeeze() # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: # Channel check error_shape = None if self.cfg.PROBLEM.NDIM == "2D" and _Y.ndim != 3: error_shape = (256,256,2) elif self.cfg.PROBLEM.NDIM == "3D" and _Y.ndim != 4: error_shape = (40,256,256,2) if error_shape is not None: raise ValueError(f"Image {test_file} wrong dimension. In instance segmentation, when 'MODEL.N_CLASSES' are " f"more than 2 labels need to have two channels, e.g. {error_shape}, containing the instance " "segmentation map (first channel) and classification map (second channel).") # Ensure channel position if self.cfg.PROBLEM.NDIM == "2D": if _Y.shape[0] <= 3: _Y = _Y.transpose((1,2,0)) else: min_val = min(_Y.shape) channel_pos = _Y.shape.index(min_val) if channel_pos != 3 and _Y.shape[channel_pos] <= 4: new_pos = [x for x in range(4) if x != channel_pos]+[channel_pos,] _Y = _Y.transpose(new_pos) # Separate instance and classification channels _Y_classes = _Y[...,1] # Classes _Y = _Y[...,0] # Instances # Measure class IoU class_iou = self.jaccard_index_matching(torch.as_tensor(class_channel.squeeze().astype(np.int32)), torch.as_tensor(_Y_classes.squeeze().astype(np.int32))) class_iou = class_iou.item() if not torch.isnan(class_iou) else 0 print(f"Class IoU: {class_iou}") results_class = class_iou if _Y.ndim == 2: _Y = np.expand_dims(_Y,0) # For torchvision models that resize need to rezise the images if w_pred.shape != _Y.shape: _Y = resize(_Y, w_pred.shape, order=0) # Convert instances to integer if _Y.dtype == np.float32: _Y = _Y.astype(np.uint32) if _Y.dtype == np.float64: _Y = _Y.astype(np.uint64) diff_ths_colored_img = abs(len(self.cfg.TEST.MATCHING_STATS_THS_COLORED_IMG) - len(self.cfg.TEST.MATCHING_STATS_THS)) colored_img_ths = self.cfg.TEST.MATCHING_STATS_THS_COLORED_IMG+[-1]*diff_ths_colored_img results = matching(_Y, w_pred, thresh=self.cfg.TEST.MATCHING_STATS_THS, report_matches=True) for i in range(len(results)): # Extract TPs, FPs and FNs from the resulting matching data structure r_stats = results[i] thr = r_stats['thresh'] # TP and FN gt_ids = r_stats['gt_ids'][1:] matched_pairs = r_stats['matched_pairs'] gt_match = [x[0] for x in matched_pairs] gt_unmatch = [x for x in gt_ids if x not in gt_match] matched_scores = list(r_stats['matched_scores'])+[0 for _ in gt_unmatch] pred_match = [x[1] for x in matched_pairs]+[-1 for _ in gt_unmatch] tag = ["TP" if score >= thr else "FN" for score in matched_scores] # FPs pred_ids = r_stats['pred_ids'][1:] fp_instances = [x for x in pred_ids if x not in pred_match] fp_instances += [pred_id for score, pred_id in zip(matched_scores, pred_match) if score < thr] # Save csv files df = pd.DataFrame(zip(gt_match+gt_unmatch, pred_match, matched_scores, tag), columns =['gt_id', 'pred_id', 'iou', 'tag']) df = df.sort_values(by=['gt_id']) df_fp = pd.DataFrame(zip(fp_instances), columns =['pred_id']) os.makedirs(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, exist_ok=True) df.to_csv(os.path.join(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, os.path.splitext(filenames[0])[0]+'_th_{}_gt_assoc.csv'.format(thr)), index=False) df_fp.to_csv(os.path.join(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, os.path.splitext(filenames[0])[0]+'_th_{}_fp.csv'.format(thr)), index=False) del r_stats['matched_scores']; del r_stats['matched_tps']; del r_stats['matched_pairs']; del r_stats['pred_ids']; del r_stats['gt_ids'] print("DatasetMatching: {}".format(r_stats)) if colored_img_ths[i] != -1 and colored_img_ths[i] == thr: print("Creating the image with a summary of detected points and false positives with colors . . .") colored_result = np.zeros(w_pred.shape+(3,), dtype=np.uint8) print("Painting TPs and FNs . . .") for j in tqdm(range(len(gt_match)), disable=not is_main_process()): color = (0,255,0) if tag[j] == "TP" else (255,0,0) # Green or red colored_result[np.where(_Y == gt_match[j])] = color for j in tqdm(range(len(gt_unmatch)), disable=not is_main_process()): colored_result[np.where(_Y == gt_unmatch[j])] = (255,0,0) # Red print("Painting FPs . . .") for j in tqdm(range(len(fp_instances)), disable=not is_main_process()): colored_result[np.where(w_pred == fp_instances[j])] = (0,0,255) # Blue save_tif(np.expand_dims(colored_result,0), self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, [os.path.splitext(filenames[0])[0]+'_th_{}.tif'.format(thr)], verbose=self.cfg.TEST.VERBOSE) del colored_result ################### # Post-processing # ################### if self.cfg.TEST.POST_PROCESSING.REPARE_LARGE_BLOBS_SIZE != -1: if self.cfg.PROBLEM.NDIM == "2D": w_pred = w_pred[0] w_pred = repare_large_blobs(w_pred[0], self.cfg.TEST.POST_PROCESSING.REPARE_LARGE_BLOBS_SIZE) if self.cfg.PROBLEM.NDIM == "2D": w_pred = np.expand_dims(w_pred,0) if self.cfg.TEST.POST_PROCESSING.VORONOI_ON_MASK: w_pred = voronoi_on_mask(w_pred, pred, th=self.cfg.TEST.POST_PROCESSING.VORONOI_TH, verbose=self.cfg.TEST.VERBOSE) del pred if self.cfg.TEST.POST_PROCESSING.CLEAR_BORDER: print("Clearing borders . . .") if self.cfg.PROBLEM.NDIM == "2D": w_pred = w_pred[0] w_pred = clear_border(w_pred) if self.cfg.PROBLEM.NDIM == "2D": w_pred = np.expand_dims(w_pred,0) if self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.ENABLE or \ self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE: if self.cfg.PROBLEM.NDIM == "2D": w_pred = w_pred[0] w_pred, d_result = measure_morphological_props_and_filter(w_pred, resolution, filter_instances=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.ENABLE, properties=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.PROPS, prop_values=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.VALUES, comp_signs=self.cfg.TEST.POST_PROCESSING.MEASURE_PROPERTIES.REMOVE_BY_PROPERTIES.SIGN) if self.cfg.PROBLEM.NDIM == "2D": w_pred = np.expand_dims(w_pred,0) # Save all instance stats if self.cfg.PROBLEM.NDIM == "2D": df = pd.DataFrame(zip(np.array(d_result['labels'], dtype=np.uint64), list(d_result['centers'][:,0]), list(d_result['centers'][:,1]), d_result['npixels'], d_result['areas'], d_result['circularities'], d_result['diameters'], d_result['perimeters'], d_result['elongations'], d_result['comment'], d_result['conditions']), columns=['label', 'axis-0', 'axis-1', 'npixels', 'area', 'circularity', 'diameter', 'perimeter', 'elongation', 'comment', 'conditions']) else: df = pd.DataFrame(zip(np.array(d_result['labels'], dtype=np.uint64), list(d_result['centers'][:,0]), list(d_result['centers'][:,1]), list(d_result['centers'][:,2]), d_result['npixels'], d_result['areas'], d_result['sphericities'], d_result['diameters'], d_result['perimeters'], d_result['comment'], d_result['conditions']), columns=['label', 'axis-0', 'axis-1', 'axis-2', 'npixels', 'volume', 'sphericity', 'diameter', 'perimeter (surface area)', 'comment', 'conditions']) df = df.sort_values(by=['label']) df.to_csv(os.path.join(out_dir, os.path.splitext(filenames[0])[0]+'_full_stats.csv'), index=False) # Save only remain instances stats df = df[df["comment"].str.contains("Strange")==False] os.makedirs(out_dir_post_proc, exist_ok=True) df.to_csv(os.path.join(out_dir_post_proc, os.path.splitext(filenames[0])[0]+'_filtered_stats.csv'), index=False) del df results_post_proc = None results_class_post_proc = None if self.post_processing['instance_post']: if self.cfg.PROBLEM.NDIM == "2D": w_pred = w_pred[0] # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: class_channel = np.where(w_pred>0, class_channel, 0) # Adapt changes to post-processed w_pred save_tif(np.expand_dims(np.concatenate([np.expand_dims(w_pred,-1), np.expand_dims(class_channel,-1)],axis=-1),0), out_dir_post_proc, filenames, verbose=self.cfg.TEST.VERBOSE) else: save_tif(np.expand_dims(np.expand_dims(w_pred,-1),0), out_dir_post_proc, filenames, verbose=self.cfg.TEST.VERBOSE) if self.cfg.TEST.MATCHING_STATS and (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST): # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: # Measure class IoU class_iou = self.jaccard_index_matching(torch.as_tensor(class_channel.squeeze().astype(np.int32)), torch.as_tensor(_Y_classes.squeeze().astype(np.int32))) class_iou = class_iou.item() if not torch.isnan(class_iou) else 0 print(f"Class IoU (post-processing): {class_iou}") results_class_post_proc = class_iou if self.cfg.PROBLEM.NDIM == "2D": w_pred = np.expand_dims(w_pred,0) print("Calculating matching stats after post-processing . . .") results_post_proc = matching(_Y, w_pred, thresh=self.cfg.TEST.MATCHING_STATS_THS, report_matches=True) for i in range(len(results_post_proc)): # Extract TPs, FPs and FNs from the resulting matching data structure r_stats = results_post_proc[i] thr = r_stats['thresh'] # TP and FN gt_ids = r_stats['gt_ids'][1:] matched_pairs = r_stats['matched_pairs'] gt_match = [x[0] for x in matched_pairs] gt_unmatch = [x for x in gt_ids if x not in gt_match] matched_scores = list(r_stats['matched_scores'])+[0 for _ in gt_unmatch] pred_match = [x[1] for x in matched_pairs]+[-1 for _ in gt_unmatch] tag = ["TP" if score >= thr else "FN" for score in matched_scores] # FPs pred_ids = r_stats['pred_ids'][1:] fp_instances = [x for x in pred_ids if x not in pred_match] fp_instances += [pred_id for score, pred_id in zip(matched_scores, pred_match) if score < thr] # Save csv files df = pd.DataFrame(zip(gt_match+gt_unmatch, pred_match, matched_scores, tag), columns =['gt_id', 'pred_id', 'iou', 'tag']) df = df.sort_values(by=['gt_id']) df_fp = pd.DataFrame(zip(fp_instances), columns =['pred_id']) os.makedirs(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, exist_ok=True) df.to_csv(os.path.join(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, os.path.splitext(filenames[0])[0]+'_post-proc_th_{}_gt_assoc.csv'.format(thr)), index=False) df_fp.to_csv(os.path.join(self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, os.path.splitext(filenames[0])[0]+'_post-proc_th_{}_fp.csv'.format(thr)), index=False) del r_stats['matched_scores']; del r_stats['matched_tps']; del r_stats['matched_pairs']; del r_stats['pred_ids']; del r_stats['gt_ids'] print("DatasetMatching: {}".format(r_stats)) if colored_img_ths[i] != -1 and colored_img_ths[i] == thr: print("Creating the image with a summary of detected points and false positives with colors . . .") colored_result = np.zeros(w_pred.shape+(3,), dtype=np.uint8) print("Painting TPs and FNs . . .") for j in tqdm(range(len(gt_match)), disable=not is_main_process()): color = (0,255,0) if tag[j] == "TP" else (255,0,0) # Green or red colored_result[np.where(_Y == gt_match[j])] = color for j in tqdm(range(len(gt_unmatch)), disable=not is_main_process()): colored_result[np.where(_Y == gt_unmatch[j])] = (255,0,0) # Red print("Painting FPs . . .") for j in tqdm(range(len(fp_instances)), disable=not is_main_process()): colored_result[np.where(w_pred == fp_instances[j])] = (0,0,255) # Blue save_tif(np.expand_dims(colored_result,0), self.cfg.PATHS.RESULT_DIR.INST_ASSOC_POINTS, [os.path.splitext(filenames[0])[0]+'_post-proc_th_{}.tif'.format(thr)], verbose=self.cfg.TEST.VERBOSE) del colored_result return results, results_post_proc, results_class, results_class_post_proc
[docs] def process_sample(self, norm): """ Function to process a sample in the inference phase. Parameters ---------- norm : List of dicts Normalization used during training. Required to denormalize the predictions of the model. """ if self.cfg.MODEL.SOURCE != "torchvision": self.instances_already_created = False super().process_sample(norm) else: self.instances_already_created = True # Data channel check if self.cfg.DATA.PATCH_SIZE[-1] != self._X.shape[-1]: raise ValueError("Channel of the DATA.PATCH_SIZE given {} does not correspond with the loaded image {}. " "Please, check the channels of the images!".format(self.cfg.DATA.PATCH_SIZE[-1], self._X.shape[-1])) ################## ### FULL IMAGE ### ################## if self.cfg.TEST.FULL_IMG: # Make the prediction with torch.cuda.amp.autocast(): pred = self.model_call_func(self._X) del self._X if self.cfg.TEST.POST_PROCESSING.APPLY_MASK: pred = apply_binary_mask(pred, self.cfg.DATA.TEST.BINARY_MASKS) self.after_full_image(pred)
[docs] def after_merge_patches(self, pred): """ Steps need to be done after merging all predicted patches into the original image. Parameters ---------- pred : Torch Tensor Model prediction. """ if pred.shape[0] == 1: if self.cfg.PROBLEM.NDIM == "3D": pred = pred[0] if not self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK: resolution = self.cfg.DATA.TEST.RESOLUTION if len(self.cfg.DATA.TEST.RESOLUTION) == 2 else self.cfg.DATA.TEST.RESOLUTION[1:] r, r_post, rcls, rcls_post = self.instance_seg_process(pred, self.processing_filenames, self.cfg.PATHS.RESULT_DIR.PER_IMAGE_INSTANCES, self.cfg.PATHS.RESULT_DIR.PER_IMAGE_POST_PROCESSING, resolution) if r is not None: self.all_matching_stats_merge_patches.append(r) if r_post is not None: self.all_matching_stats_merge_patches_post.append(r_post) if rcls is not None: self.all_class_stats_merge_patches.append(rcls) if rcls_post is not None: self.all_class_stats_merge_patches_post.append(rcls_post) else: NotImplementedError
[docs] def after_merge_patches_by_chunks_proccess_patch(self, filename): """ Place any code that needs to be done after merging all predicted patches into the original image but in the process made chunk by chunk. This function will operate patch by patch defined by ``DATA.PATCH_SIZE``. Parameters ---------- filename : List of str Filename of the predicted image H5/Zarr. """ pass
[docs] def after_full_image(self, pred): """ Steps that must be executed after generating the prediction by supplying the entire image to the model. Parameters ---------- pred : Torch Tensor Model prediction. """ if pred.shape[0] == 1: if self.cfg.PROBLEM.NDIM == "3D": pred = pred[0] if not self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK: resolution = self.cfg.DATA.TEST.RESOLUTION if len(self.cfg.DATA.TEST.RESOLUTION) == 2 else self.cfg.DATA.TEST.RESOLUTION[1:] r, r_post, rcls, rcls_post = self.instance_seg_process(pred, self.processing_filenames, self.cfg.PATHS.RESULT_DIR.FULL_IMAGE_INSTANCES, self.cfg.PATHS.RESULT_DIR.FULL_IMAGE_POST_PROCESSING, resolution) if r is not None: self.all_matching_stats.append(r) if r_post is not None: self.all_matching_stats_post.append(r_post) if rcls is not None: self.all_class_stats.append(rcls) if rcls_post is not None: self.all_class_stats_post.append(rcls_post) else: NotImplementedError
[docs] def after_all_images(self): """ Steps that must be done after predicting all images. """ super().after_all_images() if self.cfg.TEST.ANALIZE_2D_IMGS_AS_3D_STACK: print("Analysing all images as a 3D stack . . .") if type(self.all_pred) is list: self.all_pred = np.concatenate(self.all_pred) resolution = self.cfg.DATA.TEST.RESOLUTION if len(self.cfg.DATA.TEST.RESOLUTION) == 3 else (self.cfg.DATA.TEST.RESOLUTION[0],)+self.cfg.DATA.TEST.RESOLUTION r, r_post, rcls, rcls_post = self.instance_seg_process(self.all_pred, ["3D_stack_instances.tif"], self.cfg.PATHS.RESULT_DIR.AS_3D_STACK, self.cfg.PATHS.RESULT_DIR.AS_3D_STACK_POST_PROCESSING, resolution) if r is not None: self.all_matching_stats_as_3D_stack.append(r) if r_post is not None: self.all_matching_stats_as_3D_stack_post.append(r_post) if rcls is not None: self.all_class_stats_as_3D_stack.append(rcls) if rcls_post is not None: self.all_class_stats_as_3D_stack_post.append(rcls_post)
[docs] def normalize_stats(self, image_counter): """ Normalize statistics. Parameters ---------- image_counter : int Number of images to average the metrics. """ super().normalize_stats(image_counter) if (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST): if self.cfg.TEST.MATCHING_STATS: # Merge patches if len(self.all_matching_stats_merge_patches) > 0: self.stats['inst_stats_merge_patches'] = wrapper_matching_dataset_lazy(self.all_matching_stats_merge_patches, self.cfg.TEST.MATCHING_STATS_THS) # As 3D stack if len(self.all_matching_stats_as_3D_stack) > 0: self.stats['inst_stats_as_3D_stack'] = wrapper_matching_dataset_lazy(self.all_matching_stats_as_3D_stack, self.cfg.TEST.MATCHING_STATS_THS) # Full image if len(self.all_matching_stats) > 0: self.stats['inst_stats'] = wrapper_matching_dataset_lazy(self.all_matching_stats, self.cfg.TEST.MATCHING_STATS_THS) if self.post_processing['instance_post']: # Merge patches if len(self.all_matching_stats_merge_patches_post) > 0: self.stats['inst_stats_merge_patches_post'] = wrapper_matching_dataset_lazy(self.all_matching_stats_merge_patches_post, self.cfg.TEST.MATCHING_STATS_THS) # As 3D stack if len(self.all_matching_stats_as_3D_stack_post) > 0: self.stats['inst_stats_as_3D_stack_post'] = wrapper_matching_dataset_lazy(self.all_matching_stats_as_3D_stack_post, self.cfg.TEST.MATCHING_STATS_THS) # Full image if len(self.all_matching_stats_post) > 0: self.stats['inst_stats_post'] = wrapper_matching_dataset_lazy(self.all_matching_stats_post, self.cfg.TEST.MATCHING_STATS_THS) # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: # Merge patches if len(self.all_class_stats_merge_patches) > 0: self.stats['class_stats_merge_patches'] = np.mean(self.all_class_stats_merge_patches) # As 3D stack if len(self.all_class_stats_as_3D_stack) > 0: self.stats['class_stats_as_3D_stack'] = np.mean(self.all_class_stats_as_3D_stack) # Full image if len(self.all_class_stats) > 0: self.stats['class_stats'] = np.mean(self.all_class_stats) if self.post_processing['instance_post']: # Merge patches if len(self.all_class_stats_merge_patches_post) > 0: self.stats['class_stats_merge_patches_post'] = np.mean(self.all_class_stats_merge_patches_post) # As 3D stack if len(self.all_class_stats_as_3D_stack_post) > 0: self.stats['class_stats_as_3D_stack_post'] = np.mean(self.all_class_stats_as_3D_stack_post) # Full image if len(self.all_class_stats_post) > 0: self.stats['class_stats_post'] = np.mean(self.all_class_stats_post)
[docs] def print_stats(self, image_counter): """ Print statistics. Parameters ---------- image_counter : int Number of images to call ``normalize_stats``. """ if self.cfg.MODEL.SOURCE != "torchvision": super().print_stats(image_counter) super().print_post_processing_stats() print("Instance segmentation specific metrics:") if self.cfg.TEST.MATCHING_STATS and (self.cfg.DATA.TEST.LOAD_GT or self.cfg.DATA.TEST.USE_VAL_AS_TEST) : for i in range(len(self.cfg.TEST.MATCHING_STATS_THS)): print("IoU TH={}".format(self.cfg.TEST.MATCHING_STATS_THS[i])) # Merge patches if self.stats['inst_stats_merge_patches'] is not None: print(" Merge patches:") print(f" {self.stats['inst_stats_merge_patches'][i]}") # As 3D stack if self.stats['inst_stats_as_3D_stack'] is not None: print(" As 3D stack:") print(f" {self.stats['inst_stats_as_3D_stack'][i]}") # Full image if self.stats['inst_stats'] is not None: print(" Full image:") print(f" {self.stats['inst_stats'][i]}") if self.post_processing['instance_post']: print("IoU (post-processing) TH={}".format(self.cfg.TEST.MATCHING_STATS_THS[i])) # Merge patches if self.stats['inst_stats_merge_patches_post'] is not None: print(" Merge patches (post-processing):") print(f" {self.stats['inst_stats_merge_patches_post'][i]}") # As 3D stack if self.stats['inst_stats_as_3D_stack_post'] is not None: print(" As 3D stack (post-processing):") print(f" {self.stats['inst_stats_as_3D_stack_post'][i]}") # Full image if self.stats['inst_stats_post'] is not None: print(" Full image (post-processing):") print(f" {self.stats['inst_stats_post'][i]}") # Multi-head: instances + classification if self.cfg.MODEL.N_CLASSES > 2: # Merge patches if self.stats['class_stats_merge_patches'] is not None: print(f" Merge patches classification IoU: {self.stats['class_stats_merge_patches']}") # As 3D stack if self.stats['class_stats_as_3D_stack'] is not None: print(f" As 3D stack classification IoU: {self.stats['class_stats_as_3D_stack']}") # Full image if self.stats['class_stats'] is not None: print(f" Full image classification IoU: {self.stats['class_stats']}") if self.post_processing['instance_post']: # Merge patches if self.stats['class_stats_merge_patches_post'] is not None: print(f" Merge patches classification IoU (post-processing): {self.stats['class_stats_merge_patches_post']}") # As 3D stack if self.stats['class_stats_as_3D_stack_post'] is not None: print(f" As 3D stack classification IoU (post-processing): {self.stats['class_stats_as_3D_stack_post']}") # Full image if self.stats['class_stats_post'] is not None: print(f" Full image classification IoU (post-processing): {self.stats['class_stats_post']}")
[docs] def prepare_instance_data(self): """ Creates instance segmentation ground truth images to train the model based on the ground truth instances provided. They will be saved in a separate folder in the root path of the ground truth. """ original_test_path, original_test_mask_path = None, None if is_main_process(): print("###########################") print("# PREPARE INSTANCE DATA #") print("###########################") # Create selected channels for train data if (self.cfg.TRAIN.ENABLE or self.cfg.DATA.TEST.USE_VAL_AS_TEST) and (not os.path.isdir(self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_DIR) or \ not os.path.isdir(self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_MASK_DIR)): print("You select to create {} channels from given instance labels and no file is detected in {}. " "So let's prepare the data. Notice that, if you do not modify 'DATA.TRAIN.INSTANCE_CHANNELS_DIR' " "path, this process will be done just once!".format(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_DIR)) create_instance_channels(self.cfg) # Create selected channels for val data if self.cfg.TRAIN.ENABLE and not self.cfg.DATA.VAL.FROM_TRAIN and (not os.path.isdir(self.cfg.DATA.VAL.INSTANCE_CHANNELS_DIR) or \ not os.path.isdir(self.cfg.DATA.VAL.INSTANCE_CHANNELS_MASK_DIR)): print("You select to create {} channels from given instance labels and no file is detected in {}. " "So let's prepare the data. Notice that, if you do not modify 'DATA.VAL.INSTANCE_CHANNELS_DIR' " "path, this process will be done just once!".format(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, self.cfg.DATA.VAL.INSTANCE_CHANNELS_DIR)) create_instance_channels(self.cfg, data_type='val') # Create selected channels for test data once if self.cfg.TEST.ENABLE and not self.cfg.DATA.TEST.USE_VAL_AS_TEST and (not os.path.isdir(self.cfg.DATA.TEST.INSTANCE_CHANNELS_DIR) or \ (not os.path.isdir(self.cfg.DATA.TEST.INSTANCE_CHANNELS_MASK_DIR) and self.cfg.DATA.TEST.LOAD_GT)) and not self.cfg.TEST.BY_CHUNKS.ENABLE: print("You select to create {} channels from given instance labels and no file is detected in {}. " "So let's prepare the data. Notice that, if you do not modify 'DATA.TEST.INSTANCE_CHANNELS_DIR' " "path, this process will be done just once!".format(self.cfg.PROBLEM.INSTANCE_SEG.DATA_CHANNELS, self.cfg.DATA.TEST.INSTANCE_CHANNELS_MASK_DIR)) create_test_instance_channels(self.cfg) if is_dist_avail_and_initialized(): dist.barrier() opts = [] if self.cfg.TRAIN.ENABLE: print("DATA.TRAIN.PATH changed from {} to {}".format(self.cfg.DATA.TRAIN.PATH, self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_DIR)) print("DATA.TRAIN.GT_PATH changed from {} to {}".format(self.cfg.DATA.TRAIN.GT_PATH, self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_MASK_DIR)) opts.extend(['DATA.TRAIN.PATH', self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_DIR, 'DATA.TRAIN.GT_PATH', self.cfg.DATA.TRAIN.INSTANCE_CHANNELS_MASK_DIR]) if not self.cfg.DATA.VAL.FROM_TRAIN: print("DATA.VAL.PATH changed from {} to {}".format(self.cfg.DATA.VAL.PATH, self.cfg.DATA.VAL.INSTANCE_CHANNELS_DIR)) print("DATA.VAL.GT_PATH changed from {} to {}".format(self.cfg.DATA.VAL.GT_PATH, self.cfg.DATA.VAL.INSTANCE_CHANNELS_MASK_DIR)) opts.extend(['DATA.VAL.PATH', self.cfg.DATA.VAL.INSTANCE_CHANNELS_DIR, 'DATA.VAL.GT_PATH', self.cfg.DATA.VAL.INSTANCE_CHANNELS_MASK_DIR]) if self.cfg.TEST.ENABLE and not self.cfg.DATA.TEST.USE_VAL_AS_TEST and not self.cfg.TEST.BY_CHUNKS.ENABLE: print("DATA.TEST.PATH changed from {} to {}".format(self.cfg.DATA.TEST.PATH, self.cfg.DATA.TEST.INSTANCE_CHANNELS_DIR)) opts.extend(['DATA.TEST.PATH', self.cfg.DATA.TEST.INSTANCE_CHANNELS_DIR]) if self.cfg.DATA.TEST.LOAD_GT: print("DATA.TEST.GT_PATH changed from {} to {}".format(self.cfg.DATA.TEST.GT_PATH, self.cfg.DATA.TEST.INSTANCE_CHANNELS_MASK_DIR)) opts.extend(['DATA.TEST.GT_PATH', self.cfg.DATA.TEST.INSTANCE_CHANNELS_MASK_DIR]) original_test_path = self.cfg.DATA.TEST.PATH original_test_mask_path = self.cfg.DATA.TEST.GT_PATH self.cfg.merge_from_list(opts) return original_test_path, original_test_mask_path
[docs] def torchvision_model_call(self, in_img, is_train=False): """ Call a regular Pytorch model. Parameters ---------- in_img : Tensor Input image to pass through the model. is_train : bool, optional Whether if the call is during training or inference. Returns ------- prediction : Tensor Image prediction. """ filename, file_extension = os.path.splitext(self.processing_filenames[0]) # Convert first to 0-255 range if uint16 if in_img.dtype == torch.float32: if torch.max(in_img) > 1: in_img = (norm_range01(in_img, torch.uint8)[0]*255).to(torch.uint8) in_img = in_img.to(torch.uint8) # Apply TorchVision pre-processing in_img = self.torchvision_preprocessing(in_img) pred = self.model(in_img) masks = pred[0]['masks'].cpu().numpy().transpose(0,2,3,1) if masks.shape[0] != 0: masks = np.argmax(pred[0]['masks'].cpu().numpy().transpose(0,2,3,1), axis=0) else: masks = torch.ones((1,)+pred[0]['masks'].cpu().numpy().transpose(0,2,3,1).shape[1:], dtype=torch.uint8) if not is_train and masks.shape[0] != 0: # Extract each output from MaskRCNN bboxes = pred[0]['boxes'].cpu().numpy().astype(np.uint16) labels = pred[0]['labels'].cpu().numpy() scores = pred[0]['scores'].cpu().numpy() # Save all info in a csv file df = pd.DataFrame(zip(labels, scores, bboxes[:,0],bboxes[:,1],bboxes[:,2],bboxes[:,3]), columns = ['label', 'scores', 'x1', 'y1', 'x2', 'y2']) df = df.sort_values(by=['label']) df.to_csv(os.path.join(self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, filename+".csv"), index=False) # Save masks save_tif(np.expand_dims(masks,0), self.cfg.PATHS.RESULT_DIR.FULL_IMAGE, self.processing_filenames, verbose=self.cfg.TEST.VERBOSE) # Actually to allow training this should return a Torch Tensor and not a Numpy array. For instance, # segmentation training is disabled, due to the absence of generators that contain bboxes, so this can be left # returning a Numpy array. This will only be called in process_sample inference function and for full image setting return np.expand_dims(masks.squeeze(),0)