@@ -4584,6 +4584,86 @@ void FusedRopeInferMeta(const MetaTensor& q,
45844584 }
45854585}
45864586
4587+ void FusedTokenPruneInferMeta (const MetaTensor& attn,
4588+ const MetaTensor& x,
4589+ const MetaTensor& mask,
4590+ const MetaTensor& new_mask,
4591+ bool keep_first_token,
4592+ bool keep_order,
4593+ MetaTensor* slimmed_x,
4594+ MetaTensor* cls_inds) {
4595+ auto mask_dim = mask.dims ();
4596+ auto attn_dim = attn.dims ();
4597+ auto x_dim = x.dims ();
4598+ auto new_mask_dim = new_mask.dims ();
4599+
4600+ PADDLE_ENFORCE_EQ (
4601+ mask_dim.size (),
4602+ 4 ,
4603+ phi::errors::InvalidArgument (" The input mask must be 4-dimension" ));
4604+ PADDLE_ENFORCE_EQ (
4605+ attn_dim.size (),
4606+ 4 ,
4607+ phi::errors::InvalidArgument (" The input attn must be 4-dimension" ));
4608+ PADDLE_ENFORCE_EQ (
4609+ x_dim.size (),
4610+ 3 ,
4611+ phi::errors::InvalidArgument (" The input x must be 4-dimension" ));
4612+ PADDLE_ENFORCE_EQ (
4613+ new_mask_dim.size (),
4614+ 4 ,
4615+ phi::errors::InvalidArgument (" The input attn must be 4-dimension" ));
4616+ PADDLE_ENFORCE_EQ (mask_dim[0 ],
4617+ attn_dim[0 ],
4618+ phi::errors::InvalidArgument (
4619+ " The first dim of mask and attn should be the same"
4620+ " which is batch size" ));
4621+ PADDLE_ENFORCE_EQ (mask_dim[1 ],
4622+ attn_dim[1 ],
4623+ phi::errors::InvalidArgument (
4624+ " The second dim of mask and attn should be the same"
4625+ " which is nb_head" ));
4626+ PADDLE_ENFORCE_EQ (mask_dim[0 ],
4627+ x_dim[0 ],
4628+ phi::errors::InvalidArgument (
4629+ " The first dim of mask and x should be the same"
4630+ " which is batch size" ));
4631+ PADDLE_ENFORCE_EQ (
4632+ mask_dim[2 ],
4633+ mask_dim[3 ],
4634+ phi::errors::InvalidArgument (
4635+ " The third dim and the fourth dim of mask should be the same"
4636+ " which is max seq len" ));
4637+ PADDLE_ENFORCE_EQ (
4638+ attn_dim[2 ],
4639+ attn_dim[3 ],
4640+ phi::errors::InvalidArgument (
4641+ " The third dim and the fourth dim of mask should be the same"
4642+ " which is max seq len" ));
4643+ PADDLE_ENFORCE_EQ (attn_dim[2 ],
4644+ mask_dim[2 ],
4645+ phi::errors::InvalidArgument (
4646+ " The third dim of mask and attn should be the same"
4647+ " which is max seq len" ));
4648+ PADDLE_ENFORCE_EQ (attn_dim[2 ],
4649+ x_dim[1 ],
4650+ phi::errors::InvalidArgument (
4651+ " The third dim of mask and the second dim of attn"
4652+ " should be the same which is max seq len" ));
4653+
4654+ auto bsz = mask_dim[0 ];
4655+ auto c = x_dim[2 ];
4656+ auto slim_seq_len = new_mask_dim[2 ];
4657+
4658+ std::vector<int64_t > slimmed_x_dims ({bsz, slim_seq_len, c});
4659+ slimmed_x->set_dims (common::make_ddim (slimmed_x_dims));
4660+ slimmed_x->set_dtype (x.dtype ());
4661+
4662+ std::vector<int64_t > cls_inds_dims ({bsz, slim_seq_len});
4663+ cls_inds->set_dims (common::make_ddim (cls_inds_dims));
4664+ cls_inds->set_dtype (phi::DataType::INT64);
4665+ }
4666+
45874667void MoeInferMeta (const MetaTensor& x,
45884668 const MetaTensor& gate,
45894669 const MetaTensor& bmm0,
0 commit comments