Self Supervised Workflow
- class biapy.engine.self_supervised.Self_supervised_Workflow(cfg, job_identifier, device, args, **kwargs)[source]
Bases:
Base_Workflow
Self supervised workflow where the goal is to pretrain the backbone model by solving a so-called pretext task without labels. This way, the model learns a representation that can be later transferred to solve a downstream task in a labeled (but smaller) dataset. More details in our documentation.
- Parameters:
cfg (YACS configuration) – Running configuration.
Job_identifier (str) – Complete name of the running job.
device (Torch device) – Device used.
args (argpase class) – Arguments used in BiaPy’s call.
- metric_calculation(output, targets, metric_logger=None)[source]
Execution of the metrics defined in
define_metrics()
function.- Parameters:
output (Torch Tensor) – Prediction of the model.
targets (Torch Tensor) – Ground truth to compare the prediction with.
metric_logger (MetricLogger, optional) – Class to be updated with the new metric(s) value(s) calculated.
- Returns:
value – Value of the metric for the given prediction.
- Return type:
float
- prepare_targets(targets, batch)[source]
Location to perform any necessary data transformations to
targets
before calculating the loss.- Parameters:
targets (Torch Tensor) – Ground truth to compare the prediction with.
batch (Torch Tensor) – Prediction of the model.
- Returns:
targets – Resulting targets.
- Return type:
Torch tensor
- process_sample(norm)[source]
Function to process a sample in the inference phase.
- Parameters:
norm (List of dicts) – Normalization used during training. Required to denormalize the predictions of the model.
- torchvision_model_call(in_img, is_train=False)[source]
Call a regular Pytorch model.
- Parameters:
in_img (Tensor) – Input image to pass through the model.
is_train (bool, optional) – Whether if the call is during training or inference.
- Returns:
prediction – Image prediction.
- Return type:
Tensor
- after_merge_patches(pred)[source]
Steps need to be done after merging all predicted patches into the original image.
- Parameters:
pred (Torch Tensor) – Model prediction.
- after_merge_patches_by_chunks_proccess_patch(filename)[source]
Place any code that needs to be done after merging all predicted patches into the original image but in the process made chunk by chunk. This function will operate patch by patch defined by
DATA.PATCH_SIZE
.- Parameters:
filename (List of str) – Filename of the predicted image H5/Zarr.
- after_full_image(pred)[source]
Steps that must be executed after generating the prediction by supplying the entire image to the model.
- Parameters:
pred (Torch Tensor) – Model prediction.
- normalize_stats(image_counter)[source]
Normalize statistics.
- Parameters:
image_counter (int) – Number of images to average the metrics.