Source code for biapy.models.simple_cnn

"""
This module implements a simple Convolutional Neural Network (CNN) for image classification tasks. It is designed to be a straightforward and adaptable model for both 2D and 3D image inputs.

The `simple_CNN` class constructs a network composed of two main convolutional
blocks, each followed by batch normalization, activation, pooling, and dropout.
A final dense layer with Softmax activation is used for classification.

The architecture is flexible, automatically adapting to 2D or 3D input based
on the provided `image_shape`.

Classes:

- ``simple_CNN``: The main class for creating the simple CNN model.

This module uses a helper function `get_activation` from `biapy.models.blocks`
to dynamically select the activation function.
"""
import torch.nn as nn
from typing import Dict

from biapy.models.blocks import get_activation


[docs] class simple_CNN(nn.Module): """ Create a simple Convolutional Neural Network (CNN) model. This CNN architecture is designed for classification tasks and can handle both 2D and 3D image inputs. It consists of two main convolutional blocks followed by pooling and dropout, culminating in a fully connected layer for classification. Parameters ---------- image_shape : Tuple[int, ...] Dimensions of the input image. - For 2D: `(height, width, channels)` - For 3D: `(depth, height, width, channels)` The last element `image_shape[-1]` should be the number of input channels. activation : str, optional Name of the activation layer to use within the convolutional blocks. Defaults to "ReLU". n_classes : int, optional Number of output classes for the classification task. Defaults to 2. Returns ------- model : nn.Module The constructed simple CNN model. """ def __init__(self, image_shape, activation="ReLU", n_classes=2): """ Initialize the simple CNN model. Sets up the convolutional layers, batch normalization, pooling, dropout, and the final classification head based on the input image dimensions and specified parameters. It dynamically selects 2D or 3D layers. Parameters ---------- image_shape : Tuple[int, ...] Dimensions of the input image. - For 2D: `(height, width, channels)` - For 3D: `(depth, height, width, channels)` The last element is the number of input channels. activation : str, optional Name of the activation layer to use (e.g., "ReLU", "ELU", "SiLU"). Defaults to "ReLU". n_classes : int, optional Number of output classes for the classification task. Defaults to 2. """ super(simple_CNN, self).__init__() self.ndim = 3 if len(image_shape) == 4 else 2 if self.ndim == 3: conv = nn.Conv3d batchnorm_layer = nn.BatchNorm3d pool = nn.MaxPool3d else: conv = nn.Conv2d batchnorm_layer = nn.BatchNorm2d pool = nn.MaxPool2d firt_block_features = 32 second_block_features = 64 # Block 1 activation = get_activation(activation) self.block1 = nn.Sequential( conv(image_shape[-1], firt_block_features, kernel_size=3, padding="same"), batchnorm_layer(firt_block_features), activation, conv(firt_block_features, firt_block_features, kernel_size=3, padding="same"), batchnorm_layer(firt_block_features), activation, conv(firt_block_features, firt_block_features, kernel_size=5, padding="same"), pool(2), batchnorm_layer(firt_block_features), activation, nn.Dropout(0.4), ) # Block 2 self.block2 = nn.Sequential( conv( firt_block_features, second_block_features, kernel_size=3, padding="same", ), activation, batchnorm_layer(second_block_features), conv( second_block_features, second_block_features, kernel_size=3, padding="same", ), activation, batchnorm_layer(second_block_features), conv( second_block_features, second_block_features, kernel_size=5, padding="same", ), pool(2), activation, batchnorm_layer(second_block_features), nn.Dropout(0.4), ) # Last convolutional block if self.ndim == 2: h = image_shape[0] // 4 w = image_shape[1] // 4 f = h * w * second_block_features else: z = image_shape[0] // 4 h = image_shape[1] // 4 w = image_shape[2] // 4 f = z * h * w * second_block_features self.heads = nn.Sequential( nn.Flatten(), nn.Dropout(0.5), nn.Linear(f, n_classes), nn.Softmax(dim=1), )
[docs] def forward(self, x) -> Dict: """ Perform the forward pass of the simple CNN model. The input `x` passes sequentially through `block1`, `block2`, and then the `last_block` which flattens the features and applies a linear layer with Softmax for classification. Parameters ---------- x : torch.Tensor The input image tensor. Expected shape for 2D: `(batch_size, channels, height, width)`. Expected shape for 3D: `(batch_size, channels, depth, height, width)`. Returns ------- Dict A dictionary containing the classification probabilities. The key is typically 'out' or similar, mapping to a `torch.Tensor` of shape `(batch_size, n_classes)`. """ out = self.block1(x) out = self.block2(out) out = self.heads(out) return out