@@ -128,18 +128,36 @@ struct Identity<T, ComplexSum> {
128128 static  constexpr  T value = {0 , 0 };
129129};
130130
131+ template  <typename  T, typename  Op, bool  UseKahan>
132+ struct  BlockPrefixCallbackOp ;
133+ 
131134template  <typename  T, typename  Op>
132- struct  BlockPrefixCallbackOp  {
135+ struct  BlockPrefixCallbackOp <T, Op,  false >  {
133136 //  Running prefix
134137 T running_total_;
135-  T compensation_;
136138 Op op_;
137139
138140 __device__  BlockPrefixCallbackOp (T identity, Op op)
139-  : running_total_(identity), compensation_(identity),  op_(op) {}
141+  : running_total_(identity), op_(op) {}
140142
141143 //  Callback operator to be entered by the first warp of threads in the block.
142144 //  tid 0 is responsible for returning a value for seeding the block-wide scan.
145+  __device__  T operator ()(T block_aggregate) {
146+  const  T old_prefix = running_total_;
147+  running_total_ = op_ (running_total_, block_aggregate);
148+  return  old_prefix;
149+  }
150+ };
151+ 
152+ template  <typename  T, typename  Op>
153+ struct  BlockPrefixCallbackOp <T, Op, true > {
154+  T running_total_;
155+  T compensation_;
156+  Op op_;
157+ 
158+  __device__  BlockPrefixCallbackOp (T identity, Op op)
159+  : running_total_(identity), compensation_(static_cast <T>(0.0 )), op_(op) {}
160+ 
143161 __device__  T operator ()(T block_aggregate) {
144162 T old_prefix = running_total_;
145163
@@ -155,20 +173,23 @@ struct BlockPrefixCallbackOp {
155173};
156174
157175template  <typename  T>
158- struct  BlockPrefixCallbackOp <T, LogAddExp> {
176+ struct  BlockPrefixCallbackOp <T, LogAddExp,  true > {
159177 T max_so_far_;
160178 T scaled_sum_;
161179 T compensation_;
162180 LogAddExp op_;
163181
164182 __device__  BlockPrefixCallbackOp (T identity, LogAddExp op)
165-  : max_so_far_(identity), scaled_sum_(0.0 ), compensation_(0.0 ), op_(op) {}
183+  : max_so_far_(identity),
184+  scaled_sum_(static_cast <T>(0.0 )),
185+  compensation_(static_cast <T>(0.0 )),
186+  op_(op) {}
166187
167188 __device__  T operator ()(T block_aggregate) {
168189 if  (scaled_sum_ == 0.0 ) {
169190 max_so_far_ = block_aggregate;
170-  scaled_sum_ = 1.0 ;
171-  compensation_ = 0.0 ;
191+  scaled_sum_ = static_cast <T>( 1.0 ) ;
192+  compensation_ = static_cast <T>( 0.0 ) ;
172193 return  std::numeric_limits<T>::lowest ();
173194 }
174195
@@ -195,15 +216,19 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
195216 }
196217};
197218
198- template  <typename  T, int  BLOCK_THREADS, int  ITEMS_PER_THREAD, typename  Op>
219+ template  <typename  T,
220+  int  BLOCK_THREADS,
221+  int  ITEMS_PER_THREAD,
222+  typename  Op,
223+  bool  UseKahan>
199224__global__  void  BlockScanKernel (T* d_out,
200225 const  T* d_in,
201226 int64_t  grid_size,
202227 int64_t  scan_size,
203228 bool  exclusive,
204229 Op op) {
205230 using  MT = typename  phi::dtype::MPTypeTrait<T>::Type;
206-  using  CallbackOp = BlockPrefixCallbackOp<MT, Op>;
231+  using  CallbackOp = BlockPrefixCallbackOp<MT, Op, UseKahan >;
207232
208233 //  Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
209234 using  BlockLoadT = cub::
@@ -350,14 +375,30 @@ void ScanKernel(const Context& dev_ctx,
350375 }
351376 }
352377
378+  //  When scan_size is large, switch to Kahan scan to get better precision
379+  constexpr  int64_t  KAHAN_SWITCH_LENGTH = 1  << 16 ;
380+ 
353381 //  Do scan
354382 if  (!transpose && !reverse) {
355-  BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
356-  out_data, in_data, grid_size, scan_size, exclusive, op);
357- 
383+  if  (scan_size > KAHAN_SWITCH_LENGTH) {
384+  BlockScanKernel<T, 128 , 4 , Op, true >
385+  <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
386+  out_data, in_data, grid_size, scan_size, exclusive, op);
387+  } else  {
388+  BlockScanKernel<T, 128 , 4 , Op, false >
389+  <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
390+  out_data, in_data, grid_size, scan_size, exclusive, op);
391+  }
358392 } else  {
359-  BlockScanKernel<T, 128 , 4 , Op><<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
360-  next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
393+  if  (scan_size > KAHAN_SWITCH_LENGTH) {
394+  BlockScanKernel<T, 128 , 4 , Op, true >
395+  <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
396+  next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
397+  } else  {
398+  BlockScanKernel<T, 128 , 4 , Op, false >
399+  <<<scan_grid, 128 , 0 , dev_ctx.stream()>>> (
400+  next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
401+  }
361402 }
362403 swap_ptr (next_in_data, next_out_data);
363404
0 commit comments