Source code for model.dim3.medformer_utils

import torch
import torch.nn as nn
import torch.nn.functional as F
from .conv_layers import BasicBlock, Bottleneck, ConvNormAct, DepthwiseSeparableConv, MBConv, FusedMBConv
from .trans_layers import TransformerBlock, LayerNorm

from einops import rearrange
import pdb


[docs] class BidirectionAttention(nn.Module): def __init__(self, feat_dim, map_dim, out_dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., map_size=[8,8,8], proj_type='depthwose', kernel_size=[3,3,3], no_map_out=False): super().__init__() self.inner_dim = dim_head * heads self.feat_dim = feat_dim self.map_dim = map_dim self.heads = heads self.scale = dim_head ** (-0.5) self.dim_head = dim_head self.map_size = map_size assert proj_type in ['linear', 'depthwise'] if proj_type == 'linear': self.feat_qv = nn.Conv3d(feat_dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=False) self.feat_out = nn.Conv3d(self.inner_dim, out_dim, kernel_size=1, stride=1, padding=0, bias=False) else: self.feat_qv = DepthwiseSeparableConv(feat_dim, self.inner_dim*2, kernel_size=kernel_size) self.feat_out = DepthwiseSeparableConv(self.inner_dim, out_dim, kernel_size=kernel_size) self.map_qv = nn.Conv3d(map_dim, self.inner_dim*2, kernel_size=1, stride=1, padding=0, bias=False) if no_map_out: self.map_out = nn.Identity() else: self.map_out = nn.Conv3d(self.inner_dim, map_dim, kernel_size=1, stride=1, padding=0, bias=False) self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop)
[docs] def rearrange1(self, x, heads): # substitue for rearrange as it is not supported by pytorch2.0 torch.compile # b (dim_head heads) d h w -> b heads (d h w) dim_head b, l, d, h, w = x.shape dim_head = int(l / heads) x = x.view(b, dim_head, heads, -1).contiguous() # b dim_head heads (dhw) x = x.permute(0, 2, 3, 1).contiguous() # b heads (dhw) dim_head return x
[docs] def rearrange2(self, x, d, h, w): # substitue for rearrange as it is not supported by pytorch2.0 torch.compile # 'b heads (d h w) dim_head -> b (dim_head heads) d h w' b, heads, l, dim_head = x.shape x = x.permute(0, 3, 1, 2).contiguous() # b, dim_head, heads, l x = x.view(b, (heads*dim_head), d, h, w).contiguous() return x
[docs] def forward(self, feat, semantic_map): B, C, D, H, W = feat.shape feat_q, feat_v = self.feat_qv(feat).chunk(2, dim=1) # B, inner_dim, D, H, W map_q, map_v = self.map_qv(semantic_map).chunk(2, dim=1) # B, inner_dim, ms, ms, ms #feat_q, feat_v = map(lambda t: rearrange(t, 'b (dim_head heads) d h w -> b heads (d h w) dim_head', dim_head=self.dim_head, heads=self.heads, d=D, h=H, w=W, b=B), [feat_q, feat_v]) feat_q, feat_v = map(lambda t: self.rearrange1(t, self.heads), [feat_q, feat_v]) #map_q, map_v = map(lambda t: rearrange(t, 'b (dim_head heads) d h w -> b heads (d h w) dim_head', dim_head=self.dim_head, heads=self.heads, d=self.map_size[0], h=self.map_size[1], w = self.map_size[2], b=B), [map_q, map_v]) map_q, map_v = map(lambda t: self.rearrange1(t, self.heads), [map_q, map_v]) attn = torch.einsum('bhid,bhjd->bhij', feat_q, map_q) attn *= self.scale feat_map_attn = F.softmax(attn, dim=-1) # semantic map is very concise that don't need dropout, # add dropout might cause unstable during training map_feat_attn = self.attn_drop(F.softmax(attn, dim=-2)) feat_out = torch.einsum('bhij,bhjd->bhid', feat_map_attn, map_v) #feat_out = rearrange(feat_out, 'b heads (d h w) dim_head -> b (dim_head heads) d h w', d=D, h=H, w=W) feat_out = self.rearrange2(feat_out, d=D, h=H, w=W) map_out = torch.einsum('bhji,bhjd->bhid', map_feat_attn, feat_v) #map_out = rearrange(map_out, 'b heads (d h w) dim_head -> b (dim_head heads) d h w', d=self.map_size[0], h=self.map_size[1], w=self.map_size[2]) map_out = self.rearrange2(map_out, d=self.map_size[0], h=self.map_size[1], w=self.map_size[2]) feat_out = self.proj_drop(self.feat_out(feat_out)) map_out = self.map_out(map_out) return feat_out, map_out
[docs] class BidirectionAttentionBlock(nn.Module): def __init__(self, feat_dim, map_dim, out_dim, heads, dim_head, norm=nn.BatchNorm3d, act=nn.ReLU, expansion=4, attn_drop=0., proj_drop=0., map_size=[8, 8, 8], proj_type='depthwise', kernel_size=[3,3,3], no_map_out=False): super().__init__() assert norm in [nn.BatchNorm3d, nn.InstanceNorm3d, LayerNorm, True, False] assert act in [nn.ReLU, nn.ReLU6, nn.GELU, nn.SiLU, True, False] assert proj_type in ['linear', 'depthwise'] self.norm1 = norm(feat_dim) if norm else nn.Identity() # norm layer for feature map self.norm2 = norm(map_dim) if norm else nn.Identity() # norm layer for semantic map self.attn = BidirectionAttention(feat_dim, map_dim, out_dim, heads, dim_head, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, kernel_size=kernel_size, no_map_out=no_map_out) self.shortcut = nn.Sequential() if feat_dim != out_dim: self.shortcut = ConvNormAct(feat_dim, out_dim, 1, padding=0, norm=norm, act=act, preact=True) if proj_type == 'linear': self.feedforward = FusedMBConv(out_dim, out_dim, expansion=expansion, kernel_size=1, act=act, norm=norm) else: self.feedforward = MBConv(out_dim, out_dim, expansion=expansion, kernel_size=kernel_size, act=act, norm=norm)
[docs] def forward(self, x, semantic_map): feat = self.norm1(x) mapp = self.norm2(semantic_map) out, mapp = self.attn(feat, mapp) out += self.shortcut(x) out = self.feedforward(out) mapp += semantic_map return out, mapp
[docs] class PatchMerging(nn.Module): """ Modified patch merging layer that works as down-sampling """ def __init__(self, dim, out_dim, norm=nn.BatchNorm3d, proj_type='linear', down_scale=[2,2,2], kernel_size=[3,3,3]): super().__init__() self.dim = dim assert proj_type in ['linear', 'depthwise'] self.down_scale = down_scale merged_dim = 2 ** down_scale.count(2) * dim if proj_type == 'linear': self.reduction = nn.Conv3d(merged_dim, out_dim, kernel_size=1, bias=False) else: self.reduction = DepthwiseSeparableConv(merged_dim, out_dim, kernel_size=kernel_size) self.norm = norm(merged_dim)
[docs] def forward(self, x): """ x: B, C, D, H, W """ merged_x = [] for i in range(self.down_scale[0]): for j in range(self.down_scale[1]): for k in range(self.down_scale[2]): tmp_x = x[:, :, i::self.down_scale[0], j::self.down_scale[1], k::self.down_scale[2]] merged_x.append(tmp_x) x = torch.cat(merged_x, 1) x = self.norm(x) x = self.reduction(x) return x
[docs] class BasicLayer(nn.Module): """ A basic transformer layer for one stage No downsample or upsample operation in this layer, they are wrapped in the down_block of up_block """ def __init__(self, feat_dim, map_dim, out_dim, num_blocks, heads=4, dim_head=64, expansion=4, attn_drop=0., proj_drop=0., map_size=[8,8,8], proj_type='depthwise', norm=nn.BatchNorm3d, act=nn.GELU, kernel_size=[3,3,3], no_map_out=False): super().__init__() dim1 = feat_dim dim2 = out_dim self.blocks = nn.ModuleList([]) for i in range(num_blocks): no_map_out_args = False if i != (num_blocks-1) else no_map_out self.blocks.append(BidirectionAttentionBlock(dim1, map_dim, dim2, heads, dim_head, expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, kernel_size=kernel_size, no_map_out=no_map_out_args)) dim1 = out_dim
[docs] def forward(self, x, semantic_map): for block in self.blocks: x, semantic_map = block(x, semantic_map) return x, semantic_map
[docs] class SemanticMapGeneration(nn.Module): def __init__(self, feat_dim, map_dim, map_size): super().__init__() self.map_size = map_size self.map_dim = map_dim self.map_code_num = map_size[0] * map_size[1] * map_size[2] self.base_proj = nn.Conv3d(feat_dim, map_dim, kernel_size=3, padding=1, bias=False) self.semantic_proj = nn.Conv3d(feat_dim, self.map_code_num, kernel_size=3, padding=1, bias=False)
[docs] def forward(self, x): B, C, D, H, W = x.shape feat = self.base_proj(x) #B, map_dim, d, h, w weight_map = self.semantic_proj(x) # B, map_code_num, d, h, w weight_map = weight_map.view(B, self.map_code_num, -1) weight_map = F.softmax(weight_map, dim=2) #B, map_code_num, dhw) feat = feat.view(B, self.map_dim, -1) # B, map_dim, dhw semantic_map = torch.einsum('bij,bkj->bik', feat, weight_map) return semantic_map.view(B, self.map_dim, self.map_size[0], self.map_size[1], self.map_size[2])
[docs] class SemanticMapFusion(nn.Module): def __init__(self, in_dim_list, dim, heads, depth=1, norm=nn.BatchNorm3d, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim # project all maps to the same channel num self.in_proj = nn.ModuleList([]) for i in range(len(in_dim_list)): self.in_proj.append(nn.Conv3d(in_dim_list[i], dim, kernel_size=1, bias=False)) self.fusion = TransformerBlock(dim, depth, heads, dim//heads, dim, attn_drop=attn_drop, proj_drop=proj_drop) # project all maps back to their origin channel num self.out_proj = nn.ModuleList([]) for i in range(len(in_dim_list)): self.out_proj.append(nn.Conv3d(dim, in_dim_list[i], kernel_size=1, bias=False))
[docs] def forward(self, map_list): B, _, D, H, W = map_list[0].shape proj_maps = [self.in_proj[i](map_list[i]).view(B, self.dim, -1).permute(0, 2, 1) for i in range(len(map_list))] #B, L, C where L=DHW proj_maps = torch.cat(proj_maps, dim=1) attned_maps = self.fusion(proj_maps) attned_maps = attned_maps.chunk(len(map_list), dim=1) maps_out = [self.out_proj[i](attned_maps[i].permute(0, 2, 1).view(B, self.dim, D, H, W)) for i in range(len(map_list))] return maps_out
################################################################
[docs] class inconv(nn.Module): def __init__(self, in_ch, out_ch, kernel_size=[3,3,3], block=BasicBlock, norm=nn.BatchNorm3d, act=nn.GELU): super().__init__() pad_size = [i//2 for i in kernel_size] self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=pad_size, bias=False) self.conv2 = block(out_ch, out_ch, kernel_size=kernel_size, norm=norm, act=act)
[docs] def forward(self, x): out = self.conv1(x) out = self.conv2(out) return out
[docs] class down_block(nn.Module): def __init__(self, in_ch, out_ch, conv_num, trans_num, down_scale=[2,2,2], kernel_size=[3,3,3], conv_block=BasicBlock, heads=4, dim_head=64, expansion=1, attn_drop=0., proj_drop=0., map_size=[8,8,8], proj_type='depthwise', norm=nn.BatchNorm3d, act=nn.GELU, map_generate=False, map_dim=None): super().__init__() map_dim = out_ch if map_dim is None else map_dim self.map_generate = map_generate if map_generate: self.map_gen = SemanticMapGeneration(out_ch, map_dim, map_size) self.patch_merging = PatchMerging(in_ch, out_ch, norm=norm, proj_type=proj_type, down_scale=down_scale, kernel_size=kernel_size) block_list = [] for i in range(conv_num): block_list.append(conv_block(out_ch, out_ch, norm=norm, act=act, kernel_size=kernel_size)) self.conv_blocks = nn.Sequential(*block_list) self.trans_blocks = BasicLayer(out_ch, map_dim, out_ch, num_blocks=trans_num, heads=heads, \ dim_head=dim_head, norm=norm, act=act, expansion=expansion, attn_drop=attn_drop, \ proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, kernel_size=kernel_size)
[docs] def forward(self, x): x = self.patch_merging(x) out = self.conv_blocks(x) if self.map_generate: semantic_map = self.map_gen(out) else: semantic_map = None out, semantic_map = self.trans_blocks(out, semantic_map) return out, semantic_map
[docs] class up_block(nn.Module): def __init__(self, in_ch, out_ch, conv_num, trans_num, up_scale=[2,2,2], kernel_size=[3,3,3], conv_block=BasicBlock, heads=4, dim_head=64, expansion=4, attn_drop=0., proj_drop=0., map_size=[4,8,8], proj_type='depthwise', norm=nn.BatchNorm3d, act=nn.GELU, map_dim=None, map_shortcut=False, no_map_out=False): super().__init__() self.map_shortcut = map_shortcut map_dim = out_ch if map_dim is None else map_dim if map_shortcut: self.map_reduction = nn.Conv3d(in_ch+out_ch, map_dim, kernel_size=1, bias=False) else: self.map_reduction = nn.Identity() #nn.Conv3d(in_ch, map_dim, kernel_size=1, bias=False) self.trans_blocks = BasicLayer(in_ch+out_ch, map_dim, out_ch, num_blocks=trans_num, \ heads=heads, dim_head=dim_head, norm=norm, act=act, expansion=expansion, \ attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size,\ proj_type=proj_type, kernel_size=kernel_size, no_map_out=no_map_out) if trans_num == 0: dim1 = in_ch+out_ch else: dim1 = out_ch conv_list = [] for i in range(conv_num): conv_list.append(conv_block(dim1, out_ch, kernel_size=kernel_size, norm=norm, act=act)) dim1 = out_ch self.conv_blocks = nn.Sequential(*conv_list)
[docs] def forward(self, x1, x2, map1, map2=None): # x1: low-res feature, x2: high-res feature shortcut from encoder # map1: semantic map from previous low-res layer # map2: semantic map from encoder shortcut, might be none if we don't have the map from encoder x1 = F.interpolate(x1, size=x2.shape[-3:], mode='trilinear', align_corners=True) feat = torch.cat([x1, x2], dim=1) if self.map_shortcut and map2 is not None: semantic_map = torch.cat([map1, map2], dim=1) semantic_map = self.map_reduction(semantic_map) else: semantic_map = map1 out, semantic_map = self.trans_blocks(feat, semantic_map) out = self.conv_blocks(out) return out, semantic_map