Source code for model.dim3.unet

# original U-Net
# Modified from https://github.com/milesial/Pytorch-UNet

import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_utils import inconv, down_block, up_block
from .utils import get_block, get_norm
import pdb


[docs] class UNet(nn.Module): def __init__( self, in_ch, base_ch, scale=[2,2,2,2], kernel_size=[3,3,3,3], num_classes=1, block='ConvNormAct', pool=True, norm='bn' ): super().__init__() ''' Args: in_ch: the num of input channel base_ch: the num of channels in the entry level scale: should be a list to indicate the downsample scale along each axis in each level, e.g. [1, 1, 2, 2] such that all axis use the same scale or [[1,2,2], [2,2,2], [2,2,2], [2,2,2]] for difference scale on each axis kernel_size: the 3D kernel size of each level e.g. [3,3,3,3] or [[1,3,3], [1,3,3], [3,3,3], [3,3,3]] num_classes: the target class number block: 'ConvNormAct' for origin UNet, 'BasicBlock' for ResUNet pool: use maxpool or use strided conv for downsample norm: the norm layer type, bn or in ''' num_block = 2 block = get_block(block) norm = get_norm(norm) self.inc = inconv(in_ch, base_ch, block=block, kernel_size=kernel_size[0], norm=norm) self.down1 = down_block(base_ch, 2*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[0], kernel_size=kernel_size[1], norm=norm) self.down2 = down_block(2*base_ch, 4*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[1], kernel_size=kernel_size[2], norm=norm) self.down3 = down_block(4*base_ch, 8*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[2], kernel_size=kernel_size[3], norm=norm) self.down4 = down_block(8*base_ch, 10*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[3], kernel_size=kernel_size[4], norm=norm) self.up1 = up_block(10*base_ch, 8*base_ch, num_block=num_block, block=block, up_scale=scale[3], kernel_size=kernel_size[3], norm=norm) self.up2 = up_block(8*base_ch, 4*base_ch, num_block=num_block, block=block, up_scale=scale[2], kernel_size=kernel_size[2], norm=norm) self.up3 = up_block(4*base_ch, 2*base_ch, num_block=num_block, block=block, up_scale=scale[1], kernel_size=kernel_size[1], norm=norm) self.up4 = up_block(2*base_ch, base_ch, num_block=num_block, block=block, up_scale=scale[0], kernel_size=kernel_size[0], norm=norm) self.outc = nn.Conv3d(base_ch, num_classes, kernel_size=1)
[docs] def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) out = self.up1(x5, x4) out = self.up2(out, x3) out = self.up3(out, x2) out = self.up4(out, x1) out = self.outc(out) return out