Hello,
I am currently implementing a wrapper function to call the cudnnMultiHeadAttnForward() API provided by cuDNN. However, after extensive testing, I consistently encounter a parameter error:
cudnnMultiHeadAttnForward failed: CUDNN_STATUS_BAD_PARAM I am unsure whether the issue comes from my API usage or from how I am passing the parameters.
Below is the implementation of my wrapper function:
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCudnnMultiHeadAttention( int num_heads, int embed_dim, int max_seq_len_q, int max_seq_len_kv, int batch_size, int curr_idx, const int* lo_win_idx, const int* hi_win_idx, const void* dev_seq_lengths_qo, const void* dev_seq_lengths_kv, const void* queries, const void* residuals, const void* keys, const void* values, void* output, size_t weight_size_bytes, const void* weights, size_t workspace_size_bytes, CUstream stream ) { mgpuEnsureContext(); StreamHandles handles; if (!getHandlesForStream(stream, handles)) { fprintf(stderr, "[MHA] ERROR: Failed to get handles for stream %p\n", stream); return; } cudnnHandle_t handle = handles.cudnn_handle; if (num_heads <= 0 || embed_dim <= 0 || max_seq_len_q <= 0 || max_seq_len_kv <= 0 || batch_size <= 0) { fprintf(stderr, "[MHA] ERROR: Invalid dimensions: heads=%d, embed=%d, seq_q=%d, seq_kv=%d, batch=%d\n", num_heads, embed_dim, max_seq_len_q, max_seq_len_kv, batch_size); return; } if (!queries || !keys || !values || !output) { fprintf(stderr, "[MHA] ERROR: One or more data pointers are NULL\n"); return; } if (!weights && weight_size_bytes > 0) { fprintf(stderr, "[MHA] ERROR: Weights pointer is NULL but weight size > 0\n"); return; } if (embed_dim % num_heads != 0) { fprintf(stderr, "[MHA] ERROR: embed_dim (%d) must be divisible by num_heads (%d)\n", embed_dim, num_heads); return; } cudnnAttnDescriptor_t attn_desc = nullptr; cudnnStatus_t status = cudnnCreateAttnDescriptor(&attn_desc); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to create attention descriptor: %s\n", cudnnGetErrorString(status)); return; } int head_dim = embed_dim / num_heads; double sm_scaler = 1.0f / sqrtf((double)head_dim); unsigned attnMode = 0; attnMode |= CUDNN_ATTN_DISABLE_PROJ_BIASES; attnMode |= CUDNN_ATTN_QUERYMAP_ALL_TO_ONE; // attnMode |= CUDNN_ATTN_QUERYMAP_ONE_TO_ONE; status = cudnnSetAttnDescriptor( attn_desc, attnMode, // attnMode num_heads, // nHeads sm_scaler, // smScaler CUDNN_DATA_FLOAT, // dataType CUDNN_DATA_FLOAT, // computePrec CUDNN_DEFAULT_MATH, // mathType nullptr, // attnDropoutDesc nullptr, // postDropoutDesc embed_dim, // qSize embed_dim, // kSize embed_dim, // vSize 0, // qProjSize 0, // kProjSize 0, // vProjSize 0, // oProjSize max_seq_len_q, // qoMaxSeqLength max_seq_len_kv, // kvMaxSeqLength batch_size, // maxBatchSize 1 ); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to set attention descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroyAttnDescriptor(attn_desc); return; } cudnnSeqDataDescriptor_t q_desc = nullptr, k_desc = nullptr; cudnnSeqDataDescriptor_t v_desc = nullptr, o_desc = nullptr; status = cudnnCreateSeqDataDescriptor(&q_desc); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to create Q descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroyAttnDescriptor(attn_desc); return; } int q_dims[CUDNN_SEQDATA_DIM_COUNT]; q_dims[CUDNN_SEQDATA_BATCH_DIM] = batch_size; q_dims[CUDNN_SEQDATA_TIME_DIM] = max_seq_len_q; q_dims[CUDNN_SEQDATA_BEAM_DIM] = 1; q_dims[CUDNN_SEQDATA_VECT_DIM] = embed_dim; std::vector<int> q_seq_lengths(batch_size, max_seq_len_q); cudnnSeqDataAxis_t q_axes[CUDNN_SEQDATA_DIM_COUNT]; q_axes[0] = CUDNN_SEQDATA_BATCH_DIM; q_axes[1] = CUDNN_SEQDATA_BEAM_DIM; q_axes[2] = CUDNN_SEQDATA_TIME_DIM; q_axes[3] = CUDNN_SEQDATA_VECT_DIM; status = cudnnSetSeqDataDescriptor( q_desc, CUDNN_DATA_FLOAT, 4, q_dims, q_axes, batch_size, // seqLengthArraySize = batch_size * beam_size q_seq_lengths.data(), // seqLengthArray nullptr // paddingFill ); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to set Q descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } status = cudnnCreateSeqDataDescriptor(&k_desc); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to create K descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } int k_dims[CUDNN_SEQDATA_DIM_COUNT]; k_dims[CUDNN_SEQDATA_BATCH_DIM] = batch_size; k_dims[CUDNN_SEQDATA_TIME_DIM] = max_seq_len_kv; k_dims[CUDNN_SEQDATA_BEAM_DIM] = 1; k_dims[CUDNN_SEQDATA_VECT_DIM] = embed_dim; std::vector<int> kv_seq_lengths(batch_size, max_seq_len_kv); cudnnSeqDataAxis_t kv_axes[CUDNN_SEQDATA_DIM_COUNT]; kv_axes[0] = CUDNN_SEQDATA_BATCH_DIM; kv_axes[1] = CUDNN_SEQDATA_BEAM_DIM; kv_axes[2] = CUDNN_SEQDATA_TIME_DIM; kv_axes[3] = CUDNN_SEQDATA_VECT_DIM; status = cudnnSetSeqDataDescriptor( k_desc, CUDNN_DATA_FLOAT, 4, k_dims, kv_axes, batch_size, // seqLengthArraySize = batch_size * beam_size kv_seq_lengths.data(), // seqLengthArray nullptr // paddingFill ); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to set K descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(k_desc); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } status = cudnnCreateSeqDataDescriptor(&v_desc); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to create V descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(k_desc); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } status = cudnnSetSeqDataDescriptor( v_desc, CUDNN_DATA_FLOAT, 4, k_dims, kv_axes, batch_size, kv_seq_lengths.data(), nullptr // paddingFill ); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to set V descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(v_desc); cudnnDestroySeqDataDescriptor(k_desc); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } status = cudnnCreateSeqDataDescriptor(&o_desc); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to create O descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(v_desc); cudnnDestroySeqDataDescriptor(k_desc); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } status = cudnnSetSeqDataDescriptor( o_desc, CUDNN_DATA_FLOAT, 4, q_dims, q_axes, batch_size, q_seq_lengths.data(), nullptr ); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: Failed to set O descriptor: %s\n", cudnnGetErrorString(status)); cudnnDestroySeqDataDescriptor(o_desc); cudnnDestroySeqDataDescriptor(v_desc); cudnnDestroySeqDataDescriptor(k_desc); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } void* workspace = nullptr; bool using_workspace_pool = false; bool workspace_allocated = false; size_t actual_workspace_size = workspace_size_bytes; if (workspace_size_bytes == 0) { size_t estimated_size = (size_t)batch_size * max_seq_len_q * max_seq_len_kv * num_heads * sizeof(float); actual_workspace_size = std::max(estimated_size * 2, (size_t)(32 * 1024 * 1024)); // fprintf(stderr, "[MHA] Using estimated workspace size: %.2f MB\n", // actual_workspace_size / (1024.0 * 1024.0)); } if (actual_workspace_size > 0) { workspace = acquirePooledWorkspace(actual_workspace_size, stream, TENSOR_CORE_ALIGNMENT); if (workspace != nullptr) { using_workspace_pool = true; } else { CUdeviceptr workspace_ptr = allocateAlignedMemory(actual_workspace_size, TENSOR_CORE_ALIGNMENT); if (workspace_ptr != 0) { workspace = reinterpret_cast<void*>(workspace_ptr); workspace_allocated = true; fprintf(stderr, "[MHA] Using dynamic workspace (%.2f MB)\n", actual_workspace_size / (1024.0 * 1024.0)); } else { fprintf(stderr, "[MHA] ERROR: Failed to allocate workspace of %.2f MB\n", actual_workspace_size / (1024.0 * 1024.0)); cudnnDestroySeqDataDescriptor(o_desc); cudnnDestroySeqDataDescriptor(v_desc); cudnnDestroySeqDataDescriptor(k_desc); cudnnDestroySeqDataDescriptor(q_desc); cudnnDestroyAttnDescriptor(attn_desc); return; } } } const int* dev_seq_qo = static_cast<const int*>(dev_seq_lengths_qo); const int* dev_seq_kv = static_cast<const int*>(dev_seq_lengths_kv); status = cudnnMultiHeadAttnForward( handle, attn_desc, curr_idx, lo_win_idx, hi_win_idx, dev_seq_qo, dev_seq_kv, q_desc, queries, residuals, k_desc, keys, v_desc, values, o_desc, output, weight_size_bytes, weights, actual_workspace_size, workspace, 0, nullptr ); if (status != CUDNN_STATUS_SUCCESS) { fprintf(stderr, "[MHA] ERROR: cudnnMultiHeadAttnForward failed: %s\n", cudnnGetErrorString(status)); } if (o_desc) cudnnDestroySeqDataDescriptor(o_desc); if (v_desc) cudnnDestroySeqDataDescriptor(v_desc); if (k_desc) cudnnDestroySeqDataDescriptor(k_desc); if (q_desc) cudnnDestroySeqDataDescriptor(q_desc); if (attn_desc) cudnnDestroyAttnDescriptor(attn_desc); if (workspace_allocated && workspace) { CUresult result = cuMemFree(reinterpret_cast<CUdeviceptr>(workspace)); if (result != CUDA_SUCCESS) { fprintf(stderr, "[MHA] WARNING: Failed to free workspace\n"); } } } During actual testing, the debug information I receive is as follows:
I! CuDNN (v90501 17) function cudnnMultiHeadAttnForward() called:
i! handle: type=cudnnHandle_t; streamId=0x17609670;
i! attnDesc: type=cudnnAttnDescriptor_t:
i! attnMode: type=unsigned; val=CUDNN_ATTN_QUERYMAP_ONE_TO_ONE|CUDNN_ATTN_DISABLE_PROJ_BIASES (0x1);
i! nHeads: type=int; val=8;
i! smScaler: type=double; val=0.176777;
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathPrec: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! attnDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! postDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! qSize: type=int; val=256;
i! kSize: type=int; val=256;
i! vSize: type=int; val=256;
i! qProjSize: type=int; val=0;
i! kProjSize: type=int; val=0;
i! vProjSize: type=int; val=0;
i! oProjSize: type=int; val=0;
i! qoMaxSeqLength: type=int; val=256;
i! kvMaxSeqLength: type=int; val=32;
i! maxBatchSize: type=int; val=64;
i! maxBeamSize: type=int; val=1;
i! currIdx: type=int; val=-1;
i! loWinIdx: location=host; addr=0x7ffd4a348ad8;
i! hiWinIdx: location=host; addr=0x7ffd4a348ed8;
i! devSeqLengthsQO: location=dev; addr=0x7734c7000400;
i! devSeqLengthsKV: location=dev; addr=0x7734c7000600;
i! qDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=256;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=double; val=0;
i! queries: location=dev; addr=0x7734dd200000;
i! residuals: location=dev; addr=NULL_PTR;
i! kDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=32;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=double; val=0;
i! keys: location=dev; addr=0x7734dd400000;
i! vDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=32;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=double; val=0;
i! values: location=dev; addr=0x7734dd600000;
i! oDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=256;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=double; val=0;
i! out: location=dev; addr=0x7734bc400000;
i! weightSizeInBytes: type=size_t; val=0;
i! weights: location=dev; addr=NULL_PTR;
i! workSpaceSizeInBytes: type=size_t; val=33554432;
i! workSpace: location=dev; addr=0x7734a2000000;
i! reserveSpaceSizeInBytes: type=size_t; val=0;
i! reserveSpace: location=dev; addr=NULL_PTR;
i! Time: 2025-09-28T21:37:40.822081 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=0; Handle=0x258f98a0; StreamId=0x17609670.
I! CuDNN (v90501 17) function cudnnSetAttnDescriptor() called:
i! attnMode: type=unsigned; val=CUDNN_ATTN_QUERYMAP_ONE_TO_ONE|CUDNN_ATTN_DISABLE_PROJ_BIASES (0x1);
i! nHeads: type=int; val=8;
i! smScaler: type=double; val=0.176777;
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! computePrec: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! attnDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! postDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! qSize: type=int; val=256;
i! kSize: type=int; val=256;
i! vSize: type=int; val=256;
i! qProjSize: type=int; val=0;
i! kProjSize: type=int; val=0;
i! vProjSize: type=int; val=0;
i! oProjSize: type=int; val=0;
i! qoMaxSeqLength: type=int; val=256;
i! kvMaxSeqLength: type=int; val=32;
i! maxBatchSize: type=int; val=64;
i! maxBeamSize: type=int; val=1;
i! Time: 2025-09-28T21:37:40.821349 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.
I! CuDNN (v90501 17) function cudnnSetSeqDataDescriptor() called:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[256,64,1,256];
i! : type=int; val=[1,2,0,3];
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=CUDNN_DATA_FLOAT; val=NULL_PTR;
i! Time: 2025-09-28T21:37:40.821496 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.I! CuDNN (v90501 17) function cudnnSetSeqDataDescriptor() called:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[32,64,1,256];
i! : type=int; val=[1,2,0,3];
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=CUDNN_DATA_FLOAT; val=NULL_PTR;
i! Time: 2025-09-28T21:37:40.821593 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.
I would appreciate any guidance on which part of my API usage or parameter setup might be incorrect.
Thank you in advance for your help!