Source code for biapy.models.wdsr

import math
import torch
import torch.nn as nn
import torch.nn.init as init

[docs]class wdsr(nn.Module): """ WDSR model. Reference: `Wide Activation for Efficient and Accurate Image Super-Resolution <https://arxiv.org/abs/1808.08718>`_. Adapted from `here <https://github.com/yjn870/WDSR-pytorch/tree/master>`_. """ def __init__(self, scale, num_filters=32, num_res_blocks=16, res_block_expansion=6, num_channels=1): super(wdsr, self).__init__() if type( scale ) is tuple: scale = scale[0] kernel_size = 3 skip_kernel_size = 5 weight_norm = torch.nn.utils.weight_norm num_outputs = scale * scale * num_channels body = [] conv = weight_norm( nn.Conv2d( num_channels, num_filters, kernel_size, padding=kernel_size // 2)) init.ones_(conv.weight_g) init.zeros_(conv.bias) body.append(conv) for _ in range(num_res_blocks): body.append( Block( num_filters, kernel_size, res_block_expansion, weight_norm=weight_norm, res_scale=1 / math.sqrt(num_res_blocks), )) conv = weight_norm( nn.Conv2d( num_filters, num_outputs, kernel_size, padding=kernel_size // 2)) init.ones_(conv.weight_g) init.zeros_(conv.bias) body.append(conv) self.body = nn.Sequential(*body) skip = [] if num_channels != num_outputs: conv = weight_norm( nn.Conv2d( num_channels, num_outputs, skip_kernel_size, padding=skip_kernel_size // 2)) init.ones_(conv.weight_g) init.zeros_(conv.bias) skip.append(conv) self.skip = nn.Sequential(*skip) shuf = [] if scale > 1: shuf.append(nn.PixelShuffle(scale)) self.shuf = nn.Sequential(*shuf)
[docs] def forward(self, x): x = self.body(x) + self.skip(x) x = self.shuf(x) return x
[docs]class Block(nn.Module): def __init__(self, num_residual_units, kernel_size, width_multiplier=1, weight_norm=torch.nn.utils.weight_norm, res_scale=1): super(Block, self).__init__() body = [] conv = weight_norm( nn.Conv2d( num_residual_units, int(num_residual_units * width_multiplier), kernel_size, padding=kernel_size // 2)) init.constant_(conv.weight_g, 2.0) init.zeros_(conv.bias) body.append(conv) body.append(nn.ReLU(True)) conv = weight_norm( nn.Conv2d( int(num_residual_units * width_multiplier), num_residual_units, kernel_size, padding=kernel_size // 2)) init.constant_(conv.weight_g, res_scale) init.zeros_(conv.bias) body.append(conv) self.body = nn.Sequential(*body)
[docs] def forward(self, x): x = self.body(x) + x return x