Source code for model.utils

import numpy as np
import torch
import torch.nn as nn
import pdb

[docs] def get_model(args, pretrain=False): """Function to get the model based on the arguments. Actually the models available, only 3 dimensions, are: - UNETR - UNet - VNet - AttentionUNet - ResUNet - MedFormer - SegFormer - UNetPlusPlus - SwinUNETR - SAM3D - nnFormer: don't work properly TODO - VTUNet: don't work properly TODO - FCN_Net: don't work properly TODO TODO: Implement the other models, for example: - DeepLabV3 - PSPNet Args: args (argparse.Namespace): Arguments from the command line. pretrain (bool, optional): Set to true if you use a pretrained model. Defaults to False. Raises: ValueError: No pretrain model available ValueError: Invalid dimension, should be '2d' or '3d' Returns: Model: The model object. """ if args.dimension == '3d': if args.model == 'unetr': from .dim3 import UNETR return UNETR( args.in_chan, args.classes, args.training_size, feature_size=args.base_chan, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed='perceptron', norm_name='instance', res_block=True ) elif args.model == "segformer": from .dim3 import SegFormer3D return SegFormer3D( in_channels=args.in_chan, sr_ratios=args.sr_ratios, embed_dims=args.embed_dims, patch_kernel_size=args.patch_kernel_size, patch_stride=args.patch_stride, patch_padding=args.patch_padding, mlp_ratios=args.mlp_ratios, num_heads=args.num_heads, depths=args.depths, decoder_head_embedding_dim=args.decoder_head_embedding_dim, num_classes=args.classes, decoder_dropout=args.decoder_dropout, ) elif args.model == 'sam': from .dim3 import sam_model_registry3D return sam_model_registry3D[args.vit_name](checkpoint=None) raise print("\n[TODO] DONT WORK PROPERLY YET! TODO FIX IT!") return Sam3D( num_classes=args.classes, ckpt=None, image_size=args.crop_size, vit_name=args.vit_name, num_modalities=args.in_chan, do_ds=args.do_ds ) elif args.model == 'dints': from .dim3 import DiNTS, TopologySearch, TopologyInstance # dints_space = TopologySearch( # channel_mul=0.5, # num_blocks=12, # num_depths=4, # use_downsample=True, # device=args.device, # ) dints_space = TopologyInstance( channel_mul=args.channel_mul, num_blocks=args.num_blocks, num_depths=args.num_depths, use_downsample=args.use_downsample, device=args.aug_device, ) return DiNTS( dints_space=dints_space, in_channels=args.in_chan, num_classes=args.classes, act_name=args.act_name, norm_name=("INSTANCE", {"affine": True}), spatial_dims=args.spatial_dims, use_downsample=args.use_downsample, node_a=None, ) elif args.model == 'resunet': from .dim3 import UNet if pretrain: raise ValueError('No pretrain model available') return UNet( args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block ) elif args.model == 'vnet': from .dim3 import VNet if pretrain: raise ValueError('No pretrain model available') return VNet( args.in_chan, args.classes, scale=args.downsample_scale, baseChans=args.base_chan ) elif args.model == 'attention_unet': from .dim3 import AttentionUNet return AttentionUNet( args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block ) elif args.model == 'unet': from .dim3 import UNet return UNet( args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block ) elif args.model == 'medformer': from .dim3 import MedFormer return MedFormer( args.in_chan, args.classes, args.base_chan, map_size=args.map_size, conv_block=args.conv_block, conv_num=args.conv_num, trans_num=args.trans_num, num_heads=args.num_heads, fusion_depth=args.fusion_depth, fusion_dim=args.fusion_dim, fusion_heads=args.fusion_heads, expansion=args.expansion, attn_drop=args.attn_drop, proj_drop=args.proj_drop, proj_type=args.proj_type, norm=args.norm, act=args.act, kernel_size=args.kernel_size, scale=args.down_scale, aux_loss=args.aux_loss ) elif args.model == 'unet++': from .dim3 import UNetPlusPlus return UNetPlusPlus( args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block ) elif args.model == 'swin_unetr': from .dim3 import SwinUNETR return SwinUNETR( args.window_size, args.in_chan, args.classes, feature_size=args.base_chan, num_heads=args.num_heads ) elif args.model == 'nnformer': from .dim3 import nnFormer return nnFormer( args.window_size, input_channels=args.in_chan, num_classes=args.classes, deep_supervision=args.aux_loss ) elif args.model == 'vtunet': from .dim3 import VTUNet return VTUNet( args, args.classes ) elif args.model == 'fcn_net': from .dim3 import FCN_Net return FCN_Net( args.in_chan, args.classes, ) else: raise ValueError('Invalid dimension, should be \'2d\' or \'3d\'')