Source code for biapy.models.efficientnet

import torchvision.models as models
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url

import torch.nn as nn

[docs]def efficientnet(efficientnet_name, image_shape, n_classes=2, load_imagenet_weights=True): """ Create EfficientNet. Reference: `EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks <https://arxiv.org/abs/1905.11946>`_. Parameters ---------- efficientnet_name : str, optional Efficientnet model name to be loaded. Available options: "efficientnet_b[0-7]" image_shape : 2D tuple Dimensions of the input image. n_classes: int, optional Number of classes. Returns ------- model : Torch model EfficientNet model. """ def get_state_dict(self, *args, **kwargs): kwargs.pop("check_hash") return load_state_dict_from_url(self.url, *args, **kwargs) WeightsEnum.get_state_dict = get_state_dict model = getattr(models, efficientnet_name)(weights="IMAGENET1K_V1" if load_imagenet_weights else None) # Change the final classification head. model.classifier[1] = nn.Linear(in_features=1280, out_features=n_classes) return model