Skip to content

Commit bee751a

Browse files
authored
[0-size Tensor Job2 No.55] Add 0-size Tensor support for paddle.nn.functional.cross_entropy (#74131)
* Fix * Fix * Fix * Fix
1 parent 866abcf commit bee751a

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3071,6 +3071,13 @@ def cross_entropy(
30713071
# so, reduce_sum all directly is ok
30723072
return _C_ops.sum(out, [], None, False)
30733073
elif reduction == "mean":
3074+
# when reduction is mean, use paddle.nan
3075+
def _replace_nan(out):
3076+
return out + paddle.nan
3077+
3078+
if 0 in input.shape:
3079+
out = _replace_nan(out)
3080+
return _C_ops.mean_all(out)
30743081
# 1. if weight==none,
30753082
# numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
30763083
# denominator: count sample num with class_index!=ignore_index

test/legacy_test/test_cross_entropy_op.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
import unittest
1616

1717
import numpy as np
18-
from op_test import OpTest, paddle_static_guard, randomize_probability
18+
from op_test import (
19+
OpTest,
20+
get_places,
21+
paddle_static_guard,
22+
randomize_probability,
23+
)
1924

2025
import paddle
2126
from paddle import base
@@ -497,5 +502,25 @@ def get_cross_entropy(self):
497502
self.cross_entropy = np.random.random([0, 1]).astype(np.float64)
498503

499504

505+
class TestCrossEntropyOp_ZeroSize2(unittest.TestCase):
506+
def test_dygraph_api(self):
507+
for place in get_places():
508+
paddle.disable_static(place)
509+
x_np = np.random.random((16, 0)).astype(np.float64)
510+
label_np = np.random.random((16, 0)).astype(np.float64)
511+
x = paddle.to_tensor(x_np)
512+
x.stop_gradient = False
513+
label = paddle.to_tensor(label_np)
514+
label.stop_gradient = False
515+
out1 = paddle.nn.functional.cross_entropy(
516+
x, label, soft_label=True, reduction="mean"
517+
)
518+
out2 = np.array(np.nan).astype(np.float64)
519+
np.testing.assert_allclose(out1.numpy(), out2)
520+
paddle.sum(out1).backward()
521+
np.testing.assert_allclose(x.grad.shape, x.shape)
522+
paddle.enable_static()
523+
524+
500525
if __name__ == "__main__":
501526
unittest.main()

0 commit comments

Comments
 (0)