Skip to content

Commit 8383f8b

Browse files
authored
A small update to to ensure ds_loss returns a const (Project-MONAI#5350)
A small fix to followup on Project-MONAI#5338 to ensure ds_loss returns a constant, not an array ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <amyronenko@nvidia.com>
1 parent 2a46e7d commit 8383f8b

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

monai/losses/ds_loss.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,23 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: Optional[List
4242
self.weights = weights
4343
self.interp_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest"
4444

45-
def get_weight(self, level: int = 0) -> float:
45+
def get_weights(self, levels: int = 1) -> List[float]:
4646
"""
47-
Calculates a weight constant for a given image scale level
47+
Calculates weights for a given number of scale levels
4848
"""
49-
weight = 1.0
50-
if self.weights is not None and len(self.weights) > level:
51-
weight = self.weights[level]
49+
levels = max(1, levels)
50+
if self.weights is not None and len(self.weights) >= levels:
51+
weights = self.weights[:levels]
5252
elif self.weight_mode == "same":
53-
weight = 1.0
53+
weights = [1.0] * levels
5454
elif self.weight_mode == "exp":
55-
weight = max(0.5**level, 0.0625)
55+
weights = [max(0.5**l, 0.0625) for l in range(levels)]
5656
elif self.weight_mode == "two":
57-
weight = 1.0 if level == 0 else 0.5
57+
weights = [1.0 if l == 0 else 0.5 for l in range(levels)]
58+
else:
59+
weights = [1.0] * levels
5860

59-
return weight
61+
return weights
6062

6163
def get_loss(self, input: torch.Tensor, target: torch.Tensor):
6264
"""
@@ -71,10 +73,12 @@ def get_loss(self, input: torch.Tensor, target: torch.Tensor):
7173
def forward(self, input: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor):
7274

7375
if isinstance(input, (list, tuple)):
74-
loss = torch.zeros(1, dtype=torch.float, device=target.device)
76+
weights = self.get_weights(levels=len(input))
77+
loss = torch.tensor(0, dtype=torch.float, device=target.device)
7578
for l in range(len(input)):
76-
loss += self.get_loss(input[l].float(), target) * self.get_weight(l)
79+
loss += weights[l] * self.get_loss(input[l].float(), target)
7780
return loss
81+
7882
return self.loss(input.float(), target)
7983

8084

0 commit comments

Comments
 (0)