Skip to content

Commit 6489006

Browse files
wwbitejotunnwwbitejotunn
authored andcommitted
Wint8 gemm and gemv opt (#59291)
* fpAintB split-k * workspace * fix error * just_for_llama13b_bsz64-128 * llama13 opt * fix scale type of weight ony quant * draft gemv batched * accuracy fix * m size dispatch for gemv and gemm * fit dispatch * refine gemv * remove useless kernel * refine * fix bug for split-k-limit * fix bug for half scale * weight quant kernel fit for half scale * fix bf16 compile * fix sm70 autogen * fix sm70 compile error * fix code style * update * update * code-style * code-style * windows compile fix * code-style * fix merge bug --------- Co-authored-by: wwbitejotunn <wwbitejotunn@outlook.com>
1 parent 238d99d commit 6489006

25 files changed

+2331
-638
lines changed

paddle/phi/infermeta/unary.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5221,8 +5221,8 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
52215221

52225222
out->set_dtype(DataType::INT8);
52235223

5224-
scale->set_dims(common::make_ddim(dim_scale));
5225-
scale->set_dtype(DataType::FLOAT32);
5224+
scale->set_dims(phi::make_ddim(dim_scale));
5225+
scale->set_dtype(x.dtype());
52265226
}
52275227

52285228
void ChannelShuffleInferMeta(const MetaTensor& x,

paddle/phi/kernels/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ if(WITH_CUTLASS)
145145
)
146146

147147
execute_process(
148+
COMMAND
149+
${CMAKE_COMMAND} -E remove_directory
150+
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen"
148151
COMMAND
149152
${CMAKE_COMMAND} -E make_directory
150153
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/autogen"

paddle/phi/kernels/cpu/weight_quantize_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void quant_compute(const DeviceContext& dev_ctx,
4747
DDim dims = {num};
4848
const T* x_data = x.data<T>();
4949
D* out_data = out->data<D>();
50-
float* scale_data = scale->data<float>();
50+
T* scale_data = scale->data<T>();
5151

5252
DenseTensor x_int(out->type());
5353

@@ -108,7 +108,7 @@ void WeightQuantizeKernel(const Context& dev_ctx,
108108
DenseTensor* out,
109109
DenseTensor* scale) {
110110
dev_ctx.template Alloc<int8_t>(out);
111-
dev_ctx.template Alloc<float>(scale);
111+
dev_ctx.template Alloc<T>(scale);
112112
if (algo == "weight_only_int8" || algo == "llm.int8") {
113113
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, algo, arch);
114114
} else if (algo == "weight_only_int4") {

paddle/phi/kernels/funcs/weight_dequant_functor.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ struct FastWeightOnlyHalfConverter<__nv_bfloat16, 4> {
120120

121121
template <typename T>
122122
__global__ void int8_weight_only_dequant(const uint8_t* weight,
123-
const float* scale_list,
123+
const T* scale_list,
124124
T* output,
125125
const int n,
126126
const int k) {
@@ -145,7 +145,7 @@ __global__ void int8_weight_only_dequant(const uint8_t* weight,
145145
int row_id = tile_id * 2 + ((lane_id % 8) > 3 ? 1 : 0);
146146
weight += tile_id * k * 2;
147147
output += row_id * k;
148-
float scale = scale_list[row_id];
148+
float scale = static_cast<float>(scale_list[row_id]);
149149
#pragma unroll
150150
for (int i = lane_id * 16; i < k * 2; i += 16 * 32) {
151151
Load<uint8_t, 16>(&weight[i], &vec_weight);
@@ -175,7 +175,7 @@ __global__ void int8_weight_only_dequant(const uint8_t* weight,
175175

176176
template <typename T>
177177
__global__ void int4_weight_only_dequant(const uint8_t* weight,
178-
const float* scale_list,
178+
const T* scale_list,
179179
T* output,
180180
const int n,
181181
const int k) {
@@ -201,7 +201,7 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight,
201201
int row_id = tile_id * 4 + ((lane_id % 8) / 2);
202202
weight += tile_id * k / 2 * 4;
203203
output += row_id * k;
204-
float scale = scale_list[row_id];
204+
float scale = static_cast<float>(scale_list[row_id]);
205205
#pragma unroll
206206
for (int i = lane_id * 32; i < k * 4; i += 32 * 32) {
207207
Load<uint8_t, 16>(&weight[i / 2], &vec_weight);
@@ -249,15 +249,15 @@ void WeightDequantize(const Context& dev_ctx,
249249
if (algo == "weight_only_int8") {
250250
int8_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
251251
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
252-
scale.data<float>(),
252+
reinterpret_cast<const DataType*>(scale.data<T>()),
253253
reinterpret_cast<DataType*>(out->data<T>()),
254254
n,
255255
k);
256256
} else if (algo == "weight_only_int4") {
257257
grid.x /= 2;
258258
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
259259
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
260-
scale.data<float>(),
260+
reinterpret_cast<const DataType*>(scale.data<T>()),
261261
reinterpret_cast<DataType*>(out->data<T>()),
262262
n,
263263
k);

0 commit comments

Comments
 (0)