Source code for model.dim3.utils

import torch
import torch.nn as nn
import torch.nn.functional as F
from .conv_layers import BasicBlock, Bottleneck, SingleConv
from .trans_layers import LayerNorm

[docs] def get_block(name): block_map = { 'SingleConv': SingleConv, 'BasicBlock': BasicBlock, 'Bottleneck': Bottleneck, } return block_map[name]
[docs] def get_norm(name): norm_map = {'bn': nn.BatchNorm3d, 'in': nn.InstanceNorm3d, 'ln': LayerNorm } return norm_map[name]
[docs] def get_act(name): act_map = { 'relu': nn.ReLU, 'lrelu': nn.LeakyReLU, 'gelu': nn.GELU, 'swish': nn.SiLU } return act_map[name]