from einops import rearrange
from copy import deepcopy
from .nnformer_utils import softmax_helper
from torch import nn
import torch
import numpy as np
from .nnformer_utils import InitWeights_He
from .nnformer_utils import SegmentationNetwork
import torch.nn.functional
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_3tuple, trunc_normal_
[docs]
class ContiguousGrad(torch.autograd.Function):
[docs]
@staticmethod
def forward(ctx, x):
return x
[docs]
@staticmethod
def backward(ctx, grad_out):
return grad_out.contiguous()
[docs]
class Mlp(nn.Module):
""" Multilayer perceptron."""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
[docs]
def window_partition(x, window_size):
B, S, H, W, C = x.shape
x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C)
return windows
[docs]
def window_reverse(windows, window_size, S, H, W):
B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size))
x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1)
return x
[docs]
class WindowAttention_kv(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1),
num_heads))
# get pair-wise relative position index for each token inside the window
coords_s = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1
relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
trunc_normal_(self.relative_position_bias_table, std=.02)
[docs]
def forward(self, skip,x_up,pos_embed=None, mask=None):
B_, N, C = skip.shape
kv = self.kv(skip)
q = x_up
kv=kv.reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
q = q.reshape(B_,N,self.num_heads,C//self.num_heads).permute(0,2,1,3).contiguous()
k,v = kv[0], kv[1]
q = q * self.scale
attn = (q @ k.transpose(-2, -1).contiguous())
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.window_size[0] * self.window_size[1] * self.window_size[2], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
if pos_embed is not None:
x = x + pos_embed
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1),
num_heads))
# get pair-wise relative position index for each token inside the window
coords_s = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1
relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, x, mask=None,pos_embed=None):
B_, N, C = x.shape
qkv = self.qkv(x)
qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1).contiguous())
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.window_size[0] * self.window_size[1] * self.window_size[2], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
if pos_embed is not None:
x = x+pos_embed
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class PatchMerging(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Conv3d(dim,dim*2,kernel_size=3,stride=2,padding=1)
self.norm = norm_layer(dim)
[docs]
def forward(self, x, S, H, W):
B, L, C = x.shape
assert L == H * W * S, "input feature has wrong size"
x = x.view(B, S, H, W, C)
x = F.gelu(x)
x = self.norm(x)
x=x.permute(0,4,1,2,3).contiguous()
x=self.reduction(x)
x=x.permute(0,2,3,4,1).contiguous().view(B,-1,2*C)
return x
[docs]
class Patch_Expanding(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.norm = norm_layer(dim)
self.up=nn.ConvTranspose3d(dim,dim//2,2,2)
[docs]
def forward(self, x, S, H, W):
B, L, C = x.shape
assert L == H * W * S, "input feature has wrong size"
x = x.view(B, S, H, W, C)
x = self.norm(x)
x=x.permute(0,4,1,2,3).contiguous()
x = self.up(x)
x = ContiguousGrad.apply(x)
x=x.permute(0,2,3,4,1).contiguous().view(B,-1,C//2)
return x
[docs]
class BasicLayer(nn.Module):
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=True
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
[docs]
def forward(self, x, S, H, W):
# calculate attention mask for SW-MSA
Sp = int(np.ceil(S / self.window_size)) * self.window_size
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
s_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for s in s_slices:
for h in h_slices:
for w in w_slices:
img_mask[:, s, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, S, H, W)
Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2
return x, S, H, W, x_down, Ws, Wh, Ww
else:
return x, S, H, W, x, S, H, W
[docs]
class BasicLayer_up(nn.Module):
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
upsample=True
):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList()
self.blocks.append(
SwinTransformerBlock_kv(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 ,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[0] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
)
for i in range(depth-1):
self.blocks.append(
SwinTransformerBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=window_size // 2 ,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i+1] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
)
self.Upsample = upsample(dim=2*dim, norm_layer=norm_layer)
[docs]
def forward(self, x,skip, S, H, W):
x_up = self.Upsample(x, S, H, W)
x = x_up + skip
S, H, W = S * 2, H * 2, W * 2
# calculate attention mask for SW-MSA
Sp = int(np.ceil(S / self.window_size)) * self.window_size
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
s_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for s in s_slices:
for h in h_slices:
for w in w_slices:
img_mask[:, s, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size * self.window_size) # 3d��3��winds�˻�����Ŀ�Ǻܴ�ģ�����winds����̫��
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
x = self.blocks[0](x, attn_mask,skip=skip,x_up=x_up)
for i in range(self.depth-1):
x = self.blocks[i+1](x,attn_mask)
return x, S, H, W
[docs]
class project(nn.Module):
def __init__(self,in_dim,out_dim,stride,padding,activate,norm,last=False):
super().__init__()
self.out_dim=out_dim
self.conv1=nn.Conv3d(in_dim,out_dim,kernel_size=3,stride=stride,padding=padding)
self.conv2=nn.Conv3d(out_dim,out_dim,kernel_size=3,stride=1,padding=1)
self.activate=activate()
self.norm1=norm(out_dim)
self.last=last
if not last:
self.norm2=norm(out_dim)
[docs]
def forward(self,x):
x=self.conv1(x)
x=self.activate(x)
#norm1
Ws, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.norm1(x)
x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww)
x=self.conv2(x)
if not self.last:
x=self.activate(x)
#norm2
Ws, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.norm2(x)
x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww)
return x
[docs]
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_chans=4, embed_dim=96, norm_layer=None):
super().__init__()
patch_size = to_3tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
stride1=[patch_size[0],patch_size[1]//2,patch_size[2]//2]
stride2=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2]
self.proj1 = project(in_chans,embed_dim//2,stride1,1,nn.GELU,nn.LayerNorm,False)
self.proj2 = project(embed_dim//2,embed_dim,stride2,1,nn.GELU,nn.LayerNorm,True)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
[docs]
def forward(self, x):
"""Forward function."""
# padding
_, _, S, H, W = x.size()
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if S % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - S % self.patch_size[0]))
x = self.proj1(x) # B C Ws Wh Ww
x = self.proj2(x) # B C Ws Wh Ww
if self.norm is not None:
Ws, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.norm(x)
x = x.transpose(1, 2).contiguous().view(-1, self.embed_dim, Ws, Wh, Ww)
return x
[docs]
class Encoder(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_chans=1 ,
embed_dim=96,
depths=[2, 2, 2, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm,
patch_norm=True,
out_indices=(0, 1, 2, 3)
):
super().__init__()
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
pretrain_img_size[0] // patch_size[0] // 2 ** i_layer, pretrain_img_size[1] // patch_size[1] // 2 ** i_layer,
pretrain_img_size[2] // patch_size[2] // 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size[i_layer],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(
depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging
if (i_layer < self.num_layers - 1) else None
)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
# add a norm layer for each output
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
[docs]
def forward(self, x):
"""Forward function."""
x = self.patch_embed(x)
down=[]
Ws, Wh, Ww = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2).contiguous()
x = self.pos_drop(x)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, S, H, W, x, Ws, Wh, Ww = layer(x, Ws, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, S, H, W, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous()
down.append(out)
return down
[docs]
class Decoder(nn.Module):
def __init__(self,
pretrain_img_size,
embed_dim,
patch_size=4,
depths=[2,2,2],
num_heads=[24,12,6],
window_size=4,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_layers = len(depths)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers)[::-1]:
layer = BasicLayer_up(
dim=int(embed_dim * 2 ** (len(depths)-i_layer-1)),
input_resolution=(
pretrain_img_size[0] // patch_size[0] // 2 ** (len(depths)-i_layer-1), pretrain_img_size[1] // patch_size[1] // 2 ** (len(depths)-i_layer-1),
pretrain_img_size[2] // patch_size[2] // 2 ** (len(depths)-i_layer-1)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size[i_layer],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(
depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
upsample=Patch_Expanding
)
self.layers.append(layer)
self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
[docs]
def forward(self,x,skips):
outs=[]
S, H, W = x.size(2), x.size(3), x.size(4)
x = x.flatten(2).transpose(1, 2).contiguous()
for index,i in enumerate(skips):
i = i.flatten(2).transpose(1, 2).contiguous()
skips[index]=i
x = self.pos_drop(x)
for i in range(self.num_layers)[::-1]:
layer = self.layers[i]
x, S, H, W, = layer(x,skips[i], S, H, W)
out = x.view(-1, S, H, W, self.num_features[i])
outs.append(out)
return outs
[docs]
class final_patch_expanding(nn.Module):
def __init__(self,dim,num_class,patch_size):
super().__init__()
self.up=nn.ConvTranspose3d(dim,num_class,patch_size,patch_size)
[docs]
def forward(self,x):
x=x.permute(0,4,1,2,3).contiguous()
x=self.up(x)
return x