Source code for biapy.models.dfcan

# Adapted from https://github.com/L0-zhang/DFCAN-pytorch

import torch
import torch.nn as nn
import torch.fft

[docs]def fftshift2d(img, size_psc=128): bs,ch, h, w = img.shape fs11 = img[:,:, h//2:, w//2:] fs12 = img[:,:, h//2:, :w//2] fs21 = img[:,:, :h//2, w//2:] fs22 = img[:,:, :h//2, :w//2] output = torch.cat([torch.cat([fs11, fs21], dim=2), torch.cat([fs12, fs22], dim=2)], dim=3) return output
[docs]class RCAB(nn.Module): def __init__(self, size_psc=128): #size_psc:crop_size input_shape:depth super().__init__() self.size_psc = size_psc self.conv_gelu1 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, stride=1, padding="same"), nn.GELU() ) self.conv_gelu2 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, stride=1, padding="same"), nn.GELU() ) self.conv_relu1 = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, stride=1, padding="same"), nn.ReLU() ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_relu2 = nn.Sequential( nn.Conv2d(64, 4, kernel_size=1, stride=1, padding=0), nn.ReLU() ) self.conv_sigmoid = nn.Sequential( nn.Conv2d(4, 64, kernel_size=1, stride=1, padding=0), nn.Sigmoid() )
[docs] def forward(self,x,gamma=0.8): x0=x.to(torch.float32) x = self.conv_gelu1(x) x = self.conv_gelu2(x) x1 = x.to(torch.float32) x = torch.fft.fftn(x.to(torch.float32),dim=(2,3)) x = torch.pow(torch.abs(x)+1e-8, gamma) #abs x = fftshift2d(x, self.size_psc) x = self.conv_relu1(x) x = self.avg_pool(x) x = self.conv_relu2(x) x = self.conv_sigmoid(x) x = x1*x x = x0+x return x
[docs]class ResGroup(nn.Module): def __init__(self, n_RCAB=4, size_psc=128): #size_psc:crop_size input_shape:depth super().__init__() RCABs=[] for _ in range(n_RCAB): RCABs.append(RCAB(size_psc)) self.RCABs=nn.Sequential(*RCABs)
[docs] def forward(self,x): x0=x x=self.RCABs(x) x=x0+x return x
[docs]class DFCAN(nn.Module): """ Fourier channel attention network (DFCAN) for super-resolution. References: `Evaluation and development of deep neural networks for image super-resolution in optical microscopy <https://www.nature.com/articles/s41592-020-01048-5>`_. """ def __init__(self, ndim, input_shape, scale=2, n_ResGroup = 4, n_RCAB = 4): super().__init__() if type( scale ) is tuple: scale = scale[0] self.ndim = ndim size_psc = input_shape[0] self.input=nn.Sequential( nn.Conv2d(input_shape[-1], 64, kernel_size=3, stride=1, padding="same"), nn.GELU() ) ResGroups=[] for _ in range(n_ResGroup): ResGroups.append(ResGroup(n_RCAB=n_RCAB, size_psc=size_psc)) self.RGs = nn.Sequential(*ResGroups) self.conv_gelu=nn.Sequential( nn.Conv2d(64, 64*(scale ** 2), kernel_size=3, stride=1, padding="same"), nn.GELU() ) self.pixel_shuffle = nn.PixelShuffle(scale) self.conv_sigmoid = nn.Sequential( nn.Conv2d(64, input_shape[-1], kernel_size=3, stride=1, padding="same"), nn.Sigmoid() )
[docs] def forward(self,x): x=self.input(x) x=self.RGs(x) x=self.conv_gelu(x) x=self.pixel_shuffle(x) #upsampling x=self.conv_sigmoid(x) return x