1818 * Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All Rights reserved.
1919 * Copyright (c) 2019 Research Organization for Information Science
2020 * and Technology (RIST). All rights reserved.
21- * Copyright (c) 2018 Triad National Security, LLC. All rights
21+ * Copyright (c) 2018-2025 Triad National Security, LLC. All rights
2222 * reserved.
2323 * Copyright (c) 2021 IBM Corporation. All rights reserved.
2424 * $COPYRIGHT$
@@ -61,12 +61,16 @@ BEGIN_C_DECLS
6161 */
6262typedef void (ompi_op_c_handler_fn_t )(const void * , void * , int * ,
6363 struct ompi_datatype_t * * );
64+ typedef void (ompi_op_c_handler_bc_fn_t )(const void * , void * , size_t * ,
65+ struct ompi_datatype_t * * );
6466
6567/**
6668 * Typedef for fortran user-defined MPI_Ops.
6769 */
6870typedef void (ompi_op_fortran_handler_fn_t )(const void * , void * ,
6971 MPI_Fint * , MPI_Fint * );
72+ typedef void (ompi_op_fortran_handler_bc_fn_t )(const void * , void * ,
73+ size_t * , MPI_Fint * );
7074
7175/**
7276 * Typedef for Java op functions intercept (used for user-defined
@@ -98,8 +102,8 @@ typedef void (ompi_op_java_handler_fn_t)(const void *, void *, int *,
98102#define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0020
99103/** Set if the callback function is communative */
100104#define OMPI_OP_FLAGS_COMMUTE 0x0040
101-
102-
105+ /** Set if the callback function is using bigcount */
106+ #define OMPI_OP_FLAGS_BIGCOUNT 0x0080
103107
104108
105109/*
@@ -152,8 +156,12 @@ struct ompi_op_t {
152156 ompi_op_base_op_fns_t intrinsic ;
153157 /** C handler function pointer */
154158 ompi_op_c_handler_fn_t * c_fn ;
159+ /** C handler function pointer - bigcount*/
160+ ompi_op_c_handler_bc_fn_t * c_fn_bc ;
155161 /** Fortran handler function pointer */
156162 ompi_op_fortran_handler_fn_t * fort_fn ;
163+ /** Fortran handler function pointer - bigcount*/
164+ ompi_op_fortran_handler_bc_fn_t * fort_fn_bc ;
157165 /** Java intercept function data */
158166 struct {
159167 /* The OMPI C++ callback/intercept function */
@@ -333,6 +341,8 @@ int ompi_op_init(void);
333341 *
334342 * @param commute Boolean indicating whether the operation is
335343 * communative or not
344+ * @param bigcount Boolean indicating whether or not the op is
345+ * using the bigcount (MPI_Count) interface
336346 * @param func Function pointer of the error handler
337347 *
338348 * @returns op Pointer to the ompi_op_t that will be
@@ -355,6 +365,7 @@ int ompi_op_init(void);
355365 * manually.
356366 */
357367ompi_op_t * ompi_op_create_user (bool commute ,
368+ bool bigcount ,
358369 ompi_op_fortran_handler_fn_t func );
359370
360371/**
@@ -512,11 +523,9 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
512523 * in iterations of counts <= INT_MAX since it has an `int *len`
513524 * parameter.
514525 *
515- * Note: When we add BigCount support then we can distinguish between
516- * a reduction operation with `int *len` and `MPI_Count *len`. At which
517- * point we can avoid this loop.
518526 */
519- if ( OPAL_UNLIKELY (full_count > INT_MAX ) ) {
527+ if (OPAL_UNLIKELY ((full_count > INT_MAX ) &&
528+ (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )))) {
520529 size_t done_count = 0 , shift ;
521530 int iter_count ;
522531 ptrdiff_t ext , lb ;
@@ -578,8 +587,12 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
578587 /* User-defined function */
579588 if (0 != (op -> o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC )) {
580589 f_dtype = OMPI_INT_2_FINT (dtype -> d_f_to_c_index );
581- f_count = OMPI_INT_2_FINT (count );
582- op -> o_func .fort_fn (source , target , & f_count , & f_dtype );
590+ if (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )) {
591+ f_count = OMPI_INT_2_FINT (count );
592+ op -> o_func .fort_fn (source , target , & f_count , & f_dtype );
593+ } else {
594+ op -> o_func .fort_fn_bc (source , target , & full_count , & f_dtype );
595+ }
583596 return ;
584597 } else if (0 != (op -> o_flags & OMPI_OP_FLAGS_JAVA_FUNC )) {
585598 op -> o_func .java_data .intercept_fn (source , target , & count , & dtype ,
@@ -588,15 +601,25 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
588601 op -> o_func .java_data .object );
589602 return ;
590603 }
591- op -> o_func .c_fn (source , target , & count , & dtype );
604+ if (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )) {
605+ op -> o_func .c_fn (source , target , & count , & dtype );
606+ } else {
607+ op -> o_func .c_fn_bc (source , target , & full_count , & dtype );
608+ }
592609 return ;
593610}
594611
595612static inline void ompi_3buff_op_user (ompi_op_t * op , void * restrict source1 , void * restrict source2 ,
596- void * restrict result , int count , struct ompi_datatype_t * dtype )
613+ void * restrict result , size_t full_count , struct ompi_datatype_t * dtype )
597614{
598- ompi_datatype_copy_content_same_ddt (dtype , count , (char * )result , (char * )source1 );
599- op -> o_func .c_fn (source2 , result , & count , & dtype );
615+ ompi_datatype_copy_content_same_ddt (dtype , full_count , (char * )result , (char * )source1 );
616+ if (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )) {
617+ assert (full_count <= INT_MAX );
618+ int count = (int )full_count ; /* protected by loop in only caller of this function */
619+ op -> o_func .c_fn (source2 , result , & count , & dtype );
620+ } else {
621+ op -> o_func .c_fn_bc (source2 , result , & full_count , & dtype );
622+ }
600623}
601624
602625/**
@@ -618,13 +641,11 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v
618641 * with the values in the source buffer and the result is stored in
619642 * the target buffer).
620643 *
621- * This function will *only* be invoked on intrinsic MPI_Ops.
622- *
623644 * Otherwise, this function is the same as ompi_op_reduce.
624645 */
625646static inline void ompi_3buff_op_reduce (ompi_op_t * op , void * source1 ,
626647 void * source2 , void * target ,
627- int count , ompi_datatype_t * dtype )
648+ size_t full_count , ompi_datatype_t * dtype )
628649{
629650 void * restrict src1 ;
630651 void * restrict src2 ;
@@ -633,13 +654,36 @@ static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1,
633654 src2 = source2 ;
634655 tgt = target ;
635656
657+ if (OPAL_UNLIKELY ((full_count > INT_MAX ) &&
658+ (0 == (op -> o_flags & OMPI_OP_FLAGS_BIGCOUNT )))) {
659+ size_t done_count = 0 , shift , iter_count ;
660+ ptrdiff_t ext , lb ;
661+
662+ ompi_datatype_get_extent (dtype , & lb , & ext );
663+
664+ while (done_count < full_count ) {
665+ if (done_count + INT_MAX > full_count ) {
666+ iter_count = full_count - done_count ;
667+ } else {
668+ iter_count = INT_MAX ;
669+ }
670+ shift = done_count * ext ;
671+ // Recurse one level in iterations of 'int'
672+ ompi_3buff_op_reduce (op , (char * )source1 + shift , (char * )source2 + shift ,
673+ (char * )target + shift , iter_count , dtype );
674+ done_count += iter_count ;
675+ }
676+ return ;
677+ }
678+
636679 if (OPAL_LIKELY (ompi_op_is_intrinsic (op ))) {
680+ int count = (int )full_count ;
637681 op -> o_3buff_intrinsic .fns [ompi_op_ddt_map [dtype -> id ]](src1 , src2 ,
638682 tgt , & count ,
639683 & dtype ,
640684 op -> o_3buff_intrinsic .modules [ompi_op_ddt_map [dtype -> id ]]);
641685 } else {
642- ompi_3buff_op_user (op , src1 , src2 , tgt , count , dtype );
686+ ompi_3buff_op_user (op , src1 , src2 , tgt , full_count , dtype );
643687 }
644688}
645689
0 commit comments