Skip to content
2 changes: 2 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
run_test "$CDIR/eager/test_eager_all_reduce_in_place.py"
run_test "$CDIR/eager/test_eager_spmd.py"
run_test "$CDIR/test_callback.py"
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
34 changes: 34 additions & 0 deletions test/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import threading

from absl.testing import absltest
import torch
import torch_xla
from torch_xla.experimental import callback


class TestExperimentalCallback(absltest.TestCase):

@staticmethod
@torch_xla.compile
def executable():
a, b = torch.randn((100, 100), device=torch_xla.device()), torch.randn(
(100, 100), device=torch_xla.device())
return a @ b

def test_callback(self):
event = threading.Event()
c = self.executable()

def cb(tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it difficult to support the cases where call_back takes more than just the tensor value? call_back might be a general python function that tensor is just one of the inputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should functools.partial in that case to add other arguments to the callable https://docs.python.org/3/library/functools.html#functools.partial

self.assertIs(c, tensor)
# TODO: check that result is both assigned and completed
self.assertNotIn("Data Handle: None",
torch_xla._XLAC._get_xla_tensor_debug_info(tensor))
event.set()

callback.on_ready_callback(c, cb)
event.wait(3)


if __name__ == "__main__":
absltest.main()
38 changes: 31 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "pybind11/attr.h"
#include "pybind11/cast.h"
#include "pybind11/detail/common.h"
#include "pybind11/functional.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
Expand Down Expand Up @@ -133,7 +134,7 @@ void PrepareToExit() {
if (client != nullptr) {
auto xla_device = GetDeviceOrCurrent("");
SetAllReduceToken(xla_device, nullptr);
XLAGraphExecutor::Get()->WaitDeviceOps({});
WaitDeviceOps();
}
}

Expand Down Expand Up @@ -2619,6 +2620,29 @@ void InitXlaModuleBindings(py::module m) {
return false;
});

m.def("_on_ready_callback",
[](const at::Tensor& tensor, const std::function<void()>& callback) {
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
XLA_CHECK(xtensor) << "The input is not an XLA tensor.";
// Wait for placeholder `Data`s to be assigned
XLAGraphExecutor::Get()->WaitDeviceOps({});
std::shared_ptr<runtime::ComputationClient::Data> data;
if (xtensor->CurrentDataHandle() != nullptr) {
data = UnwrapXlaData(xtensor->CurrentDataHandle());
} else if (xtensor->CurrentIrValue().node != nullptr) {
DeviceData* device_data =
DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (device_data != nullptr) {
data = UnwrapXlaData(device_data->data());
} else {
XLA_ERROR() << "Could not get the buffer pointer for XLATensor "
"with IR that's not DeviceData";
}
XLA_ERROR() << "Could not get buffer for tensor";
}
runtime::GetComputationClient()->OnReadyCallback(data, callback);
});

m.def("_unsafe_buffer_pointer",
[](const at::Tensor& input) -> std::uintptr_t {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down Expand Up @@ -2646,9 +2670,9 @@ void InitXlaModuleBindings(py::module m) {

// from an XLA tensor to a PyCapsule.
// When consuming the PyCapsule, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results.
// (waits for all kernels in all streams on a CUDA device to complete) if
// the current stream is different from the ext_data's stream. Otherwise, we
// may risk of getting incorrect results.
m.def("_to_dlpack", [](const at::Tensor& input) -> py::handle {
DLManagedTensor* dlMTensor;
{
Expand All @@ -2660,9 +2684,9 @@ void InitXlaModuleBindings(py::module m) {

// from a dlpack PyCapsule to an XLA tensor
// If ext_data is the result of an CUDA computation, we should synchronize
// (waits for all kernels in all streams on a CUDA device to complete) if the
// current stream is different from the ext_data's stream. Otherwise, we may
// risk of getting incorrect results. Or you can use torch_xla's
// (waits for all kernels in all streams on a CUDA device to complete) if
// the current stream is different from the ext_data's stream. Otherwise, we
// may risk of getting incorrect results. Or you can use torch_xla's
// from_dlpack(cuda_tensor) and it will handle the synchronization for you.
m.def("_from_dlpack", [](py::handle ext_data) -> at::Tensor {
return tensor_fromDLPack(ext_data.ptr());
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ class ComputationClient {
void* function_ptr,
const std::string& platform) = 0;

// Installs a callback to be called when the buffer backing `data` is ready.
virtual void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) = 0;

// Utility API around the vector based Compile() API to compile a single
// computation.
ComputationPtr Compile(xla::XlaComputation computation,
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ class IfrtComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
};

void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

private:
std::shared_ptr<xla::ifrt::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,5 +1033,22 @@ void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name,
}
}

void PjRtComputationClient::OnReadyCallback(
ComputationClient::DataPtr data, const std::function<void()>& callback) {
std::shared_ptr<xla::PjRtBuffer> buffer;
if (auto pjrt_data = std::dynamic_pointer_cast<PjRtData>(data)) {
buffer = pjrt_data->buffer;
} else if (auto sharded_data =
std::dynamic_pointer_cast<PjRtShardedData>(data)) {
XLA_CHECK(sharded_data->shards.size()) << "sharded data has no shards";
buffer = sharded_data->shards[0]->buffer;
} else {
XLA_ERROR() << "received invalid data pointer";
}
XLA_CHECK(buffer) << "received placeholder data as argument";
buffer->GetReadyFuture().OnReady(
[callback](absl::Status unused) { callback(); });
}

} // namespace runtime
} // namespace torch_xla
3 changes: 3 additions & 0 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ class PjRtComputationClient : public ComputationClient {
void RegisterCustomCall(const std::string& fn_name, void* function_ptr,
const std::string& platform) override;

void OnReadyCallback(DataPtr data,
const std::function<void()>& callback) override;

private:
std::unique_ptr<xla::PjRtClient> client_;
std::unique_ptr<XlaCoordinator> coordinator_;
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/experimental/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Callable
import torch
import torch_xla


def on_ready_callback(tensor, callback: Callable[[torch.Tensor], None]):
"""Installs callback on `tensor` to be called when underlying buffer is ready.

Note: Since `callback` will need to re-acquire the GIL since it is a Python
callable. If the main thread is blocking on `callback` and holding the GIL,
this will result in a deadlock.
"""

def _callback_wrapper():
callback(tensor)

torch_xla._XLAC._on_ready_callback(tensor, _callback_wrapper)