Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ Nets
.. autoclass:: ResNet
:members:

`ResNetFeatures`
~~~~~~~~~~~~~~~~
.. autoclass:: ResNetFeatures
:members:

`SENet`
~~~~~~~
.. autoclass:: SENet
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
ResNet,
ResNetBlock,
ResNetBottleneck,
ResNetEncoder,
ResNetFeatures,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
Expand Down
8 changes: 5 additions & 3 deletions monai/networks/nets/flexible_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from monai.networks.layers.utils import get_act_layer
from monai.networks.nets import EfficientNetEncoder
from monai.networks.nets.basic_unet import UpCat
from monai.networks.nets.resnet import ResNetEncoder
from monai.utils import InterpolateMode, optional_import

__all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"]
Expand Down Expand Up @@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str):

FLEXUNET_BACKBONE = FlexUNetEncoderRegister()
FLEXUNET_BACKBONE.register_class(EfficientNetEncoder)
FLEXUNET_BACKBONE.register_class(ResNetEncoder)


class UNetDecoder(nn.Module):
Expand Down Expand Up @@ -238,7 +240,7 @@ def __init__(
) -> None:
"""
A flexible implement of UNet, in which the backbone/encoder can be replaced with
any efficient network. Currently the input must have a 2 or 3 spatial dimension
any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension
and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
is False.
Please notice each output of backbone must be 2x downsample in spatial dimension
Expand All @@ -248,8 +250,8 @@ def __init__(
Args:
in_channels: number of input channels.
out_channels: number of output channels.
backbone: name of backbones to initialize, only support efficientnet right now,
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
backbone: name of backbones to initialize, only support efficientnet and resnet right now,
can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].
pretrained: whether to initialize pretrained ImageNet weights, only available
for spatial_dims=2 and batch norm is used, default to False.
decoder_channels: number of output channels for all feature maps in decoder.
Expand Down
141 changes: 141 additions & 0 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.nn as nn

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils import ensure_tuple_rep
Expand All @@ -45,6 +46,19 @@
"resnet200",
]


resnet_params = {
# model_name: (block, layers, shortcut_type, bias_downsample, datasets23)
"resnet10": ("basic", [1, 1, 1, 1], "B", False, True),
"resnet18": ("basic", [2, 2, 2, 2], "A", True, True),
"resnet34": ("basic", [3, 4, 6, 3], "A", True, True),
"resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True),
"resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False),
"resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False),
"resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False),
}


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class ResNetFeatures(ResNet):

def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None:
"""Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
segmentation and objection models.

Compared with the class `ResNet`, the only different place is the forward function.

Args:
model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
pretrained: whether to initialize pretrained Med3D weights,
only available for spatial_dims=3 and in_channels=1.
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of input channels for first convolutional layer.
"""
if model_name not in resnet_params:
model_name_string = ", ".join(resnet_params.keys())
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")

block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name]

super().__init__(
block=block,
layers=layers,
block_inplanes=get_inplanes(),
spatial_dims=spatial_dims,
n_input_channels=in_channels,
conv1_t_stride=2,
shortcut_type=shortcut_type,
feed_forward=False,
bias_downsample=bias_downsample,
)
if pretrained:
if spatial_dims == 3 and in_channels == 1:
_load_state_dict(self, model_name, datasets23=datasets23)
else:
raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.")

def forward(self, inputs: torch.Tensor):
"""
Args:
inputs: input should have spatially N dimensions
``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.

Returns:
a list of torch Tensors.
"""
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)

features = []
features.append(x)

if not self.no_max_pool:
x = self.maxpool(x)

x = self.layer1(x)
features.append(x)

x = self.layer2(x)
features.append(x)

x = self.layer3(x)
features.append(x)

x = self.layer4(x)
features.append(x)

return features


class ResNetEncoder(ResNetFeatures, BaseEncoder):
"""Wrap the original resnet to an encoder for flexible-unet."""

backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]

@classmethod
def get_encoder_parameters(cls) -> list[dict]:
"""Get the initialization parameter for resnet backbones."""
parameter_list = []
for backbone_name in cls.backbone_names:
parameter_list.append(
{"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1}
)
return parameter_list

@classmethod
def num_channels_per_output(cls) -> list[tuple[int, ...]]:
"""Get number of resnet backbone output feature maps channel."""
return [
(64, 64, 128, 256, 512),
(64, 64, 128, 256, 512),
(64, 64, 128, 256, 512),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
(64, 256, 512, 1024, 2048),
]

@classmethod
def num_outputs(cls) -> list[int]:
"""Get number of resnet backbone output feature maps.

Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
"""
return [5] * 7

@classmethod
def get_encoder_names(cls) -> list[str]:
"""Get names of resnet backbones."""
return cls.backbone_names


def _resnet(
arch: str,
block: type[ResNetBlock | ResNetBottleneck],
Expand Down Expand Up @@ -541,3 +669,16 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type


def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None:
search_res = re.search(r"resnet(\d+)", model_name)
if search_res:
resnet_depth = int(search_res.group(1))
datasets23 = model_name.endswith("_23datasets")
else:
raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.")

model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23)
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict)
Loading