Skip to content

Commit 921c091

Browse files
Fix a bug in ReadData, ReadDataBc and ReadDataReduce when NX != 1 (#36373)
* Update the implement of reduceAnyKernel according to kernel primitive api * Fix a bug in ReadData, ReadDataBc and ReadDataReduce when NX != 1
1 parent 5eb640c commit 921c091

File tree

5 files changed

+286
-139
lines changed

5 files changed

+286
-139
lines changed

paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ __device__ __forceinline__ void LoadData(
171171
// num: how many data will be deal with in this time
172172
if (need_broadcast) {
173173
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(dst, src, block_offset,
174-
config, numel, 1, 1);
174+
config, numel);
175175
} else {
176176
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
177177
}

paddle/fluid/operators/fused/attn_bias_add.cu.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ __global__ void BroadcastKernelBinary(
7272
// load in0
7373
if (use_broadcast[0]) {
7474
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
75-
arg0, in0, fix, configlists[0], numel, 1, 1);
75+
arg0, in0, fix, configlists[0], numel);
7676
} else {
7777
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
7878
}
7979
// load in1
8080
if (use_broadcast[1]) {
8181
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
82-
arg1, in1, fix, configlists[1], numel, 1, 1);
82+
arg1, in1, fix, configlists[1], numel);
8383
} else {
8484
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
8585
}

paddle/fluid/operators/kernel_primitives/compute_primitives.h

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,16 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
135135
} // namespace details
136136

137137
/**
138-
* @brief Perform unary calculation according to OpFunc. Size of input and
138+
* @brief Perform unary calculation according to OpFunc. Shape of input and
139139
* output are the same.
140140
*
141141
* @template paraments
142-
* InT: Data type of in.
143-
* OutT: Data type of out.
142+
* InT: The data type of in.
143+
* OutT: The data type of out.
144144
* NX: The number of data columns loaded by each thread.
145145
* NY: The number of data rows loaded by each thread.
146146
* BlockSize: Identifies the current device thread index method. For GPU,
147-
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
148-
* the index. Currently only GPU was supported.
147+
* threadIdx.x is used as the thread index. Currently only GPU was supported.
149148
* OpFunc: Compute functor which has an operator() as following:
150149
* template <typename InT, typename OutT>
151150
* struct XxxFunctor {
@@ -170,21 +169,20 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
170169
}
171170

172171
/**
173-
* @brief Binary calculation according to OpFunc. Size of The input and output
172+
* @brief Binary calculation according to OpFunc. Shape of The input and output
174173
* are the same.
175174
*
176175
* @template paraments
177-
* InT: Data type of in1 and in2.
178-
* OutT: Data type of out.
179-
* NX: The number of data columns loaded by each thread.
180-
* NY: The number of data rows loaded by each thread.
176+
* InT: The data type of in1 and in2.
177+
* OutT: The data type of out.
178+
* NX: The number of data columns computed by each thread.
179+
* NY: The number of data rows computed by each thread.
181180
* BlockSize: Identifies the current device thread index method. For GPU,
182-
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
183-
* the index. Currently only GPU was supported.
181+
* threadIdx.x is used as the thread index. Currently only GPU was supported.
184182
* OpFunc: Compute functor which has an operator() as following:
185-
* template <typename InT, typename OutT>
183+
* template <typename InT>
186184
* struct XxxFunctor {
187-
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
185+
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
188186
* return ...;
189187
* }
190188
* };
@@ -193,7 +191,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
193191
* out: The register pointer of out, the size is NX * NY.
194192
* in1: The register pointer of fist input, size is NX * NY.
195193
* in2: The register pointer of second input, size is NX * NY.
196-
* compute: Compute function which was declared like OpFunc<InT, OutT>().
194+
* compute: Compute function which was declared like OpFunc<InT>().
197195
*/
198196
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
199197
class OpFunc>
@@ -207,21 +205,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
207205
}
208206

209207
/**
210-
* @brief Ternary calculation according to OpFunc. Size of input and output
208+
* @brief Ternary calculation according to OpFunc. Shape of input and output
211209
* are the same.
212210
*
213211
* @template paraments
214-
* InT: Data type of in1 and in2.
215-
* OutT: Data type of out.
212+
* InT: The data type of in1 and in2.
213+
* OutT: The data type of out.
216214
* NX: The number of data columns loaded by each thread.
217215
* NY: The number of data rows loaded by each thread.
218216
* BlockSize: Identifies the current device thread index method. For GPU,
219-
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
220-
* the index. Currently only GPU was supported.
217+
* threadIdx.x is used as the thread index. Currently only GPU was supported.
221218
* OpFunc: Compute functor which has an operator() as following
222-
* template <typename InT, typename OutT>
219+
* template <typename InT>
223220
* struct XxxFunctor {
224-
* HOSTDEVICE OutT operator()(const InT& a, const InT& b, const InT& c)
221+
* HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
225222
* const {
226223
* return ...;
227224
* }
@@ -232,7 +229,7 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
232229
* in1: The register pointer of fist input, size is NX * NY.
233230
* in2: The register pointer of second input, size is NX * NY.
234231
* in3: The register pointer of third input, size is NX * NY.
235-
* compute: Compute function which was declared like OpFunc<InT, OutT>().
232+
* compute: Compute function which was declared like OpFunc<InT>().
236233
*/
237234
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
238235
class OpFunc>
@@ -247,30 +244,29 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const InT* in1,
247244
}
248245

249246
/**
250-
* @brief Multivariate calculation according to OpFunc. Size of input and output
251-
* are the same.
247+
* @brief Multivariate calculation according to OpFunc. Shape of inputs and
248+
* output are the same.
252249
*
253250
* @template paraments
254-
* InT: Data type of in1, in2 and in3.
255-
* OutT: Data type of out.
251+
* InT: The data type of in1, in2 and in3.
252+
* OutT: The data type of out.
256253
* NX: The number of data columns loaded by each thread.
257254
* NY: The number of data rows loaded by each thread.
258255
* BlockSize: Identifies the current device thread index method. For GPU,
259-
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
260-
* the index. Currently only GPU was supported.
261-
* Arity: The size of ins
256+
* threadIdx.x is used as the thread index. Currently only GPU was supported.
257+
* Arity: The size of ins.
262258
* OpFunc: Compute functor which has an operator() as following:
263-
* template <typename InT, typename OutT>
259+
* template <typename InT>
264260
* struct XxxFunctor {
265-
* HOSTDEVICE OutT operator()(const InT* args) const {
261+
* HOSTDEVICE InT operator()(const InT* args) const {
266262
* return ...;
267263
* }
268264
* };
269265
*
270266
* @param
271267
* out: The register pointer of out, the size is NX * NY.
272-
* ins: An array of pointers consisting of multiple inputs.
273-
* compute: Compute function which was declared like OpFunc<InT, OutT>().
268+
* ins: A pointers of array consisting of multiple inputs.
269+
* compute: Compute function which was declared like OpFunc<InT>().
274270
*/
275271
template <typename InT, typename OutT, int NX, int NY, int BlockSize, int Arity,
276272
class OpFunc>
@@ -293,13 +289,12 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY],
293289
* shape is [NY, NX].
294290
*
295291
* @template paraments
296-
* InT: Data type of in1 and in2.
297-
* OutT: Data type of out.
292+
* InT: The data type of in1 and in2.
293+
* OutT: The data type of out.
298294
* NX: The number of data columns loaded by each thread.
299295
* NY: The number of data rows loaded by each thread.
300296
* BlockSize: Identifies the current device thread index method. For GPU,
301-
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
302-
* the index. Currently only GPU was supported.
297+
* threadIdx.x is used as the thread index. Currently only GPU was supported.
303298
* OpFunc: Compute functor which has an operator() as following
304299
* template <typename InT, typename OutT>
305300
* struct XxxFunctor {
@@ -339,8 +334,7 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const InT* in1,
339334
* NX: The number of data continuously loaded by each thread.
340335
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
341336
* BlockSize: Identifies the current device thread index method. For GPU,
342-
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
343-
* the index. Currently only GPU was supported.
337+
* threadIdx.x is used as the thread index. Currently only GPU was supported.
344338
* ReduceFunctor: Compute functor which has an operator() as following
345339
* template <typename InT>
346340
* struct ReduceFunctor {

0 commit comments

Comments
 (0)