Skip to content

Commit da93c1b

Browse files
authored
Dicece check for float target (Project-MONAI#5326)
Fixes Project-MONAI/tutorials#987 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: myron <amyronenko@nvidia.com>
1 parent e081510 commit da93c1b

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

.github/workflows/pythonapp-gpu.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ concurrency:
1616

1717
jobs:
1818
GPU-quick-py3: # GPU with full dependencies
19-
if: github.repository == 'Project-MONAI/MONAI'
19+
if: ${{ github.repository == 'Project-MONAI/MONAI' && github.event.pull_request.merged != true }}
2020
strategy:
2121
matrix:
2222
environment:
@@ -124,6 +124,7 @@ jobs:
124124
python -m pip install -r requirements-dev.txt
125125
python -m pip list
126126
- name: Run quick tests (GPU)
127+
if: github.event.pull_request.merged != true
127128
run: |
128129
git clone --depth 1 \
129130
https://github.com/Project-MONAI/MONAI-extra-test-data.git /MONAI-extra-test-data

monai/losses/dice.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,9 @@ def ce(self, input: torch.Tensor, target: torch.Tensor):
711711
"Using argmax (as a workaround) to convert target to a single channel."
712712
)
713713
target = torch.argmax(target, dim=1)
714-
else: # target has the same shape as input, class probabilities in [0, 1], as floats
715-
target = target.to(input) # check its values are in [0, 1]??
714+
elif not torch.is_floating_point(target):
715+
target = target.to(dtype=input.dtype)
716+
716717
return self.cross_entropy(input, target)
717718

718719
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)