Skip to content

Commit 7a12ccd

Browse files
jspark1105facebook-github-bot
authored andcommitted
optimize FloatToFused8BitRowwiseQuantized and Fused8BitRowwiseQuantizedToFloat (pytorch#31470)
Summary: Pull Request resolved: pytorch#31470 Optimize performance of these two operators. Additionally use nearbyint instead of round to be consistent with 4-bit embedding table quantization. Reviewed By: hyuen Differential Revision: D19072103 fbshipit-source-id: efe96f14aeff7958cceb453ed625d3fd693891ff
1 parent 0b57b38 commit 7a12ccd

File tree

6 files changed

+364
-54
lines changed

6 files changed

+364
-54
lines changed

caffe2/operators/fused_rowwise_8bit_conversion_ops.cc

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
namespace caffe2 {
55

66
namespace {
7-
void convertfp32fp32(float* dst, const float* src, size_t N) {
8-
memcpy(dst, src, sizeof(float) * N);
9-
}
10-
117
void convertfp16fp32(float* dst, const at::Half* src, size_t N) {
128
for (size_t i = 0; i < N; i++) {
139
dst[i] = src[i];
@@ -23,7 +19,11 @@ void convertfp32fp16(at::Half* dst, const float* src, size_t N) {
2319

2420
REGISTER_CPU_OPERATOR(
2521
FloatToFused8BitRowwiseQuantized,
26-
FloatToFused8BitRowwiseQuantizedOp<float, convertfp32fp32, CPUContext>);
22+
FloatToFused8BitRowwiseQuantizedOp<
23+
float,
24+
nullptr,
25+
false, /* HAS_CONVERT */
26+
CPUContext>);
2727
OPERATOR_SCHEMA(FloatToFused8BitRowwiseQuantized)
2828
.NumInputs(1)
2929
.NumOutputs(1)
@@ -52,7 +52,11 @@ NO_GRADIENT(FloatToFused8BitRowwiseQuantized);
5252

5353
REGISTER_CPU_OPERATOR(
5454
HalfFloatToFused8BitRowwiseQuantized,
55-
FloatToFused8BitRowwiseQuantizedOp<at::Half, convertfp16fp32, CPUContext>);
55+
FloatToFused8BitRowwiseQuantizedOp<
56+
at::Half,
57+
convertfp16fp32,
58+
true, /* HAS_CONVERT*/
59+
CPUContext>);
5660
OPERATOR_SCHEMA(HalfFloatToFused8BitRowwiseQuantized)
5761
.NumInputs(1)
5862
.NumOutputs(1)
@@ -81,7 +85,11 @@ NO_GRADIENT(HalfFloatToFused8BitRowwiseQuantized);
8185

8286
REGISTER_CPU_OPERATOR(
8387
Fused8BitRowwiseQuantizedToFloat,
84-
Fused8BitRowwiseQuantizedToFloatOp<float, convertfp32fp32, CPUContext>);
88+
Fused8BitRowwiseQuantizedToFloatOp<
89+
float,
90+
nullptr,
91+
false, /* HAS_CONVERT */
92+
CPUContext>);
8593
OPERATOR_SCHEMA(Fused8BitRowwiseQuantizedToFloat)
8694
.NumInputs(1)
8795
.NumOutputs(1)
@@ -114,7 +122,11 @@ NO_GRADIENT(Fused8BitRowwiseQuantizedToFloat);
114122

115123
REGISTER_CPU_OPERATOR(
116124
Fused8BitRowwiseQuantizedToHalfFloat,
117-
Fused8BitRowwiseQuantizedToFloatOp<at::Half, convertfp32fp16, CPUContext>);
125+
Fused8BitRowwiseQuantizedToFloatOp<
126+
at::Half,
127+
convertfp32fp16,
128+
true, /* HAS_CONVERT */
129+
CPUContext>);
118130
OPERATOR_SCHEMA(Fused8BitRowwiseQuantizedToHalfFloat)
119131
.NumInputs(1)
120132
.NumOutputs(1)
@@ -152,7 +164,8 @@ NO_GRADIENT(Fused8BitRowwiseQuantizedToHalfFloat);
152164
using Fused8BitRowwiseQuantizedToFloatCPUOp =
153165
caffe2::Fused8BitRowwiseQuantizedToFloatOp<
154166
float,
155-
caffe2::convertfp32fp32,
167+
nullptr,
168+
false, /* HAS_CONVERT */
156169
caffe2::CPUContext>;
157170

158171
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(

caffe2/operators/fused_rowwise_8bit_conversion_ops.h

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,23 @@
66
#include "caffe2/core/logging.h"
77
#include "caffe2/core/operator.h"
88
#include "caffe2/operators/reducer_functors.h"
9-
#include "caffe2/utils/eigen_utils.h"
9+
#include "caffe2/perfkernels/fused_8bit_rowwise_conversion.h"
1010
#include "caffe2/utils/math.h"
1111

1212
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(Fused8BitRowwiseQuantizedToFloat);
1313

1414
namespace caffe2 {
1515

16-
#define IS_LITTLE_ENDIAN \
17-
[] { \
18-
const int32_t kValue = 1; \
19-
return reinterpret_cast<const uint8_t*>(&kValue)[0] == 1; \
16+
#define IS_LITTLE_ENDIAN \
17+
[] { \
18+
const int32_t kValue = 1; \
19+
return reinterpret_cast<const std::uint8_t*>(&kValue)[0] == 1; \
2020
}()
2121

2222
template <
2323
typename T,
2424
void (*convert)(float* dst, const T* src, size_t N),
25+
bool HAS_CONVERT,
2526
class Context>
2627
class FloatToFused8BitRowwiseQuantizedOp : public Operator<Context> {
2728
public:
@@ -45,40 +46,38 @@ class FloatToFused8BitRowwiseQuantizedOp : public Operator<Context> {
4546
// bytes of each row for scale (4 bytes) and bias (4 bytes).
4647
// | ... int8 data ... | scale | bias |
4748
// | number_of_columns | 4B | 4B |
48-
const std::vector<int64_t> output_dimensions = {input_rows,
49-
input_columns + 8};
49+
const std::vector<std::int64_t> output_dimensions = {
50+
input_rows,
51+
input_columns + static_cast<std::int64_t>(2 * sizeof(float))};
5052
auto* output = Output(
51-
DATA_FUSED_SCALE_BIAS_INT8, output_dimensions, at::dtype<uint8_t>());
53+
DATA_FUSED_SCALE_BIAS_INT8,
54+
output_dimensions,
55+
at::dtype<std::uint8_t>());
5256

5357
const auto* input_data = input.template data<T>();
54-
auto* output_data = output->template mutable_data<uint8_t>();
58+
auto* output_data = output->template mutable_data<std::uint8_t>();
5559
const auto output_columns = output->size(1);
5660

57-
if (!std::is_same<T, float>::value && !std::is_same<T, at::Half>::value) {
58-
CAFFE_THROW("Unsupported data type");
61+
bool is_float = std::is_same<T, float>::value;
62+
if (!HAS_CONVERT) {
63+
CAFFE_ENFORCE(is_float, "convert can be nullptr only if T is float");
64+
FloatToFused8BitRowwiseQuantized(
65+
reinterpret_cast<const float*>(input_data),
66+
input_rows,
67+
input_columns,
68+
output_data);
69+
return true;
5970
}
6071

61-
vector<float> tmp;
62-
tmp.resize(input_columns, 0.0);
72+
bool is_half = std::is_same<T, at::Half>::value;
73+
CAFFE_ENFORCE(is_float || is_half);
74+
75+
vector<float> tmp(input_columns);
6376

6477
for (size_t row = 0; row < input_rows; ++row) {
6578
convert(tmp.data(), input_data + row * input_columns, input_columns);
66-
ConstEigenVectorArrayMap<float> input_row(tmp.data(), input_columns);
67-
uint8_t* output_row = output_data + row * output_columns;
68-
EigenVectorArrayMap<uint8_t> output_row_values(output_row, input_columns);
69-
EigenVectorArrayMap<float> output_row_scale_bias(
70-
reinterpret_cast<float*>(output_row + input_columns), 2);
71-
72-
const float minimum_element = input_row.minCoeff();
73-
const float maximum_element = input_row.maxCoeff();
74-
const float range = maximum_element - minimum_element;
75-
76-
output_row_scale_bias(0) = range / 255.0f;
77-
output_row_scale_bias(1) = minimum_element;
78-
const auto inverse_scale = 255.0f / (range + kEpsilon);
79-
output_row_values = ((input_row - minimum_element) * inverse_scale)
80-
.round()
81-
.cast<uint8_t>();
79+
FloatToFused8BitRowwiseQuantized(
80+
tmp.data(), 1, input_columns, output_data + row * output_columns);
8281
}
8382

8483
return true;
@@ -92,6 +91,7 @@ class FloatToFused8BitRowwiseQuantizedOp : public Operator<Context> {
9291
template <
9392
typename T,
9493
void (*convert)(T* dst, const float* src, size_t N),
94+
bool HAS_CONVERT,
9595
class Context>
9696
class Fused8BitRowwiseQuantizedToFloatOp : public Operator<Context> {
9797
public:
@@ -109,28 +109,35 @@ class Fused8BitRowwiseQuantizedToFloatOp : public Operator<Context> {
109109

110110
// The last 8 bytes per row are the scale and the bias. The rest of
111111
// input_columns is the number of values in the original row.
112-
const std::vector<int64_t> output_dimensions = {input_rows,
113-
input_columns - 8};
112+
const std::vector<std::int64_t> output_dimensions = {
113+
input_rows,
114+
input_columns - static_cast<std::int64_t>(2 * sizeof(float))};
114115
auto* output = Output(DATA_FLOAT, output_dimensions, at::dtype<T>());
115116
const auto output_columns = output->size(1);
116117

117-
const auto* input_data = input.template data<uint8_t>();
118+
const auto* input_data = input.template data<std::uint8_t>();
118119
T* output_data = output->template mutable_data<T>();
119120

120-
vector<float> tmp;
121-
tmp.resize(input_columns, 0.0);
121+
bool is_float = std::is_same<T, float>::value;
122122

123-
for (size_t row = 0; row < input_rows; ++row) {
124-
const uint8_t* input_row = input_data + row * input_columns;
125-
ConstEigenVectorArrayMap<uint8_t> input_row_values(
126-
input_row, output_columns);
127-
ConstEigenVectorArrayMap<float> input_row_scale_bias(
128-
reinterpret_cast<const float*>(input_row + output_columns), 2);
123+
if (!HAS_CONVERT) {
124+
CAFFE_ENFORCE(is_float, "convert can be nullptr only if T is float");
125+
Fused8BitRowwiseQuantizedToFloat(
126+
input_data,
127+
input_rows,
128+
input_columns,
129+
reinterpret_cast<float*>(output_data));
130+
return true;
131+
}
129132

130-
EigenVectorArrayMap<float> output_row(tmp.data(), output_columns);
131-
output_row = input_row_values.cast<float>() * input_row_scale_bias(0) +
132-
input_row_scale_bias(1);
133+
bool is_half = std::is_same<T, at::Half>::value;
134+
CAFFE_ENFORCE(is_float || is_half);
133135

136+
vector<float> tmp(input_columns);
137+
138+
for (size_t row = 0; row < input_rows; ++row) {
139+
Fused8BitRowwiseQuantizedToFloat(
140+
input_data + row * input_columns, 1, input_columns, tmp.data());
134141
convert(output_data + row * output_columns, tmp.data(), output_columns);
135142
}
136143
return true;
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "fused_8bit_rowwise_conversion.h"
2+
3+
#include <algorithm>
4+
#include <cmath>
5+
6+
#include "common.h"
7+
8+
namespace caffe2 {
9+
10+
void FloatToFused8BitRowwiseQuantized__base(
11+
const float* input,
12+
int input_rows,
13+
int input_columns,
14+
std::uint8_t* output) {
15+
constexpr float kEpsilon = 1e-8f;
16+
17+
int output_columns = input_columns + 2 * sizeof(float);
18+
for (std::size_t row = 0; row < input_rows; ++row) {
19+
const float* input_row = input + row * input_columns;
20+
std::uint8_t* output_row = output + row * output_columns;
21+
float* output_row_scale_bias =
22+
reinterpret_cast<float*>(output_row + input_columns);
23+
24+
float minimum_element =
25+
*std::min_element(input_row, input_row + input_columns);
26+
float maximum_element =
27+
*std::max_element(input_row, input_row + input_columns);
28+
float range = maximum_element - minimum_element;
29+
30+
output_row_scale_bias[0] = range / 255.0f;
31+
output_row_scale_bias[1] = minimum_element;
32+
const auto inverse_scale = 255.0f / (range + kEpsilon);
33+
for (std::size_t col = 0; col < input_columns; ++col) {
34+
output_row[col] =
35+
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
36+
}
37+
}
38+
}
39+
40+
void Fused8BitRowwiseQuantizedToFloat__base(
41+
const std::uint8_t* input,
42+
int input_rows,
43+
int input_columns,
44+
float* output) {
45+
int output_columns = input_columns - 2 * sizeof(float);
46+
47+
for (std::size_t row = 0; row < input_rows; ++row) {
48+
const std::uint8_t* input_row = input + row * input_columns;
49+
const float* input_row_scale_bias =
50+
reinterpret_cast<const float*>(input_row + output_columns);
51+
float* output_row = output + row * output_columns;
52+
53+
for (std::size_t col = 0; col < output_columns; ++col) {
54+
output_row[col] =
55+
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
56+
}
57+
}
58+
}
59+
60+
decltype(FloatToFused8BitRowwiseQuantized__base)
61+
FloatToFused8BitRowwiseQuantized__avx2_fma;
62+
void FloatToFused8BitRowwiseQuantized(
63+
const float* input,
64+
int input_rows,
65+
int input_columns,
66+
std::uint8_t* output) {
67+
AVX2_FMA_DO(
68+
FloatToFused8BitRowwiseQuantized,
69+
input,
70+
input_rows,
71+
input_columns,
72+
output);
73+
BASE_DO(
74+
FloatToFused8BitRowwiseQuantized,
75+
input,
76+
input_rows,
77+
input_columns,
78+
output);
79+
}
80+
81+
decltype(Fused8BitRowwiseQuantizedToFloat__base)
82+
Fused8BitRowwiseQuantizedToFloat__avx2_fma;
83+
void Fused8BitRowwiseQuantizedToFloat(
84+
const std::uint8_t* input,
85+
int input_rows,
86+
int input_columns,
87+
float* output) {
88+
AVX2_FMA_DO(
89+
Fused8BitRowwiseQuantizedToFloat,
90+
input,
91+
input_rows,
92+
input_columns,
93+
output);
94+
BASE_DO(
95+
Fused8BitRowwiseQuantizedToFloat,
96+
input,
97+
input_rows,
98+
input_columns,
99+
output);
100+
}
101+
102+
} // namespace caffe2
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
5+
namespace caffe2 {
6+
7+
void FloatToFused8BitRowwiseQuantized(
8+
const float* input,
9+
int input_rows,
10+
int input_columns,
11+
std::uint8_t* output);
12+
13+
void Fused8BitRowwiseQuantizedToFloat(
14+
const std::uint8_t* input,
15+
int input_rows,
16+
int input_columns,
17+
float* output);
18+
19+
} // namespace caffe2

0 commit comments

Comments
 (0)