Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions third_party/xla_client/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ class ComputationClient {
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape) = 0;

// Copies `data->buffer` to `dst` device buffer.
virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0;

// Reads the tensor literal values stored at TPU server sites, behind the
// supplied handles.
virtual std::vector<Literal> TransferFromServer(
Expand Down
22 changes: 22 additions & 0 deletions third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "absl/strings/ascii.h"
#include "absl/types/span.h"
#include "pjrt_computation_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
Expand Down Expand Up @@ -171,6 +172,27 @@ ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer(
return std::make_shared<PjRtShardedData>(device, shape, pjrt_data_shards);
}

ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
ComputationClient::DataPtr data, std::string dst) {
tensorflow::profiler::TraceMe activity(
"PjRtComputationClient::CopyToDevice",
tensorflow::profiler::TraceMeLevel::kInfo);
const PjRtData* pjrt_data = dynamic_cast<PjRtData*>(data.get());
XLA_CHECK(pjrt_data->HasValue()) << "Can't copy invalid device data.";

PjRtDevice* dst_device = StringToPjRtDevice(dst);
XLA_CHECK(dst_device->IsAddressable()) << dst << "is not addressable.";

// Returns error if the buffer is already on `dst_device`.
StatusOr<std::unique_ptr<PjRtBuffer>> status_or =
pjrt_data->buffer->CopyToDevice(dst_device);
XLA_CHECK(status_or.ok())
<< pjrt_data->device() << " buffer already exists on " << dst;

return std::make_shared<PjRtData>(dst, pjrt_data->shape(),
std::move(status_or.value()));
}

std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
absl::Span<const DataPtr> handles) {
metrics::TimedSection timed(TransferFromServerMetric());
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla_client/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class PjRtComputationClient : public ComputationClient {
DataPtr TransferShardsToServer(absl::Span<const TensorSource> tensor_shards,
std::string device, xla::Shape shape) override;

DataPtr CopyToDevice(DataPtr data, std::string dst) override;

std::vector<ComputationPtr> Compile(
std::vector<CompileInstance> instances) override;

Expand Down
4 changes: 4 additions & 0 deletions third_party/xla_client/xrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ class XrtComputationClient : public ComputationClient {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

DataPtr CopyToDevice(DataPtr data, std::string dst) override {
XLA_ERROR() << __FUNCTION__ << " not implemented";
}

std::vector<Literal> TransferFromServer(
absl::Span<const DataPtr> handles) override;

Expand Down
31 changes: 3 additions & 28 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,38 +195,13 @@ ShardingUtil::InputHandler(
int64_t source_device_i =
ParseDeviceString(shards[0]->device()).ordinal();
arguments_by_device[source_device_i][argument_i] = shards[0];

auto literal = std::make_shared<xla::Literal>(std::move(
xla::ComputationClient::Get()->TransferFromServer(shards)[0]));
std::vector<xla::ComputationClient::TensorSource> source_tensors;
for (int64_t device_i = 0; device_i < devices.size(); ++device_i) {
if (device_i != source_device_i) {
auto populate_fn =
[&](const xla::ComputationClient::TensorSource& source_tensor,
void* dest_buffer, size_t dest_buffer_size) {
std::memcpy(dest_buffer, literal->untyped_data(),
dest_buffer_size);
};
source_tensors.emplace_back(
xla::Shape(shards[0]->shape().ToProto()),
ParseDeviceString(absl::StrFormat(":%d", device_i)).toString(),
std::move(populate_fn));
}
}

std::vector<xla::ComputationClient::DataPtr> replicated_shards =
xla::ComputationClient::Get()->TransferToServer(source_tensors);
auto itr = replicated_shards.begin();
for (int64_t device_i = 0; device_i < devices.size(); ++device_i) {
if (device_i != source_device_i) {
arguments_by_device[device_i][argument_i] = *itr;
++itr;
arguments_by_device[device_i][argument_i] =
xla::ComputationClient::Get()->CopyToDevice(shards[0],
devices[device_i]);
}
}
XLA_CHECK(itr == replicated_shards.end())
<< "Replicated arguments[" << argument_i << "] on "
<< shards[0]->device() << " " << replicated_shards.size()
<< " times (expected " << (devices.size() - 1) << ").";
}
}

Expand Down