Skip to content

Commit 04b33b7

Browse files
Jokerenfacebook-github-bot
authored andcommitted
Add byte_weight_dequant_op
Summary: Pull Request resolved: pytorch#9541 Reviewed By: hlu1 Differential Revision: D8882964 fbshipit-source-id: 06d2e0d227ea6a4a8dc5ef1ea9dd1d449c149b47
1 parent c1ee883 commit 04b33b7

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#include "caffe2/operators/byte_weight_dequant_op.h"
2+
3+
#include "caffe2/utils/math.h"
4+
5+
namespace caffe2 {
6+
7+
REGISTER_CPU_OPERATOR(ByteWeightDequant, ByteWeightDequantOp<CPUContext>);
8+
9+
OPERATOR_SCHEMA(ByteWeightDequant).NumInputs(1).NumOutputs(1);
10+
11+
} // namespace caffe2
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_
2+
#define CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_
3+
4+
#include "caffe2/core/operator.h"
5+
#include "caffe2/utils/eigen_utils.h"
6+
#include "caffe2/utils/math.h"
7+
8+
namespace caffe2 {
9+
10+
template <typename Context>
11+
class ByteWeightDequantOp : public Operator<Context> {
12+
public:
13+
ByteWeightDequantOp(const OperatorDef& operator_def, Workspace* ws)
14+
: Operator<Context>(operator_def, ws),
15+
min_(OperatorBase::GetSingleArgument<float>("min", -3)),
16+
max_(OperatorBase::GetSingleArgument<float>("max", 3)),
17+
shape_(OperatorBase::GetRepeatedArgument<int>("shape")) {}
18+
19+
USE_OPERATOR_FUNCTIONS(Context);
20+
using Operator<Context>::Operator;
21+
22+
bool RunOnDevice() override {
23+
const auto& WI = Input(0);
24+
auto* Y = Output(0);
25+
Y->Resize(shape_);
26+
float bin_interval = (max_ - min_) / 255.0;
27+
int total = 1;
28+
for (int i = 0; i < shape_.size(); i++) {
29+
total *= Y->dim(i);
30+
}
31+
const uint8_t* Xdata;
32+
if (WI.template IsType<uint8_t>()) {
33+
CAFFE_ENFORCE(total, WI.nbytes());
34+
Xdata = WI.template data<uint8_t>();
35+
} else {
36+
CAFFE_ENFORCE(total, WI.template data<std::string>()[0].size());
37+
Xdata = reinterpret_cast<const uint8_t*>(
38+
WI.template data<std::string>()[0].c_str());
39+
}
40+
auto* Ydata = Y->template mutable_data<float>();
41+
ConstEigenVectorMap<uint8_t> index(&Xdata[0], total);
42+
EigenVectorMap<float> weights(&Ydata[0], total);
43+
weights = (index.cast<float>().array() * bin_interval) + min_;
44+
return true;
45+
}
46+
47+
private:
48+
float min_;
49+
float max_;
50+
std::vector<int> shape_;
51+
};
52+
53+
} // namespace caffe2
54+
55+
#endif // CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_

0 commit comments

Comments
 (0)