@@ -108,14 +108,43 @@ class FusedWeightOnlyLinearWithBiasPattern
108108 //
109109 paddle::drr::ResultPattern res = src.ResultPattern ();
110110
111- const auto &weight_quantize =
112- res.Op (paddle::dialect::WeightQuantizeOp::name (),
113- {{" algo" , res.StrAttr (algo_)},
114- {" arch" , res.Int32Attr (sm_version_)},
115- {" group_size" , res.Int32Attr (-1 )}});
116- weight_quantize ({&res.Tensor (" w" )},
117- {&res.Tensor (" quanted_weight_tensor" ),
118- &res.Tensor (" weight_scale_tensor" )});
111+ if (algo_ == " weight_only_int4" ) {
112+ // TODO(liuyuanle): When the operator weight_quantize supports
113+ // weight_only_int4 on gpu version, delete the memory copy.
114+ const auto &memcpy_d2h =
115+ res.Op (paddle::dialect::MemcpyOp::name (),
116+ {{" dst_place_type" , res.Int32Attr (0 /* cpu*/ )}});
117+ res.Tensor (" w_cpu" ) = memcpy_d2h (res.Tensor (" w" ));
118+ const auto &weight_quantize =
119+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
120+ {{" algo" , res.StrAttr (algo_)},
121+ {" arch" , res.Int32Attr (sm_version_)},
122+ {" group_size" , res.Int32Attr (-1 )}});
123+ weight_quantize ({&res.Tensor (" w_cpu" )},
124+ {&res.Tensor (" quanted_weight_tensor_cpu" ),
125+ &res.Tensor (" weight_scale_tensor_cpu" )});
126+
127+ const auto &memcpy_h2d_1 =
128+ res.Op (paddle::dialect::MemcpyOp::name (),
129+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
130+ res.Tensor (" quanted_weight_tensor" ) =
131+ memcpy_h2d_1 (res.Tensor (" quanted_weight_tensor_cpu" ));
132+ const auto &memcpy_h2d_2 =
133+ res.Op (paddle::dialect::MemcpyOp::name (),
134+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
135+ res.Tensor (" weight_scale_tensor" ) =
136+ memcpy_h2d_2 (res.Tensor (" weight_scale_tensor_cpu" ));
137+ } else {
138+ const auto &weight_quantize =
139+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
140+ {{" algo" , res.StrAttr (algo_)},
141+ {" arch" , res.Int32Attr (sm_version_)},
142+ {" group_size" , res.Int32Attr (-1 )}});
143+
144+ weight_quantize ({&res.Tensor (" w" )},
145+ {&res.Tensor (" quanted_weight_tensor" ),
146+ &res.Tensor (" weight_scale_tensor" )});
147+ }
119148
120149 const auto &weight_only_linear =
121150 res.Op (paddle::dialect::WeightOnlyLinearOp::name (),
@@ -192,15 +221,43 @@ class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
192221 //
193222 paddle::drr::ResultPattern res = src.ResultPattern ();
194223
195- const auto &weight_quantize =
196- res.Op (paddle::dialect::WeightQuantizeOp::name (),
197- {{" algo" , res.StrAttr (algo_)},
198- {" arch" , res.Int32Attr (sm_version_)},
199- {" group_size" , res.Int32Attr (-1 )}});
200- weight_quantize ({&res.Tensor (" w" )},
201- {&res.Tensor (" quanted_weight_tensor" ),
202- &res.Tensor (" weight_scale_tensor" )});
203-
224+ if (algo_ == " weight_only_int4" ) {
225+ // TODO(liuyuanle): When the operator weight_quantize supports
226+ // weight_only_int4 on gpu version, delete the memory copy.
227+ const auto &memcpy_d2h =
228+ res.Op (paddle::dialect::MemcpyOp::name (),
229+ {{" dst_place_type" , res.Int32Attr (0 /* cpu*/ )}});
230+ res.Tensor (" w_cpu" ) = memcpy_d2h (res.Tensor (" w" ));
231+ const auto &weight_quantize =
232+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
233+ {{" algo" , res.StrAttr (algo_)},
234+ {" arch" , res.Int32Attr (sm_version_)},
235+ {" group_size" , res.Int32Attr (-1 )}});
236+ weight_quantize ({&res.Tensor (" w_cpu" )},
237+ {&res.Tensor (" quanted_weight_tensor_cpu" ),
238+ &res.Tensor (" weight_scale_tensor_cpu" )});
239+
240+ const auto &memcpy_h2d_1 =
241+ res.Op (paddle::dialect::MemcpyOp::name (),
242+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
243+ res.Tensor (" quanted_weight_tensor" ) =
244+ memcpy_h2d_1 (res.Tensor (" quanted_weight_tensor_cpu" ));
245+ const auto &memcpy_h2d_2 =
246+ res.Op (paddle::dialect::MemcpyOp::name (),
247+ {{" dst_place_type" , res.Int32Attr (1 /* gpu*/ )}});
248+ res.Tensor (" weight_scale_tensor" ) =
249+ memcpy_h2d_2 (res.Tensor (" weight_scale_tensor_cpu" ));
250+ } else {
251+ const auto &weight_quantize =
252+ res.Op (paddle::dialect::WeightQuantizeOp::name (),
253+ {{" algo" , res.StrAttr (algo_)},
254+ {" arch" , res.Int32Attr (sm_version_)},
255+ {" group_size" , res.Int32Attr (-1 )}});
256+
257+ weight_quantize ({&res.Tensor (" w" )},
258+ {&res.Tensor (" quanted_weight_tensor" ),
259+ &res.Tensor (" weight_scale_tensor" )});
260+ }
204261 const auto &weight_only_linear =
205262 res.Op (paddle::dialect::WeightOnlyLinearOp::name (),
206263 {{" weight_dtype" ,
0 commit comments