|
| 1 | +#ifndef CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ |
| 2 | +#define CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ |
| 3 | + |
| 4 | +#include "caffe2/core/operator.h" |
| 5 | +#include "caffe2/utils/eigen_utils.h" |
| 6 | +#include "caffe2/utils/math.h" |
| 7 | + |
| 8 | +namespace caffe2 { |
| 9 | + |
| 10 | +template <typename Context> |
| 11 | +class ByteWeightDequantOp : public Operator<Context> { |
| 12 | + public: |
| 13 | + ByteWeightDequantOp(const OperatorDef& operator_def, Workspace* ws) |
| 14 | + : Operator<Context>(operator_def, ws), |
| 15 | + min_(OperatorBase::GetSingleArgument<float>("min", -3)), |
| 16 | + max_(OperatorBase::GetSingleArgument<float>("max", 3)), |
| 17 | + shape_(OperatorBase::GetRepeatedArgument<int>("shape")) {} |
| 18 | + |
| 19 | + USE_OPERATOR_FUNCTIONS(Context); |
| 20 | + using Operator<Context>::Operator; |
| 21 | + |
| 22 | + bool RunOnDevice() override { |
| 23 | + const auto& WI = Input(0); |
| 24 | + auto* Y = Output(0); |
| 25 | + Y->Resize(shape_); |
| 26 | + float bin_interval = (max_ - min_) / 255.0; |
| 27 | + int total = 1; |
| 28 | + for (int i = 0; i < shape_.size(); i++) { |
| 29 | + total *= Y->dim(i); |
| 30 | + } |
| 31 | + const uint8_t* Xdata; |
| 32 | + if (WI.template IsType<uint8_t>()) { |
| 33 | + CAFFE_ENFORCE(total, WI.nbytes()); |
| 34 | + Xdata = WI.template data<uint8_t>(); |
| 35 | + } else { |
| 36 | + CAFFE_ENFORCE(total, WI.template data<std::string>()[0].size()); |
| 37 | + Xdata = reinterpret_cast<const uint8_t*>( |
| 38 | + WI.template data<std::string>()[0].c_str()); |
| 39 | + } |
| 40 | + auto* Ydata = Y->template mutable_data<float>(); |
| 41 | + ConstEigenVectorMap<uint8_t> index(&Xdata[0], total); |
| 42 | + EigenVectorMap<float> weights(&Ydata[0], total); |
| 43 | + weights = (index.cast<float>().array() * bin_interval) + min_; |
| 44 | + return true; |
| 45 | + } |
| 46 | + |
| 47 | + private: |
| 48 | + float min_; |
| 49 | + float max_; |
| 50 | + std::vector<int> shape_; |
| 51 | +}; |
| 52 | + |
| 53 | +} // namespace caffe2 |
| 54 | + |
| 55 | +#endif // CAFFE2_OPERATORS_BYTE_WEIGHT_DEQUANT_OP_H_ |
0 commit comments