# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import itertools
from collections.abc import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
from typing_extensions import Final
from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
from monai.utils.deprecate_utils import deprecated_arg
rearrange, _ = optional_import("einops", name="rearrange")
__all__ = [
"SwinUNETR",
"window_partition",
"window_reverse",
"WindowAttention",
"SwinTransformerBlock",
"PatchMerging",
"PatchMergingV2",
"MERGING_MODE",
"BasicLayer",
"SwinTransformer",
]
[docs]
class SwinUNETR(nn.Module):
"""
Swin UNETR based on: "Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
<https://arxiv.org/abs/2201.01266>"
"""
patch_size: Final[int] = 2
@deprecated_arg(
name="img_size",
since="1.3",
removed="1.5",
msg_suffix="The img_size argument is not required anymore and "
"checks on the input size are run during forward().",
)
def __init__(
self,
img_size: Sequence[int] | int,
in_channels: int,
out_channels: int,
depths: Sequence[int] = (2, 2, 2, 2),
num_heads: Sequence[int] = (3, 6, 12, 24),
feature_size: int = 24,
norm_name: tuple | str = "instance",
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
dropout_path_rate: float = 0.0,
normalize: bool = True,
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
img_size: spatial dimension of input image.
This argument is only used for checking that the input image size is divisible by the patch size.
The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
feature_size: dimension of network feature size.
depths: number of layers in each stage.
num_heads: number of attention heads.
norm_name: feature normalization type and arguments.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
dropout_path_rate: drop path rate.
normalize: normalize output intermediate features in each stage.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: number of spatial dims.
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
Examples::
# for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48.
>>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48)
# for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage.
>>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
>>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
"""
super().__init__()
img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(7, spatial_dims)
if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")
self._check_input_size(img_size)
if not (0 <= drop_rate <= 1):
raise ValueError("dropout rate should be between 0 and 1.")
if not (0 <= attn_drop_rate <= 1):
raise ValueError("attention dropout rate should be between 0 and 1.")
if not (0 <= dropout_path_rate <= 1):
raise ValueError("drop path rate should be between 0 and 1.")
if feature_size % 12 != 0:
raise ValueError("feature_size should be divisible by 12.")
self.normalize = normalize
self.swinViT = SwinTransformer(
in_chans=in_channels,
embed_dim=feature_size,
window_size=window_size,
patch_size=patch_sizes,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dropout_path_rate,
norm_layer=nn.LayerNorm,
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
use_v2=use_v2,
)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder2 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=feature_size,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder3 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=2 * feature_size,
out_channels=2 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder4 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=4 * feature_size,
out_channels=4 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.encoder10 = UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=16 * feature_size,
out_channels=16 * feature_size,
kernel_size=3,
stride=1,
norm_name=norm_name,
res_block=True,
)
self.decoder5 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=16 * feature_size,
out_channels=8 * feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=True,
)
self.decoder4 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 8,
out_channels=feature_size * 4,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=True,
)
self.decoder3 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 4,
out_channels=feature_size * 2,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=True,
)
self.decoder2 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size * 2,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=True,
)
self.decoder1 = UnetrUpBlock(
spatial_dims=spatial_dims,
in_channels=feature_size,
out_channels=feature_size,
kernel_size=3,
upsample_kernel_size=2,
norm_name=norm_name,
res_block=True,
)
self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
[docs]
def load_from(self, weights):
with torch.no_grad():
self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
for bname, block in self.swinViT.layers1[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers1")
self.swinViT.layers1[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
)
self.swinViT.layers1[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
)
self.swinViT.layers1[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
)
for bname, block in self.swinViT.layers2[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers2")
self.swinViT.layers2[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
)
self.swinViT.layers2[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
)
self.swinViT.layers2[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
)
for bname, block in self.swinViT.layers3[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers3")
self.swinViT.layers3[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
)
self.swinViT.layers3[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
)
self.swinViT.layers3[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
)
for bname, block in self.swinViT.layers4[0].blocks.named_children():
block.load_from(weights, n_block=bname, layer="layers4")
self.swinViT.layers4[0].downsample.reduction.weight.copy_(
weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
)
self.swinViT.layers4[0].downsample.norm.weight.copy_(
weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
)
self.swinViT.layers4[0].downsample.norm.bias.copy_(
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
)
@torch.jit.unused
def _check_input_size(self, spatial_shape):
img_size = np.array(spatial_shape)
remainder = (img_size % np.power(self.patch_size, 5)) > 0
if remainder.any():
wrong_dims = (np.where(remainder)[0] + 2).tolist()
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by {self.patch_size}**5."
)
[docs]
def forward(self, x_in):
if not torch.jit.is_scripting():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
enc1 = self.encoder2(hidden_states_out[0])
enc2 = self.encoder3(hidden_states_out[1])
enc3 = self.encoder4(hidden_states_out[2])
dec4 = self.encoder10(hidden_states_out[4])
dec3 = self.decoder5(dec4, hidden_states_out[3])
dec2 = self.decoder4(dec3, enc3)
dec1 = self.decoder3(dec2, enc2)
dec0 = self.decoder2(dec1, enc1)
out = self.decoder1(dec0, enc0)
logits = self.out(out)
return logits
[docs]
def window_partition(x, window_size):
"""window partition operation based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
x: input tensor.
window_size: local window size.
"""
x_shape = x.size()
if len(x_shape) == 5:
b, d, h, w, c = x_shape
x = x.view(
b,
d // window_size[0],
window_size[0],
h // window_size[1],
window_size[1],
w // window_size[2],
window_size[2],
c,
)
windows = (
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
)
elif len(x_shape) == 4:
b, h, w, c = x.shape
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
return windows
[docs]
def window_reverse(windows, window_size, dims):
"""window reverse operation based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
windows: windows tensor.
window_size: local window size.
dims: dimension values.
"""
if len(dims) == 4:
b, d, h, w = dims
x = windows.view(
b,
d // window_size[0],
h // window_size[1],
w // window_size[2],
window_size[0],
window_size[1],
window_size[2],
-1,
)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1)
elif len(dims) == 3:
b, h, w = dims
x = windows.view(b, h // window_size[0], w // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
return x
def get_window_size(x_size, window_size, shift_size=None):
"""Computing window size based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
x_size: input size.
window_size: local window size.
shift_size: window shifting size.
"""
use_window_size = list(window_size)
if shift_size is not None:
use_shift_size = list(shift_size)
for i in range(len(x_size)):
if x_size[i] <= window_size[i]:
use_window_size[i] = x_size[i]
if shift_size is not None:
use_shift_size[i] = 0
if shift_size is None:
return tuple(use_window_size)
else:
return tuple(use_window_size), tuple(use_shift_size)
[docs]
class WindowAttention(nn.Module):
"""
Window based multi-head self attention module with relative position bias based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(
self,
dim: int,
num_heads: int,
window_size: Sequence[int],
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
"""
Args:
dim: number of feature channels.
num_heads: number of attention heads.
window_size: local window size.
qkv_bias: add a learnable bias to query, key, value.
attn_drop: attention dropout rate.
proj_drop: dropout rate of output.
"""
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
mesh_args = torch.meshgrid.__kwdefaults__
if len(self.window_size) == 3:
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
num_heads,
)
)
coords_d = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
if mesh_args is not None:
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij"))
else:
coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
elif len(self.window_size) == 2:
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
if mesh_args is not None:
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"))
else:
coords = torch.stack(torch.meshgrid(coords_h, coords_w))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, x, mask):
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.clone()[:n, :n].reshape(-1)
].reshape(n, n, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn).to(v.dtype)
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class PatchMergingV2(nn.Module):
"""
Patch merging layer based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(self, dim: int, norm_layer: type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3) -> None:
"""
Args:
dim: number of feature channels.
norm_layer: normalization layer.
spatial_dims: number of spatial dims.
"""
super().__init__()
self.dim = dim
if spatial_dims == 3:
self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)
self.norm = norm_layer(8 * dim)
elif spatial_dims == 2:
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
[docs]
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
b, d, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
x = torch.cat(
[x[:, i::2, j::2, k::2, :] for i, j, k in itertools.product(range(2), range(2), range(2))], -1
)
elif len(x_shape) == 4:
b, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2))
x = torch.cat([x[:, j::2, i::2, :] for i, j in itertools.product(range(2), range(2))], -1)
x = self.norm(x)
x = self.reduction(x)
return x
[docs]
class PatchMerging(PatchMergingV2):
"""The `PatchMerging` module previously defined in v0.9.0."""
[docs]
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 4:
return super().forward(x)
if len(x_shape) != 5:
raise ValueError(f"expecting 5D x, got {x.shape}.")
b, d, h, w, c = x_shape
pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2, 0, d % 2))
x0 = x[:, 0::2, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, 0::2, :]
x2 = x[:, 0::2, 1::2, 0::2, :]
x3 = x[:, 0::2, 0::2, 1::2, :]
x4 = x[:, 1::2, 0::2, 1::2, :]
x5 = x[:, 0::2, 1::2, 0::2, :]
x6 = x[:, 0::2, 0::2, 1::2, :]
x7 = x[:, 1::2, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1)
x = self.norm(x)
x = self.reduction(x)
return x
MERGING_MODE = {"merging": PatchMerging, "mergingv2": PatchMergingV2}
def compute_mask(dims, window_size, shift_size, device):
"""Computing region masks based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
Args:
dims: dimension values.
window_size: local window size.
shift_size: shift size.
device: device.
"""
cnt = 0
if len(dims) == 3:
d, h, w = dims
img_mask = torch.zeros((1, d, h, w, 1), device=device)
for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None):
img_mask[:, d, h, w, :] = cnt
cnt += 1
elif len(dims) == 2:
h, w = dims
img_mask = torch.zeros((1, h, w, 1), device=device)
for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None):
for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size)
mask_windows = mask_windows.squeeze(-1)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
[docs]
class BasicLayer(nn.Module):
"""
Basic Swin Transformer layer in one stage based on: "Liu et al.,
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>"
https://github.com/microsoft/Swin-Transformer
"""
def __init__(
self,
dim: int,
depth: int,
num_heads: int,
window_size: Sequence[int],
drop_path: list,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
drop: float = 0.0,
attn_drop: float = 0.0,
norm_layer: type[LayerNorm] = nn.LayerNorm,
downsample: nn.Module | None = None,
use_checkpoint: bool = False,
) -> None:
"""
Args:
dim: number of feature channels.
depth: number of layers in each stage.
num_heads: number of attention heads.
window_size: local window size.
drop_path: stochastic depth rate.
mlp_ratio: ratio of mlp hidden dim to embedding dim.
qkv_bias: add a learnable bias to query, key, value.
drop: dropout rate.
attn_drop: attention dropout rate.
norm_layer: normalization layer.
downsample: an optional downsampling layer at the end of the layer.
use_checkpoint: use gradient checkpointing for reduced memory usage.
"""
super().__init__()
self.window_size = window_size
self.shift_size = tuple(i // 2 for i in window_size)
self.no_shift = tuple(0 for i in window_size)
self.depth = depth
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList(
[
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=self.window_size,
shift_size=self.no_shift if (i % 2 == 0) else self.shift_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
use_checkpoint=use_checkpoint,
)
for i in range(depth)
]
)
self.downsample = downsample
if callable(self.downsample):
self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size))
[docs]
def forward(self, x):
x_shape = x.size()
if len(x_shape) == 5:
b, c, d, h, w = x_shape
window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size)
x = rearrange(x, "b c d h w -> b d h w c")
dp = int(np.ceil(d / window_size[0])) * window_size[0]
hp = int(np.ceil(h / window_size[1])) * window_size[1]
wp = int(np.ceil(w / window_size[2])) * window_size[2]
attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask)
x = x.view(b, d, h, w, -1)
if self.downsample is not None:
x = self.downsample(x)
x = rearrange(x, "b d h w c -> b c d h w")
elif len(x_shape) == 4:
b, c, h, w = x_shape
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
x = rearrange(x, "b c h w -> b h w c")
hp = int(np.ceil(h / window_size[0])) * window_size[0]
wp = int(np.ceil(w / window_size[1])) * window_size[1]
attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device)
for blk in self.blocks:
x = blk(x, attn_mask)
x = x.view(b, h, w, -1)
if self.downsample is not None:
x = self.downsample(x)
x = rearrange(x, "b h w c -> b c h w")
return x
def filter_swinunetr(key, value):
"""
A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
This function is typically used with `monai.networks.copy_model_state`
[1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
<https://arxiv.org/abs/2307.16896>"
Args:
key: the key in the source state dict used for the update.
value: the value in the source state dict used for the update.
Examples::
import torch
from monai.apps import download_url
from monai.networks.utils import copy_model_state
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
resource = (
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
)
ssl_weights_path = "./ssl_pretrained_weights.pth"
download_url(resource, ssl_weights_path)
ssl_weights = torch.load(ssl_weights_path)["model"]
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
"""
if key in [
"encoder.mask_token",
"encoder.norm.weight",
"encoder.norm.bias",
"out.conv.conv.weight",
"out.conv.conv.bias",
]:
return None
if key[:8] == "encoder.":
if key[8:19] == "patch_embed":
new_key = "swinViT." + key[8:]
else:
new_key = "swinViT." + key[8:18] + key[20:]
return new_key, value
else:
return None