Skip to content

Commit 05a1ccc

Browse files
committed
fix BlockPrefixCallbackOp
1 parent 1e11628 commit 05a1ccc

File tree

1 file changed

+55
-14
lines changed

1 file changed

+55
-14
lines changed

paddle/phi/kernels/gpu/cum_kernel.cu

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,36 @@ struct Identity<T, ComplexSum> {
128128
static constexpr T value = {0, 0};
129129
};
130130

131+
template <typename T, typename Op, bool UseKahan>
132+
struct BlockPrefixCallbackOp;
133+
131134
template <typename T, typename Op>
132-
struct BlockPrefixCallbackOp {
135+
struct BlockPrefixCallbackOp<T, Op, false> {
133136
// Running prefix
134137
T running_total_;
135-
T compensation_;
136138
Op op_;
137139

138140
__device__ BlockPrefixCallbackOp(T identity, Op op)
139-
: running_total_(identity), compensation_(identity), op_(op) {}
141+
: running_total_(identity), op_(op) {}
140142

141143
// Callback operator to be entered by the first warp of threads in the block.
142144
// tid 0 is responsible for returning a value for seeding the block-wide scan.
145+
__device__ T operator()(T block_aggregate) {
146+
const T old_prefix = running_total_;
147+
running_total_ = op_(running_total_, block_aggregate);
148+
return old_prefix;
149+
}
150+
};
151+
152+
template <typename T, typename Op>
153+
struct BlockPrefixCallbackOp<T, Op, true> {
154+
T running_total_;
155+
T compensation_;
156+
Op op_;
157+
158+
__device__ BlockPrefixCallbackOp(T identity, Op op)
159+
: running_total_(identity), compensation_(static_cast<T>(0.0)), op_(op) {}
160+
143161
__device__ T operator()(T block_aggregate) {
144162
T old_prefix = running_total_;
145163

@@ -155,20 +173,23 @@ struct BlockPrefixCallbackOp {
155173
};
156174

157175
template <typename T>
158-
struct BlockPrefixCallbackOp<T, LogAddExp> {
176+
struct BlockPrefixCallbackOp<T, LogAddExp, true> {
159177
T max_so_far_;
160178
T scaled_sum_;
161179
T compensation_;
162180
LogAddExp op_;
163181

164182
__device__ BlockPrefixCallbackOp(T identity, LogAddExp op)
165-
: max_so_far_(identity), scaled_sum_(0.0), compensation_(0.0), op_(op) {}
183+
: max_so_far_(identity),
184+
scaled_sum_(static_cast<T>(0.0)),
185+
compensation_(static_cast<T>(0.0)),
186+
op_(op) {}
166187

167188
__device__ T operator()(T block_aggregate) {
168189
if (scaled_sum_ == 0.0) {
169190
max_so_far_ = block_aggregate;
170-
scaled_sum_ = 1.0;
171-
compensation_ = 0.0;
191+
scaled_sum_ = static_cast<T>(1.0);
192+
compensation_ = static_cast<T>(0.0);
172193
return std::numeric_limits<T>::lowest();
173194
}
174195

@@ -195,15 +216,19 @@ struct BlockPrefixCallbackOp<T, LogAddExp> {
195216
}
196217
};
197218

198-
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
219+
template <typename T,
220+
int BLOCK_THREADS,
221+
int ITEMS_PER_THREAD,
222+
typename Op,
223+
bool UseKahan>
199224
__global__ void BlockScanKernel(T* d_out,
200225
const T* d_in,
201226
int64_t grid_size,
202227
int64_t scan_size,
203228
bool exclusive,
204229
Op op) {
205230
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
206-
using CallbackOp = BlockPrefixCallbackOp<MT, Op>;
231+
using CallbackOp = BlockPrefixCallbackOp<MT, Op, UseKahan>;
207232

208233
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
209234
using BlockLoadT = cub::
@@ -350,14 +375,30 @@ void ScanKernel(const Context& dev_ctx,
350375
}
351376
}
352377

378+
// When scan_size is large, switch to Kahan scan to get better precision
379+
constexpr int64_t KAHAN_SWITCH_LENGTH = 1 << 16;
380+
353381
// Do scan
354382
if (!transpose && !reverse) {
355-
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
356-
out_data, in_data, grid_size, scan_size, exclusive, op);
357-
383+
if (scan_size > KAHAN_SWITCH_LENGTH) {
384+
BlockScanKernel<T, 128, 4, Op, true>
385+
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(
386+
out_data, in_data, grid_size, scan_size, exclusive, op);
387+
} else {
388+
BlockScanKernel<T, 128, 4, Op, false>
389+
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(
390+
out_data, in_data, grid_size, scan_size, exclusive, op);
391+
}
358392
} else {
359-
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
360-
next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
393+
if (scan_size > KAHAN_SWITCH_LENGTH) {
394+
BlockScanKernel<T, 128, 4, Op, true>
395+
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(
396+
next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
397+
} else {
398+
BlockScanKernel<T, 128, 4, Op, false>
399+
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(
400+
next_out_data, next_in_data, grid_size, scan_size, exclusive, op);
401+
}
361402
}
362403
swap_ptr(next_in_data, next_out_data);
363404

0 commit comments

Comments
 (0)