Skip to content
7 changes: 7 additions & 0 deletions test/cpp/metrics_snapshot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "tensorflow/compiler/xla/xla_client/tf_logging.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch/csrc/lazy/core/metrics.h"

namespace torch_xla {
namespace cpp_test {
Expand All @@ -20,6 +21,12 @@ MetricsSnapshot::MetricsSnapshot() {
xla::metrics::CounterData* counter = xla::metrics::GetCounter(name);
counters_map_.emplace(name, counter->Value());
}

// See NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
for (auto& name : torch::lazy::GetCounterNames()) {
auto* counter = torch::lazy::GetCounter(name);
counters_map_.emplace(name, counter->Value());
}
}

std::vector<MetricsSnapshot::ChangedCounter> MetricsSnapshot::CounterChanged(
Expand Down
51 changes: 50 additions & 1 deletion test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_short_metrics_report_default_list(self):
xm.mark_step()
t4 = t1 * 2
xm.mark_step()
short_report = met.short_metrics_report()
self.assertIn("CachedCompile", short_report)

def test_short_metrics_report_custom_list(self):
Expand All @@ -69,9 +70,11 @@ def test_short_metrics_report_custom_list(self):
# using the default metrics list in this case
self.assertIn('CompileTime', short_report)
short_report = met.short_metrics_report(
counter_names=['CreateCompileHandles'], metric_names=['InboundData'])
counter_names=['CreateCompileHandles'],
metric_names=['InboundData', 'InputOutputAliasCount'])
self.assertNotIn('CompileTime', short_report)
self.assertIn('InboundData', short_report)
self.assertIn('InputOutputAliasCount', short_report)

def test_short_metrics_fallback_counter(self):
xla_device = xm.xla_device()
Expand All @@ -87,6 +90,52 @@ def test_short_metrics_fallback_counter(self):
counter_names=['CreateCompileHandles'],
metric_names=['InboundData']))

def test_metrics_report(self):
# TODO(jwtan): Add test to cover TrimIrGraph, SyncTensorsToData, TransferToServerAsync, IrValueTensorToXlaData
xla_device = xm.xla_device()
t1 = torch.tensor(1456, device=xla_device)
t2 = t1 * 2
xm.mark_step()
t2_cpu = t2.cpu()
report = met.metrics_report()

# counters
self.assertIn("DeviceDataCacheMiss", report)
self.assertIn("CreateXlaTensor", report)
self.assertIn("DestroyXlaTensor", report)
self.assertIn("UncachedCompile", report)
self.assertIn("MarkStep", report)
# If test_metrics_report is ran together with other tests,
# the number could be different. So we simply assert them
# to be none-zero.
self.assertNotEqual(len(met.counter_names()), 0)
self.assertNotEqual(met.counter_value("DeviceDataCacheMiss"), 0)
self.assertNotEqual(met.counter_value("CreateXlaTensor"), 0)
self.assertNotEqual(met.counter_value("DestroyXlaTensor"), 0)
self.assertNotEqual(met.counter_value("UncachedCompile"), 0)
self.assertNotEqual(met.counter_value("MarkStep"), 0)

met.clear_counters()
self.assertEqual(met.counter_value("DeviceDataCacheMiss"), 0)

# metrics
self.assertIn("TensorsGraphSize", report)
self.assertIn("InputOutputAliasCount", report)

# timed metrics
self.assertIn("TensorToData", report)
self.assertIn("UnwrapXlaData", report)
self.assertIn("WrapXlaData", report)
self.assertIn("DeviceLockWait", report)

# repeat the same computation and expect to see the CachedCompile counter
t3 = t1 * 2
xm.mark_step()
t4 = t1 * 2
xm.mark_step()
report = met.metrics_report()
self.assertIn("CachedCompile", report)


if __name__ == '__main__':
test = unittest.main()
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla_client/metrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ class Counter {
mutable std::atomic<CounterData*> data_;
};

// XLA_COUNTER should only be used within xla_client. Please use
// TORCH_LAZY_COUNTER in pytorch/xla. For more information, see
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
#define XLA_COUNTER(name, value) \
do { \
static ::xla::metrics::Counter* __counter = \
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ at::Tensor XLANativeFunctions::einsum(c10::string_view equation,
if (tensors.size() < 1 || tensors.size() > 2 ||
!EinsumUtilities::EquationIsValid(cleansed_equation) ||
TensorsAreOfType(xla_tensors, at::ScalarType::Long)) {
XLA_COUNTER("EinsumFallback", 1);
TORCH_LAZY_COUNTER("EinsumFallback", 1);
return at::native::einsum(equation, tensors, path);
}
return aten_autograd_ops::EinsumAutogradFunction::apply(cleansed_equation,
Expand Down Expand Up @@ -2987,7 +2987,7 @@ at::Scalar XLANativeFunctions::_local_scalar_dense(const at::Tensor& self) {
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
XLATensor::SyncLiveTensorsGraph(&self_tensor->GetDevice(), /*devices=*/{},
/*wait=*/true);
XLA_COUNTER("EarlySyncLiveTensorsCount", 1);
TORCH_LAZY_COUNTER("EarlySyncLiveTensorsCount", 1);
}
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(_local_scalar_dense)>::call(self);
Expand Down
68 changes: 49 additions & 19 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1182,31 +1182,61 @@ void InitXlaModuleBindings(py::module m) {
XLATensor::WaitDeviceOps(devices);
},
py::arg("devices"));
m.def("_xla_counter_names", []() { return xla::metrics::GetCounterNames(); });
m.def("_xla_counter_names", []() {
auto counter_names = torch::lazy::GetCounterNames();
auto xla_counter_names = xla::metrics::GetCounterNames();
counter_names.insert(counter_names.end(), xla_counter_names.begin(),
xla_counter_names.end());
return counter_names;
});
m.def("_xla_counter_value", [](const std::string& name) -> py::object {
xla::metrics::CounterData* data = xla::metrics::GetCounter(name);
return data != nullptr ? py::cast<int64_t>(data->Value()) : py::none();
auto* data = torch::lazy::GetCounter(name);
if (data != nullptr) {
return py::cast<int64_t>(data->Value());
}

auto* xla_data = xla::metrics::GetCounter(name);
return xla_data != nullptr ? py::cast<int64_t>(xla_data->Value())
: py::none();
});
m.def("_xla_metric_names", []() { return xla::metrics::GetMetricNames(); });
m.def("_xla_metric_data", [](const std::string& name) -> py::object {
return GetMetricData(name);
});
m.def("_xla_metrics_report",
[]() { return xla::metrics_reader::CreateMetricReport(); });
m.def("_short_xla_metrics_report",
[](const py::list& counter_names, const py::list& metric_names) {
std::vector<std::string> counter_name_vec;
std::vector<std::string> metric_name_vec;
for (auto& counter : counter_names) {
counter_name_vec.push_back(counter.cast<std::string>());
}
for (auto& metric : metric_names) {
metric_name_vec.push_back(metric.cast<std::string>());
}
return xla::metrics_reader::CreateMetricReport(counter_name_vec,
metric_name_vec);
});
m.def("_clear_xla_counters", []() { xla::metrics::ClearCounters(); });
m.def("_xla_metrics_report", []() {
// NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER]
// Counters and Metrics are divided into two groups: one in PyTorch/XLA and
// another in ComputationClient. Therefore, we need to stitch the report
// together. Ideally, those two sets shouldn't have any overlaps. The reason
// why is that we cannot have ComputationClient to use the
// TORCH_LAZY_COUNTER as it currently cannot depend on PyTorch (as part of
// TensorFlow).
// TODO(jwtan): Unify them once ComputationClient becomes a standalone
// library.
return torch::lazy::CreateMetricReport() +
xla::metrics_reader::CreateMetricReport();
});
m.def("_short_xla_metrics_report", [](const py::list& counter_names,
const py::list& metric_names) {
std::vector<std::string> counter_name_vec;
std::vector<std::string> metric_name_vec;
for (auto& counter : counter_names) {
counter_name_vec.push_back(counter.cast<std::string>());
}
for (auto& metric : metric_names) {
metric_name_vec.push_back(metric.cast<std::string>());
}
// See NOTE: [TORCH_LAZY_COUNTER v.s. XLA_COUNTER].
return torch::lazy::CreateMetricReport(counter_name_vec, metric_name_vec) +
xla::metrics_reader::CreateMetricReport(counter_name_vec,
metric_name_vec);
});
m.def("_clear_xla_counters", []() {
// TODO(jwtan): We should probably upstream the ability to reset counters
// and metrics separately to upstream.
torch::lazy::MetricsArena::Get()->Reset();
xla::metrics::ClearCounters();
});
m.def("_clear_xla_metrics", []() { xla::metrics::ClearMetrics(); });
m.def("_xla_tensors_report",
[](size_t nodes_threshold, const std::string& device) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/op_by_op_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
ComputeNodeKey(node, op_input_shapes, nodes_key_seed);
cxop.computation = compile_cache_.Get(cache_key);
if (cxop.computation == nullptr) {
XLA_COUNTER("OpByOpCompileCacheMiss", 1);
TORCH_LAZY_COUNTER("OpByOpCompileCacheMiss", 1);

// Within a single IR graph, there can be many duplicated IR nodes, so
// make sure we do not issue an XLA compilation for each one of those.
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/dynamic_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ SizeNode::SizeNode(torch::lazy::Value input, size_t dim)

int64_t SizeNode::getDynamicValue() const {
if (dynamic_value_computed_) {
XLA_COUNTER("CachedSizeNodeValue", 1);
TORCH_LAZY_COUNTER("CachedSizeNodeValue", 1);
return runtime_size_;
}
torch::lazy::NodePtr cloned =
Expand Down
18 changes: 10 additions & 8 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "tensorflow/compiler/xla/xla_client/cache.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/env_vars.h"
#include "tensorflow/compiler/xla/xla_client/metrics.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/thread_pool.h"
#include "tensorflow/compiler/xla/xla_client/unique.h"
Expand All @@ -31,6 +30,7 @@
#include "torch/csrc/lazy/core/helpers.h"
#include "torch/csrc/lazy/core/ir_util.h"
#include "torch/csrc/lazy/core/lazy_graph_executor.h"
#include "torch/csrc/lazy/core/metrics.h"
#include "torch/csrc/lazy/core/tensor_util.h"
#include "torch/csrc/lazy/core/util.h"
#include "torch_xla/csrc/computation.h"
Expand Down Expand Up @@ -242,7 +242,7 @@ torch::lazy::BackendDataPtr GetDeviceData(
at::Tensor tensor_copy = torch::lazy::CopyTensor(tensor);
device_data = TensorToXlaData(tensor_copy, device);
cache->Add(std::move(tensor_copy), device_data);
XLA_COUNTER("DeviceDataCacheMiss", 1);
TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
}
return device_data;
}
Expand Down Expand Up @@ -303,14 +303,14 @@ class XLATensor::DeviceContextArena {
DeviceContext* devctx = GetDeviceContext(data->device);
std::lock_guard<std::mutex> lock(devctx->lock);
devctx->tensors_data.emplace(data->unique_id, data);
XLA_COUNTER("CreateXlaTensor", 1);
TORCH_LAZY_COUNTER("CreateXlaTensor", 1);
}

void UnregisterTensor(Data* data) {
DeviceContext* devctx = GetDeviceContext(data->device);
std::lock_guard<std::mutex> lock(devctx->lock);
devctx->tensors_data.erase(data->unique_id);
XLA_COUNTER("DestroyXlaTensor", 1);
TORCH_LAZY_COUNTER("DestroyXlaTensor", 1);
}

std::vector<XLATensorPtr> GetLiveTensors(
Expand Down Expand Up @@ -742,7 +742,7 @@ void XLATensor::TryLimitGraphSize() {
size_t graph_size =
torch::lazy::Util::GetGraphSize({data()->ir_value.node.get()});
if (graph_size > kMaxPendingGraphSize) {
XLA_COUNTER("TrimIrGraph", 1);
TORCH_LAZY_COUNTER("TrimIrGraph", 1);
ApplyPendingGraph();
}
}
Expand Down Expand Up @@ -1322,7 +1322,7 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors(
coll.hash,
xla::ComputationClient::Get()->GetResourceDomain(coll.device.toString()));
if (!at_tensors.empty()) {
XLA_COUNTER("SyncTensorsToData", at_tensors.size());
TORCH_LAZY_COUNTER("SyncTensorsToData", at_tensors.size());
// Create data handles with shardings. If a tensor has a
// sharding annotation, then a BackendDataPtr with PjRtShardedData is
// returned; if there is no sharding annotation, then a BackendDataPtr with
Expand All @@ -1348,7 +1348,7 @@ XLATensor::ComputationCache::TypePtr XLATensor::LookupCachedCompile(
ComputationCache::TypePtr cached_computation =
GetComputationCache()->Get(hash);
if (cached_computation == nullptr) {
XLA_COUNTER("UncachedCompile", 1);
TORCH_LAZY_COUNTER("UncachedCompile", 1);
return nullptr;
}
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash)
Expand All @@ -1357,7 +1357,7 @@ XLATensor::ComputationCache::TypePtr XLATensor::LookupCachedCompile(
cached_computation->computation->computation()
.proto()
.SerializeAsString()));
XLA_COUNTER("CachedCompile", 1);
TORCH_LAZY_COUNTER("CachedCompile", 1);
return cached_computation;
}

Expand Down Expand Up @@ -1619,6 +1619,8 @@ void XLATensor::SyncLiveTensorsGraph(const torch::lazy::BackendDevice* device,
}

void XLATensor::MarkStep(const torch::lazy::BackendDevice& device) {
// TODO(jwtan): Replace this with TORCH_LAZY_COUNTER. We need MarkStep to
// remain as XLA_COUNTER to support xla::metrics::CreatePerformanceReport().
XLA_COUNTER("MarkStep", 1);
DeviceContextArena::Get()->MarkStep(device);
torch::lazy::ScopePusher::ResetScopes();
Expand Down