@@ -20,6 +20,7 @@ limitations under the License. */
2020#include < hip/hip_runtime.h>
2121#endif
2222#include < stdio.h>
23+ #include " paddle/fluid/platform/bfloat16.h"
2324#include " paddle/fluid/platform/complex.h"
2425#include " paddle/fluid/platform/float16.h"
2526
@@ -244,6 +245,72 @@ __device__ __forceinline__ void VectorizedAtomicAddPerBlock(
244245#endif
245246#endif
246247
248+ // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
249+ inline static __device__ uint32_t bf16_add_to_low_half (uint32_t val, float x) {
250+ bfloat16 low_half;
251+ // the bfloat16 in lower 16bits
252+ low_half.x = static_cast <uint16_t >(val & 0xFFFFu );
253+ low_half = static_cast <bfloat16>(static_cast <float >(low_half) + x);
254+ return (val & 0xFFFF0000u ) | low_half.x ;
255+ }
256+
257+ inline static __device__ uint32_t bf16_add_to_high_half (uint32_t val, float x) {
258+ bfloat16 high_half;
259+ // the bfloat16 in higher 16bits
260+ high_half.x = static_cast <uint16_t >(val >> 16 );
261+ high_half = static_cast <bfloat16>(static_cast <float >(high_half) + x);
262+ return (val & 0xFFFFu ) | (static_cast <uint32_t >(high_half.x ) << 16 );
263+ }
264+
265+ #if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
266+ static __device__ __forceinline__ bfloat16 CUDABF16ToPDBF16 (__nv_bfloat16 x) {
267+ return *reinterpret_cast <bfloat16 *>(&x);
268+ }
269+
270+ static __device__ __forceinline__ __nv_bfloat16 PDBF16ToCUDABF16 (bfloat16 x) {
271+ return *reinterpret_cast <__nv_bfloat16 *>(&x);
272+ }
273+
274+ CUDA_ATOMIC_WRAPPER (Add, bfloat16) {
275+ return CUDABF16ToPDBF16 (atomicAdd (reinterpret_cast <__nv_bfloat16 *>(address),
276+ PDBF16ToCUDABF16 (val)));
277+ }
278+ #else
279+ CUDA_ATOMIC_WRAPPER (Add, bfloat16) {
280+ // concrete packed bfloat16 value may exsits in lower or higher 16bits
281+ // of the 32bits address.
282+ uint32_t *address_as_ui = reinterpret_cast <uint32_t *>(
283+ reinterpret_cast <char *>(address) -
284+ (reinterpret_cast <uintptr_t >(address) & 0x02 ));
285+ float val_f = static_cast <float >(val);
286+ uint32_t old = *address_as_ui;
287+ uint32_t sum;
288+ uint32_t newval;
289+ uint32_t assumed;
290+ if (((uintptr_t )address & 0x02 ) == 0 ) {
291+ // the bfloat16 value stay at lower 16 bits of the address.
292+ do {
293+ assumed = old;
294+ old = atomicCAS (address_as_ui, assumed,
295+ bf16_add_to_low_half (assumed, val_f));
296+ } while (old != assumed);
297+ bfloat16 ret;
298+ ret.x = old & 0xFFFFu ;
299+ return ret;
300+ } else {
301+ // the bfloat16 value stay at higher 16 bits of the address.
302+ do {
303+ assumed = old;
304+ old = atomicCAS (address_as_ui, assumed,
305+ bf16_add_to_high_half (assumed, val_f));
306+ } while (old != assumed);
307+ bfloat16 ret;
308+ ret.x = old >> 16 ;
309+ return ret;
310+ }
311+ }
312+ #endif
313+
247314CUDA_ATOMIC_WRAPPER (Add, complex <float >) {
248315 float *real = reinterpret_cast <float *>(address);
249316 float *imag = real + 1 ;
0 commit comments