Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions paddle/phi/kernels/cpu/layer_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
Expand All @@ -39,6 +40,21 @@ void LayerNormGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (scale_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(scale_grad->dims())),
0,
scale_grad);
if (bias_grad)
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(bias_grad->dims())),
0,
bias_grad);
return;
}
auto* scale = scale_opt.get_ptr();
auto d_y = out_grad;

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/layer_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void LayerNormKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<T>(mean);
dev_ctx.template Alloc<T>(var);
if (x.numel() == 0) return;

auto matrix_dim = common::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/kernels/gpu/layer_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#include "paddle/phi/kernels/funcs/layer_norm_util.h"

Expand All @@ -34,6 +36,31 @@ void LayerNormGradKernel(const Context &dev_ctx,
DenseTensor *x_grad,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (scale_grad) {
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(scale_grad->dims())),
0,
scale_grad);
if (scale_opt.get_ptr() && x.dtype() != scale_opt.get().dtype()) {
phi::CastKernel<T, Context>(
dev_ctx, *scale_grad, scale_opt.get().dtype(), scale_grad);
}
}
if (bias_grad) {
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(bias_grad->dims())),
0,
bias_grad);
if (bias_opt.get_ptr() && x.dtype() != bias_opt.get().dtype()) {
phi::CastKernel<T, Context>(
dev_ctx, *bias_grad, bias_opt.get().dtype(), bias_grad);
}
}
return;
}
using U = phi::funcs::LayerNormParamType<T>;
// d_x, d_scale, d_bias may be nullptr
auto *d_x = x_grad;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ void LayerNormKernel(const Context &dev_ctx,
auto *y_data = dev_ctx.template Alloc<T>(y);
auto *mean_data = dev_ctx.template Alloc<U>(mean);
auto *var_data = dev_ctx.template Alloc<U>(var);
if (x.numel() == 0) return;

bool valid_scale = (scale != nullptr);
bool valid_bias = (bias != nullptr);
Expand Down
16 changes: 16 additions & 0 deletions paddle/phi/kernels/xpu/layer_norm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"

namespace phi {

Expand All @@ -32,6 +33,21 @@ void LayerNormGradImpl(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
if (x.numel() == 0) {
dev_ctx.template Alloc<T>(x_grad);
if (scale_grad)
phi::Full<T, Context>(
dev_ctx,
phi::IntArray(common::vectorize(scale_grad->dims())),
0,
scale_grad);
if (bias_grad)
phi::Full<T, Context>(dev_ctx,
phi::IntArray(common::vectorize(bias_grad->dims())),
0,
bias_grad);
return;
}
const auto* scale_ptr = scale.get_ptr();
using XPUType = typename XPUTypeTrait<T>::Type;
using XPUTypeTW = typename XPUTypeTrait<TW>::Type;
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/xpu/layer_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void LayerNormKernelImpl(const Context& dev_ctx,
auto* out_data = dev_ctx.template Alloc<T>(out);
auto* mean_data = dev_ctx.template Alloc<float>(mean);
auto* variance_data = dev_ctx.template Alloc<float>(variance);
if (x.numel() == 0) return;

int r = xpu::layer_norm(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def affine_grid(
_out_shape = (
out_shape.tolist() if isinstance(out_shape, Variable) else out_shape
)
if isinstance(_out_shape, paddle.Tensor) and _out_shape.size == 0:
raise ValueError("The out_shape cannot be empty.")
theta = theta._use_gpudnn(use_cudnn)
return _C_ops.affine_grid(theta, _out_shape, align_corners)
elif in_pir_mode():
Expand Down
29 changes: 28 additions & 1 deletion test/legacy_test/test_affine_grid_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from op_test import OpTest
from op_test import OpTest, get_places

import paddle

Expand Down Expand Up @@ -224,6 +224,33 @@ def initTestCase(self):
self.align_corners = False


class TestAffineGridAPI_ZeroSize(unittest.TestCase):
def init_dtype(self):
self.dtype = 'float32'

def setUp(self):
self.init_dtype()
self.place = get_places()
self.theta_shape = (17, 2, 3)
self.output_shape = np.random.random([0]).astype("int32")

def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
theta_np = np.random.randint(1, 3, self.theta_shape).astype(
self.dtype
)
theta = paddle.to_tensor(theta_np)
with self.assertRaises(ValueError):
paddle.nn.functional.vision.affine_grid(
theta, paddle.to_tensor(self.output_shape)
)
paddle.enable_static()

for place in self.place:
run(place)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
23 changes: 23 additions & 0 deletions test/legacy_test/test_layer_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,29 @@ def init_dtype(self):
self.dtype = 'bfloat16'


@unittest.skipIf(
not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(),
"core is not compiled with CUDA",
)
class TestLayerNormBF16OpByOpTest_ZeroSize(TestLayerNormOpByOpTest):
def initConfig(self):
self.__class__.exist_fp64_check_grad = True
self.ori_atol = 1e-2
self.ori_rtol = 1e-2

self.max_relative_error = 1e-5

self.dtype = np.float32
self.x_shape = [2, 0, 6, 3]
self.epsilon = 0.00001
self.begin_norm_axis = 1
self.has_scale = True
self.has_bias = False
self.check_prim = False
self.check_prim_pir = False
self.check_pir = True


if __name__ == '__main__':
paddle.enable_static()
unittest.main()