Skip to content

Commit 723f285

Browse files
Justin Lebartensorflower-gardener
authored andcommitted
[XLA] Improvements to replay_computation tool.
* Reduce threshold at which we run fake-data generation on the device from 1gb to 1mb. At the old threshold, I observed cases where we'd spend many seconds, and >50% of our runtime, in logf(), used for computing random numbers. * Don't retrieve or print the result when running with fake data. Presumably this is uninteresting, because garbage in, garbage out. Retrieving this data can take as long as running the whole computation, and printing it can take many times longer. * Add a LOG(INFO) indicating how long execution took. * Add a --num_runs flag. This is particularly important on GPUs, where the first run does autotuning, and so isn't interesting from a performance perspective. PiperOrigin-RevId: 177185636
1 parent c81a8ae commit 723f285

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

tensorflow/compiler/xla/client/lib/testing.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
5151

5252
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
5353
Client* client) {
54-
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
54+
if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) {
5555
StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
5656
if (!literal_status.ok()) {
5757
// If we got an Unimplemented error, fall back to making the fake data via

tensorflow/compiler/xla/tools/replay_computation.cc

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ namespace {
6565
// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
6666
// otherwise, no infeed is performed.
6767
StatusOr<std::unique_ptr<Literal>> ReplayComputation(
68-
const SessionModule& module, tensorflow::StringPiece fake_infeed_shape,
69-
bool use_fake_data, Client* client) {
68+
const SessionModule& module, int num_runs,
69+
tensorflow::StringPiece fake_infeed_shape, bool use_fake_data,
70+
Client* client) {
7071
TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module));
7172

7273
std::vector<std::unique_ptr<GlobalData>> arguments;
@@ -107,33 +108,58 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(
107108
for (auto& argument : arguments) {
108109
execute_arguments.push_back(argument.get());
109110
}
110-
return client->ExecuteAndTransfer(computation, execute_arguments);
111+
112+
// Run the computation num_runs times, and return the result from the last
113+
// execution.
114+
std::unique_ptr<Literal> result;
115+
for (int i = 0; i < num_runs; ++i) {
116+
ExecutionProfile profile;
117+
if (use_fake_data) {
118+
// If using fake data, execute the computation but don't bother retrieving
119+
// the result -- presumably it's uninteresting, since our data is fake.
120+
TF_RETURN_IF_ERROR(client
121+
->Execute(computation, execute_arguments,
122+
/*execution_options=*/nullptr, &profile)
123+
.status());
124+
} else {
125+
TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer(
126+
computation, execute_arguments,
127+
/*execution_options=*/nullptr, &profile));
128+
}
129+
LOG(INFO) << "Execution took "
130+
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
131+
}
132+
133+
return std::move(result);
111134
}
112135

113-
int RealMain(tensorflow::gtl::ArraySlice<char*> args,
136+
int RealMain(tensorflow::gtl::ArraySlice<char*> args, int num_runs,
114137
tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) {
115138
Client* client = ClientLibrary::LocalClientOrDie();
116139
tensorflow::Env* env = tensorflow::Env::Default();
117140
int exit_status = EXIT_SUCCESS;
118141
for (char* arg : args) {
119142
SessionModule module;
120143
TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
121-
StatusOr<std::unique_ptr<Literal>> result_status =
122-
ReplayComputation(module, fake_infeed_shape, use_fake_data, client);
144+
StatusOr<std::unique_ptr<Literal>> result_status = ReplayComputation(
145+
module, num_runs, fake_infeed_shape, use_fake_data, client);
123146
if (!result_status.ok()) {
124147
fprintf(stderr, "%s: error: %s\n", arg,
125148
result_status.status().ToString().c_str());
126149
exit_status = EXIT_FAILURE;
127150
continue;
128151
}
152+
129153
std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
130-
fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
131-
ShapeUtil::HumanString(result->shape()).c_str(),
132-
result->ToString().c_str());
133-
if (module.has_result()) {
134-
fprintf(stdout, "was %s:%s\n",
135-
ShapeUtil::HumanString(module.result().shape()).c_str(),
136-
Literal(module.result()).ToString().c_str());
154+
if (result != nullptr) {
155+
fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
156+
ShapeUtil::HumanString(result->shape()).c_str(),
157+
result->ToString().c_str());
158+
if (module.has_result()) {
159+
fprintf(stdout, "was %s:%s\n",
160+
ShapeUtil::HumanString(module.result().shape()).c_str(),
161+
Literal(module.result()).ToString().c_str());
162+
}
137163
}
138164
}
139165
return exit_status;
@@ -147,9 +173,12 @@ int main(int argc, char** argv) {
147173
// Flags
148174
xla::string fake_infeed_shape;
149175
bool use_fake_data = false;
176+
int num_runs = 1;
150177
const std::vector<tensorflow::Flag> flag_list = {
151178
tensorflow::Flag("use_fake_data", &use_fake_data,
152179
"Replay computation using fake data"),
180+
tensorflow::Flag("num_runs", &num_runs,
181+
"Number of times to run each computation"),
153182
tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape,
154183
"Shape of fake data to construct for (infinite) infeed"),
155184
};
@@ -162,5 +191,5 @@ int main(int argc, char** argv) {
162191

163192
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
164193
args.pop_front(); // Pop off the binary name, argv[0]
165-
return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data);
194+
return xla::tools::RealMain(args, num_runs, fake_infeed_shape, use_fake_data);
166195
}

0 commit comments

Comments
 (0)