Skip to content

Commit 4547242

Browse files
merge custom op and fix code style
1 parent a803b53 commit 4547242

File tree

23 files changed

+213
-429
lines changed

23 files changed

+213
-429
lines changed

csrc/cpu/src/stop_generation_multi_ends.cc

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,9 @@
1515
#include <stdlib.h>
1616
#include <string.h>
1717

18-
#include "paddle/extension.h"
18+
#include "helper.h"
1919
#include <stdio.h>
2020

21-
22-
bool is_in_end(const int64_t id, const int64_t* end_ids, int length) {
23-
bool flag = false;
24-
for (int i = 0; i < length; i++) {
25-
if (id == end_ids[i]) {
26-
return true;
27-
}
28-
}
29-
return flag;
30-
}
31-
3221
void set_value_by_flags(const bool* stop_flags,
3322
const int64_t* end_ids,
3423
int64_t* topk_ids,

csrc/gpu/get_output.cc

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,38 @@ struct msgdata {
2828
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
2929
};
3030

31-
void GetOutputFunc(const paddle::Tensor& x,
31+
struct SpeculateMsgData {
32+
long mtype;
33+
int mtext[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2]; // stop_flag, bsz, tokens
34+
};
35+
36+
static struct msgdata msg_rcv;
37+
static struct SpeculateMsgData specu_msg_rcv;
38+
39+
void GetOutput(const paddle::Tensor& x,
3240
int64_t rank_id,
33-
bool wait_flag) {
41+
bool wait_flag,
42+
bool speculative_decoding) {
3443
if (rank_id > 0) return;
3544

36-
static struct msgdata msg_rcv;
37-
3845
static key_t key = ftok("./", 1);
3946

4047
static int msgid = msgget(key, IPC_CREAT | 0666);
4148

4249
int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
4350
int ret = -1;
4451
if (!wait_flag) {
45-
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
52+
if (!speculative_decoding) {
53+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
54+
} else {
55+
ret = msgrcv(msgid, &specu_msg_rcv, (SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
56+
}
4657
} else {
47-
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
58+
if (!speculative_decoding) {
59+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
60+
} else{
61+
ret = msgrcv(msgid, &specu_msg_rcv, (SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2) * 4, 0, 0);
62+
}
4863
}
4964
if(ret == -1)
5065
{
@@ -54,61 +69,20 @@ void GetOutputFunc(const paddle::Tensor& x,
5469
return;
5570
}
5671

57-
int bsz = msg_rcv.mtext[1];
58-
59-
for (int64_t i = 0; i < bsz + 2; i++) {
60-
out_data[i] = (int64_t)msg_rcv.mtext[i];
61-
}
62-
return;
63-
}
64-
65-
struct SpeculateMsgData {
66-
long mtype;
67-
int mtext[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2]; // stop_flag, bsz, tokens
68-
};
69-
70-
71-
void SpeculateGetOutputFunc(const paddle::Tensor& x,
72-
int64_t rank_id,
73-
bool wait_flag) {
74-
if (rank_id > 0) {
75-
return;
76-
}
77-
static struct SpeculateMsgData msg_rcv;
78-
79-
static key_t key = ftok("./", 1);
8072

81-
static int msgid = msgget(key, IPC_CREAT | 0666);
82-
83-
int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
84-
int ret = -1;
85-
if (!wait_flag) {
86-
ret = msgrcv(msgid, &msg_rcv, (SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
87-
} else {
88-
ret = msgrcv(msgid, &msg_rcv, (SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2) * 4, 0, 0);
89-
}
90-
if(ret == -1) {
91-
out_data[0] = -2;
92-
out_data[1] = 0;
93-
return;
94-
}
73+
if (!speculative_decoding) {
9574
int bsz = msg_rcv.mtext[1];
96-
97-
for (int64_t i = 0; i < SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2; i++) {
98-
out_data[i] = (int64_t)msg_rcv.mtext[i];
75+
for (int64_t i = 0; i < bsz + 2; i++) {
76+
out_data[i] = (int64_t)msg_rcv.mtext[i];
9977
}
100-
return;
101-
}
102-
103-
void GetOutput(const paddle::Tensor& x,
104-
int64_t rank_id,
105-
bool wait_flag,
106-
bool speculative_decoding){
107-
if (speculative_decoding) {
108-
SpeculateGetOutputFunc(x, rank_id, wait_flag);
10978
} else {
110-
GetOutputFunc(x, rank_id, wait_flag);
79+
int bsz = specu_msg_rcv.mtext[1];
80+
for (int64_t i = 0; i < SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2; i++) {
81+
out_data[i] = (int64_t)specu_msg_rcv.mtext[i];
82+
}
11183
}
84+
85+
return;
11286
}
11387

11488
PD_BUILD_OP(get_output)

csrc/gpu/get_padding_offset_v2.cu

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
2323
const int64_t *input_data,
2424
const int *cum_offsets,
2525
const int *seq_lens,
26+
const int64_t *draft_tokens,
27+
const int *seq_lens_encoder,
28+
const int max_draft_tokens,
2629
const int max_seq_len) {
2730
// get padding offset of each batch
2831
const int bi = blockIdx.x;
@@ -31,8 +34,18 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
3134
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
3235
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
3336
const int tgt_seq_id = bi * max_seq_len - cum_offset + i;
34-
const int src_seq_id = bi * max_seq_len + i;
35-
output_data[tgt_seq_id] = input_data[src_seq_id];
37+
if (draft_tokens == nullptr) {
38+
const int src_seq_id = bi * max_seq_len + i;
39+
output_data[tgt_seq_id] = input_data[src_seq_id];
40+
} else { // speculative decoding
41+
if (seq_lens_encoder[bi] > 0) {
42+
const int src_seq_id = bi * max_seq_len + i;
43+
output_data[tgt_seq_id] = input_data[src_seq_id];
44+
} else {
45+
const int src_seq_id = bi * max_draft_tokens + i;
46+
output_data[tgt_seq_id] = draft_tokens[src_seq_id];
47+
}
48+
}
3649
}
3750
if (ti == 0) {
3851
if (bi == 0) {
@@ -50,7 +63,9 @@ __global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
5063
std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
5164
const paddle::Tensor& cum_offsets,
5265
const paddle::Tensor& token_num,
53-
const paddle::Tensor& seq_len) {
66+
const paddle::Tensor& seq_len,
67+
const paddle::optional<paddle::Tensor>& draft_tokens,
68+
const paddle::optional<paddle::Tensor>& seq_lens_encoder) {
5469
auto cu_stream = input_ids.stream();
5570
std::vector<int64_t> input_ids_shape = input_ids.shape();
5671
const int bsz = seq_len.shape()[0];
@@ -65,23 +80,46 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
6580
auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
6681
auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
6782

68-
GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
69-
padding_offset.data<int>(),
70-
cum_offsets_out.data<int>(),
71-
cu_seqlens_q.data<int>(),
72-
cu_seqlens_k.data<int>(),
73-
x_remove_padding.data<int64_t>(),
74-
input_ids.data<int64_t>(),
75-
cum_offsets.data<int>(),
76-
seq_len.data<int>(),
77-
seq_length);
83+
int max_draft_tokens = 0;
84+
if (draft_tokens) { // speculative decoding
85+
max_draft_tokens = draft_tokens.get().shape()[1];
86+
GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
87+
padding_offset.data<int>(),
88+
cum_offsets_out.data<int>(),
89+
cu_seqlens_q.data<int>(),
90+
cu_seqlens_k.data<int>(),
91+
x_remove_padding.data<int64_t>(),
92+
input_ids.data<int64_t>(),
93+
cum_offsets.data<int>(),
94+
seq_len.data<int>(),
95+
draft_tokens.get_ptr()->data<int64_t>(),
96+
seq_lens_encoder.get_ptr()->data<int>(),
97+
max_draft_tokens,
98+
seq_length);
99+
} else {
100+
GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
101+
padding_offset.data<int>(),
102+
cum_offsets_out.data<int>(),
103+
cu_seqlens_q.data<int>(),
104+
cu_seqlens_k.data<int>(),
105+
x_remove_padding.data<int64_t>(),
106+
input_ids.data<int64_t>(),
107+
cum_offsets.data<int>(),
108+
seq_len.data<int>(),
109+
nullptr,
110+
nullptr,
111+
max_draft_tokens,
112+
seq_length);
113+
}
78114
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
79115
}
80116

81117
std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector<int64_t>& input_ids_shape,
82118
const std::vector<int64_t>& cum_offsets_shape,
83119
const std::vector<int64_t>& token_num_shape,
84-
const std::vector<int64_t>& seq_len_shape) {
120+
const std::vector<int64_t>& seq_len_shape,
121+
const std::vector<int64_t>& draft_tokens_shape,
122+
const std::vector<int64_t>& seq_lens_encoder_shape) {
85123
int64_t bsz = seq_len_shape[0];
86124
int64_t seq_len = input_ids_shape[1];
87125
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
@@ -90,12 +128,14 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector
90128
std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(const paddle::DataType& input_ids_dtype,
91129
const paddle::DataType& cum_offsets_dtype,
92130
const paddle::DataType& token_num_dtype,
93-
const paddle::DataType& seq_len_dtype) {
131+
const paddle::DataType& seq_len_dtype,
132+
const paddle::DataType& draft_tokens_dtype,
133+
const paddle::DataType& seq_lens_encoder_dtype) {
94134
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
95135
}
96136

97137
PD_BUILD_OP(get_padding_offset_v2)
98-
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
138+
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len", paddle::Optional("draft_tokens"), paddle::Optional("seq_lens_encoder"),})
99139
.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
100140
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
101141
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))

0 commit comments

Comments
 (0)