Skip to content

Commit 22fb3b1

Browse files
authored
[XPU] support cross attention for decoder model (#63203)
1 parent ccdfb84 commit 22fb3b1

File tree

11 files changed

+1400
-3
lines changed

11 files changed

+1400
-3
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,8 @@ if(WITH_XPU)
274274
${XPU_PASS_DEPS})
275275
pass_library(decoder_attention_xpu_fuse_pass inference DIR xpu DEPS
276276
${XPU_PASS_DEPS})
277+
pass_library(cross_attention_xpu_fuse_pass inference DIR xpu DEPS
278+
${XPU_PASS_DEPS})
277279
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
278280
${XPU_PASS_DEPS})
279281
pass_library(multi_encoder_xpu_adaptive_seqlen_fuse_pass inference DIR xpu

paddle/fluid/framework/ir/xpu/cross_attention_xpu_fuse_pass.cc

Lines changed: 666 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/pass.h"
19+
20+
namespace phi {
21+
class DenseTensor;
22+
} // namespace phi
23+
24+
namespace paddle {
25+
namespace framework {
26+
class Scope;
27+
} // namespace framework
28+
} // namespace paddle
29+
30+
namespace paddle {
31+
namespace framework {
32+
namespace ir {
33+
34+
/*
35+
This pass is used to fuse the cross attention op into one op in decoder.
36+
models .
37+
38+
Origin subgraph:
39+
40+
mask input_q input_kv
41+
| | | |
42+
| | |-----------|
43+
| matmul matmul matmul
44+
| |q |k |v
45+
| | | |
46+
| | | |
47+
| add add add
48+
| | | |
49+
| | | |
50+
| reshape reshape reshape
51+
| | | |
52+
| | | |
53+
| transpose transpose transpose
54+
| | | |
55+
| | | |
56+
| (scale) | |
57+
| | | |
58+
\ |(x) |(y) |
59+
\ \ / |
60+
\ qk_matmul |
61+
\ | |
62+
\ | |
63+
add /
64+
| /
65+
| /
66+
softmax /
67+
\ /
68+
\ /
69+
qkv_matmul
70+
|
71+
|
72+
transpose
73+
|
74+
|
75+
reshape
76+
|
77+
|
78+
output
79+
80+
-------------------------------------------------------
81+
Fused subgraph:
82+
input_q input_kv
83+
| |
84+
| |
85+
| |
86+
cross_attention_xpu
87+
|
88+
|
89+
|
90+
output
91+
92+
*/
93+
94+
class CrossAttentionXPUFusePass : public FusePassBase {
95+
protected:
96+
void ApplyImpl(ir::Graph* graph) const override;
97+
98+
private:
99+
void ApplyCrossAttentionXPUFuse(ir::Graph* graph, bool with_q_scale) const;
100+
101+
// 1. Generate q/k/v_w_max tensor
102+
// 2. Quant q/k/v_w to int16
103+
void PrepareQKVWeight(Graph* graph,
104+
Scope* scope,
105+
BlockDesc* block,
106+
Node* w,
107+
Node** real_w,
108+
Node** w_max) const;
109+
110+
// Cast fc_bias to fp32
111+
void PrepareQKVBias(Graph* graph,
112+
Scope* scope,
113+
BlockDesc* block,
114+
Node* q_bias,
115+
Node* k_bias,
116+
Node* v_bias,
117+
Node** real_q_bias,
118+
Node** real_k_bias,
119+
Node** real_v_bias) const;
120+
121+
const std::string name_scope_{"cross_attention_xpu_fuse_pass"};
122+
};
123+
124+
} // namespace ir
125+
} // namespace framework
126+
} // namespace paddle

paddle/fluid/framework/ir/xpu/decoder_attention_xpu_fuse_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,15 @@ DecoderAttentionFusePattern::DecoderAttentionFusePattern(
163163

164164
// link nodes
165165
reshape2_1->LinksFrom({input_q}).LinksTo({reshape2_1_out});
166-
reshape2_2->LinksFrom({input_k}).LinksTo({reshape2_2_out});
167-
reshape2_3->LinksFrom({input_v}).LinksTo({reshape2_3_out});
168166
transpose2_1->LinksFrom({reshape2_1_out}).LinksTo({transpose2_1_out});
167+
reshape2_2->LinksFrom({input_k}).LinksTo({reshape2_2_out});
169168
transpose2_2->LinksFrom({reshape2_2_out}).LinksTo({transpose2_2_out});
170-
transpose2_3->LinksFrom({reshape2_3_out}).LinksTo({transpose2_3_out});
171169
qk_matmul->LinksFrom({transpose2_1_out, transpose2_2_out})
172170
.LinksTo({qk_matmul_out});
173171
scale->LinksFrom({qk_matmul_out}).LinksTo({scale_out});
174172
qk_softmax->LinksFrom({scale_out}).LinksTo({qk_softmax_out});
173+
reshape2_3->LinksFrom({input_v}).LinksTo({reshape2_3_out});
174+
transpose2_3->LinksFrom({reshape2_3_out}).LinksTo({transpose2_3_out});
175175
qkv_matmul->LinksFrom({qk_softmax_out, transpose2_3_out})
176176
.LinksTo({qkv_matmul_out});
177177
transpose2_4->LinksFrom({qkv_matmul_out}).LinksTo({transpose2_4_out});

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
546546
"multi_encoder_xpu_slice_fuse_pass",
547547
"fused_multi_transformer_cachekv_layout_trans_pass",
548548
"fused_multi_transformer_int8_cachekv_layout_trans_pass",
549+
"cross_attention_xpu_fuse_pass",
549550
"decoder_attention_xpu_fuse_pass",
550551
"one_beam_size_fuse_pass",
551552
"fold_interp_outsize_fuse_pass",

paddle/phi/api/yaml/fused_ops.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@
8383
data_type : x
8484
optional : bias, branch, branch_max ,x_max, scale_max, out_max_in
8585

86+
- op : cross_attention_xpu
87+
args : (Tensor input_q, Tensor input_kv, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor mask, int head_num, int head_dim, float alpha, DataType out_dtype)
88+
output : Tensor(qkv), Tensor(qkv_max)
89+
infer_meta :
90+
func : CrossAttentionXPUInferMeta
91+
kernel :
92+
func : cross_attention_xpu
93+
data_type : input_q
94+
8695
- op : dequantize_xpu
8796
args : (Tensor x, DataType out_dtype, float scale = 1.0f)
8897
output : Tensor(y)

paddle/phi/backends/xpu/xpu2_op_list.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,8 @@ XPUOpMap& get_kl2_ops() {
12311231
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
12321232
{"roformer_relative_embedding_xpu",
12331233
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
1234+
{"cross_attention_xpu",
1235+
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
12341236
{"variable_length_memory_efficient_attention",
12351237
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
12361238
{"flash_attn_unpadded",

paddle/phi/infermeta/fusion.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3858,6 +3858,95 @@ void SinePosXPUInferMeta(const MetaTensor& x,
38583858
out->set_dtype(x.dtype());
38593859
}
38603860

3861+
void CrossAttentionXPUInferMeta(
3862+
const MetaTensor& input_q,
3863+
const MetaTensor& input_kv,
3864+
const std::vector<const MetaTensor*>& fc_weight,
3865+
const std::vector<const MetaTensor*>& fc_weight_max,
3866+
const std::vector<const MetaTensor*>& fc_bias,
3867+
const MetaTensor& mask,
3868+
int head_num,
3869+
int head_dim,
3870+
float alpha,
3871+
DataType out_dtype,
3872+
MetaTensor* qkv,
3873+
MetaTensor* qkv_max) {
3874+
auto input_q_dims = input_q.dims();
3875+
auto input_kv_dims = input_kv.dims();
3876+
auto mask_dims = mask.dims();
3877+
// input shape : {B, L, H*D}
3878+
PADDLE_ENFORCE_EQ(input_q_dims.size(),
3879+
3,
3880+
phi::errors::InvalidArgument(
3881+
"The dim of input_q should be 3! But received ",
3882+
input_q_dims.size()));
3883+
PADDLE_ENFORCE_EQ(input_kv_dims.size(),
3884+
3,
3885+
phi::errors::InvalidArgument(
3886+
"The dim of input_kv should be 3! But received ",
3887+
input_kv_dims.size()));
3888+
// sequece length of q and k/v not requied to be eqaul
3889+
// but batch size and dim should be the same
3890+
PADDLE_ENFORCE_EQ(
3891+
input_q_dims[0],
3892+
input_kv_dims[0],
3893+
phi::errors::InvalidArgument("The batch size of input_q and input_kv "
3894+
"should be the same! Received ",
3895+
input_q_dims[0],
3896+
" vs ",
3897+
input_kv_dims[0]));
3898+
PADDLE_ENFORCE_EQ(
3899+
input_q_dims[2],
3900+
input_kv_dims[2],
3901+
phi::errors::InvalidArgument("The hidden_dim of input_q and input_kv "
3902+
"should be the same! Received ",
3903+
input_q_dims[2],
3904+
" vs ",
3905+
input_kv_dims[2]));
3906+
int hidden_dim = head_num * head_dim;
3907+
PADDLE_ENFORCE_EQ(
3908+
input_q_dims[2],
3909+
hidden_dim,
3910+
phi::errors::InvalidArgument(
3911+
"The last dimension of input_q should be [H*D]! Received ",
3912+
input_q_dims[2],
3913+
" != expected ",
3914+
hidden_dim));
3915+
PADDLE_ENFORCE_EQ(fc_weight.size(),
3916+
3,
3917+
phi::errors::InvalidArgument(
3918+
"The size of fc_weight should be 3! But received ",
3919+
fc_weight.size()));
3920+
PADDLE_ENFORCE_EQ(fc_weight_max.size(),
3921+
3,
3922+
phi::errors::InvalidArgument(
3923+
"The size of fc_weight_max should be 3! But received ",
3924+
fc_weight_max.size()));
3925+
PADDLE_ENFORCE_EQ(
3926+
fc_bias.size(),
3927+
3,
3928+
phi::errors::InvalidArgument(
3929+
"The size of fc_bias should be 3! But received ", fc_bias.size()));
3930+
PADDLE_ENFORCE_LE(
3931+
mask_dims.size(),
3932+
4,
3933+
phi::errors::InvalidArgument(
3934+
"The dim of mask should be not greater than 4!", mask_dims.size()));
3935+
3936+
// output shape: {B, qL, H*D}
3937+
qkv->set_dims(
3938+
phi::make_ddim({input_q_dims[0], input_q_dims[1], head_num * head_dim}));
3939+
qkv->set_dtype(out_dtype);
3940+
qkv->set_layout(input_q.layout());
3941+
// TODO(Terry) optmize the max value num
3942+
// unable to pass few PR-CIs, so just use a constant value
3943+
// int xpu2_max_value_num = phi::backends::xpu::get_xpu_max_ptr_size(-1);
3944+
const int xpu2_max_value_num = 6;
3945+
qkv_max->set_dims(phi::make_ddim({xpu2_max_value_num}));
3946+
qkv_max->set_dtype(out_dtype);
3947+
qkv_max->set_layout(input_q.layout());
3948+
}
3949+
38613950
void MultiGruInferMeta(
38623951
const MetaTensor& x,
38633952
const std::vector<const MetaTensor*>& weight_x,

paddle/phi/infermeta/fusion.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,19 @@ void RoformerRelativePosXPUInferMeta(const MetaTensor& x,
905905
const MetaTensor& cos_emb,
906906
int max_pos_len,
907907
MetaTensor* out);
908+
void CrossAttentionXPUInferMeta(
909+
const MetaTensor& input_q,
910+
const MetaTensor& input_kv,
911+
const std::vector<const MetaTensor*>& fc_weight,
912+
const std::vector<const MetaTensor*>& fc_weight_max,
913+
const std::vector<const MetaTensor*>& fc_bias,
914+
const MetaTensor& mask,
915+
int head_num,
916+
int head_dim,
917+
float alpha,
918+
DataType out_dtype,
919+
MetaTensor* qkv,
920+
MetaTensor* qkv_max);
908921

909922
void MultiGruInferMeta(
910923
const MetaTensor& x,

0 commit comments

Comments
 (0)