Source code for model.dim3.conv_layers

import torch
import torch.nn as nn
import torch.nn.functional as F
from .trans_layers import LayerNorm
import pdb


__all__ = [
    'ConvNormAct',
    'BasicBlock',
    'Bottleneck',
    'DepthwiseSeparableConv',
]


[docs] class ConvNormAct(nn.Module): """ Layer grouping a convolution, normalization and activation function normalization includes BN as IN """ def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, groups=1, dilation=1, bias=False, norm=nn.BatchNorm3d, act=nn.ReLU, preact=False): super().__init__() assert norm in [nn.BatchNorm3d, nn.InstanceNorm3d, LayerNorm, True, False] assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False] self.conv = nn.Conv3d( in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, dilation=dilation, bias=bias ) if preact: self.norm = norm(in_ch, eps=1e-4) if norm else nn.Identity() else: self.norm = norm(out_ch, eps=1e-4) if norm else nn.Identity() self.act = act() if act else nn.Identity() self.preact = preact
[docs] def forward(self, x): if self.preact: out = self.conv(self.act(self.norm(x))) else: out = self.act(self.norm(self.conv(x))) return out
class SingleConv(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=[3,3,3], stride=1, norm=nn.BatchNorm3d, act=nn.ReLU, preact=False): super().__init__() assert norm in [nn.BatchNorm3d, nn.InstanceNorm3d, LayerNorm, True, False] assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False] pad_size = [i//2 for i in kernel_size] self.conv = ConvNormAct(in_ch, out_ch, kernel_size, stride=stride, padding=pad_size, norm=norm, act=act, preact=preact) def forward(self, x): return self.conv(x)
[docs] class BasicBlock(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=[3,3,3], stride=1, norm=nn.BatchNorm3d, act=nn.ReLU, preact=True): super().__init__() assert norm in [nn.BatchNorm3d, nn.InstanceNorm3d, LayerNorm, True, False] assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False] pad_size = [i//2 for i in kernel_size] self.conv1 = ConvNormAct(in_ch, out_ch, kernel_size, stride=stride, padding=pad_size, norm=norm, act=act, preact=preact) self.conv2 = ConvNormAct(out_ch, out_ch, kernel_size, stride=1, padding=pad_size, norm=norm, act=act, preact=preact) self.shortcut = nn.Sequential() if stride != 1 or in_ch != out_ch: self.shortcut = ConvNormAct(in_ch, out_ch, kernel_size, stride=stride, padding=pad_size, norm=norm, act=act, preact=preact)
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) out += self.shortcut(residual) return out
[docs] class Bottleneck(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=[3,3,3], stride=1, groups=1, dilation=1, norm=nn.BatchNorm3d, act=nn.ReLU, preact=True): super().__init__() assert norm in [nn.BatchNorm3d, nn.InstanceNorm3d, LayerNorm, True, False] assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False] pad_size = [i//2 for i in kernel_size] self.expansion = 2 self.conv1 = ConvNormAct(in_ch, out_ch//self.expansion, 1, stride=1, padding=0, norm=norm, act=act, preact=preact) self.conv2 = ConvNormAct(out_ch//self.expansion, out_ch//self.expansion, kernel_size, stride=stride, padding=pad_size, norm=norm, act=act, groups=groups, dilation=dilation, preact=preact) self.conv3 = ConvNormAct(out_ch//self.expansion, out_ch, 1, stride=1, padding=0, norm=norm, act=act, preact=preact) self.shortcut = nn.Sequential() if stride != 1 or in_ch != out_ch: self.shortcut = ConvNormAct(in_ch, out_ch, kernel_size, stride=stride, padding=pad_size, norm=norm, act=act, preact=preact)
[docs] def forward(self, x): residual = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) out += self.shortcut(residual) return out
[docs] class DepthwiseSeparableConv(nn.Module): def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, bias=False): super().__init__() if isinstance(kernel_size, list): padding = [i//2 for i in kernel_size] else: padding = kernel_size // 2 self.depthwise = nn.Conv3d( in_channels=in_ch, out_channels=in_ch, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_ch, bias=bias ) self.pointwise = nn.Conv3d( in_channels=in_ch, out_channels=out_ch, kernel_size=1, stride=1, padding=0, groups=1, bias=bias )
[docs] def forward(self, x): out = self.depthwise(x) out = self.pointwise(out) return out
class SEBlock(nn.Module): def __init__(self, in_ch, ratio=4, act=nn.ReLU): super().__init__() self.squeeze = nn.AdaptiveAvgPool3d(1) self.excitation = nn.Sequential( nn.Conv3d(in_ch, in_ch//ratio, kernel_size=1), act(), nn.Conv3d(in_ch//ratio, in_ch, kernel_size=1), nn.Sigmoid() ) def forward(self, x): out = self.squeeze(x) out = self.excitation(out) return x * out class DropPath(nn.Module): def __init__(self, p=0): super().__init__() self.p = p def forward(self, x): if (not self.p) or (not self.training): return x batch_size = x.shape[0] random_tensor = torch.rand(batch_size, 1, 1, 1, 1).to(x.device) binary_mask = self.p < random_tensor x = x.div(1 - self.p) x = x * binary_mask return x class MBConv(nn.Module): def __init__(self, in_ch, out_ch, expansion=4, kernel_size=3, stride=1, ratio=4, p=0, se=True, norm=nn.BatchNorm3d, act=nn.ReLU): super().__init__() if isinstance(kernel_size, int): padding = (kernel_size - 1) // 2 else: padding = [(t-1)//2 for t in kernel_size] expanded = expansion * in_ch self.se = se self.expand_proj = nn.Identity() if (expansion==1) else ConvNormAct(in_ch, expanded, kernel_size=1, padding=0, norm=norm, act=act, preact=True) self.depthwise = ConvNormAct(expanded, expanded, kernel_size=kernel_size, stride=stride, padding=padding, groups=expanded, act=act, norm=norm, preact=True) if self.se: self.se = SEBlock(expanded, ratio=ratio) self.pointwise = ConvNormAct(expanded, out_ch, kernel_size=1, padding=0, norm=norm, act=False, preact=True) self.drop_path = DropPath(p) self.shortcut = nn.Sequential() if in_ch != out_ch or stride !=1: self.shortcut = nn.Sequential(ConvNormAct(in_ch, out_ch, kernel_size, stride=stride, padding=padding, norm=False, act=False)) def forward(self, x): residual = x x = self.expand_proj(x) x = self.depthwise(x) if self.se: x = self.se(x) x = self.pointwise(x) x = self.drop_path(x) x += self.shortcut(residual) return x class FusedMBConv(nn.Module): def __init__(self, in_ch, out_ch, expansion=4, kernel_size=3, stride=1, ratio=4, p=0, se=True, norm=nn.BatchNorm3d, act=nn.ReLU): super().__init__() if isinstance(kernel_size, int): padding = (kernel_size -1) // 2 else: padding = [(t-1)//2 for t in kernel_size] expanded = expansion * in_ch self.stride= stride self.se = se self.conv3x3 = ConvNormAct(in_ch, expanded, kernel_size=kernel_size, stride=stride, padding=padding, groups=1, norm=norm, act=act, preact=True) if self.se: self.se_block = SEBlock(expanded, ratio=ratio) self.pointwise = ConvNormAct(expanded, out_ch, kernel_size=1, padding=0, norm=norm, act=False, preact=True) self.drop_path = DropPath(p) self.shortcut = nn.Sequential() if in_ch != out_ch or stride !=1: self.shortcut = nn.Sequential(ConvNormAct(in_ch, out_ch, kernel_size, stride=stride, padding=padding, norm=False, act=False)) def forward(self, x): residual = x x = self.conv3x3(x) if self.se: x = self.se_block(x) x = self.pointwise(x) x = self.drop_path(x) x = x + self.shortcut(residual) return x