Skip to content

Conversation

@yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Jul 6, 2022

This is a follow-up to #3476 and contributes to #3871. The changes include:

  • Compile partitioned HLO computation graph with sharding annotations.
  • PjRtComputationClient integration to support SPMD sharded operations.
  • PjRtShardedData struct to represent sharded Data.
  • InputHandler for parameter sharding and sharded data transfer.
  • Remove duplicate copies of sharding annotations.
  • ExecuteReplicated for partitioned computation.

The PoC implementation supports replicated and tiled sharding annotations, and single-host xla:tpu backend. This enables a simple sharded computation on v3-8, like

 t1 = torch.randn(1, 128, device='cpu') t2 = torch.randn(1, 128, device='cpu') expected = t1 @ t2.T xt1 = t1.to(xm.xla_device()) xt2 = t2.to(xm.xla_device()) xs.mark_sharding(xt1, (1, 8), (0, 1)) self.assertEqual('{devices=[1,8]0,1,2,3,4,5,6,7}', torch_xla._XLAC._get_xla_sharding_spec(xt1)) actual = (xt1 @ xt2.T).cpu() self.assertTrue(torch.allclose(expected, actual)) 
@yeounoh yeounoh added DO_NOT_MERGE Not for merging. distributed SPMD and other distributed things. labels Jul 6, 2022
@yeounoh yeounoh self-assigned this Jul 6, 2022
@yeounoh yeounoh marked this pull request as draft July 6, 2022 01:14
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 5 times, most recently from d38bd7d to 09f4640 Compare July 11, 2022 07:18
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch from 09f4640 to 5e07428 Compare July 13, 2022 20:13
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 4 times, most recently from c9399ac to 91262bc Compare July 23, 2022 00:14
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 15 times, most recently from 0ca964c to c26f94b Compare July 26, 2022 04:28
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch 3 times, most recently from 2c62b19 to 3fb72d0 Compare October 13, 2022 06:47
@yeounoh yeounoh force-pushed the xla_spmd_pjrt_integration branch from 3fb72d0 to bfa3b85 Compare October 13, 2022 06:49
@yeounoh
Copy link
Contributor Author

yeounoh commented Oct 13, 2022

CPU test passes, but the GPU fails with the following somewhat unrelated (at least on the outset) error:

*** Begin stack trace ***	tsl::CurrentStackTrace[abi:cxx11]()	gsignal	abort	xla::XrtLocalService::XrtLocalService(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int)	xla::XrtComputationClient::MaybeCreateLocalService(xla::XrtComputationClient::Options const&)	xla::XrtComputationClient::XrtComputationClient(xla::XrtComputationClient::Options, std::unique_ptr<tensorflow::tpu::TopologyProto, std::default_delete<tensorflow::tpu::TopologyProto> >)	xla::ComputationClient::Create()	xla::ComputationClient::Get()	_PyMethodDef_RawFastCallKeywords	_PyEval_EvalFrameDefault	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyFunction_FastCallDict	_PyObject_GenericGetAttrWithDict	_PyEval_EvalFrameDefault	_PyEval_EvalCodeWithName	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyEval_EvalFrameDefault	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyFunction_FastCallDict	_PyEval_EvalFrameDefault	_PyFunction_FastCallDict	_PyEval_EvalFrameDefault	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyEval_EvalCodeWithName	_PyFunction_FastCallKeywords	_PyEval_EvalFrameDefault	_PyEval_EvalCodeWithName	PyEval_EvalCode	PyRun_StringFlags	PyRun_SimpleStringFlags	_Py_UnixMain	__libc_start_main *** End stack trace *** Traceback (most recent call last): File "/tmp/pytorch/xla/test/test_torch_distributed_multi_all_reduce_xla_backend.py", line 38, in <module> xmp.spawn(_mp_fn, args=()) File "/opt/conda/lib/python3.7/site-packages/torch_xla-1.14-py3.7-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 399, in spawn start_method=start_method) File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes while not context.join(): File "/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 146, in join signal_name=name torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT 

The mater branch is green, though. cc @JackCaoG

@JackCaoG
Copy link
Collaborator

Seesm irrelevant, let me just restart the gpu ci

@JackCaoG
Copy link
Collaborator

I will take another pass and try to merge it.

@yeounoh
Copy link
Contributor Author

yeounoh commented Oct 13, 2022

Seesm irrelevant, let me just restart the gpu ci

Yea, this one succeeded. Thanks @JackCaoG

@JackCaoG JackCaoG self-requested a review October 13, 2022 22:39
expected = t + t

xt = t.to(xm.xla_device())
n_devices = xm.xrt_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does CI run this test or we only run it on TPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only run cpp tests -- covers the internal changes that affects the non-spmd code paths -- and the python API tests are disabled link. I will re-enable it after debugging/ adding the API unit tests.

virtual void TransferToServer(absl::Span<const TensorSource> tensors,
absl::Span<const DataPtr> datas) = 0;

// Transfers local sharded tensor values to the TPU servers and returns a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use TPU Device instead of TPU Server, there is no server in PJRT context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

void XLATensor::SetShardingSpec(const ShardingSpec& sharding_spec) {
XLA_CHECK(GetIrValue().node != nullptr) << "Tyring to access a null cursor";
dynamic_cast<XlaNode*>(data()->ir_value.node.get())
dynamic_cast<XlaNode*>(GetIrValue().node.get())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we should add a XlaNodeCast to replace dynamic_cast<XlaNode*> so it is cleaner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I normally prefer more explicit type identifiers especially for casting (similar to avoid using auto too much).

// TODO(yeounoh): Sharding annotation must be removed by explicit call to
// ClearSharding.
ShardingSpecPtr sharding = sharding_spec();
if (sharding != nullptr) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need a test for this. For example when we deep copy a tensor with sharding, the result tensor should also have sharding. Something similar to

y = copy.deepcopy(x)

@steventk-g can you add a test case?

Copy link
Collaborator

@steventk-g steventk-g Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I've created an issue to track it #4095

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, @steventk-g let me handle this if you haven't already started.


auto cached_computation = std::make_shared<CachedComputation>(
std::move(compile_result.computation));
std::move(compile_result.computation), compile_result.is_sharded);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need is_sharded separatelly in CachedComputation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass around is_sharded between APIs, or wrap it inside the CachedComputation. Is sharded is later needed for the execution (will be associated with the cached computation only), and the latter doesn't require changing the function APIs here and there.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, I had a question regarding ExecuteReplicated in #3684 (comment). If we can align on that this pr is ready to merge.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yeounoh ! I will merge this pr to unblock @steventk-g

@JackCaoG JackCaoG merged commit 0f23514 into master Oct 17, 2022
steventk-g pushed a commit that referenced this pull request Oct 18, 2022
* Create gspmd test * Add experimental XLAShardedTensor and mark_sharding API * Add ShardingSpec annotation to XLA tensor * Add partitioining test * Update sharding spec to support full replication & mesh sharding * Add testing probe for partitioner * Add spmd partitioner dependency * Tensor sharding annotation and sharded HLO dumping function. * Add checks for _xla_mark_sharding * Compile tensor ops with sharding annotations * Make SpmdPartitioningPass in tensor.Compile() * Use sharding custom_call & add more comments * is_spmd device assignment for xrt_computation * Disable GPU for SPMD * Rebasing master with ltc migration changes * CreateXrtSpmdComputation only if spmd is enalbed in HloModule * Remove xrt changes before landing the experimental feature. * Refactor experimental support for SPMD changes * Update sharding spec to support full replication & mesh sharding * Tensor sharding annotation and sharded HLO dumping function. * Add checks for _xla_mark_sharding * Compile tensor ops with sharding annotations * Make SpmdPartitioningPass in tensor.Compile() * Rebasing master with ltc migration changes * CreateXrtSpmdComputation only if spmd is enalbed in HloModule * PjRt compile partitioned graph with SPMD partitioning option * Introduce ExecuteReplicated in pjrt_computation_client * * Add ShardingUtil::InputHandler for input sharding * Add `SPMD` XlaDeviceType & GetVirtualDevice() * Add PjRtComputationClient::PjRtShardedData * Add more unit tests to XLAShardingTest (C++) * Add more unit tests to XLAShardingTest (Python) * Allow `_xla_mark_sharding` to initiate sharded data transfer * Refactor device transfers to use `BufferFromHostBuffer` * Replace InlinedVector in TransferToServer * * Remove kEnvSpmdTest flag * Fix mark_sharding bugs * Allow XLATensor::SetShardingSpec to receive ShardingSpecPtr * Use unpartitioned tensor shape and device for PjRtShardedData. * Disable partial sharding in mark_sharding * Remove duplicate copies of sharding annotation * Allow ToHlo to return partitioned HLO if sharded * Fix lint errors * * Add/expand CreateTensorsData & InputHandler tests * Add device assignment for SPMD compilation * [SPMD] Refactor `_xla_partitioning_pass`. * [SPMD] Refactor `_xla_mark_sharding`. * [SPMD] Support higher-order mesh topology. * [SPMD] inherit global tensor requires_grad in XLAShardedTensor * [SPMD] Disable aliasing if is_spmd. * [SPMD] Use all devices instead of local * [SPMD] Define clear_sharding interface * [SPMD] experiment with IR sharding preservation. * Rebase master * Refactor and add comments Co-authored-by: Will Cromar <wcromar@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things.

6 participants