Source code for model.dim3.unet_utils

import torch
import torch.nn as nn
import torch.nn.functional as F
from .conv_layers import BasicBlock, Bottleneck, ConvNormAct
import pdb 

[docs] class inconv(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=[3,3,3], block=BasicBlock, norm=nn.BatchNorm3d): super().__init__() if isinstance(kernel_size, int): kernel_size = [kernel_size] * 3 pad_size = [i//2 for i in kernel_size] self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=pad_size, bias=False) self.conv2 = block(out_ch, out_ch, kernel_size=kernel_size, norm=norm)
[docs] def forward(self, x): out = self.conv1(x) out = self.conv2(out) return out
[docs] class down_block(nn.Module): def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, kernel_size=[3,3,3], down_scale=[2,2,2], pool=True, norm=nn.BatchNorm3d): super().__init__() if isinstance(kernel_size, int): kernel_size = [kernel_size] * 3 if isinstance(down_scale, int): down_scale = [down_scale] * 3 block_list = [] if pool: block_list.append(nn.MaxPool3d(down_scale)) block_list.append(block(in_ch, out_ch, kernel_size=kernel_size, norm=norm)) else: block_list.append(block(in_ch, out_ch, stride=down_scale, kernel_size=kernel_size, norm=norm)) for i in range(num_block-1): block_list.append(block(out_ch, out_ch, stride=1, kernel_size=kernel_size, norm=norm)) self.conv = nn.Sequential(*block_list)
[docs] def forward(self, x): return self.conv(x)
[docs] class up_block(nn.Module): def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, kernel_size=[3,3,3], up_scale=[2,2,2], norm=nn.BatchNorm3d): super().__init__() if isinstance(kernel_size, int): kernel_size = [kernel_size] * 3 if isinstance(up_scale, int): up_scale = [up_scale] * 3 self.up_scale = up_scale block_list = [] block_list.append(block(in_ch+out_ch, out_ch, kernel_size=kernel_size, norm=norm)) for i in range(num_block-1): block_list.append(block(out_ch, out_ch, kernel_size=kernel_size, norm=norm)) self.conv = nn.Sequential(*block_list)
[docs] def forward(self, x1, x2): x1 = F.interpolate(x1, size=x2.shape[2:], mode='trilinear', align_corners=True) out = torch.cat([x2, x1], dim=1) out = self.conv(out) return out