@@ -23,6 +23,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
2323 const int64_t *input_data,
2424 const int *cum_offsets,
2525 const int *seq_lens,
26+ const int64_t *draft_tokens,
27+ const int *seq_lens_encoder,
28+ const int max_draft_tokens,
2629 const int max_seq_len) {
2730 // get padding offset of each batch
2831 const int bi = blockIdx .x ;
@@ -31,8 +34,18 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
3134 for (int i = ti; i < seq_lens[bi]; i += blockDim .x ) {
3235 padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
3336 const int tgt_seq_id = bi * max_seq_len - cum_offset + i;
34- const int src_seq_id = bi * max_seq_len + i;
35- output_data[tgt_seq_id] = input_data[src_seq_id];
37+ if (draft_tokens == nullptr ) {
38+ const int src_seq_id = bi * max_seq_len + i;
39+ output_data[tgt_seq_id] = input_data[src_seq_id];
40+ } else { // speculative decoding
41+ if (seq_lens_encoder[bi] > 0 ) {
42+ const int src_seq_id = bi * max_seq_len + i;
43+ output_data[tgt_seq_id] = input_data[src_seq_id];
44+ } else {
45+ const int src_seq_id = bi * max_draft_tokens + i;
46+ output_data[tgt_seq_id] = draft_tokens[src_seq_id];
47+ }
48+ }
3649 }
3750 if (ti == 0 ) {
3851 if (bi == 0 ) {
@@ -50,7 +63,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
5063std::vector<paddle::Tensor> GetPaddingOffsetV2 (const paddle::Tensor& input_ids,
5164 const paddle::Tensor& cum_offsets,
5265 const paddle::Tensor& token_num,
53- const paddle::Tensor& seq_len) {
66+ const paddle::Tensor& seq_len,
67+ const paddle::optional<paddle::Tensor>& draft_tokens,
68+ const paddle::optional<paddle::Tensor>& seq_lens_encoder) {
5469 auto cu_stream = input_ids.stream ();
5570 std::vector<int64_t > input_ids_shape = input_ids.shape ();
5671 const int bsz = seq_len.shape ()[0 ];
@@ -65,23 +80,46 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
6580 auto cu_seqlens_q = GetEmptyTensor ({bsz + 1 }, paddle::DataType::INT32, input_ids.place ());
6681 auto cu_seqlens_k = GetEmptyTensor ({bsz + 1 }, paddle::DataType::INT32, input_ids.place ());
6782
68- GetPaddingOffsetV2Kernel<<<bsz, 128 , 0 , cu_stream>>> (
69- padding_offset.data <int >(),
70- cum_offsets_out.data <int >(),
71- cu_seqlens_q.data <int >(),
72- cu_seqlens_k.data <int >(),
73- x_remove_padding.data <int64_t >(),
74- input_ids.data <int64_t >(),
75- cum_offsets.data <int >(),
76- seq_len.data <int >(),
77- seq_length);
83+ int max_draft_tokens = 0 ;
84+ if (draft_tokens) { // speculative decoding
85+ max_draft_tokens = draft_tokens.get ().shape ()[1 ];
86+ GetPaddingOffsetV2Kernel<<<bsz, 128 , 0 , cu_stream>>> (
87+ padding_offset.data <int >(),
88+ cum_offsets_out.data <int >(),
89+ cu_seqlens_q.data <int >(),
90+ cu_seqlens_k.data <int >(),
91+ x_remove_padding.data <int64_t >(),
92+ input_ids.data <int64_t >(),
93+ cum_offsets.data <int >(),
94+ seq_len.data <int >(),
95+ draft_tokens.get_ptr ()->data <int64_t >(),
96+ seq_lens_encoder.get_ptr ()->data <int >(),
97+ max_draft_tokens,
98+ seq_length);
99+ } else {
100+ GetPaddingOffsetV2Kernel<<<bsz, 128 , 0 , cu_stream>>> (
101+ padding_offset.data <int >(),
102+ cum_offsets_out.data <int >(),
103+ cu_seqlens_q.data <int >(),
104+ cu_seqlens_k.data <int >(),
105+ x_remove_padding.data <int64_t >(),
106+ input_ids.data <int64_t >(),
107+ cum_offsets.data <int >(),
108+ seq_len.data <int >(),
109+ nullptr ,
110+ nullptr ,
111+ max_draft_tokens,
112+ seq_length);
113+ }
78114 return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
79115}
80116
81117std::vector<std::vector<int64_t >> GetPaddingOffsetV2InferShape (const std::vector<int64_t >& input_ids_shape,
82118 const std::vector<int64_t >& cum_offsets_shape,
83119 const std::vector<int64_t >& token_num_shape,
84- const std::vector<int64_t >& seq_len_shape) {
120+ const std::vector<int64_t >& seq_len_shape,
121+ const std::vector<int64_t >& draft_tokens_shape,
122+ const std::vector<int64_t >& seq_lens_encoder_shape) {
85123 int64_t bsz = seq_len_shape[0 ];
86124 int64_t seq_len = input_ids_shape[1 ];
87125 return {{-1 }, {bsz}, {-1 }, {bsz + 1 }, {bsz + 1 }};
@@ -90,12 +128,14 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector
90128std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype (const paddle::DataType& input_ids_dtype,
91129 const paddle::DataType& cum_offsets_dtype,
92130 const paddle::DataType& token_num_dtype,
93- const paddle::DataType& seq_len_dtype) {
131+ const paddle::DataType& seq_len_dtype,
132+ const paddle::DataType& draft_tokens_dtype,
133+ const paddle::DataType& seq_lens_encoder_dtype) {
94134 return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
95135}
96136
97137PD_BUILD_OP (get_padding_offset_v2)
98- .Inputs({" input_ids" , " cum_offsets" , " token_num" , " seq_len" })
138+ .Inputs({" input_ids" , " cum_offsets" , " token_num" , " seq_len" , paddle::Optional ( " draft_tokens " ), paddle::Optional ( " seq_lens_encoder " ), })
99139 .Outputs({" x_remove_padding" , " cum_offsets_out" , " padding_offset" , " cu_seqlens_q" , " cu_seqlens_k" })
100140 .SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
101141 .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
0 commit comments