Skip to content
53 changes: 37 additions & 16 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/elementwise_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
Expand Down Expand Up @@ -64,18 +65,28 @@ void CommonGradBroadcastCPU(const DenseTensor &x,
const CPUContext &dev_ctx,
DX_OP dx_op,
DY_OP dy_op) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

std::vector<int64_t> index_array(max_dim, 0);
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const Tout *out_data = out.data<Tout>();
const Tout *dout_data = dout.data<Tout>();
T *dx_data = dx == nullptr ? nullptr : dev_ctx.Alloc<T>(dx);
T *dy_data = dy == nullptr ? nullptr : dev_ctx.Alloc<T>(dy);
if (dx_data != nullptr) {
memset(dx_data, 0, dx->numel() * sizeof(T));

DenseTensor dx_mp, dy_mp;
MPType *dx_mp_data = nullptr;
MPType *dy_mp_data = nullptr;
if (dx != nullptr) {
dx_mp.Resize(dx->dims());
dev_ctx.Alloc<MPType>(&dx_mp);
dx_mp_data = dx_mp.data<MPType>();
memset(dx_mp_data, 0, dx->numel() * sizeof(MPType));
}
if (dy_data != nullptr) {
memset(dy_data, 0, dy->numel() * sizeof(T));
if (dy != nullptr) {
dy_mp.Resize(dy->dims());
dev_ctx.Alloc<MPType>(&dy_mp);
dy_mp_data = dy_mp.data<MPType>();
memset(dy_mp_data, 0, dy->numel() * sizeof(MPType));
}
const int64_t out_size = std::accumulate(out_dims_array,
out_dims_array + max_dim,
Expand All @@ -87,22 +98,32 @@ void CommonGradBroadcastCPU(const DenseTensor &x,
GetElementwiseIndex<int64_t>(x_dims_array, max_dim, index_array.data());
y_index =
GetElementwiseIndex<int64_t>(y_dims_array, max_dim, index_array.data());
if (dx_data != nullptr) {
dx_data[x_index] += dx_op(x_data[x_index],
y_data[y_index],
out_data[out_index],
dout_data[out_index]);
if (dx_mp_data != nullptr) {
dx_mp_data[x_index] += static_cast<MPType>(dx_op(x_data[x_index],
y_data[y_index],
out_data[out_index],
dout_data[out_index]));
}
if (dy_data != nullptr) {
dy_data[y_index] += dy_op(x_data[x_index],
y_data[y_index],
out_data[out_index],
dout_data[out_index]);
if (dy_mp_data != nullptr) {
dy_mp_data[y_index] += static_cast<MPType>(dy_op(x_data[x_index],
y_data[y_index],
out_data[out_index],
dout_data[out_index]));
}

UpdateElementwiseIndexArray<int64_t>(
out_dims_array, max_dim, index_array.data());
}
if (dx != nullptr) {
dev_ctx.Alloc<T>(dx);
phi::CastKernel<MPType, CPUContext>(
dev_ctx, dx_mp, phi::CppTypeToDataType<T>::Type(), dx);
}
if (dy != nullptr) {
dev_ctx.Alloc<T>(dy);
phi::CastKernel<MPType, CPUContext>(
dev_ctx, dy_mp, phi::CppTypeToDataType<T>::Type(), dy);
}
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
Expand Down
21 changes: 21 additions & 0 deletions test/legacy_test/test_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,26 @@ def init_shapes(self):
self.y_shape = [5, 1]


class TestMultiplyApiBF16(unittest.TestCase):
# Now only check the successful run of multiply with bfloat16 and backward.
def setUp(self):
paddle.device.set_device('cpu')

def test_multiply(self):
self.x_shape = [1, 1024, 32, 128]
self.y_shape = [1, 1024, 1, 128]
x = paddle.rand(self.x_shape, dtype='bfloat16')
x.stop_gradient = False
y = paddle.rand(self.y_shape, dtype='bfloat16')
y.stop_gradient = False
res = paddle.multiply(x, y)
loss = res.sum()
loss.backward()
assert x.grad is not None
assert x.grad.dtype == paddle.bfloat16
assert y.grad is not None
assert y.grad.dtype == paddle.bfloat16


if __name__ == '__main__':
unittest.main()