@@ -65,8 +65,9 @@ namespace {
6565// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
6666// otherwise, no infeed is performed.
6767StatusOr<std::unique_ptr<Literal>> ReplayComputation (
68- const SessionModule& module , tensorflow::StringPiece fake_infeed_shape,
69- bool use_fake_data, Client* client) {
68+ const SessionModule& module , int num_runs,
69+ tensorflow::StringPiece fake_infeed_shape, bool use_fake_data,
70+ Client* client) {
7071 TF_ASSIGN_OR_RETURN (Computation computation, client->LoadSnapshot (module ));
7172
7273 std::vector<std::unique_ptr<GlobalData>> arguments;
@@ -107,33 +108,58 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
107108 for (auto & argument : arguments) {
108109 execute_arguments.push_back (argument.get ());
109110 }
110- return client->ExecuteAndTransfer (computation, execute_arguments);
111+
112+ // Run the computation num_runs times, and return the result from the last
113+ // execution.
114+ std::unique_ptr<Literal> result;
115+ for (int i = 0 ; i < num_runs; ++i) {
116+ ExecutionProfile profile;
117+ if (use_fake_data) {
118+ // If using fake data, execute the computation but don't bother retrieving
119+ // the result -- presumably it's uninteresting, since our data is fake.
120+ TF_RETURN_IF_ERROR (client
121+ ->Execute (computation, execute_arguments,
122+ /* execution_options=*/ nullptr , &profile)
123+ .status ());
124+ } else {
125+ TF_ASSIGN_OR_RETURN (result, client->ExecuteAndTransfer (
126+ computation, execute_arguments,
127+ /* execution_options=*/ nullptr , &profile));
128+ }
129+ LOG (INFO) << " Execution took "
130+ << static_cast <double >(profile.compute_time_ns ()) / 1e9 << " s" ;
131+ }
132+
133+ return std::move (result);
111134}
112135
113- int RealMain (tensorflow::gtl::ArraySlice<char *> args,
136+ int RealMain (tensorflow::gtl::ArraySlice<char *> args, int num_runs,
114137 tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) {
115138 Client* client = ClientLibrary::LocalClientOrDie ();
116139 tensorflow::Env* env = tensorflow::Env::Default ();
117140 int exit_status = EXIT_SUCCESS;
118141 for (char * arg : args) {
119142 SessionModule module ;
120143 TF_CHECK_OK (tensorflow::ReadBinaryProto (env, arg, &module ));
121- StatusOr<std::unique_ptr<Literal>> result_status =
122- ReplayComputation ( module , fake_infeed_shape, use_fake_data, client);
144+ StatusOr<std::unique_ptr<Literal>> result_status = ReplayComputation (
145+ module , num_runs , fake_infeed_shape, use_fake_data, client);
123146 if (!result_status.ok ()) {
124147 fprintf (stderr, " %s: error: %s\n " , arg,
125148 result_status.status ().ToString ().c_str ());
126149 exit_status = EXIT_FAILURE;
127150 continue ;
128151 }
152+
129153 std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie ();
130- fprintf (stdout, " %s: %s :: %s:%s\n " , arg, module .entry ().name ().c_str (),
131- ShapeUtil::HumanString (result->shape ()).c_str (),
132- result->ToString ().c_str ());
133- if (module .has_result ()) {
134- fprintf (stdout, " was %s:%s\n " ,
135- ShapeUtil::HumanString (module .result ().shape ()).c_str (),
136- Literal (module .result ()).ToString ().c_str ());
154+ if (result != nullptr ) {
155+ fprintf (stdout, " %s: %s :: %s:%s\n " , arg, module .entry ().name ().c_str (),
156+ ShapeUtil::HumanString (result->shape ()).c_str (),
157+ result->ToString ().c_str ());
158+ if (module .has_result ()) {
159+ fprintf (stdout, " was %s:%s\n " ,
160+ ShapeUtil::HumanString (module .result ().shape ()).c_str (),
161+ Literal (module .result ()).ToString ().c_str ());
162+ }
137163 }
138164 }
139165 return exit_status;
@@ -147,9 +173,12 @@ int main(int argc, char** argv) {
147173 // Flags
148174 xla::string fake_infeed_shape;
149175 bool use_fake_data = false ;
176+ int num_runs = 1 ;
150177 const std::vector<tensorflow::Flag> flag_list = {
151178 tensorflow::Flag (" use_fake_data" , &use_fake_data,
152179 " Replay computation using fake data" ),
180+ tensorflow::Flag (" num_runs" , &num_runs,
181+ " Number of times to run each computation" ),
153182 tensorflow::Flag (" fake_infeed_shape" , &fake_infeed_shape,
154183 " Shape of fake data to construct for (infinite) infeed" ),
155184 };
@@ -162,5 +191,5 @@ int main(int argc, char** argv) {
162191
163192 tensorflow::gtl::ArraySlice<char *> args (argv, argc);
164193 args.pop_front (); // Pop off the binary name, argv[0]
165- return xla::tools::RealMain (args, fake_infeed_shape, use_fake_data);
194+ return xla::tools::RealMain (args, num_runs, fake_infeed_shape, use_fake_data);
166195}
0 commit comments