Source code for biapy.utils.callbacks

import os
import time
import warnings
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from import imsave, imread

import numpy as np
import torch

[docs]class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience. Copied from: """ def __init__(self, patience=7, delta=0, trace_func=print): """ Parameters ---------- patience : int, optional How long to wait after last time validation loss improved. delta : float, optional Minimum change in the monitored quantity to qualify as an improvement. trace_func : function, optional Trace print function. """ self.patience = patience self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf = delta self.trace_func = trace_func def __call__(self, val_loss): score = -val_loss if self.best_score is None: self.best_score = score elif score < self.best_score + self.counter += 1 self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.val_loss_min = val_loss self.counter = 0
# class MAE_callback(): # """ Save MAE process in selected epochs. """ # def __init__(self, test_gen, epoch_interval=None, out_dir="MAE_checks"): # self.epoch_interval = epoch_interval # self.test_gen = test_gen # self.out_dir = out_dir # def on_epoch_end(self, epoch, logs=None): # if self.epoch_interval and epoch != 0 and epoch % self.epoch_interval == 0: # test_images = next(iter(self.test_gen)) # idx = np.random.choice(test_images.shape[0]) # cmap = 'gray' if test_images[idx].shape[-1] == 1 else '' # test_images_patches = self.model.patchify(test_images)[idx] # latent, mask, ids_restore, unmasked_ids, masked_ids = self.model.encoder(test_images, training=False) # reconstructed = self.model.decoder(latent, training=False, ids_restore=ids_restore)[idx] # masked_img = self.model.generate_masked_image(test_images_patches, unmasked_ids[idx]) # fig = plt.figure(figsize=(30, 5), constrained_layout=True) # grid = gridspec.GridSpec(1, 5, figure=fig) # ax1 = fig.add_subplot(grid[0]) # ax1.imshow(test_images[idx], cmap=cmap) # ax1.set_title(f"Original: {epoch:03d}") # rows, cols = self.model.encoder.patch_embed.grid_size # patch_size = self.model.encoder.patch_embed.patch_size[0] # gs00 = gridspec.GridSpecFromSubplotSpec(rows, cols, subplot_spec=grid[1]) # for i in range(rows): # for j in range(cols): # ax = fig.add_subplot(gs00[i, j], xticklabels=[],yticklabels=[]) # ax.imshow(tf.reshape(test_images_patches[(i*cols)+j],(patch_size, patch_size)), cmap=cmap) # ax.set_xticks([]) # ax.set_yticks([]) # ax3 = fig.add_subplot(grid[2]) # ax3.imshow(self.model.unpatchify(test_images_patches), cmap=cmap) # ax3.set_title(f"Unpatchify: {epoch:03d}") # ax4 = fig.add_subplot(grid[3]) # ax4.imshow(self.model.unpatchify(masked_img), cmap=cmap) # ax4.set_title(f"Masked: {epoch:03d}") # ax5 = fig.add_subplot(grid[4]) # ax5.imshow(self.model.unpatchify(reconstructed), cmap=cmap) # ax5.set_title(f"Reconstructed: {epoch:03d}") # os.makedirs(self.out_dir, exist_ok=True) # f = os.path.join(self.out_dir, "img_{}.png".format(epoch)) # plt.savefig(f) # plt.close() # print("Saving MAE callback image . . .")