Skip to content

Commit a375679

Browse files
committed
[Paddle-Inference] fix_qkv_plugin: fix half scale
1 parent 80f301e commit a375679

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ template <typename T>
229229
__global__ void apply_scale(T *data, T scale, int n) {
230230
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
231231
int tid = blockIdx.x * blockDim.x + threadIdx.x;
232-
data[tid] = data[tid] * scale;
232+
if (tid < n) {
233+
data[tid] = data[tid] * scale;
234+
}
233235
#endif
234236
}
235237

@@ -347,8 +349,8 @@ int QkvToContextPluginDynamic::enqueue(
347349
platform::CUDAPlace(device_id)));
348350

349351
int n_q = seq_len * head_number_ * head_size_ * batch;
350-
int threads = head_number_ * head_size_ * batch;
351-
int blocks = seq_len;
352+
constexpr int threads = 128;
353+
int blocks = (n_q + threads - 1) / threads;
352354

353355
apply_scale<<<blocks, threads, 0, stream>>>(tptr, static_cast<half>(scale_),
354356
n_q);

0 commit comments

Comments
 (0)