Skip to content

Commit 371e691

Browse files
committed
[XPU] support cross attention for decoder model
1 parent 813ccc5 commit 371e691

File tree

9 files changed

+1129
-0
lines changed

9 files changed

+1129
-0
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ if(WITH_XPU)
274274
${XPU_PASS_DEPS})
275275
pass_library(decoder_attention_xpu_fuse_pass inference DIR xpu DEPS
276276
${XPU_PASS_DEPS})
277+
pass_library(cross_attention_xpu_fuse_pass inference DIR xpu DEPS
278+
${XPU_PASS_DEPS})
277279
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
278280
${XPU_PASS_DEPS})
279281
pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu

paddle/fluid/framework/ir/xpu/cross_attention_xpu_fuse_pass.cc

Lines changed: 666 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
20+
namespace phi {
21+
class DenseTensor;
22+
} // namespace phi
23+
24+
namespace paddle {
25+
namespace framework {
26+
class Scope;
27+
} // namespace framework
28+
} // namespace paddle
29+
30+
namespace paddle {
31+
namespace framework {
32+
namespace ir {
33+
34+
/*
35+
This pass is used to fuse the cross attention op into one op in decoder.
36+
models .
37+
38+
Origin subgraph:
39+
40+
mask input_q input_kv
41+
| | | |
42+
| | |-----------|
43+
| matmul matmul matmul
44+
| |q |k |v
45+
| | | |
46+
| | | |
47+
| add add add
48+
| | | |
49+
| | | |
50+
| reshape reshape reshape
51+
| | | |
52+
| | | |
53+
| transpose transpose transpose
54+
| | | |
55+
| | | |
56+
| (scale) | |
57+
| | | |
58+
\ |(x) |(y) |
59+
\ \ / |
60+
\ qk_matmul |
61+
\ | |
62+
\ | |
63+
add /
64+
| /
65+
| /
66+
softmax /
67+
\ /
68+
\ /
69+
qkv_matmul
70+
|
71+
|
72+
transpose
73+
|
74+
|
75+
reshape
76+
|
77+
|
78+
output
79+
80+
-------------------------------------------------------
81+
Fused subgraph:
82+
input_q input_kv
83+
| |
84+
| |
85+
| |
86+
cross_attention_xpu
87+
|
88+
|
89+
|
90+
output
91+
92+
*/
93+
94+
class CrossAttentionXPUFusePass : public FusePassBase {
95+
protected:
96+
void ApplyImpl(ir::Graph* graph) const override;
97+
98+
private:
99+
void ApplyCrossAttentionXPUFuse(ir::Graph* graph, bool with_q_scale) const;
100+
101+
// 1. Generate q/k/v_w_max tensor
102+
// 2. Quant q/k/v_w to int16
103+
void PrepareQKVWeight(Graph* graph,
104+
Scope* scope,
105+
BlockDesc* block,
106+
Node* w,
107+
Node** real_w,
108+
Node** w_max) const;
109+
110+
// Cast fc_bias to fp32
111+
void PrepareQKVBias(Graph* graph,
112+
Scope* scope,
113+
BlockDesc* block,
114+
Node* q_bias,
115+
Node* k_bias,
116+
Node* v_bias,
117+
Node** real_q_bias,
118+
Node** real_k_bias,
119+
Node** real_v_bias) const;
120+
121+
const std::string name_scope_{"cross_attention_xpu_fuse_pass"};
122+
};
123+
124+
} // namespace ir
125+
} // namespace framework
126+
} // namespace paddle

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
545545
"multi_encoder_xpu_slice_fuse_pass",
546546
"fused_multi_transformer_cachekv_layout_trans_pass",
547547
"fused_multi_transformer_int8_cachekv_layout_trans_pass",
548+
"cross_attention_xpu_fuse_pass",
548549
"decoder_attention_xpu_fuse_pass",
549550
"one_beam_size_fuse_pass",
550551
"fold_interp_outsize_fuse_pass",

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@
8383
data_type : x
8484
optional : bias, branch, branch_max ,x_max, scale_max, out_max_in
8585

86+
- op : cross_attention_xpu
87+
args : (Tensor input_q, Tensor input_kv, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor mask, int head_num, int head_dim, float alpha, DataType out_dtype)
88+
output : Tensor(qkv), Tensor(qkv_max)
89+
infer_meta :
90+
func : CrossAttentionXPUInferMeta
91+
kernel :
92+
func : cross_attention_xpu
93+
data_type : input_q
94+
8695
- op : dequantize_xpu
8796
args : (Tensor x, DataType out_dtype, float scale = 1.0f)
8897
output : Tensor(y)

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,8 @@ XPUOpMap& get_kl2_ops() {
12251225
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
12261226
{"roformer_relative_embedding_xpu",
12271227
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
1228+
{"cross_attention_xpu",
1229+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
12281230
{"variable_length_memory_efficient_attention",
12291231
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
12301232
{"flash_attn_unpadded",

paddle/phi/infermeta/fusion.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,92 @@ void SinePosXPUInferMeta(const MetaTensor& x,
38163816
out->set_dtype(x.dtype());
38173817
}
38183818

3819+
void CrossAttentionXPUInferMeta(
3820+
const MetaTensor& input_q,
3821+
const MetaTensor& input_kv,
3822+
const std::vector<const MetaTensor*>& fc_weight,
3823+
const std::vector<const MetaTensor*>& fc_weight_max,
3824+
const std::vector<const MetaTensor*>& fc_bias,
3825+
const MetaTensor& mask,
3826+
int head_num,
3827+
int head_dim,
3828+
float alpha,
3829+
DataType out_dtype,
3830+
MetaTensor* qkv,
3831+
MetaTensor* qkv_max) {
3832+
auto input_q_dims = input_q.dims();
3833+
auto input_kv_dims = input_kv.dims();
3834+
auto mask_dims = mask.dims();
3835+
// input shape : {B, L, H*D}
3836+
PADDLE_ENFORCE_EQ(input_q_dims.size(),
3837+
3,
3838+
phi::errors::InvalidArgument(
3839+
"The dim of input_q should be 3! But received ",
3840+
input_q_dims.size()));
3841+
PADDLE_ENFORCE_EQ(input_kv_dims.size(),
3842+
3,
3843+
phi::errors::InvalidArgument(
3844+
"The dim of input_kv should be 3! But received ",
3845+
input_kv_dims.size()));
3846+
// sequece length of q and k/v not requied to be eqaul
3847+
// but batch size and dim should be the same
3848+
PADDLE_ENFORCE_EQ(
3849+
input_q_dims[0],
3850+
input_kv_dims[0],
3851+
phi::errors::InvalidArgument("The batch size of input_q and input_kv "
3852+
"should be the same! Received ",
3853+
input_q_dims[0],
3854+
" vs ",
3855+
input_kv_dims[0]));
3856+
PADDLE_ENFORCE_EQ(
3857+
input_q_dims[2],
3858+
input_kv_dims[2],
3859+
phi::errors::InvalidArgument("The hidden_dim of input_q and input_kv "
3860+
"should be the same! Received ",
3861+
input_q_dims[2],
3862+
" vs ",
3863+
input_kv_dims[2]));
3864+
int hidden_dim = head_num * head_dim;
3865+
PADDLE_ENFORCE_EQ(
3866+
input_q_dims[2],
3867+
hidden_dim,
3868+
phi::errors::InvalidArgument(
3869+
"The last dimension of input_q should be [H*D]! Received ",
3870+
input_q_dims[2],
3871+
" != expected ",
3872+
hidden_dim));
3873+
PADDLE_ENFORCE_EQ(fc_weight.size(),
3874+
3,
3875+
phi::errors::InvalidArgument(
3876+
"The size of fc_weight should be 3! But received ",
3877+
fc_weight.size()));
3878+
PADDLE_ENFORCE_EQ(fc_weight_max.size(),
3879+
3,
3880+
phi::errors::InvalidArgument(
3881+
"The size of fc_weight_max should be 3! But received ",
3882+
fc_weight_max.size()));
3883+
PADDLE_ENFORCE_EQ(
3884+
fc_bias.size(),
3885+
3,
3886+
phi::errors::InvalidArgument(
3887+
"The size of fc_bias should be 3! But received ", fc_bias.size()));
3888+
PADDLE_ENFORCE_EQ(
3889+
mask_dims.size(),
3890+
4,
3891+
phi::errors::InvalidArgument("The dim of mask should be 4! But received ",
3892+
mask_dims.size()));
3893+
3894+
// output shape: {B, qL, H*D}
3895+
qkv->set_dims(
3896+
phi::make_ddim({input_q_dims[0], input_q_dims[1], head_num * head_dim}));
3897+
qkv->set_dtype(out_dtype);
3898+
qkv->set_layout(input_q.layout());
3899+
int xpu2_max_value_num = phi::backends::xpu::get_xpu_max_ptr_size(-1);
3900+
qkv_max->set_dims(phi::make_ddim({xpu2_max_value_num}));
3901+
qkv_max->set_dtype(out_dtype);
3902+
qkv_max->set_layout(input_q.layout());
3903+
}
3904+
38193905
void MultiGruInferMeta(
38203906
const MetaTensor& x,
38213907
const std::vector<const MetaTensor*>& weight_x,

paddle/phi/infermeta/fusion.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,19 @@ void RoformerRelativePosXPUInferMeta(const MetaTensor& x,
877877
const MetaTensor& cos_emb,
878878
int max_pos_len,
879879
MetaTensor* out);
880+
void CrossAttentionXPUInferMeta(
881+
const MetaTensor& input_q,
882+
const MetaTensor& input_kv,
883+
const std::vector<const MetaTensor*>& fc_weight,
884+
const std::vector<const MetaTensor*>& fc_weight_max,
885+
const std::vector<const MetaTensor*>& fc_bias,
886+
const MetaTensor& mask,
887+
int head_num,
888+
int head_dim,
889+
float alpha,
890+
DataType out_dtype,
891+
MetaTensor* qkv,
892+
MetaTensor* qkv_max);
880893

881894
void MultiGruInferMeta(
882895
const MetaTensor& x,

0 commit comments

Comments
 (0)