Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3071,6 +3071,13 @@ def cross_entropy(
# so, reduce_sum all directly is ok
return _C_ops.sum(out, [], None, False)
elif reduction == "mean":
# when reduction is mean, use paddle.nan
def _replace_nan(out):
return out + paddle.nan

if 0 in input.shape:
out = _replace_nan(out)
return _C_ops.mean_all(out)
# 1. if weight==none,
# numerator: reduce_sum all loss directly is ok causeof base_softmax_with_cross_entropy's inner logic
# denominator: count sample num with class_index!=ignore_index
Expand Down
27 changes: 26 additions & 1 deletion test/legacy_test/test_cross_entropy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import unittest

import numpy as np
from op_test import OpTest, paddle_static_guard, randomize_probability
from op_test import (
OpTest,
get_places,
paddle_static_guard,
randomize_probability,
)

import paddle
from paddle import base
Expand Down Expand Up @@ -497,5 +502,25 @@ def get_cross_entropy(self):
self.cross_entropy = np.random.random([0, 1]).astype(np.float64)


class TestCrossEntropyOp_ZeroSize2(unittest.TestCase):
def test_dygraph_api(self):
for place in get_places():
paddle.disable_static(place)
x_np = np.random.random((16, 0)).astype(np.float64)
label_np = np.random.random((16, 0)).astype(np.float64)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
label = paddle.to_tensor(label_np)
label.stop_gradient = False
out1 = paddle.nn.functional.cross_entropy(
x, label, soft_label=True, reduction="mean"
)
out2 = np.array(np.nan).astype(np.float64)
np.testing.assert_allclose(out1.numpy(), out2)
paddle.sum(out1).backward()
np.testing.assert_allclose(x.grad.shape, x.shape)
paddle.enable_static()


if __name__ == "__main__":
unittest.main()
Loading