Skip to content

Commit 217f6c6

Browse files
authored
Make the flexible unet more extensible to different backbones (Project-MONAI#5038)
Signed-off-by: binliu <binliu@nvidia.com> Fixes Project-MONAI#5011 . ### Description Currently, the FlexibleUNet structure only supports the efficient-net series as the backbone and it is hard to extend to other network structures. A more extensible and convenient way to add more benchmark or user-defined backbones will make it more flexible. I plan to make it by doing steps below 1. A base backbone/encoder class that defines interfaces like the number of output feature maps, a list of output feature map channels, string names of backbones and so on. 2. A register that can dynamically register backbones to a dict that will be used during the flexible unet initialization. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: binliu <binliu@nvidia.com>
1 parent cb41e69 commit 217f6c6

File tree

7 files changed

+465
-75
lines changed

7 files changed

+465
-75
lines changed

monai/networks/blocks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
2020
from .downsample import MaxAvgPool
2121
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
22+
from .encoder import BaseEncoder
2223
from .fcn import FCN, GCN, MCFCN, Refine
2324
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
2425
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock

monai/networks/blocks/encoder.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from abc import ABCMeta, abstractmethod
13+
from typing import Dict, List, Tuple
14+
15+
__all__ = ["BaseEncoder"]
16+
17+
18+
class BaseEncoder(metaclass=ABCMeta):
19+
"""
20+
Abstract class defines interface of encoders in flexible unet.
21+
Encoders in flexible unet must derive from this class. Each interface method
22+
should return a list containing relative information about a series of newtworks
23+
defined by encoder. For example, the efficient-net encoder implement 10 basic
24+
network structures in one encoder. When calling `get_encoder_name_string_list`
25+
function, a string list like ["efficientnet-b0", "efficientnet-b1" ... "efficientnet-l2"]
26+
should be returned.
27+
"""
28+
29+
@classmethod
30+
@abstractmethod
31+
def get_encoder_parameters(cls) -> List[Dict]:
32+
"""
33+
Get parameter list to initialize encoder networks.
34+
Each parameter dict must have `spatial_dims`, `in_channels`
35+
and `pretrained` parameters.
36+
The reason that this function should return a list is that a
37+
series of encoders can be implemented by one encoder class
38+
given different initialization parameters. Each parameter dict
39+
in return list should be able to initialize a unique encoder.
40+
"""
41+
raise NotImplementedError
42+
43+
@classmethod
44+
@abstractmethod
45+
def num_channels_per_output(cls) -> List[Tuple[int, ...]]:
46+
"""
47+
Get number of output features' channels.
48+
The reason that this function should return a list is that a
49+
series of encoders can be implemented by one encoder class
50+
given different initialization parameters. And it is possible
51+
that different encoders have different output feature map
52+
channels. Therefore a list of output feature map channel tuples
53+
corresponding to each encoder should be returned by this method.
54+
"""
55+
raise NotImplementedError
56+
57+
@classmethod
58+
@abstractmethod
59+
def num_outputs(cls) -> List[int]:
60+
"""
61+
Get number of outputs of encoder.
62+
The reason that this function should return a list is that a
63+
series of encoders can be implemented by one encoder class
64+
given different initialization parameters. And it is possible
65+
that different encoders have different output feature numbers.
66+
Therefore a list of output feature numbers corresponding to
67+
each encoder should be returned by this method.
68+
"""
69+
raise NotImplementedError
70+
71+
@classmethod
72+
@abstractmethod
73+
def get_encoder_names(cls) -> List[str]:
74+
"""
75+
Get the name string of encoders which will be used to initialize
76+
flexible unet.
77+
The reason that this function should return a list is that a
78+
series of encoders can be implemented by one encoder class
79+
given different initialization parameters. And a name string is
80+
the key to each encoder in flexible unet backbone registry.
81+
Therefore this method should return every encoder name that needs
82+
to be registed in flexible unet.
83+
"""
84+
raise NotImplementedError

monai/networks/nets/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@
3838
EfficientNet,
3939
EfficientNetBN,
4040
EfficientNetBNFeatures,
41+
EfficientNetEncoder,
4142
drop_connect,
4243
get_efficientnet_image_size,
4344
)
44-
from .flexible_unet import FlexibleUNet
45+
from .flexible_unet import FLEXUNET_BACKBONE, FlexibleUNet, FlexUNet, FlexUNetEncoderRegister
4546
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
4647
from .generator import Generator
4748
from .highresnet import HighResBlock, HighResNet
@@ -50,7 +51,18 @@
5051
from .netadapter import NetAdapter
5152
from .regressor import Regressor
5253
from .regunet import GlobalNet, LocalNet, RegUNet
53-
from .resnet import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
54+
from .resnet import (
55+
ResNet,
56+
ResNetBlock,
57+
ResNetBottleneck,
58+
resnet10,
59+
resnet18,
60+
resnet34,
61+
resnet50,
62+
resnet101,
63+
resnet152,
64+
resnet200,
65+
)
5466
from .segresnet import SegResNet, SegResNetVAE
5567
from .segresnet_ds import SegResNetDS
5668
from .senet import (

monai/networks/nets/efficientnet.py

Lines changed: 87 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
import operator
1414
import re
1515
from functools import reduce
16-
from typing import List, NamedTuple, Optional, Tuple, Type, Union
16+
from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
1717

1818
import torch
1919
from torch import nn
2020
from torch.utils import model_zoo
2121

22+
from monai.networks.blocks import BaseEncoder
2223
from monai.networks.layers.factories import Act, Conv, Pad, Pool
2324
from monai.networks.layers.utils import get_norm_layer
2425
from monai.utils.module import look_up_option
@@ -30,6 +31,7 @@
3031
"drop_connect",
3132
"EfficientNetBNFeatures",
3233
"BlockArgs",
34+
"EfficientNetEncoder",
3335
]
3436

3537
efficientnet_params = {
@@ -528,11 +530,8 @@ def __init__(
528530

529531
# check if model_name is valid model
530532
if model_name not in efficientnet_params.keys():
531-
raise ValueError(
532-
"invalid model_name {} found, must be one of {} ".format(
533-
model_name, ", ".join(efficientnet_params.keys())
534-
)
535-
)
533+
model_name_string = ", ".join(efficientnet_params.keys())
534+
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
536535

537536
# get network parameters
538537
weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]
@@ -588,11 +587,8 @@ def __init__(
588587

589588
# check if model_name is valid model
590589
if model_name not in efficientnet_params.keys():
591-
raise ValueError(
592-
"invalid model_name {} found, must be one of {} ".format(
593-
model_name, ", ".join(efficientnet_params.keys())
594-
)
595-
)
590+
model_name_string = ", ".join(efficientnet_params.keys())
591+
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
596592

597593
# get network parameters
598594
weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]
@@ -638,6 +634,80 @@ def forward(self, inputs: torch.Tensor):
638634
return features
639635

640636

637+
class EfficientNetEncoder(EfficientNetBNFeatures, BaseEncoder):
638+
"""
639+
Wrap the original efficientnet to an encoder for flexible-unet.
640+
"""
641+
642+
backbone_names = [
643+
"efficientnet-b0",
644+
"efficientnet-b1",
645+
"efficientnet-b2",
646+
"efficientnet-b3",
647+
"efficientnet-b4",
648+
"efficientnet-b5",
649+
"efficientnet-b6",
650+
"efficientnet-b7",
651+
"efficientnet-b8",
652+
"efficientnet-l2",
653+
]
654+
655+
@classmethod
656+
def get_encoder_parameters(cls) -> List[Dict]:
657+
"""
658+
Get the initialization parameter for efficientnet backbones.
659+
"""
660+
parameter_list = []
661+
for backbone_name in cls.backbone_names:
662+
parameter_list.append(
663+
{
664+
"model_name": backbone_name,
665+
"pretrained": True,
666+
"progress": True,
667+
"spatial_dims": 2,
668+
"in_channels": 3,
669+
"num_classes": 1000,
670+
"norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
671+
"adv_prop": "ap" in backbone_name,
672+
}
673+
)
674+
return parameter_list
675+
676+
@classmethod
677+
def num_channels_per_output(cls) -> List[Tuple[int, ...]]:
678+
"""
679+
Get number of efficientnet backbone output feature maps' channel.
680+
"""
681+
return [
682+
(16, 24, 40, 112, 320),
683+
(16, 24, 40, 112, 320),
684+
(16, 24, 48, 120, 352),
685+
(24, 32, 48, 136, 384),
686+
(24, 32, 56, 160, 448),
687+
(24, 40, 64, 176, 512),
688+
(32, 40, 72, 200, 576),
689+
(32, 48, 80, 224, 640),
690+
(32, 56, 88, 248, 704),
691+
(72, 104, 176, 480, 1376),
692+
]
693+
694+
@classmethod
695+
def num_outputs(cls) -> List[int]:
696+
"""
697+
Get number of efficientnet backbone output feature maps.
698+
Since every backbone contains the same 5 output feature maps,
699+
the number list should be `[5] * 10`.
700+
"""
701+
return [5] * 10
702+
703+
@classmethod
704+
def get_encoder_names(cls) -> List[str]:
705+
"""
706+
Get names of efficient backbone.
707+
"""
708+
return cls.backbone_names
709+
710+
641711
def get_efficientnet_image_size(model_name: str) -> int:
642712
"""
643713
Get the input image size for a given efficientnet model.
@@ -651,9 +721,8 @@ def get_efficientnet_image_size(model_name: str) -> int:
651721
"""
652722
# check if model_name is valid model
653723
if model_name not in efficientnet_params.keys():
654-
raise ValueError(
655-
"invalid model_name {} found, must be one of {} ".format(model_name, ", ".join(efficientnet_params.keys()))
656-
)
724+
model_name_string = ", ".join(efficientnet_params.keys())
725+
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
657726

658727
# return input image size (all dims equal so only need to return for one dim)
659728
_, _, res, _, _ = efficientnet_params[model_name]
@@ -927,15 +996,10 @@ def to_string(self):
927996
A string notation of BlockArgs object arguments.
928997
Example: "r1_k3_s11_e1_i32_o16_se0.25_noskip".
929998
"""
930-
string = "r{}_k{}_s{}{}_e{}_i{}_o{}_se{}".format(
931-
self.num_repeat,
932-
self.kernel_size,
933-
self.stride,
934-
self.stride,
935-
self.expand_ratio,
936-
self.input_filters,
937-
self.output_filters,
938-
self.se_ratio,
999+
string = (
1000+
f"r{self.num_repeat}_k{self.kernel_size}_s{self.stride}{self.stride}"
1001+
f"_e{self.expand_ratio}_i{self.input_filters}_o{self.output_filters}"
1002+
f"_se{self.se_ratio}"
9391003
)
9401004

9411005
if not self.id_skip:

0 commit comments

Comments
 (0)