@@ -395,6 +395,182 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
395395 CudaAtomicAdd (imag, val.imag ));
396396}
397397
398+ // For atomicMul.
399+ CUDA_ATOMIC_WRAPPER (Mul, int ) {
400+ int res = *address, old = res; // NOLINT
401+ do {
402+ old = res;
403+ res = atomicCAS (address, // NOLINT
404+ old, // NOLINT
405+ val * old); // NOLINT
406+ } while (old != res);
407+ return res;
408+ }
409+
410+ CUDA_ATOMIC_WRAPPER (Mul, unsigned int ) {
411+ unsigned int res = *address, old = res; // NOLINT
412+ do {
413+ old = res;
414+ res = atomicCAS (address, // NOLINT
415+ old, // NOLINT
416+ val * old); // NOLINT
417+ } while (old != res);
418+ return res;
419+ }
420+ // CUDA API uses unsigned long long int, we cannot use uint64_t here.
421+ // It because unsigned long long int is not necessarily uint64_t
422+ CUDA_ATOMIC_WRAPPER (Mul, unsigned long long int ) { // NOLINT
423+ unsigned long long int old = *address, assumed; // NOLINT
424+
425+ do {
426+ assumed = old;
427+ old = atomicCAS (address, assumed, val * assumed);
428+ } while (assumed != old);
429+ return old;
430+ }
431+
432+ CUDA_ATOMIC_WRAPPER (Mul, int64_t ) {
433+ // Here, we check long long int must be int64_t.
434+ static_assert (sizeof (int64_t ) == sizeof (long long int ), // NOLINT
435+ " long long should be int64" );
436+ long long int res = *address, old = res; // NOLINT
437+ do {
438+ old = res;
439+ res = (long long int )atomicCAS ( // NOLINT
440+ (unsigned long long int *)address, // NOLINT
441+ (unsigned long long int )old, // NOLINT
442+ (unsigned long long int )val * (unsigned long long int )old); // NOLINT
443+ } while (old != res);
444+ return res;
445+ }
446+
447+ CUDA_ATOMIC_WRAPPER (Mul, float ) {
448+ int *const address_as_i = reinterpret_cast <int *>(address);
449+ int old = *address_as_i, assumed;
450+
451+ do {
452+ assumed = old;
453+ old = atomicCAS (
454+ address_as_i, assumed, __float_as_int (val * __int_as_float (assumed)));
455+ } while (assumed != old);
456+
457+ return __int_as_float (old);
458+ }
459+
460+ CUDA_ATOMIC_WRAPPER (Mul, double ) {
461+ unsigned long long int *const address_as_ull = // NOLINT
462+ reinterpret_cast <unsigned long long int *>(address); // NOLINT
463+ unsigned long long int old = *address_as_ull, assumed; // NOLINT
464+
465+ do {
466+ assumed = old;
467+
468+ old = atomicCAS (address_as_ull,
469+ assumed,
470+ __double_as_longlong (val * __longlong_as_double (assumed)));
471+ } while (assumed != old);
472+
473+ return __longlong_as_double (old);
474+ }
475+
476+ #ifdef PADDLE_CUDA_FP16
477+ inline static __device__ uint32_t mul_to_low_half (uint32_t val, float x) {
478+ phi::dtype::float16 low_half;
479+ // The float16 in lower 16bits
480+ low_half.x = static_cast <uint16_t >(val & 0xFFFFu );
481+ low_half = static_cast <phi::dtype::float16>(static_cast <float >(low_half) * x);
482+ return (val & 0xFFFF0000u ) | low_half.x ;
483+ }
484+
485+ inline static __device__ uint32_t mul_to_high_half (uint32_t val, float x) {
486+ phi::dtype::float16 high_half;
487+ // The float16 in higher 16bits
488+ high_half.x = static_cast <uint16_t >(val >> 16 );
489+ high_half =
490+ static_cast <phi::dtype::float16>(static_cast <float >(high_half) * x);
491+ return (val & 0xFFFFu ) | (static_cast <uint32_t >(high_half.x ) << 16 );
492+ }
493+
494+ CUDA_ATOMIC_WRAPPER (Mul, phi::dtype::float16) {
495+ if (*address >= val) {
496+ return *address;
497+ }
498+ uint32_t *address_as_ui = reinterpret_cast <uint32_t *>(
499+ reinterpret_cast <char *>(address) -
500+ (reinterpret_cast <uintptr_t >(address) & 0x02 ));
501+ float val_f = static_cast <float >(val);
502+ uint32_t old = *address_as_ui;
503+ uint32_t assumed;
504+ if (((uintptr_t )address & 0x02 ) == 0 ) {
505+ // The float16 value stay at lower 16 bits of the address.
506+ do {
507+ assumed = old;
508+ old = atomicCAS (address_as_ui, assumed, mul_to_low_half (assumed, val_f));
509+ } while (old != assumed);
510+ phi::dtype::float16 ret;
511+ ret.x = old & 0xFFFFu ;
512+ return ret;
513+ } else {
514+ // The float16 value stay at higher 16 bits of the address.
515+ do {
516+ assumed = old;
517+ old = atomicCAS (address_as_ui, assumed, mul_to_high_half (assumed, val_f));
518+ } while (old != assumed);
519+ phi::dtype::float16 ret;
520+ ret.x = old >> 16 ;
521+ return ret;
522+ }
523+ }
524+ #endif
525+
526+ inline static __device__ uint32_t bf16_mul_to_low_half (uint32_t val, float x) {
527+ phi::dtype::bfloat16 low_half;
528+ // The bfloat16 in lower 16bits
529+ low_half.x = static_cast <uint16_t >(val & 0xFFFFu );
530+ low_half =
531+ static_cast <phi::dtype::bfloat16>(static_cast <float >(low_half) * x);
532+ return (val & 0xFFFF0000u ) | low_half.x ;
533+ }
534+
535+ inline static __device__ uint32_t bf16_mul_to_high_half (uint32_t val, float x) {
536+ phi::dtype::bfloat16 high_half;
537+ // The bfloat16 in higher 16bits
538+ high_half.x = static_cast <uint16_t >(val >> 16 );
539+ high_half =
540+ static_cast <phi::dtype::bfloat16>(static_cast <float >(high_half) * x);
541+ return (val & 0xFFFFu ) | (static_cast <uint32_t >(high_half.x ) << 16 );
542+ }
543+
544+ CUDA_ATOMIC_WRAPPER (Mul, phi::dtype::bfloat16) {
545+ uint32_t *address_as_ui = reinterpret_cast <uint32_t *>(
546+ reinterpret_cast <char *>(address) -
547+ (reinterpret_cast <uintptr_t >(address) & 0x02 ));
548+ float val_f = static_cast <float >(val);
549+ uint32_t old = *address_as_ui;
550+ uint32_t assumed;
551+ if (((uintptr_t )address & 0x02 ) == 0 ) {
552+ // The bfloat16 value stay at lower 16 bits of the address.
553+ do {
554+ assumed = old;
555+ old = atomicCAS (
556+ address_as_ui, assumed, bf16_mul_to_low_half (assumed, val_f));
557+ } while (old != assumed);
558+ phi::dtype::bfloat16 ret;
559+ ret.x = old & 0xFFFFu ;
560+ return ret;
561+ } else {
562+ // The bfloat16 value stay at higher 16 bits of the address.
563+ do {
564+ assumed = old;
565+ old = atomicCAS (
566+ address_as_ui, assumed, bf16_mul_to_high_half (assumed, val_f));
567+ } while (old != assumed);
568+ phi::dtype::bfloat16 ret;
569+ ret.x = old >> 16 ;
570+ return ret;
571+ }
572+ }
573+
398574// For atomicMax
399575USE_CUDA_ATOMIC (Max, int );
400576USE_CUDA_ATOMIC (Max, unsigned int );
0 commit comments