Skip to content

Commit 6d26b33

Browse files
authored
[bf16] add bf16 kernel: scale gather sum (#39683)
* add scale gather sum * refine CUDA_ATOMIC_WRAPPER ADD for bf16 * add gather unittest * solve conflict * add scale uinttest * add sum unittest * solve conflict * refine gather unittest * refine unittest
1 parent 9de7989 commit 6d26b33

File tree

10 files changed

+164
-8
lines changed

10 files changed

+164
-8
lines changed

paddle/fluid/operators/gather_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,14 @@ REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
201201
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
202202
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
203203
ops::GatherOpKernel<uint8_t>,
204-
ops::GatherOpKernel<int64_t>);
204+
ops::GatherOpKernel<int64_t>,
205+
ops::GatherOpKernel<phi::dtype::bfloat16>);
205206
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
206207
ops::GatherGradientOpKernel<double>,
207208
ops::GatherGradientOpKernel<int>,
208209
ops::GatherGradientOpKernel<uint8_t>,
209-
ops::GatherGradientOpKernel<int64_t>);
210+
ops::GatherGradientOpKernel<int64_t>,
211+
ops::GatherGradientOpKernel<phi::dtype::bfloat16>);
210212
REGISTER_OP_VERSION(gather)
211213
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
212214
paddle::framework::compatible::OpVersionDesc().NewInput(

paddle/fluid/operators/gather_op.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
130130
ops::GatherOpCUDAKernel<double>,
131131
ops::GatherOpCUDAKernel<int64_t>,
132132
ops::GatherOpCUDAKernel<int>,
133-
ops::GatherOpCUDAKernel<plat::float16>);
133+
ops::GatherOpCUDAKernel<plat::float16>,
134+
ops::GatherOpCUDAKernel<plat::bfloat16>);
134135
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
135136
ops::GatherGradOpCUDAKernel<double>,
136137
ops::GatherGradOpCUDAKernel<int64_t>,
137138
ops::GatherGradOpCUDAKernel<int>,
138-
ops::GatherGradOpCUDAKernel<plat::float16>);
139+
ops::GatherGradOpCUDAKernel<plat::float16>,
140+
ops::GatherGradOpCUDAKernel<plat::bfloat16>);

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <vector>
1717

1818
#include "paddle/fluid/operators/math/selected_rows_functor.h"
19+
#include "paddle/fluid/platform/bfloat16.h"
1920
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
2021
#include "paddle/fluid/platform/float16.h"
2122
#include "paddle/phi/kernels/funcs/math_function.h"
@@ -445,6 +446,7 @@ template struct MergeAdd<platform::CUDADeviceContext, double>;
445446
template struct MergeAdd<platform::CUDADeviceContext, int>;
446447
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
447448
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
449+
template struct MergeAdd<platform::CUDADeviceContext, platform::bfloat16>;
448450
template struct MergeAdd<platform::CUDADeviceContext, platform::complex<float>>;
449451
template struct MergeAdd<platform::CUDADeviceContext,
450452
platform::complex<double>>;

paddle/fluid/operators/sum_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,5 @@ REGISTER_OP_CUDA_KERNEL(
258258
ops::SumKernel<paddle::platform::CUDADeviceContext, double>,
259259
ops::SumKernel<paddle::platform::CUDADeviceContext, int>,
260260
ops::SumKernel<paddle::platform::CUDADeviceContext, int64_t>,
261-
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>);
261+
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::float16>,
262+
ops::SumKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>);

paddle/fluid/platform/device/gpu/gpu_primitives.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include <hip/hip_runtime.h>
2121
#endif
2222
#include <stdio.h>
23+
#include "paddle/fluid/platform/bfloat16.h"
2324
#include "paddle/fluid/platform/complex.h"
2425
#include "paddle/fluid/platform/float16.h"
2526

@@ -244,6 +245,72 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
244245
#endif
245246
#endif
246247

248+
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
249+
inline static __device__ uint32_t bf16_add_to_low_half(uint32_t val, float x) {
250+
bfloat16 low_half;
251+
// the bfloat16 in lower 16bits
252+
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
253+
low_half = static_cast<bfloat16>(static_cast<float>(low_half) + x);
254+
return (val & 0xFFFF0000u) | low_half.x;
255+
}
256+
257+
inline static __device__ uint32_t bf16_add_to_high_half(uint32_t val, float x) {
258+
bfloat16 high_half;
259+
// the bfloat16 in higher 16bits
260+
high_half.x = static_cast<uint16_t>(val >> 16);
261+
high_half = static_cast<bfloat16>(static_cast<float>(high_half) + x);
262+
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
263+
}
264+
265+
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
266+
static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16(__nv_bfloat16 x) {
267+
return *reinterpret_cast<bfloat16 *>(&x);
268+
}
269+
270+
static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16(bfloat16 x) {
271+
return *reinterpret_cast<__nv_bfloat16 *>(&x);
272+
}
273+
274+
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
275+
return CUDABF16ToPDBF16(atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
276+
PDBF16ToCUDABF16(val)));
277+
}
278+
#else
279+
CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
280+
// concrete packed bfloat16 value may exsits in lower or higher 16bits
281+
// of the 32bits address.
282+
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
283+
reinterpret_cast<char *>(address) -
284+
(reinterpret_cast<uintptr_t>(address) & 0x02));
285+
float val_f = static_cast<float>(val);
286+
uint32_t old = *address_as_ui;
287+
uint32_t sum;
288+
uint32_t newval;
289+
uint32_t assumed;
290+
if (((uintptr_t)address & 0x02) == 0) {
291+
// the bfloat16 value stay at lower 16 bits of the address.
292+
do {
293+
assumed = old;
294+
old = atomicCAS(address_as_ui, assumed,
295+
bf16_add_to_low_half(assumed, val_f));
296+
} while (old != assumed);
297+
bfloat16 ret;
298+
ret.x = old & 0xFFFFu;
299+
return ret;
300+
} else {
301+
// the bfloat16 value stay at higher 16 bits of the address.
302+
do {
303+
assumed = old;
304+
old = atomicCAS(address_as_ui, assumed,
305+
bf16_add_to_high_half(assumed, val_f));
306+
} while (old != assumed);
307+
bfloat16 ret;
308+
ret.x = old >> 16;
309+
return ret;
310+
}
311+
}
312+
#endif
313+
247314
CUDA_ATOMIC_WRAPPER(Add, complex<float>) {
248315
float *real = reinterpret_cast<float *>(address);
249316
float *imag = real + 1;

paddle/phi/kernels/gpu/scale_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(scale,
7070
float,
7171
double,
7272
phi::dtype::float16,
73+
phi::dtype::bfloat16,
7374
uint8_t,
7475
int8_t,
7576
int16_t,

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,12 @@ def _append_ops(self, block):
482482

483483
op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
484484
"infer datatype from inputs and outputs for this test case"
485-
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
485+
if self.is_bfloat16_op():
486+
self.dtype = np.uint16
487+
self.__class__.dtype = self.dtype
488+
self.output_dtype = np.uint16
489+
else:
490+
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
486491
inputs = append_input_output(block, op_proto, self.inputs, True,
487492
self.dtype)
488493
outputs = append_input_output(block, op_proto, self.outputs, False,

python/paddle/fluid/tests/unittests/test_gather_op.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import unittest
1818
import numpy as np
19-
from op_test import OpTest
19+
from op_test import OpTest, convert_float_to_uint16
2020
import paddle
2121
import paddle.fluid as fluid
2222
from paddle.framework import core
@@ -117,6 +117,39 @@ def config(self):
117117
self.index_type = "int32"
118118

119119

120+
class TestGatherBF16Op(OpTest):
121+
def setUp(self):
122+
self.op_type = "gather"
123+
self.dtype = np.uint16
124+
self.config()
125+
xnp = np.random.random(self.x_shape).astype(np.float32)
126+
axis_np = np.array(self.axis).astype(self.axis_type)
127+
index_np = np.array(self.index).astype(self.index_type)
128+
self.inputs = {
129+
'X': convert_float_to_uint16(xnp),
130+
'Index': index_np,
131+
'Axis': axis_np
132+
}
133+
out = gather_numpy(self.inputs['X'], index_np, axis_np[0])
134+
self.outputs = {'Out': out}
135+
136+
def test_check_output(self):
137+
self.check_output()
138+
139+
def test_check_grad(self):
140+
self.check_grad(['X'], 'Out', numeric_grad_delta=0.5)
141+
142+
def config(self):
143+
"""
144+
For multi-dimension input
145+
"""
146+
self.x_shape = (3, 88, 3)
147+
self.index = [1, 3, 5]
148+
self.index_type = "int32"
149+
self.axis = [1]
150+
self.axis_type = "int32"
151+
152+
120153
class TestGatherOp1(OpTest):
121154
def setUp(self):
122155
self.op_type = "gather"

python/paddle/fluid/tests/unittests/test_scale_op.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import unittest
1818
import numpy as np
19-
from op_test import OpTest
19+
from op_test import OpTest, convert_float_to_uint16
2020
import paddle
2121
import paddle.fluid as fluid
2222
import paddle.fluid.core as core
@@ -153,6 +153,23 @@ def test_check_grad(self):
153153
place, ["X"], "Out", max_relative_error=0.05)
154154

155155

156+
class TestScaleBF16Op(OpTest):
157+
def setUp(self):
158+
self.op_type = "scale"
159+
self.dtype = np.uint16
160+
self.attrs = {'scale': -2.3}
161+
x = np.random.random((10, 10)).astype(np.float32)
162+
out = x * np.float32(self.attrs['scale'])
163+
self.inputs = {'X': convert_float_to_uint16(x)}
164+
self.outputs = {'Out': convert_float_to_uint16(out)}
165+
166+
def test_check_output(self):
167+
self.check_output()
168+
169+
def test_check_grad(self):
170+
self.check_grad(['X'], 'Out', numeric_grad_delta=0.8)
171+
172+
156173
@unittest.skipIf(not core.is_compiled_with_cuda(),
157174
"core is not compiled with CUDA")
158175
class TestScaleFp16OpSelectedRows(TestScaleOpSelectedRows):

python/paddle/fluid/tests/unittests/test_sum_op.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,32 @@ def test_w_is_selected_rows(self):
298298
globals()[cls_name] = TestSumFp16Case
299299

300300

301+
#----------- test bf16 -----------
302+
class TestSumBF16Op(OpTest):
303+
def setUp(self):
304+
self.op_type = "sum"
305+
self.init_kernel_type()
306+
x0 = np.random.random((3, 40)).astype(np.float32)
307+
x1 = np.random.random((3, 40)).astype(np.float32)
308+
x2 = np.random.random((3, 40)).astype(np.float32)
309+
y = x0 + x1 + x2
310+
self.inputs = {
311+
"X": [("x0", convert_float_to_uint16(x0)),
312+
("x1", convert_float_to_uint16(x1)),
313+
("x2", convert_float_to_uint16(x2))]
314+
}
315+
self.outputs = {'Out': convert_float_to_uint16(y)}
316+
317+
def init_kernel_type(self):
318+
self.dtype = np.uint16
319+
320+
def test_check_output(self):
321+
self.check_output()
322+
323+
def test_check_grad(self):
324+
self.check_grad(['x0'], 'Out', numeric_grad_delta=0.5)
325+
326+
301327
class API_Test_Add_n(unittest.TestCase):
302328
def test_api(self):
303329
with fluid.program_guard(fluid.Program(), fluid.Program()):

0 commit comments

Comments
 (0)