@@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License.
14
14
==============================================================================*/
15
- // Before calling this test program, download a model as follows.
16
- // $ curl https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb \
17
- // -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
18
- // adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
19
- // /data/local/tmp
20
- // $ curl
21
- // https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
22
- // -o /tmp/imagenet_comp_graph_label_strings.txt
23
- // adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
15
+ /* Before calling this test program, download a model as follows.
16
+ $ curl
17
+ https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb
18
+ \ -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
19
+ $ adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
20
+ /data/local/tmp
21
+ $ curl
22
+ https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
23
+ -o /tmp/imagenet_comp_graph_label_strings.txt
24
+ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
25
+ */
24
26
25
27
#include < memory>
26
28
@@ -49,15 +51,26 @@ using ConstByteArray = ISocControlWrapper::ConstByteArray;
49
51
constexpr const char * const IMAGE_FILENAME = " /data/local/tmp/img_299x299.bmp" ;
50
52
constexpr const char * const MODEL_FILENAME =
51
53
" /data/local/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb" ;
54
+ constexpr const char * const FUSED_MODEL_FILENAME =
55
+ " /data/local/tmp/"
56
+ " tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb" ;
57
+ constexpr const char * const REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME =
58
+ " remote_fused_graph_execute_node" ;
52
59
53
- const bool USE_TF_RUNTIME = true ;
54
60
const bool DBG_DUMP_FLOAT_DATA = false ;
55
61
const int WIDTH = 299 ;
56
62
const int HEIGHT = 299 ;
57
63
const int DEPTH = 3 ;
58
64
const int EXPECTED_FIRST_RESULT_ID = 59 ;
59
65
const int EXECUTION_REPEAT_COUNT = 3 ;
60
66
67
+ static void CheckHexagonControllerVersion () {
68
+ HexagonControlWrapper hexagon_control_wrapper;
69
+ const int version = hexagon_control_wrapper.GetVersion ();
70
+ ASSERT_GE (version, 1 );
71
+ LOG (INFO) << " Hexagon controller version is " << version;
72
+ }
73
+
61
74
static void DumpTop10Results (const int byte_size,
62
75
const float * const float_array) {
63
76
const int element_count = byte_size / sizeof (float );
@@ -159,9 +172,6 @@ static void RunInferenceByHexagonControlWrapper(
159
172
img_floats.size () * sizeof (float ), DT_FLOAT);
160
173
161
174
HexagonControlWrapper hexagon_control_wrapper;
162
- const int version = hexagon_control_wrapper.GetVersion ();
163
- ASSERT_GE (version, 1 );
164
- LOG (INFO) << " Hexagon controller version is " << version;
165
175
// 1. Initialize hexagon
166
176
hexagon_control_wrapper.Init ();
167
177
@@ -196,13 +206,61 @@ static void RunInferenceByHexagonControlWrapper(
196
206
hexagon_control_wrapper.Finalize ();
197
207
}
198
208
209
+ static void RunFusedGraph (const GraphDef& fused_graph_def) {
210
+ // Setup input tensor
211
+ std::vector<float > img_floats;
212
+ LoadImage (&img_floats);
213
+
214
+ LOG (INFO) << " Ioading image finished." ;
215
+ Tensor img_tensor (DT_FLOAT, {1 , WIDTH, HEIGHT, DEPTH});
216
+ ASSERT_EQ (WIDTH * HEIGHT * DEPTH, img_floats.size ());
217
+ ASSERT_EQ (img_tensor.TotalBytes (), img_floats.size () * sizeof (float ));
218
+
219
+ LOG (INFO) << " Copy data to tensor." ;
220
+ std::memcpy (img_tensor.flat <float >().data (), img_floats.data (),
221
+ img_tensor.TotalBytes ());
222
+
223
+ // Setup session
224
+ std::vector<Tensor> output_tensors;
225
+ SessionOptions session_options;
226
+ session_options.env = Env::Default ();
227
+ std::unique_ptr<Session> session =
228
+ std::unique_ptr<Session>(NewSession (session_options));
229
+ Status status = session->Create (fused_graph_def);
230
+ ASSERT_TRUE (status.ok ());
231
+
232
+ // Setup session arguments
233
+ RunOptions run_options;
234
+ run_options.set_trace_level (RunOptions::FULL_TRACE);
235
+ RunMetadata run_metadata;
236
+
237
+ std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
238
+ input_tensors.emplace_back (" Mul" , img_tensor);
239
+ std::vector<string> output_node_names;
240
+ output_node_names.emplace_back (REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME);
241
+
242
+ LOG (INFO) << " Run graph" ;
243
+ // Run inference with all node as output
244
+ status = session->Run (run_options, input_tensors, output_node_names, {},
245
+ &output_tensors, &run_metadata);
246
+ ASSERT_TRUE (status.ok ());
247
+ ASSERT_EQ (1 , output_tensors.size ());
248
+ const Tensor& output_tensor = output_tensors.at (0 );
249
+ LOG (INFO) << " Output byte size = " << output_tensor.TotalBytes ();
250
+ LOG (INFO) << " Output shape = " << output_tensor.shape ().DebugString ();
251
+ DumpTop10Results (output_tensor.TotalBytes (),
252
+ output_tensor.flat <float >().data ());
253
+ }
254
+
199
255
// CAVEAT: This test only runs when you specify hexagon library using
200
256
// makefile.
201
257
// TODO(satok): Make this generic so that this can run without any
202
258
// additional steps.
203
259
#ifdef USE_HEXAGON_LIBS
204
260
TEST (GraphTransferer, RunInceptionV3OnHexagonExample) {
205
- if (USE_TF_RUNTIME) return ;
261
+ LOG (INFO) << " Run inception v3 on hexagon with hexagon controller" ;
262
+ CheckHexagonControllerVersion ();
263
+
206
264
const IGraphTransferOpsDefinitions* ops_definitions =
207
265
&HexagonOpsDefinitions::getInstance ();
208
266
std::vector<GraphTransferer::InputNodeInfo> input_node_info_list = {
@@ -226,31 +284,22 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExample) {
226
284
}
227
285
228
286
TEST (GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
229
- if (!USE_TF_RUNTIME) return ;
287
+ LOG (INFO) << " Fuse and run inception v3 on hexagon with tf runtime" ;
288
+ CheckHexagonControllerVersion ();
289
+
230
290
const IGraphTransferOpsDefinitions* ops_definitions =
231
291
&HexagonOpsDefinitions::getInstance ();
232
292
std::vector<GraphTransferer::InputNodeInfo> inputs = {
233
293
GraphTransferer::InputNodeInfo{
234
294
" Mul" , Tensor{DT_FLOAT, {1 , WIDTH, HEIGHT, DEPTH}}}};
235
295
std::vector<string> outputs = {" softmax" };
236
- const bool is_text_proto = false ;
237
296
238
297
std::vector<float > img_floats;
239
298
LoadImage (&img_floats);
240
299
241
300
LOG (INFO) << " Ioading image finished." ;
242
301
243
- Tensor img_tensor (DT_FLOAT, {1 , WIDTH, HEIGHT, DEPTH});
244
- ASSERT_EQ (WIDTH * HEIGHT * DEPTH, img_floats.size ());
245
- ASSERT_EQ (img_tensor.TotalBytes (), img_floats.size () * sizeof (float ));
246
-
247
- LOG (INFO) << " Copy data to tensor." ;
248
-
249
- std::memcpy (img_tensor.flat <float >().data (), img_floats.data (),
250
- img_tensor.TotalBytes ());
251
-
252
302
GraphDef graph_def;
253
-
254
303
Status status = ReadBinaryProto (Env::Default (), MODEL_FILENAME, &graph_def);
255
304
256
305
ASSERT_TRUE (status.ok ());
@@ -259,40 +308,22 @@ TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
259
308
GraphTransferer gt;
260
309
gt.EnableStrictCheckMode (false );
261
310
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef (
262
- HexagonOpsDefinitions::getInstance (), " remote_fused_graph_execute_node" ,
263
- inputs, outputs, graph_def, >);
264
-
265
- // Setup session
266
- std::vector<Tensor> output_tensors;
267
- SessionOptions session_options;
268
- session_options.env = Env::Default ();
269
- std::unique_ptr<Session> session =
270
- std::unique_ptr<Session>(NewSession (session_options));
271
- status = session->Create (fused_graph_def);
272
- ASSERT_TRUE (status.ok ());
311
+ HexagonOpsDefinitions::getInstance (),
312
+ REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, graph_def, >);
273
313
274
- // Setup session arguments
275
- RunOptions run_options;
276
- run_options.set_trace_level (RunOptions::FULL_TRACE);
277
- RunMetadata run_metadata;
314
+ RunFusedGraph (fused_graph_def);
315
+ }
278
316
279
- std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
280
- input_tensors.emplace_back (" Mul" , img_tensor);
281
- std::vector<string> output_node_names;
282
- output_node_names.emplace_back (" remote_fused_graph_execute_node" );
317
+ TEST (GraphTransferer, RunInceptionV3OnHexagonExampleWithFusedGraph) {
318
+ LOG (INFO) << " Run inception v3 with fused graph" ;
319
+ CheckHexagonControllerVersion ();
283
320
284
- LOG (INFO) << " Run graph" ;
285
- // Run inference with all node as output
286
- status = session->Run (run_options, input_tensors, output_node_names, {},
287
- &output_tensors, &run_metadata);
288
- ASSERT_TRUE (status.ok ());
289
- ASSERT_EQ (1 , output_tensors.size ());
290
- const Tensor& output_tensor = output_tensors.at (0 );
291
- LOG (INFO) << " Output byte size = " << output_tensor.TotalBytes ();
292
- LOG (INFO) << " Output shape = " << output_tensor.shape ().DebugString ();
293
- DumpTop10Results (output_tensor.TotalBytes (),
294
- output_tensor.flat <float >().data ());
321
+ GraphDef fused_graph_def;
322
+ Status status =
323
+ ReadBinaryProto (Env::Default (), FUSED_MODEL_FILENAME, &fused_graph_def);
324
+ RunFusedGraph (fused_graph_def);
295
325
}
326
+
296
327
#endif
297
328
298
329
} // namespace tensorflow
0 commit comments