Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions paddle/phi/kernels/cpu/elementwise_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,25 @@ void RemainderGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
if (dout.numel() == 0) {
if (dx) {
if (dx->numel() == 0) {
dev_ctx.template Alloc<T>(dx);
} else {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
}
}
if (dy) {
if (dy->numel() == 0) {
dev_ctx.template Alloc<T>(dy);
} else {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
}
}
return;
}
funcs::ElementwiseGradPreProcess(dout, dx);
int axis = -1;
phi::funcs::
Expand Down
19 changes: 19 additions & 0 deletions paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,25 @@ void RemainderGradKernel(const Context& dev_ctx,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
if (dout.numel() == 0) {
if (dx) {
if (dx->numel() == 0) {
dev_ctx.template Alloc<T>(dx);
} else {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dx->dims())), 0, dx);
}
}
if (dy) {
if (dy->numel() == 0) {
dev_ctx.template Alloc<T>(dy);
} else {
phi::Full<T, Context>(
dev_ctx, phi::IntArray(common::vectorize(dy->dims())), 0, dy);
}
}
return;
}
const auto place = dev_ctx.GetPlace();
int axis = -1;
if (dx != nullptr && dy != nullptr) {
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,14 @@ void RemainderRawKernel(const Context& dev_ctx,
const DenseTensor& y,
int axis,
DenseTensor* out) {
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
// allocate memory for out
dev_ctx.template Alloc<T>(out);
auto x_dims = x.dims();
auto y_dims = y.dims();
const auto& x_dims = x.dims();
const auto& y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) { // NOLINT
funcs::ElementwiseCompute<funcs::RemainderFunctor<T>, T>(
dev_ctx, x, y, funcs::RemainderFunctor<T>(), out, axis);
Expand Down
12 changes: 10 additions & 2 deletions paddle/phi/kernels/xpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ void RemainderKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
auto f = [](xpu::Context* xpu_ctx,
const XPUType* x,
const XPUType* y,
Expand Down Expand Up @@ -92,8 +96,12 @@ void RemainderKernel<phi::dtype::complex<float>, XPUContext>(
const DenseTensor& y,
DenseTensor* out) {
using T = phi::dtype::complex<float>;
auto x_dims = x.dims();
auto y_dims = y.dims();
if (out && out->numel() == 0) {
dev_ctx.template Alloc<T>(out);
return;
}
const auto& x_dims = x.dims();
const auto& y_dims = y.dims();
auto out_dims = phi::funcs::BroadcastTwoDims(x_dims, y_dims);
std::vector<int64_t> out_dims_vec = phi::vectorize(out_dims);

Expand Down
95 changes: 95 additions & 0 deletions test/legacy_test/test_elementwise_mod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def init_axis(self):
pass


class TestElementwiseModOp_ZeroSize1(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [0, 1]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [0, 1]).astype(self.dtype)
self.out = np.mod(self.x, self.y)


class TestElementwiseModOp_ZeroSize2(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [6, 0, 1]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [6, 1, 0]).astype(self.dtype)
self.out = np.mod(self.x, self.y)


class TestElementwiseModOp_ZeroSize3(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [1, 0, 4]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [0, 1, 4]).astype(self.dtype)
self.out = np.mod(self.x, self.y)


class TestElementwiseModOp_ZeroDim1(TestElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, []).astype(self.dtype)
Expand Down Expand Up @@ -331,6 +352,28 @@ def test_dygraph_broadcast_to_z(self):
self.assertEqual(z.dtype, x.dtype)
np.testing.assert_allclose(z_np, z.numpy())

def test_dygraph_zero_size_shape(self):
with dygraph_guard():
dtypes = ['int32', 'int64', 'float32', 'float64']
places = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for dtype in dtypes:
for place in places:
shape = [1, 2, 0, 4, 5]
x_np = np.random.uniform(-1000, 1000, shape).astype(dtype)
y_np = np.random.uniform(-1000, 1000, shape).astype(dtype)
# make sure all element in y is non-zero
y_np[np.isclose(y_np, 0)] = -1
z_np = np.remainder(x_np, y_np)
x = paddle.to_tensor(x_np, dtype=dtype, place=place)
x.stop_gradient = False
y = paddle.to_tensor(y_np, dtype=dtype, place=place)
y.stop_gradient = False
z = paddle.remainder(x, y)
self.assertEqual(z.dtype, x.dtype)
np.testing.assert_allclose(z_np, z.numpy())

def test_check_grad(self):
with dygraph_guard():
dtypes = ['int32', 'int64', 'float32', 'float64']
Expand Down Expand Up @@ -379,6 +422,58 @@ def test_check_grad(self):
dy_np = dy_np.sum(axis=dim, keepdims=True)
np.testing.assert_allclose(dy_np, dy.numpy(), 5e-5)

def test_check_grad_zero_size(self):
with dygraph_guard():
dtypes = ['int32', 'int64', 'float32', 'float64']
places = [paddle.CPUPlace()] # only test in cpu
if core.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
shape_combinations = [
([0], [0]),
([2, 0, 4], [1]),
([5, 0], [1, 5, 0]),
([0, 4], [2, 0, 4]),
([1, 0, 3], [1, 0, 3]),
([3, 0, 2], [3, 1, 2]),
([5, 1, 3], [5, 0, 3]),
([2, 1, 0, 1], [1, 0, 1, 5]),
]
for dtype in dtypes:
for place in places:
for x_shape, y_shape in shape_combinations:
x_np = np.random.uniform(-1000, 1000, x_shape).astype(
dtype
)
x_np[x_np == 0] = -1
y_np = np.random.uniform(-1000, 1000, y_shape).astype(
dtype
)
y_np[np.isclose(y_np, 0)] = -1
z_np = np.remainder(x_np, y_np)

x = paddle.to_tensor(
x_np, dtype=dtype, place=place, stop_gradient=False
)
y = paddle.to_tensor(
y_np, dtype=dtype, place=place, stop_gradient=False
)
z = paddle.remainder(x, y)
self.assertEqual(z.dtype, x.dtype)
np.testing.assert_allclose(z_np, z.numpy())

v_np = np.random.uniform(
-1000, 1000, z_np.shape
).astype(dtype)
v = paddle.to_tensor(v_np, dtype=dtype, place=place)

dx = paddle.grad(z, x, v, retain_graph=True)[0]
dx_np = np.zeros_like(dx.numpy())
np.testing.assert_allclose(dx_np, dx.numpy(), 5e-5)

dy = paddle.grad(z, y, v, retain_graph=True)[0]
dy_np = np.zeros_like(dy.numpy())
np.testing.assert_allclose(dy_np, dy.numpy(), 5e-5)


class TestRemainderOp(unittest.TestCase):
def setUp(self):
Expand Down
12 changes: 12 additions & 0 deletions test/xpu/test_elementwise_mod_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)

class ElementwiseModOpZeroSize(ElementwiseModOp):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [0, 10]).astype(self.dtype)
self.y = np.random.uniform(0, 1000, [0, 10]).astype(self.dtype)
self.out = np.mod(self.x, self.y)
self.inputs = {
'X': OpTest.np_dtype_to_base_dtype(self.x),
'Y': OpTest.np_dtype_to_base_dtype(self.y),
}
self.outputs = {'Out': self.out}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}

class TestRemainderOp(unittest.TestCase):
def test_dygraph(self):
with base.dygraph.guard():
Expand Down
Loading