Skip to content

Commit 725c8de

Browse files
Torchvision pretrain fix (#8563)
Fixes #8552. ### Description This updates how pretrained model weights are loaded through Torchvision. This may not preserve historical results if the weights being loaded are now different since the "DEFAULT" weights may not be the weights loaded when using the `pretrained=True` argument. I tried to preserved behaviour as indicated in the Torchvision source code where possible. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b5bc69d commit 725c8de

File tree

7 files changed

+23
-22
lines changed

7 files changed

+23
-22
lines changed

monai/networks/blocks/fcn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def __init__(
123123
self.upsample_mode = upsample_mode
124124
self.conv2d_type = conv2d_type
125125
self.out_channels = out_channels
126-
resnet = models.resnet50(pretrained=pretrained, progress=progress)
126+
resnet = models.resnet50(
127+
progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
128+
)
127129

128130
self.conv1 = resnet.conv1
129131
self.bn0 = resnet.bn1

monai/networks/nets/milmodel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
import torch.nn as nn
1818

19-
from monai.utils.module import optional_import
19+
from monai.utils import optional_import
2020

2121
models, _ = optional_import("torchvision.models")
2222

@@ -48,7 +48,6 @@ class MILModel(nn.Module):
4848
Defaults to ``None`` (necessary only when using a custom backbone)
4949
trans_blocks: number of the blocks in `TransformEncoder` layer.
5050
trans_dropout: dropout rate in `TransformEncoder` layer.
51-
5251
"""
5352

5453
def __init__(
@@ -74,7 +73,7 @@ def __init__(
7473
self.transformer: nn.Module | None = None
7574

7675
if backbone is None:
77-
net = models.resnet50(pretrained=pretrained)
76+
net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
7877
nfc = net.fc.in_features # save the number of final features
7978
net.fc = torch.nn.Identity() # remove final linear layer
8079

@@ -99,7 +98,7 @@ def hook(module, input, output):
9998
torch_model = getattr(models, backbone, None)
10099
if torch_model is None:
101100
raise ValueError("Unknown torch vision model" + str(backbone))
102-
net = torch_model(pretrained=pretrained)
101+
net = torch_model(weights="DEFAULT" if pretrained else None)
103102

104103
if getattr(net, "fc", None) is not None:
105104
nfc = net.fc.in_features # save the number of final features

monai/networks/nets/torchvision_fc.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,11 @@ def __init__(
112112
weights=None,
113113
**kwargs,
114114
):
115-
if weights is not None:
116-
model = getattr(models, model_name)(weights=weights, **kwargs)
117-
elif pretrained:
118-
model = getattr(models, model_name)(weights="DEFAULT", **kwargs)
119-
else:
120-
model = getattr(models, model_name)(weights=None, **kwargs)
115+
# if pretrained is False, weights is a weight tensor or None for no pretrained loading
116+
if pretrained and weights is None:
117+
weights = "DEFAULT"
118+
119+
model = getattr(models, model_name)(weights=weights, **kwargs)
121120

122121
super().__init__(
123122
model=model,

monai/transforms/io/array.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
from monai.data.utils import is_no_channel
4444
from monai.transforms.transform import Transform
4545
from monai.transforms.utility.array import EnsureChannelFirst
46-
from monai.utils import GridSamplePadMode
47-
from monai.utils import ImageMetaKey as Key
4846
from monai.utils import (
47+
GridSamplePadMode,
48+
ImageMetaKey,
4949
MetaKeys,
5050
OptionalImportError,
5151
convert_to_dst_type,
@@ -293,7 +293,8 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
293293
# make sure all elements in metadata are little endian
294294
meta_data = switch_endianness(meta_data, "<")
295295

296-
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
296+
# Path obj should be strings for data loader
297+
meta_data[ImageMetaKey.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}"
297298
img = MetaTensor.ensure_torch_and_prune_meta(
298299
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
299300
)
@@ -548,7 +549,7 @@ def __call__(self, img: NdarrayOrTensor):
548549
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
549550
)
550551

551-
input_path = meta_data[Key.FILENAME_OR_OBJ]
552+
input_path = meta_data[ImageMetaKey.FILENAME_OR_OBJ]
552553
output_path = meta_data[MetaKeys.SAVED_TO]
553554
log_data = {"input": input_path, "output": output_path}
554555

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pep8-naming
1818
pycodestyle
1919
pyflakes
2020
black>=25.1.0
21-
isort>=5.1, <6.0
21+
isort>=5.1, !=6.0.0
2222
ruff
2323
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
2424
types-setuptools

tests/networks/nets/test_densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape):
9696
net = model(**input_param).to(device)
9797
with eval_mode(net):
9898
result = net.features.forward(example)
99-
torchvision_net = torchvision.models.densenet121(pretrained=True).to(device)
99+
torchvision_net = torchvision.models.densenet121(weights="DEFAULT").to(device)
100100
with eval_mode(torchvision_net):
101101
expected_result = torchvision_net.features.forward(example)
102102
self.assertTrue(torch.all(result == expected_result))

tests/networks/nets/test_milmodel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@
4444
TEST_CASE_MILMODEL.append(test_case)
4545

4646
# torchvision backbone
47-
TEST_CASE_MILMODEL.append(
48-
[{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)]
49-
)
50-
TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)])
47+
for pretrained in [True, False]:
48+
TEST_CASE_MILMODEL.append(
49+
[{"num_classes": 5, "backbone": "resnet18", "pretrained": pretrained}, (2, 2, 3, 512, 512), (2, 5)]
50+
)
5151

5252
# custom backbone
53-
backbone = models.densenet121(pretrained=False)
53+
backbone = models.densenet121()
5454
backbone_nfeatures = backbone.classifier.in_features
5555
backbone.classifier = torch.nn.Identity()
5656
TEST_CASE_MILMODEL.append(

0 commit comments

Comments
 (0)