Skip to content

Commit 982131c

Browse files
authored
Deep supervision loss wrapper class (Project-MONAI#5338)
This adds a DeepSupervisionLoss wrapper class around the main loss function to accept a list of tensors returned from a deeply supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels (accounting for potential differences in shapes between targets and ds outputs) The wrapper class is designed to work with arbitrary existing loss,e.g. ``` loss = DiceCELoss(to_onehot_y=True, softmax=True) ds_loss = DeepSupervisionLoss(loss) ``` Whereas the existing loss accepts the input as a single Tensor, ds_loss accepts the input as a list of Tensors (for each output of a deeply supervised network). If only a simple Tensor input is provided, ds_loss behaves exactly the same as the underlying loss. I added unit tests too. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <amyronenko@nvidia.com>
1 parent e37de69 commit 982131c

File tree

3 files changed

+270
-0
lines changed

3 files changed

+270
-0
lines changed

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
generalized_dice_focal,
2727
generalized_wasserstein_dice,
2828
)
29+
from .ds_loss import DeepSupervisionLoss
2930
from .focal_loss import FocalLoss
3031
from .giou_loss import BoxGIoULoss, giou
3132
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss

monai/losses/ds_loss.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import List, Optional, Union
13+
14+
import torch
15+
import torch.nn.functional as F
16+
from torch.nn.modules.loss import _Loss
17+
18+
from monai.utils import pytorch_after
19+
20+
21+
class DeepSupervisionLoss(_Loss):
22+
"""
23+
Wrapper class around the main loss function to accept a list of tensors returned from a deeply
24+
supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels.
25+
"""
26+
27+
def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: Optional[List[float]] = None) -> None:
28+
"""
29+
Args:
30+
loss: main loss instance, e.g DiceLoss().
31+
weight_mode: {``"same"``, ``"exp"``, ``"two"``}
32+
Specifies the weights calculation for each image level. Defaults to ``"exp"``.
33+
- ``"same"``: all weights are equal to 1.
34+
- ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc .
35+
- ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
36+
weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
37+
regardless of the weight_mode
38+
"""
39+
super().__init__()
40+
self.loss = loss
41+
self.weight_mode = weight_mode
42+
self.weights = weights
43+
self.interp_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest"
44+
45+
def get_weight(self, level: int = 0) -> float:
46+
"""
47+
Calculates a weight constant for a given image scale level
48+
"""
49+
weight = 1.0
50+
if self.weights is not None and len(self.weights) > level:
51+
weight = self.weights[level]
52+
elif self.weight_mode == "same":
53+
weight = 1.0
54+
elif self.weight_mode == "exp":
55+
weight = max(0.5**level, 0.0625)
56+
elif self.weight_mode == "two":
57+
weight = 1.0 if level == 0 else 0.5
58+
59+
return weight
60+
61+
def get_loss(self, input: torch.Tensor, target: torch.Tensor):
62+
"""
63+
Calculates a loss output accounting for differences in shapes,
64+
and downsizing targets if necessary (using nearest neighbor interpolation)
65+
Generally downsizing occurs for all level, except for the first (level==0)
66+
"""
67+
if input.shape[2:] != target.shape[2:]:
68+
target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)
69+
return self.loss(input, target)
70+
71+
def forward(self, input: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor):
72+
73+
if isinstance(input, (list, tuple)):
74+
loss = torch.zeros(1, dtype=torch.float, device=target.device)
75+
for l in range(len(input)):
76+
loss += self.get_loss(input[l].float(), target) * self.get_weight(l)
77+
return loss
78+
return self.loss(input.float(), target)
79+
80+
81+
ds_loss = DeepSupervisionLoss

tests/test_ds_loss.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
from parameterized import parameterized
17+
18+
from monai.losses import DeepSupervisionLoss, DiceCELoss, DiceFocalLoss, DiceLoss
19+
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
20+
21+
TEST_CASES_DICECE = [
22+
[
23+
{"to_onehot_y": True},
24+
{},
25+
{
26+
"input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
27+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
28+
},
29+
0.606557,
30+
]
31+
]
32+
33+
TEST_CASES_DICECE2 = [
34+
[
35+
{"to_onehot_y": True},
36+
{},
37+
{
38+
"input": [
39+
torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
40+
torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
41+
torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
42+
],
43+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
44+
},
45+
1.78144,
46+
],
47+
[
48+
{"to_onehot_y": True},
49+
{"weight_mode": "same"},
50+
{
51+
"input": [
52+
torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
53+
torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
54+
torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
55+
],
56+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
57+
},
58+
3.5529,
59+
],
60+
[
61+
{"to_onehot_y": True},
62+
{"weight_mode": "two"},
63+
{
64+
"input": [
65+
torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
66+
torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
67+
torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
68+
],
69+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
70+
},
71+
2.07973,
72+
],
73+
[
74+
{"to_onehot_y": True},
75+
{"weights": [0.1, 0.2, 0.3]},
76+
{
77+
"input": [
78+
torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
79+
torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
80+
torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
81+
],
82+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
83+
},
84+
0.76924,
85+
],
86+
]
87+
88+
89+
TEST_CASES_DICE = [
90+
[
91+
{"to_onehot_y": True},
92+
{
93+
"input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
94+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
95+
},
96+
0.166666, # the result equals to -1 + np.log(1 + np.exp(1))
97+
],
98+
[
99+
{"to_onehot_y": True},
100+
{
101+
"input": [
102+
torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
103+
torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
104+
torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
105+
],
106+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
107+
},
108+
0.666665,
109+
],
110+
]
111+
112+
TEST_CASES_DICEFOCAL = [
113+
[
114+
{"to_onehot_y": True},
115+
{
116+
"input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
117+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
118+
},
119+
0.32124, # the result equals to -1 + np.log(1 + np.exp(1))
120+
],
121+
[
122+
{"to_onehot_y": True},
123+
{
124+
"input": [
125+
torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
126+
torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
127+
torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
128+
],
129+
"target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
130+
},
131+
1.06452,
132+
],
133+
]
134+
135+
136+
class TestDSLossDiceCE(unittest.TestCase):
137+
@parameterized.expand(TEST_CASES_DICECE)
138+
def test_result(self, input_param, input_param2, input_data, expected_val):
139+
diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)
140+
result = diceceloss(**input_data)
141+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
142+
143+
def test_ill_shape(self):
144+
loss = DeepSupervisionLoss(DiceCELoss())
145+
with self.assertRaisesRegex(ValueError, ""):
146+
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
147+
148+
def test_ill_reduction(self):
149+
with self.assertRaisesRegex(ValueError, ""):
150+
loss = DeepSupervisionLoss(DiceCELoss(reduction="none"))
151+
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
152+
153+
@SkipIfBeforePyTorchVersion((1, 10))
154+
def test_script(self):
155+
loss = DeepSupervisionLoss(DiceCELoss())
156+
test_input = torch.ones(2, 1, 8, 8)
157+
test_script_save(loss, test_input, test_input)
158+
159+
160+
@SkipIfBeforePyTorchVersion((1, 11))
161+
class TestDSLossDiceCE2(unittest.TestCase):
162+
@parameterized.expand(TEST_CASES_DICECE2)
163+
def test_result(self, input_param, input_param2, input_data, expected_val):
164+
diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)
165+
result = diceceloss(**input_data)
166+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
167+
168+
169+
@SkipIfBeforePyTorchVersion((1, 11))
170+
class TestDSLossDice(unittest.TestCase):
171+
@parameterized.expand(TEST_CASES_DICE)
172+
def test_result(self, input_param, input_data, expected_val):
173+
loss = DeepSupervisionLoss(DiceLoss(**input_param))
174+
result = loss(**input_data)
175+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
176+
177+
178+
@SkipIfBeforePyTorchVersion((1, 11))
179+
class TestDSLossDiceFocal(unittest.TestCase):
180+
@parameterized.expand(TEST_CASES_DICEFOCAL)
181+
def test_result(self, input_param, input_data, expected_val):
182+
loss = DeepSupervisionLoss(DiceFocalLoss(**input_param))
183+
result = loss(**input_data)
184+
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
185+
186+
187+
if __name__ == "__main__":
188+
unittest.main()

0 commit comments

Comments
 (0)