Skip to content

Commit fd2389f

Browse files
【PIR OpTest Fix No.26】 fix test_fused_token_prune_op (#62644)
* fix test_fused_token_prune_op * fix codestyle * fix codestyle2 * Update paddle/fluid/pir/dialect/operator/ir/ops.yaml Co-authored-by: kangguangli <kangguangli@hotmail.com> * Update paddle/phi/api/yaml/op_compat.yaml Co-authored-by: kangguangli <kangguangli@hotmail.com> * fix name --------- Co-authored-by: kangguangli <kangguangli@hotmail.com>
1 parent edc3bfc commit fd2389f

File tree

7 files changed

+106
-0
lines changed

7 files changed

+106
-0
lines changed

paddle/fluid/pir/dialect/op_generator/ops_api_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
'fused_elemwise_add_activation',
149149
'fused_scale_bias_relu_conv_bn',
150150
'fused_scale_bias_add_relu',
151+
'fused_token_prune',
151152
'fused_dconv_drelu_dbn',
152153
'fused_dot_product_attention',
153154
'nce',

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,14 @@
805805
func : fused_softmax_mask_upper_triangle
806806
backward: fused_softmax_mask_upper_triangle_grad
807807

808+
- op : fused_token_prune
809+
args : (Tensor attn, Tensor x, Tensor mask, Tensor new_mask, bool keep_first_token = true, bool keep_order = false)
810+
output : Tensor(slimmed_x), Tensor(cls_inds)
811+
infer_meta :
812+
func : FusedTokenPruneInferMeta
813+
kernel:
814+
func : fused_token_prune
815+
808816
- op : gaussian
809817
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
810818
output: Tensor(out)

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ const std::unordered_set<std::string> LegacyOpList = {
4545
FtrlOp::name(),
4646
FusedElemwiseAddActivationOp::name(),
4747
FusedElemwiseAddActivationGradOp::name(),
48+
FusedTokenPruneOp::name(),
4849
DpsgdOp::name(),
4950
SendV2Op::name(),
5051
RecvV2Op::name(),

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3712,6 +3712,12 @@
37123712
outputs :
37133713
{out : Out}
37143714

3715+
- op: fused_token_prune
3716+
inputs :
3717+
{attn: Attn, x: X, mask: Mask, new_mask: NewMask}
3718+
outputs :
3719+
{slimmed_x : SlimmedX, cls_inds : CLSInds}
3720+
37153721
- op: fusion_squared_mat_sub
37163722
inputs :
37173723
x : X

paddle/phi/infermeta/multiary.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4584,6 +4584,86 @@ void FusedRopeInferMeta(const MetaTensor& q,
45844584
}
45854585
}
45864586

4587+
void FusedTokenPruneInferMeta(const MetaTensor& attn,
4588+
const MetaTensor& x,
4589+
const MetaTensor& mask,
4590+
const MetaTensor& new_mask,
4591+
bool keep_first_token,
4592+
bool keep_order,
4593+
MetaTensor* slimmed_x,
4594+
MetaTensor* cls_inds) {
4595+
auto mask_dim = mask.dims();
4596+
auto attn_dim = attn.dims();
4597+
auto x_dim = x.dims();
4598+
auto new_mask_dim = new_mask.dims();
4599+
4600+
PADDLE_ENFORCE_EQ(
4601+
mask_dim.size(),
4602+
4,
4603+
phi::errors::InvalidArgument("The input mask must be 4-dimension"));
4604+
PADDLE_ENFORCE_EQ(
4605+
attn_dim.size(),
4606+
4,
4607+
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
4608+
PADDLE_ENFORCE_EQ(
4609+
x_dim.size(),
4610+
3,
4611+
phi::errors::InvalidArgument("The input x must be 4-dimension"));
4612+
PADDLE_ENFORCE_EQ(
4613+
new_mask_dim.size(),
4614+
4,
4615+
phi::errors::InvalidArgument("The input attn must be 4-dimension"));
4616+
PADDLE_ENFORCE_EQ(mask_dim[0],
4617+
attn_dim[0],
4618+
phi::errors::InvalidArgument(
4619+
"The first dim of mask and attn should be the same"
4620+
"which is batch size"));
4621+
PADDLE_ENFORCE_EQ(mask_dim[1],
4622+
attn_dim[1],
4623+
phi::errors::InvalidArgument(
4624+
"The second dim of mask and attn should be the same"
4625+
"which is nb_head"));
4626+
PADDLE_ENFORCE_EQ(mask_dim[0],
4627+
x_dim[0],
4628+
phi::errors::InvalidArgument(
4629+
"The first dim of mask and x should be the same"
4630+
"which is batch size"));
4631+
PADDLE_ENFORCE_EQ(
4632+
mask_dim[2],
4633+
mask_dim[3],
4634+
phi::errors::InvalidArgument(
4635+
"The third dim and the fourth dim of mask should be the same"
4636+
"which is max seq len"));
4637+
PADDLE_ENFORCE_EQ(
4638+
attn_dim[2],
4639+
attn_dim[3],
4640+
phi::errors::InvalidArgument(
4641+
"The third dim and the fourth dim of mask should be the same"
4642+
"which is max seq len"));
4643+
PADDLE_ENFORCE_EQ(attn_dim[2],
4644+
mask_dim[2],
4645+
phi::errors::InvalidArgument(
4646+
"The third dim of mask and attn should be the same"
4647+
"which is max seq len"));
4648+
PADDLE_ENFORCE_EQ(attn_dim[2],
4649+
x_dim[1],
4650+
phi::errors::InvalidArgument(
4651+
"The third dim of mask and the second dim of attn"
4652+
"should be the same which is max seq len"));
4653+
4654+
auto bsz = mask_dim[0];
4655+
auto c = x_dim[2];
4656+
auto slim_seq_len = new_mask_dim[2];
4657+
4658+
std::vector<int64_t> slimmed_x_dims({bsz, slim_seq_len, c});
4659+
slimmed_x->set_dims(common::make_ddim(slimmed_x_dims));
4660+
slimmed_x->set_dtype(x.dtype());
4661+
4662+
std::vector<int64_t> cls_inds_dims({bsz, slim_seq_len});
4663+
cls_inds->set_dims(common::make_ddim(cls_inds_dims));
4664+
cls_inds->set_dtype(phi::DataType::INT64);
4665+
}
4666+
45874667
void MoeInferMeta(const MetaTensor& x,
45884668
const MetaTensor& gate,
45894669
const MetaTensor& bmm0,

paddle/phi/infermeta/multiary.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,15 @@ void FusedRopeInferMeta(const MetaTensor& q,
908908
MetaTensor* out_k,
909909
MetaTensor* out_v);
910910

911+
void FusedTokenPruneInferMeta(const MetaTensor& attn,
912+
const MetaTensor& x,
913+
const MetaTensor& mask,
914+
const MetaTensor& new_mask,
915+
bool keep_first_token,
916+
bool keep_order,
917+
MetaTensor* slimmed_x,
918+
MetaTensor* cls_inds);
919+
911920
void MultiheadMatmulInferMeta(const MetaTensor& input,
912921
const MetaTensor& w,
913922
const MetaTensor& bias,

test/white_list/pir_op_test_white_list

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ test_fused_fc_elementwise_layernorm_op
116116
test_fused_feedforward_op
117117
test_fused_gate_attention_op
118118
test_fused_multihead_matmul_op
119+
test_fused_token_prune_op
119120
test_fusion_seqexpand_concat_fc_op
120121
test_fusion_transpose_flatten_concat_op
121122
test_gather_nd_op

0 commit comments

Comments
 (0)