Skip to content

Commit 2d62336

Browse files
authored
Accuracy op (#3907)
* init add * add topk op * someupdate * fix style check * add test py file * update top k cuda kernel * follow comments * remove debug print * accuracy_op * fix casting error * fix casting error * fix casting error * fix rename bug... * make it smaller * update cast
1 parent b3f6b5a commit 2d62336

File tree

5 files changed

+238
-0
lines changed

5 files changed

+238
-0
lines changed

paddle/operators/accuracy_op.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/accuracy_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AccuracyOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(const framework::InferShapeContext &ctx) const override {
26+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"),
27+
"Input of Inference must be initialized.");
28+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
29+
"Input of Inference must be initialized.");
30+
auto *inference = ctx.Input<framework::Tensor>("Inference");
31+
auto *label = ctx.Input<framework::Tensor>("Label");
32+
33+
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
34+
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
35+
"inference size must be the same as label size");
36+
37+
ctx.Output<Tensor>("Accuracy")->Resize({1});
38+
}
39+
};
40+
41+
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
42+
public:
43+
AccuracyOpMaker(framework::OpProto *proto,
44+
framework::OpAttrChecker *op_checker)
45+
: OpProtoAndCheckerMaker(proto, op_checker) {
46+
// TODO(typhoonzero): support both inference value and indices.
47+
AddInput("Inference", "topk(indices) the network output");
48+
AddInput("Label", "Label of the training data");
49+
// TODO(typhoonzero): AddInput("Weight", ...
50+
AddOutput("Accuracy", "The accuracy of current batch");
51+
52+
AddComment(
53+
R"DOC(Accuracy. It will print accuracy rate for classification.
54+
The accuracy is:
55+
.. math::
56+
accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})DOC");
57+
}
58+
};
59+
60+
} // namespace operators
61+
} // namespace paddle
62+
63+
namespace ops = paddle::operators;
64+
REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker);
65+
REGISTER_OP_CPU_KERNEL(accuracy,
66+
ops::AccuracyKernel<paddle::platform::CPUPlace, float>);

paddle/operators/accuracy_op.cu

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/accuracy_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
__global__ void AccuracySingleKernel(const int N, const int D, const int top_k,
21+
const int* Xdata, const int* labelData,
22+
float* accuracy) {
23+
int correct = 0;
24+
for (int row = 0; row < N; row++) {
25+
const int label = labelData[row];
26+
for (int col = 0; col < D; col++) {
27+
const int pred = Xdata[row * D + col];
28+
if (pred == label) {
29+
++correct;
30+
break;
31+
}
32+
}
33+
}
34+
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
35+
}
36+
37+
template <typename T>
38+
class AccuracyOpCUDAKernel : public framework::OpKernel {
39+
public:
40+
void Compute(const framework::ExecutionContext& ctx) const override {
41+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
42+
"It must use GPUPlace.");
43+
auto* inference = ctx.Input<Tensor>("Inference");
44+
auto* label = ctx.Input<Tensor>("Label");
45+
auto* accuracy = ctx.Output<Tensor>("Accuracy");
46+
// FIXME(typhoonzero): only support indices currently
47+
// if add support for output values, how to detect the data type?
48+
const int* inference_data = inference->data<int>();
49+
const int* label_data = label->data<int>();
50+
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
51+
52+
size_t num_samples = inference->dims()[0];
53+
size_t infer_width = inference->dims()[1];
54+
cudaMemset((void**)&accuracy_data, 0, sizeof(float));
55+
56+
if (num_samples == 0) {
57+
return;
58+
}
59+
60+
AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
61+
label_data, accuracy_data);
62+
}
63+
};
64+
65+
} // namespace operators
66+
} // namespace paddle
67+
68+
REGISTER_OP_GPU_KERNEL(accuracy,
69+
paddle::operators::AccuracyOpCUDAKernel<float>);

paddle/operators/accuracy_op.h

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <algorithm>
17+
#include "paddle/framework/eigen.h"
18+
#include "paddle/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using Tensor = framework::Tensor;
24+
25+
template <typename T, int MajorType = Eigen::RowMajor,
26+
typename IndexType = Eigen::DenseIndex>
27+
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
28+
29+
template <typename T, int MajorType = Eigen::RowMajor,
30+
typename IndexType = Eigen::DenseIndex>
31+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
32+
33+
template <typename T, int MajorType = Eigen::RowMajor,
34+
typename IndexType = Eigen::DenseIndex>
35+
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
36+
37+
template <typename Place, typename T>
38+
class AccuracyKernel : public framework::OpKernel {
39+
public:
40+
void Compute(const framework::ExecutionContext& ctx) const override {
41+
auto* inference = ctx.Input<Tensor>("Inference");
42+
auto* label = ctx.Input<Tensor>("Label");
43+
auto* accuracy = ctx.Output<Tensor>("Accuracy");
44+
45+
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
46+
47+
const T* inference_data = inference->data<T>();
48+
const T* label_data = label->data<T>();
49+
50+
size_t num_samples = inference->dims()[0];
51+
size_t class_dim = inference->dims()[1];
52+
*accuracy_data = 0.0f;
53+
54+
if (num_samples == 0) {
55+
return;
56+
}
57+
58+
int num_correct = 0;
59+
// assume inference is already the topk of the output
60+
for (size_t i = 0; i < num_samples; ++i) {
61+
PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0");
62+
for (size_t j = 0; j < class_dim; ++j) {
63+
if (inference_data[i * class_dim + j] == label_data[i]) {
64+
++num_correct;
65+
break;
66+
}
67+
}
68+
}
69+
70+
// FIXME(typhoonzero): we don't accumulate the accuracy for now.
71+
*accuracy_data =
72+
static_cast<float>(num_correct) / static_cast<float>(num_samples);
73+
}
74+
};
75+
76+
} // namespace operators
77+
} // namespace paddle

paddle/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ USE_OP(cos_sim);
5252
USE_CPU_ONLY_OP(gather);
5353
USE_OP(pad);
5454
USE_CPU_ONLY_OP(scatter);
55+
USE_OP(accuracy);
5556
USE_CPU_ONLY_OP(concat);
5657
USE_OP(top_k);
5758
USE_OP(squared_l2_distance);
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
class TestAccuracyOp(OpTest):
7+
def setUp(self):
8+
self.op_type = "accuracy"
9+
infer = np.random.randint(0, 2, (32, 1)).astype("int")
10+
label = np.random.randint(0, 2, (32, )).astype("int")
11+
self.inputs = {'Inference': infer, "Label": label}
12+
num_correct = 0
13+
for rowid in xrange(32):
14+
for ele in infer[rowid]:
15+
if ele == label[rowid]:
16+
num_correct += 1
17+
break
18+
self.outputs = {'Accuracy': [num_correct / 32.0]}
19+
20+
def test_check_output(self):
21+
self.check_output()
22+
23+
24+
if __name__ == '__main__':
25+
unittest.main()

0 commit comments

Comments
 (0)