Skip to content

Commit 8b4077e

Browse files
committed
update
1 parent e3f95cf commit 8b4077e

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,43 @@ class FusedWeightOnlyLinearWithBiasPattern
108108
//
109109
paddle::drr::ResultPattern res = src.ResultPattern();
110110

111-
const auto &weight_quantize =
112-
res.Op(paddle::dialect::WeightQuantizeOp::name(),
113-
{{"algo", res.StrAttr(algo_)},
114-
{"arch", res.Int32Attr(sm_version_)},
115-
{"group_size", res.Int32Attr(-1)}});
116-
weight_quantize({&res.Tensor("w")},
117-
{&res.Tensor("quanted_weight_tensor"),
118-
&res.Tensor("weight_scale_tensor")});
111+
if (algo_ == "weight_only_int4") {
112+
// TODO(liuyuanle): When the operator weight_quantize supports
113+
// weight_only_int4 on gpu version, delete the memory copy.
114+
const auto &memcpy_d2h =
115+
res.Op(paddle::dialect::MemcpyOp::name(),
116+
{{"dst_place_type", res.Int32Attr(0 /*cpu*/)}});
117+
res.Tensor("w_cpu") = memcpy_d2h(res.Tensor("w"));
118+
const auto &weight_quantize =
119+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
120+
{{"algo", res.StrAttr(algo_)},
121+
{"arch", res.Int32Attr(sm_version_)},
122+
{"group_size", res.Int32Attr(-1)}});
123+
weight_quantize({&res.Tensor("w_cpu")},
124+
{&res.Tensor("quanted_weight_tensor_cpu"),
125+
&res.Tensor("weight_scale_tensor_cpu")});
126+
127+
const auto &memcpy_h2d_1 =
128+
res.Op(paddle::dialect::MemcpyOp::name(),
129+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
130+
res.Tensor("quanted_weight_tensor") =
131+
memcpy_h2d_1(res.Tensor("quanted_weight_tensor_cpu"));
132+
const auto &memcpy_h2d_2 =
133+
res.Op(paddle::dialect::MemcpyOp::name(),
134+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
135+
res.Tensor("weight_scale_tensor") =
136+
memcpy_h2d_2(res.Tensor("weight_scale_tensor_cpu"));
137+
} else {
138+
const auto &weight_quantize =
139+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
140+
{{"algo", res.StrAttr(algo_)},
141+
{"arch", res.Int32Attr(sm_version_)},
142+
{"group_size", res.Int32Attr(-1)}});
143+
144+
weight_quantize({&res.Tensor("w")},
145+
{&res.Tensor("quanted_weight_tensor"),
146+
&res.Tensor("weight_scale_tensor")});
147+
}
119148

120149
const auto &weight_only_linear =
121150
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
@@ -192,15 +221,43 @@ class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase {
192221
//
193222
paddle::drr::ResultPattern res = src.ResultPattern();
194223

195-
const auto &weight_quantize =
196-
res.Op(paddle::dialect::WeightQuantizeOp::name(),
197-
{{"algo", res.StrAttr(algo_)},
198-
{"arch", res.Int32Attr(sm_version_)},
199-
{"group_size", res.Int32Attr(-1)}});
200-
weight_quantize({&res.Tensor("w")},
201-
{&res.Tensor("quanted_weight_tensor"),
202-
&res.Tensor("weight_scale_tensor")});
203-
224+
if (algo_ == "weight_only_int4") {
225+
// TODO(liuyuanle): When the operator weight_quantize supports
226+
// weight_only_int4 on gpu version, delete the memory copy.
227+
const auto &memcpy_d2h =
228+
res.Op(paddle::dialect::MemcpyOp::name(),
229+
{{"dst_place_type", res.Int32Attr(0 /*cpu*/)}});
230+
res.Tensor("w_cpu") = memcpy_d2h(res.Tensor("w"));
231+
const auto &weight_quantize =
232+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
233+
{{"algo", res.StrAttr(algo_)},
234+
{"arch", res.Int32Attr(sm_version_)},
235+
{"group_size", res.Int32Attr(-1)}});
236+
weight_quantize({&res.Tensor("w_cpu")},
237+
{&res.Tensor("quanted_weight_tensor_cpu"),
238+
&res.Tensor("weight_scale_tensor_cpu")});
239+
240+
const auto &memcpy_h2d_1 =
241+
res.Op(paddle::dialect::MemcpyOp::name(),
242+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
243+
res.Tensor("quanted_weight_tensor") =
244+
memcpy_h2d_1(res.Tensor("quanted_weight_tensor_cpu"));
245+
const auto &memcpy_h2d_2 =
246+
res.Op(paddle::dialect::MemcpyOp::name(),
247+
{{"dst_place_type", res.Int32Attr(1 /*gpu*/)}});
248+
res.Tensor("weight_scale_tensor") =
249+
memcpy_h2d_2(res.Tensor("weight_scale_tensor_cpu"));
250+
} else {
251+
const auto &weight_quantize =
252+
res.Op(paddle::dialect::WeightQuantizeOp::name(),
253+
{{"algo", res.StrAttr(algo_)},
254+
{"arch", res.Int32Attr(sm_version_)},
255+
{"group_size", res.Int32Attr(-1)}});
256+
257+
weight_quantize({&res.Tensor("w")},
258+
{&res.Tensor("quanted_weight_tensor"),
259+
&res.Tensor("weight_scale_tensor")});
260+
}
204261
const auto &weight_only_linear =
205262
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
206263
{{"weight_dtype",

paddle/phi/kernels/gpu/weight_quantize_kernel.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
#include "paddle/common/enforce.h"
1415
#include "paddle/phi/backends/gpu/gpu_context.h"
1516
#include "paddle/phi/common/datatype_traits.h"
1617
#include "paddle/phi/core/dense_tensor.h"
@@ -72,14 +73,14 @@ void WeightQuantizeKernel(const Context& dev_ctx,
7273
weight_shape,
7374
arch);
7475
} else if (algo == "weight_only_int4") {
75-
phi::errors::Unimplemented(
76+
PADDLE_FATAL(phi::errors::Unimplemented(
7677
"Weight quant gpu kernel currently don't support weight_only_int4 "
77-
"algo, please use cpu version.");
78+
"algo, please use cpu version."));
7879
} else {
79-
phi::errors::Unimplemented(
80+
PADDLE_FATAL(phi::errors::Unimplemented(
8081
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
8182
"'llm.int8'], but got[%s]",
82-
algo);
83+
algo));
8384
}
8485
}
8586
} // namespace phi

0 commit comments

Comments
 (0)