-
Couldn't load subscription status.
- Fork 560
[SPMD][PoC] compile & execute with PjRt #3684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
d38bd7d to 09f4640 Compare 09f4640 to 5e07428 Compare c9399ac to 91262bc Compare 0ca964c to c26f94b Compare 2c62b19 to 3fb72d0 Compare 3fb72d0 to bfa3b85 Compare | CPU test passes, but the GPU fails with the following somewhat unrelated (at least on the outset) error: The mater branch is green, though. cc @JackCaoG |
| Seesm irrelevant, let me just restart the gpu ci |
| I will take another pass and try to merge it. |
| expected = t + t | ||
| | ||
| xt = t.to(xm.xla_device()) | ||
| n_devices = xm.xrt_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does CI run this test or we only run it on TPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only run cpp tests -- covers the internal changes that affects the non-spmd code paths -- and the python API tests are disabled link. I will re-enable it after debugging/ adding the API unit tests.
| virtual void TransferToServer(absl::Span<const TensorSource> tensors, | ||
| absl::Span<const DataPtr> datas) = 0; | ||
| | ||
| // Transfers local sharded tensor values to the TPU servers and returns a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use TPU Device instead of TPU Server, there is no server in PJRT context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| void XLATensor::SetShardingSpec(const ShardingSpec& sharding_spec) { | ||
| XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor"; | ||
| dynamic_cast<XlaNode*>(data()->ir_value.node.get()) | ||
| dynamic_cast<XlaNode*>(GetIrValue().node.get()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, we should add a XlaNodeCast to replace dynamic_cast<XlaNode*> so it is cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I normally prefer more explicit type identifiers especially for casting (similar to avoid using auto too much).
| // TODO(yeounoh): Sharding annotation must be removed by explicit call to | ||
| // ClearSharding. | ||
| ShardingSpecPtr sharding = sharding_spec(); | ||
| if (sharding != nullptr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need a test for this. For example when we deep copy a tensor with sharding, the result tensor should also have sharding. Something similar to
Line 1670 in b3f79cc
| y = copy.deepcopy(x) |
@steventk-g can you add a test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I've created an issue to track it #4095
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, @steventk-g let me handle this if you haven't already started.
| | ||
| auto cached_computation = std::make_shared<CachedComputation>( | ||
| std::move(compile_result.computation)); | ||
| std::move(compile_result.computation), compile_result.is_sharded); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need is_sharded separatelly in CachedComputation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could pass around is_sharded between APIs, or wrap it inside the CachedComputation. Is sharded is later needed for the execution (will be associated with the cached computation only), and the latter doesn't require changing the function APIs here and there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM, I had a question regarding ExecuteReplicated in #3684 (comment). If we can align on that this pr is ready to merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yeounoh ! I will merge this pr to unblock @steventk-g
* Create gspmd test * Add experimental XLAShardedTensor and mark_sharding API * Add ShardingSpec annotation to XLA tensor * Add partitioining test * Update sharding spec to support full replication & mesh sharding * Add testing probe for partitioner * Add spmd partitioner dependency * Tensor sharding annotation and sharded HLO dumping function. * Add checks for _xla_mark_sharding * Compile tensor ops with sharding annotations * Make SpmdPartitioningPass in tensor.Compile() * Use sharding custom_call & add more comments * is_spmd device assignment for xrt_computation * Disable GPU for SPMD * Rebasing master with ltc migration changes * CreateXrtSpmdComputation only if spmd is enalbed in HloModule * Remove xrt changes before landing the experimental feature. * Refactor experimental support for SPMD changes * Update sharding spec to support full replication & mesh sharding * Tensor sharding annotation and sharded HLO dumping function. * Add checks for _xla_mark_sharding * Compile tensor ops with sharding annotations * Make SpmdPartitioningPass in tensor.Compile() * Rebasing master with ltc migration changes * CreateXrtSpmdComputation only if spmd is enalbed in HloModule * PjRt compile partitioned graph with SPMD partitioning option * Introduce ExecuteReplicated in pjrt_computation_client * * Add ShardingUtil::InputHandler for input sharding * Add `SPMD` XlaDeviceType & GetVirtualDevice() * Add PjRtComputationClient::PjRtShardedData * Add more unit tests to XLAShardingTest (C++) * Add more unit tests to XLAShardingTest (Python) * Allow `_xla_mark_sharding` to initiate sharded data transfer * Refactor device transfers to use `BufferFromHostBuffer` * Replace InlinedVector in TransferToServer * * Remove kEnvSpmdTest flag * Fix mark_sharding bugs * Allow XLATensor::SetShardingSpec to receive ShardingSpecPtr * Use unpartitioned tensor shape and device for PjRtShardedData. * Disable partial sharding in mark_sharding * Remove duplicate copies of sharding annotation * Allow ToHlo to return partitioned HLO if sharded * Fix lint errors * * Add/expand CreateTensorsData & InputHandler tests * Add device assignment for SPMD compilation * [SPMD] Refactor `_xla_partitioning_pass`. * [SPMD] Refactor `_xla_mark_sharding`. * [SPMD] Support higher-order mesh topology. * [SPMD] inherit global tensor requires_grad in XLAShardedTensor * [SPMD] Disable aliasing if is_spmd. * [SPMD] Use all devices instead of local * [SPMD] Define clear_sharding interface * [SPMD] experiment with IR sharding preservation. * Rebase master * Refactor and add comments Co-authored-by: Will Cromar <wcromar@google.com>
This is a follow-up to #3476 and contributes to #3871. The changes include:
Compilepartitioned HLO computation graph with sharding annotations.PjRtComputationClientintegration to supportSPMDsharded operations.PjRtShardedDatastruct to represent shardedData.InputHandlerfor parameter sharding and sharded data transfer.ExecuteReplicatedfor partitioned computation.The PoC implementation supports
replicatedandtiledsharding annotations, and single-hostxla:tpubackend. This enables a simple sharded computation on v3-8, like