Skip to content

Commit 80f301e

Browse files
committed
fix_qkv_plugin: half_scale
1 parent 81cfbdd commit 80f301e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ int QkvToContextPluginDynamic::enqueue(
347347
platform::CUDAPlace(device_id)));
348348

349349
int n_q = seq_len * head_number_ * head_size_ * batch;
350-
constexpr int threads = 128;
351-
int blocks = (n_q + threads - 1) / threads;
350+
int threads = head_number_ * head_size_ * batch;
351+
int blocks = seq_len;
352352

353353
apply_scale<<<blocks, threads, 0, stream>>>(tptr, static_cast<half>(scale_),
354354
n_q);

0 commit comments

Comments
 (0)