Skip to content

Commit 223ae54

Browse files
Add a graph transform utility for hexagon
Change: 147385554
1 parent bad13d7 commit 223ae54

File tree

7 files changed

+300
-59
lines changed

7 files changed

+300
-59
lines changed

tensorflow/contrib/makefile/build_all_android.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ if [[ "${DOWNLOAD_AND_USE_HEXAGON}" == "true" ]]; then
9090
fi
9191

9292
if [[ ! -z "${HEXAGON_LIB_PATH}" ]]; then
93-
echo "Copy hexagon libraries"
93+
echo "Copy hexagon libraries from ${HEXAGON_LIB_PATH}"
9494

9595
mkdir -p "${HEXAGON_DOWNLOAD_PATH}/libs"
9696
cp -fv "${HEXAGON_LIB_PATH}/libhexagon_controller.so" \

tensorflow/core/kernels/hexagon/BUILD

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ tf_cc_test(
5050
"graph_transferer_test.cc",
5151
"hexagon_graph_execution_test.cc",
5252
],
53+
data = ["//tensorflow/core:example_parser_configuration_testdata"],
5354
deps = [
5455
":graph_transferer",
5556
"//tensorflow/cc:cc_ops",
@@ -87,7 +88,6 @@ tf_kernel_library(
8788
"i_graph_transfer_ops_definitions.h",
8889
"i_soc_control_wrapper.h",
8990
],
90-
data = ["//tensorflow/core:example_parser_configuration_testdata"],
9191
deps = [
9292
"//tensorflow/cc:cc_ops",
9393
"//tensorflow/cc:remote_fused_graph_ops",
@@ -99,3 +99,40 @@ tf_kernel_library(
9999
"//third_party/eigen3",
100100
],
101101
)
102+
103+
cc_library(
104+
name = "hexagon_rewriter_transform",
105+
srcs = [
106+
"hexagon_rewriter_transform.cc",
107+
],
108+
deps = [
109+
":graph_transferer",
110+
"//tensorflow/cc:cc_ops",
111+
"//tensorflow/cc:remote_fused_graph_ops",
112+
"//tensorflow/cc:scope",
113+
"//tensorflow/core",
114+
"//tensorflow/core:core_cpu",
115+
"//tensorflow/core:framework",
116+
"//tensorflow/core:lib",
117+
"//tensorflow/tools/graph_transforms:transform_utils",
118+
"//third_party/eigen3",
119+
],
120+
alwayslink = 1,
121+
)
122+
123+
tf_cc_test(
124+
name = "hexagon_rewriter_transform_test",
125+
size = "small",
126+
srcs = ["hexagon_rewriter_transform_test.cc"],
127+
deps = [
128+
":hexagon_rewriter_transform",
129+
"//tensorflow/cc:cc_ops",
130+
"//tensorflow/core:core_cpu",
131+
"//tensorflow/core:core_cpu_internal",
132+
"//tensorflow/core:framework",
133+
"//tensorflow/core:test",
134+
"//tensorflow/core:test_main",
135+
"//tensorflow/core:testlib",
136+
"//tensorflow/tools/graph_transforms:transform_utils",
137+
],
138+
)

tensorflow/core/kernels/hexagon/graph_transfer_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ GraphTransferUtils::GetTopNFloatResults(const float* const data,
5858
CHECK(gt != nullptr);
5959
GraphTransferer::OutputTensorInfo output_tensor_info;
6060
Status status = gt->DryRunInferenceForAllNode(
61-
def, inputs, false /* initialize_by_zero */, &output_tensor_info);
61+
def, inputs, true /* initialize_by_zero */, &output_tensor_info);
6262
CHECK(status.ok());
6363
status = gt->LoadGraphFromProto(ops_definitions, def, inputs, outputs, false,
6464
output_tensor_info.output_tensor_map);

tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc

Lines changed: 87 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
// Before calling this test program, download a model as follows.
16-
// $ curl https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb \
17-
// -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
18-
// adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
19-
// /data/local/tmp
20-
// $ curl
21-
// https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
22-
// -o /tmp/imagenet_comp_graph_label_strings.txt
23-
// adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
15+
/* Before calling this test program, download a model as follows.
16+
$ curl
17+
https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb
18+
\ -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
19+
$ adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
20+
/data/local/tmp
21+
$ curl
22+
https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
23+
-o /tmp/imagenet_comp_graph_label_strings.txt
24+
adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
25+
*/
2426

2527
#include <memory>
2628

@@ -49,15 +51,26 @@ using ConstByteArray = ISocControlWrapper::ConstByteArray;
4951
constexpr const char* const IMAGE_FILENAME = "/data/local/tmp/img_299x299.bmp";
5052
constexpr const char* const MODEL_FILENAME =
5153
"/data/local/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb";
54+
constexpr const char* const FUSED_MODEL_FILENAME =
55+
"/data/local/tmp/"
56+
"tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb";
57+
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME =
58+
"remote_fused_graph_execute_node";
5259

53-
const bool USE_TF_RUNTIME = true;
5460
const bool DBG_DUMP_FLOAT_DATA = false;
5561
const int WIDTH = 299;
5662
const int HEIGHT = 299;
5763
const int DEPTH = 3;
5864
const int EXPECTED_FIRST_RESULT_ID = 59;
5965
const int EXECUTION_REPEAT_COUNT = 3;
6066

67+
static void CheckHexagonControllerVersion() {
68+
HexagonControlWrapper hexagon_control_wrapper;
69+
const int version = hexagon_control_wrapper.GetVersion();
70+
ASSERT_GE(version, 1);
71+
LOG(INFO) << "Hexagon controller version is " << version;
72+
}
73+
6174
static void DumpTop10Results(const int byte_size,
6275
const float* const float_array) {
6376
const int element_count = byte_size / sizeof(float);
@@ -159,9 +172,6 @@ static void RunInferenceByHexagonControlWrapper(
159172
img_floats.size() * sizeof(float), DT_FLOAT);
160173

161174
HexagonControlWrapper hexagon_control_wrapper;
162-
const int version = hexagon_control_wrapper.GetVersion();
163-
ASSERT_GE(version, 1);
164-
LOG(INFO) << "Hexagon controller version is " << version;
165175
// 1. Initialize hexagon
166176
hexagon_control_wrapper.Init();
167177

@@ -196,13 +206,61 @@ static void RunInferenceByHexagonControlWrapper(
196206
hexagon_control_wrapper.Finalize();
197207
}
198208

209+
static void RunFusedGraph(const GraphDef& fused_graph_def) {
210+
// Setup input tensor
211+
std::vector<float> img_floats;
212+
LoadImage(&img_floats);
213+
214+
LOG(INFO) << "Ioading image finished.";
215+
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
216+
ASSERT_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
217+
ASSERT_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
218+
219+
LOG(INFO) << "Copy data to tensor.";
220+
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
221+
img_tensor.TotalBytes());
222+
223+
// Setup session
224+
std::vector<Tensor> output_tensors;
225+
SessionOptions session_options;
226+
session_options.env = Env::Default();
227+
std::unique_ptr<Session> session =
228+
std::unique_ptr<Session>(NewSession(session_options));
229+
Status status = session->Create(fused_graph_def);
230+
ASSERT_TRUE(status.ok());
231+
232+
// Setup session arguments
233+
RunOptions run_options;
234+
run_options.set_trace_level(RunOptions::FULL_TRACE);
235+
RunMetadata run_metadata;
236+
237+
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
238+
input_tensors.emplace_back("Mul", img_tensor);
239+
std::vector<string> output_node_names;
240+
output_node_names.emplace_back(REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME);
241+
242+
LOG(INFO) << "Run graph";
243+
// Run inference with all node as output
244+
status = session->Run(run_options, input_tensors, output_node_names, {},
245+
&output_tensors, &run_metadata);
246+
ASSERT_TRUE(status.ok());
247+
ASSERT_EQ(1, output_tensors.size());
248+
const Tensor& output_tensor = output_tensors.at(0);
249+
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
250+
LOG(INFO) << "Output shape = " << output_tensor.shape().DebugString();
251+
DumpTop10Results(output_tensor.TotalBytes(),
252+
output_tensor.flat<float>().data());
253+
}
254+
199255
// CAVEAT: This test only runs when you specify hexagon library using
200256
// makefile.
201257
// TODO(satok): Make this generic so that this can run without any
202258
// additional steps.
203259
#ifdef USE_HEXAGON_LIBS
204260
TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
205-
if (USE_TF_RUNTIME) return;
261+
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
262+
CheckHexagonControllerVersion();
263+
206264
const IGraphTransferOpsDefinitions* ops_definitions =
207265
&HexagonOpsDefinitions::getInstance();
208266
std::vector<GraphTransferer::InputNodeInfo> input_node_info_list = {
@@ -226,31 +284,22 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
226284
}
227285

228286
TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
229-
if (!USE_TF_RUNTIME) return;
287+
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
288+
CheckHexagonControllerVersion();
289+
230290
const IGraphTransferOpsDefinitions* ops_definitions =
231291
&HexagonOpsDefinitions::getInstance();
232292
std::vector<GraphTransferer::InputNodeInfo> inputs = {
233293
GraphTransferer::InputNodeInfo{
234294
"Mul", Tensor{DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}}}};
235295
std::vector<string> outputs = {"softmax"};
236-
const bool is_text_proto = false;
237296

238297
std::vector<float> img_floats;
239298
LoadImage(&img_floats);
240299

241300
LOG(INFO) << "Ioading image finished.";
242301

243-
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
244-
ASSERT_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
245-
ASSERT_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
246-
247-
LOG(INFO) << "Copy data to tensor.";
248-
249-
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
250-
img_tensor.TotalBytes());
251-
252302
GraphDef graph_def;
253-
254303
Status status = ReadBinaryProto(Env::Default(), MODEL_FILENAME, &graph_def);
255304

256305
ASSERT_TRUE(status.ok());
@@ -259,40 +308,22 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
259308
GraphTransferer gt;
260309
gt.EnableStrictCheckMode(false);
261310
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef(
262-
HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node",
263-
inputs, outputs, graph_def, &gt);
264-
265-
// Setup session
266-
std::vector<Tensor> output_tensors;
267-
SessionOptions session_options;
268-
session_options.env = Env::Default();
269-
std::unique_ptr<Session> session =
270-
std::unique_ptr<Session>(NewSession(session_options));
271-
status = session->Create(fused_graph_def);
272-
ASSERT_TRUE(status.ok());
311+
HexagonOpsDefinitions::getInstance(),
312+
REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, graph_def, &gt);
273313

274-
// Setup session arguments
275-
RunOptions run_options;
276-
run_options.set_trace_level(RunOptions::FULL_TRACE);
277-
RunMetadata run_metadata;
314+
RunFusedGraph(fused_graph_def);
315+
}
278316

279-
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
280-
input_tensors.emplace_back("Mul", img_tensor);
281-
std::vector<string> output_node_names;
282-
output_node_names.emplace_back("remote_fused_graph_execute_node");
317+
TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithFusedGraph) {
318+
LOG(INFO) << "Run inception v3 with fused graph";
319+
CheckHexagonControllerVersion();
283320

284-
LOG(INFO) << "Run graph";
285-
// Run inference with all node as output
286-
status = session->Run(run_options, input_tensors, output_node_names, {},
287-
&output_tensors, &run_metadata);
288-
ASSERT_TRUE(status.ok());
289-
ASSERT_EQ(1, output_tensors.size());
290-
const Tensor& output_tensor = output_tensors.at(0);
291-
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
292-
LOG(INFO) << "Output shape = " << output_tensor.shape().DebugString();
293-
DumpTop10Results(output_tensor.TotalBytes(),
294-
output_tensor.flat<float>().data());
321+
GraphDef fused_graph_def;
322+
Status status =
323+
ReadBinaryProto(Env::Default(), FUSED_MODEL_FILENAME, &fused_graph_def);
324+
RunFusedGraph(fused_graph_def);
295325
}
326+
296327
#endif
297328

298329
} // namespace tensorflow
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// Wraps the hexagon rewriter in a transform so it can be used as part of the
17+
// graph transform tool.
18+
// A usage example, based on the Image Understanding pipeline:
19+
/*
20+
bazel build tensorflow/tools/graph_transforms:transform_graph
21+
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
22+
--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
23+
--out_graph=\
24+
/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
25+
--inputs='Mul' \
26+
--outputs='softmax' \
27+
--transforms='\
28+
rewrite_quantized_stripped_model_for_hexagon(
29+
input_shape0="1,299,299,3" \
30+
input_type0="float" \
31+
)'
32+
*/
33+
34+
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
35+
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
36+
#include "tensorflow/tools/graph_transforms/transform_utils.h"
37+
38+
namespace tensorflow {
39+
namespace graph_transforms {
40+
constexpr const char* const INPUT_SHAPE_PREFIX = "input_shape";
41+
constexpr const char* const INPUT_TYPE_PREFIX = "input_type";
42+
43+
Status RewriteQuantizedStrippedModelForHexagon(
44+
const GraphDef& input_graph_def, const TransformFuncContext& context,
45+
GraphDef* output_graph_def) {
46+
LOG(INFO) << "Transforming quantized stripped model to a remote fused "
47+
"graph execute op...";
48+
std::vector<GraphTransferer::InputNodeInfo> inputs;
49+
std::vector<string> outputs;
50+
for (int i = 0; i < context.input_names.size(); ++i) {
51+
const string& input_name = context.input_names.at(i);
52+
53+
// Get input shape
54+
string shape_string;
55+
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
56+
INPUT_SHAPE_PREFIX + std::to_string(i), "", &shape_string));
57+
std::vector<int64> dims;
58+
CHECK(str_util::SplitAndParseAsInts(shape_string, ',', &dims));
59+
60+
// Get input data type
61+
string data_type_string;
62+
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
63+
INPUT_TYPE_PREFIX + std::to_string(i), "", &data_type_string));
64+
DataType data_type;
65+
CHECK(DataTypeFromString(data_type_string, &data_type))
66+
<< "\"" << data_type_string << "\" was an invalid type";
67+
68+
LOG(INFO) << "Input(" << i << "): name = " << input_name
69+
<< ", shape = " << shape_string
70+
<< ", type = " << data_type_string;
71+
72+
inputs.emplace_back(GraphTransferer::InputNodeInfo{
73+
input_name, {data_type, TensorShape(dims)}});
74+
}
75+
76+
for (const string& output_name : context.output_names) {
77+
outputs.emplace_back(output_name);
78+
}
79+
GraphTransferer gt;
80+
gt.EnableStrictCheckMode(false);
81+
*output_graph_def = GraphTransferUtils::BuildFusedGraphDef(
82+
HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node",
83+
inputs, outputs, input_graph_def, &gt);
84+
return Status::OK();
85+
}
86+
87+
REGISTER_GRAPH_TRANSFORM("rewrite_quantized_stripped_model_for_hexagon",
88+
RewriteQuantizedStrippedModelForHexagon);
89+
90+
} // namespace graph_transforms
91+
} // namespace tensorflow

0 commit comments

Comments
 (0)