# original U-Net
# Modified from https://github.com/milesial/Pytorch-UNet
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_utils import inconv, down_block, up_block
from .utils import get_block, get_norm
import pdb
[docs]
class UNet(nn.Module):
def __init__(
self,
in_ch,
base_ch,
scale=[2,2,2,2],
kernel_size=[3,3,3,3],
num_classes=1,
block='ConvNormAct',
pool=True,
norm='bn'
):
super().__init__()
'''
Args:
in_ch: the num of input channel
base_ch: the num of channels in the entry level
scale: should be a list to indicate the downsample scale along each axis
in each level, e.g. [1, 1, 2, 2] such that all axis use the same scale
or [[1,2,2], [2,2,2], [2,2,2], [2,2,2]] for difference scale on each axis
kernel_size: the 3D kernel size of each level
e.g. [3,3,3,3] or [[1,3,3], [1,3,3], [3,3,3], [3,3,3]]
num_classes: the target class number
block: 'ConvNormAct' for origin UNet, 'BasicBlock' for ResUNet
pool: use maxpool or use strided conv for downsample
norm: the norm layer type, bn or in
'''
num_block = 2
block = get_block(block)
norm = get_norm(norm)
self.inc = inconv(in_ch, base_ch, block=block, kernel_size=kernel_size[0], norm=norm)
self.down1 = down_block(base_ch, 2*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[0], kernel_size=kernel_size[1], norm=norm)
self.down2 = down_block(2*base_ch, 4*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[1], kernel_size=kernel_size[2], norm=norm)
self.down3 = down_block(4*base_ch, 8*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[2], kernel_size=kernel_size[3], norm=norm)
self.down4 = down_block(8*base_ch, 10*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[3], kernel_size=kernel_size[4], norm=norm)
self.up1 = up_block(10*base_ch, 8*base_ch, num_block=num_block, block=block, up_scale=scale[3], kernel_size=kernel_size[3], norm=norm)
self.up2 = up_block(8*base_ch, 4*base_ch, num_block=num_block, block=block, up_scale=scale[2], kernel_size=kernel_size[2], norm=norm)
self.up3 = up_block(4*base_ch, 2*base_ch, num_block=num_block, block=block, up_scale=scale[1], kernel_size=kernel_size[1], norm=norm)
self.up4 = up_block(2*base_ch, base_ch, num_block=num_block, block=block, up_scale=scale[0], kernel_size=kernel_size[0], norm=norm)
self.outc = nn.Conv3d(base_ch, num_classes, kernel_size=1)
[docs]
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
out = self.up1(x5, x4)
out = self.up2(out, x3)
out = self.up3(out, x2)
out = self.up4(out, x1)
out = self.outc(out)
return out