Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
36c9d3d
fixed it
vanbasten23 Dec 1, 2023
e1164be
added the test
vanbasten23 Dec 4, 2023
61d4cda
remove prints
vanbasten23 Dec 4, 2023
a025c23
fix up
vanbasten23 Dec 4, 2023
cb37524
replace gpu with cuda
vanbasten23 Dec 5, 2023
93f1dc5
fix up
vanbasten23 Jan 4, 2024
b6c58ca
trigger another ci
vanbasten23 Jan 5, 2024
d724f64
fix the test_runtime.py test so it runs for GPU as well.
vanbasten23 Jan 5, 2024
1dd4261
remove pdb
vanbasten23 Jan 5, 2024
b39fbb2
skip a test temperarily.
vanbasten23 Jan 6, 2024
9a7b423
Fix the test XLAShardingTest.CreateTensorsData
vanbasten23 Jan 8, 2024
b90b342
skip another multithreading test for gpu
vanbasten23 Jan 8, 2024
ee243d0
remove prints
vanbasten23 Jan 9, 2024
3fa038f
remove more prints
vanbasten23 Jan 9, 2024
9823269
remove global_runtime_device_count test case
vanbasten23 Jan 9, 2024
b10570b
fix linter
vanbasten23 Jan 16, 2024
1832477
fix the broken tests
vanbasten23 Jan 16, 2024
d590e30
fix a build issue
vanbasten23 Jan 17, 2024
4d74168
fix linter
vanbasten23 Jan 17, 2024
b154538
fix build after pin update and a failing spmd test
vanbasten23 Jan 19, 2024
6bfdf19
fix linter
vanbasten23 Jan 19, 2024
c5e717e
fix comments
vanbasten23 Jan 20, 2024
0291311
Incorporate the fix cl/601517680
vanbasten23 Jan 26, 2024
3c774fe
add comment to the patch and fix linter
vanbasten23 Jan 27, 2024
862f7da
add a single processing gpu test.
vanbasten23 Jan 31, 2024
5d64007
add print in spmd tests. Local test works but fail in the CI
vanbasten23 Jan 31, 2024
d59457f
fix test and linter
vanbasten23 Feb 1, 2024
e8e1697
clean up prints
vanbasten23 Feb 1, 2024
db3d0c8
clean up another prints
vanbasten23 Feb 1, 2024
10e459b
fix linter
vanbasten23 Feb 1, 2024
aee08df
fix BasicShardingTest.test_2d_tensor_3d_mesh on cpu
vanbasten23 Feb 2, 2024
b3991a7
fix comments
vanbasten23 Feb 2, 2024
953d9b4
remove unwanted function
vanbasten23 Feb 3, 2024
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ http_archive(
"//openxla_patches:cache_urls.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_hanging.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
Expand Down
36 changes: 36 additions & 0 deletions openxla_patches/gpu_hanging.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// This patch is for https://github.com/openxla/xla/commit/ec0177de1748b4ebb0ecbd6f26043fdb1eb47d24.
// It can be removed in the next openXLA pin update after 01/26/2024.
diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment with the correspond git commit hash in openXLA?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

index 0f1818be2..c181f3025 100644
--- a/xla/service/gpu/gpu_executable.cc
+++ b/xla/service/gpu/gpu_executable.cc
@@ -382,9 +382,13 @@ absl::Status ExecuteThunks(const std::string& module_name,
}
}

- // Maybe join a round of rendezvous after thunk initialization.
- TF_RETURN_IF_ERROR(
- MaybeRendezvousAfterInitialization(run_options, thunks_initialized));
+ // Maybe join a round of rendezvous after thunk initialization. We do this
+ // only in presence of collective cliques which means that we have collective
+ // operations in the XLA operations that tend to cause deadlocks.
+ if (!collective_cliques.empty()) {
+ TF_RETURN_IF_ERROR(
+ MaybeRendezvousAfterInitialization(run_options, thunks_initialized));
+ }

// Prepare parameters for thunks execution.
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create(
diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h
index 51a566b8f..94bab421f 100644
--- a/xla/service/gpu/thunk.h
+++ b/xla/service/gpu/thunk.h
@@ -175,6 +175,8 @@ class Thunk {
absl::StatusOr<NcclComm::Lock> GetComm(const NcclCliqueKey& clique_key,
int32_t rank) const;

+ bool empty() const { return cliques_map_.empty(); }
+
private:
CliquesMap cliques_map_;
};
16 changes: 11 additions & 5 deletions test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ void TestSingleReplication(
instances.emplace_back(CreateCrsComputation(shape), device_str,
all_device_strings, &shape);
}
auto compiled_computations =
torch_xla::runtime::GetComputationClient()->Compile(std::move(instances));
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr>
compiled_computations =
torch_xla::runtime::GetComputationClient()->Compile(
std::move(instances));

std::vector<at::Tensor> tensors;
for (size_t i = 0; i < device_strings.size(); ++i) {
tensors.push_back(at::ones({8, 8}, at::TensorOptions(at::kFloat)));
}
auto tensors_data = CreateTensorsData(tensors, device_strings);
std::vector<torch::lazy::BackendDataPtr> tensors_data =
CreateTensorsData(tensors, device_strings);

std::vector<std::vector<torch_xla::runtime::ComputationClient::DataPtr>>
results(device_strings.size());
Expand All @@ -75,7 +78,7 @@ void TestSingleReplication(
counter.Wait();

for (size_t i = 0; i < results.size(); ++i) {
auto literals =
std::vector<xla::Literal> literals =
torch_xla::runtime::GetComputationClient()->TransferFromDevice(
results[i]);
ASSERT_EQ(literals.size(), 1);
Expand All @@ -92,9 +95,12 @@ void TestSingleReplication(

class ReplicationTest : public AtenXlaTensorTestBase {};

// Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU
// device per process instead of relying on threads so we will not run the test
// on GPU.
TEST_F(ReplicationTest, TestNSingleReplication) {
WithAllDevices(
{XlaDeviceType::TPU, XlaDeviceType::CUDA},
{XlaDeviceType::TPU},
[&](const std::vector<torch::lazy::BackendDevice>& devices,
const std::vector<torch::lazy::BackendDevice>& all_devices) {
TestSingleReplication(devices, all_devices);
Expand Down
5 changes: 0 additions & 5 deletions test/cpp/test_xla_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,6 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
}

TEST_F(XLAShardingTest, CreateTensorsData) {
if (torch_xla::runtime::sys_util::GetEnvString(
torch_xla::runtime::env::kEnvPjRtDevice, "") == "") {
GTEST_SKIP() << "`PJRT_DEVICE` is not set.";
}

std::vector<at::Tensor> tensors(2);
auto tensor = at::ones({8, 8}, at::TensorOptions(at::kFloat));
xla::Shape tensor_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
class TestExperimentalPjrtGpu(parameterized.TestCase):
class TestExperimentalPjrtMultiGpu(parameterized.TestCase):

def setUp(self):
xr.set_device_type('CUDA')
Expand Down
49 changes: 49 additions & 0 deletions test/pjrt/test_runtime_single_proc_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import concurrent.futures
import itertools
import os
import queue
import requests
import unittest
import subprocess

import numpy as np
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import runtime as xr
from torch_xla._internal import pjrt
from absl.testing import absltest, parameterized


@unittest.skipIf(xr.device_type() != "CUDA",
f"GPU tests should only run on GPU devices.")
class TestExperimentalSingleProcPjrtGpu(parameterized.TestCase):

@classmethod
def setUpClass(cls):
command = 'nvidia-smi --list-gpus | wc -l'
result = subprocess.run(
command,
capture_output=True,
shell=True,
check=True,
text=True,
)
cls.num_cuda_devices = int(result.stdout)

def test_num_local_devices(self):
self.assertLen(xm.get_xla_supported_devices(),
xr.addressable_device_count())
self.assertEqual(self.num_cuda_devices, xr.addressable_device_count())

def test_num_global_devices(self):
self.assertLen(torch_xla._XLAC._xla_get_all_devices(),
xr.global_device_count())
self.assertEqual(self.num_cuda_devices, xr.global_device_count())


if __name__ == '__main__':
absltest.main()
4 changes: 4 additions & 0 deletions test/pjrt/test_torchrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def setUp(self):
def tearDown(self) -> None:
dist.destroy_process_group()

def test_addressable_device_count(self):
devices_per_process = xr.addressable_device_count()
self.assertEqual(devices_per_process, 1)

def test_all_gather(self):
dist_world_size = xu.getenv_as('WORLD_SIZE', int)
devices_per_thread = xr.addressable_device_count()
Expand Down
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_hlo_metadata.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/pjrt/test_runtime.py"
run_test "$CDIR/pjrt/test_runtime_gpu.py"
run_test "$CDIR/pjrt/test_runtime_single_proc_gpu.py"
run_test "$CDIR/pjrt/test_runtime_multi_gpu.py"
run_test "$CDIR/pjrt/test_runtime_multi_cpu.py"
run_test "$CDIR/pjrt/test_internal_tpu.py"
run_test "$CDIR/pjrt/test_ddp.py"
Expand Down
55 changes: 30 additions & 25 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,13 @@ def test_xla_sharding_type(self):
t = torch.randn(10, 20).to(xm.xla_device())
self.assertEqual(torch_xla._XLAC._get_xla_sharding_type(t), None)

x_dim = 2 if self.n_devices % 4 == 0 else 1
x_dim = 2 if self.n_devices >= 2 else 1
# if self.n_devices==4, mesh=(2,2)
# if self.n_devices==2, mesh=(2,1)
# if self.n_devices==1, mesh=(1,1)
mesh = self._get_mesh((x_dim, self.n_devices // x_dim))
xt = xs.mark_sharding(t, mesh, (0, 1))
if self.n_devices > 1:
if self.n_devices >= 2:
self.assertEqual(xt.sharding_type, xs.ShardingType.TILED)
else:
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
Expand All @@ -221,7 +224,7 @@ def test_xla_sharding_type(self):

xs.clear_sharding(t)
xt = xs.mark_sharding(t, mesh, (None, 1))
if self.n_devices > 1:
if mesh.get_logical_mesh().shape[1] > 1:
self.assertEqual(xt.sharding_type, xs.ShardingType.PARTIAL)
else:
self.assertEqual(xt.sharding_type, xs.ShardingType.REPLICATED)
Expand Down Expand Up @@ -339,14 +342,13 @@ def test_mark_sharding_partial(self):
mesh = self._get_mesh((z_dim, self.n_devices // z_dim))
xt1 = xs.mark_sharding(t1, mesh, (0, None))

# partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
# partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertTrue('[%d,1,%d]' %
(z_dim, self.n_devices //
z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('[%d,1,%d]' % (z_dim, self.n_devices // z_dim),
torch_xla._XLAC._get_xla_sharding_spec(t1))
# replicated group should share the same data content.
if (self.n_devices // z_dim) > 1:
shards = xt1.local_shards
Expand Down Expand Up @@ -381,14 +383,13 @@ def test_mark_sharding_partial_unordered(self):
mesh = self._get_mesh((z_dim, 1, self.n_devices // z_dim))
xt1 = xs.mark_sharding(t1, mesh, (1, None, 0))

# partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
# partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
# xt1 is sharded `z_dim`-way, replicated `n_devices/z_dim`-way.
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertTrue('[1,1,%d,%d]' %
(z_dim, self.n_devices //
z_dim) in torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(t1))
self.assertIn('[1,1,%d,%d]' % (z_dim, self.n_devices // z_dim),
torch_xla._XLAC._get_xla_sharding_spec(t1))
# replicated group should share the same data content.
if (self.n_devices // z_dim) > 1:
shards = xt1.local_shards
Expand Down Expand Up @@ -485,14 +486,14 @@ def test_partial_replication_addmm(self):
xs.mark_sharding(xw, mesh, (None, 1))

# Check if the partial replication annotations are passed to the compiler.
# Note that partial replication requires >1 devices; otherwise, it's replicated.
if self.n_devices > 1:
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(xx))
self.assertTrue('last_tile_dim_replicate' in
torch_xla._XLAC._get_xla_sharding_spec(xw))
# Note that partial replication requires >= 4 devices; otherwise, it's replicated.
if self.n_devices >= 4:
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(xx))
self.assertIn('last_tile_dim_replicate',
torch_xla._XLAC._get_xla_sharding_spec(xw))
actual = (xx @ xw + xb).cpu()
self.assertTrue(torch.allclose(expected, actual))
self.assertTrue(torch.allclose(expected, actual, atol=1e-5))

def test_clear_sharding(self):
xt = torch.randn(2, 4, 8, 16).to(xm.xla_device())
Expand Down Expand Up @@ -723,10 +724,14 @@ def test_2d_tensor_3d_mesh(self):
# Meaningful test for higher-order mesh with extra replication
# requires multiple devices. Otherwise, this should defaults back to
# full replication.
if self.n_devices > 1:
if self.n_devices >= 4:
mesh = self._get_mesh((2, self.n_devices // 2, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = 'sharding={devices=[1,%d,2]' % (self.n_devices // 2)
elif self.n_devices == 2:
mesh = self._get_mesh((2, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
sharding_annotation = "sharding={replicated}"
else:
mesh = self._get_mesh((1, 1, 1))
xs.mark_sharding(t1, mesh, partition_spec=(2, 1))
Expand Down
2 changes: 0 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ def diff_output(testcase, output1, output2, rtol, atol, equal_nan=True):
output2_cpu = output2.detach().cpu()
if output2_cpu.dtype != output1.dtype:
output2_cpu = output2_cpu.to(output1.dtype)
# import pdb
# pdb.set_trace()
testcase.assertTrue(
torch.allclose(
output1, output2_cpu, atol=atol, rtol=rtol, equal_nan=equal_nan))
Expand Down
11 changes: 9 additions & 2 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,14 @@ def forward(self, x):
return F.log_softmax(x, dim=1)


@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorMNIST(test_utils.XlaTestCase):

def test(self):
# devices=['xla:0', 'xla:1', 'xla:2', 'xla:3'] for example.
devices = xm.get_xla_supported_devices()
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=8)
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
Expand Down Expand Up @@ -267,6 +272,10 @@ def loop_fn(model, loader, device, context):
model_parallel(loop_fn, train_loader)


@unittest.skipIf(
xr.device_type() == 'CUDA',
'Parallelism for DataParallel uses multi-threads. But cuda assumes one GPU device per process instead of relying on threads.'
)
class TestParallelTensorResnet18(test_utils.XlaTestCase):

def test(self):
Expand Down Expand Up @@ -1247,8 +1256,6 @@ def test_fn(a):

self.runAtenTest(torch.zeros([4, 4]), test_fn)

@unittest.skipIf(xr.device_type() == 'GPU',
"This test fails only on GPU with 07/05 XLA pin update.")
def test_stack_pred(self):

def test_fn(a):
Expand Down
12 changes: 11 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
that kind.

Returns:
The list of device strings.
The list of device strings such as ['xla:0', 'xla:1', ...]
"""
# TODO(wcromar): Remove `devkind` after 2.3 release cut. We no longer support
# multiple device types.
Expand Down Expand Up @@ -217,6 +217,14 @@ def _xla_real_device(device):


def xla_real_devices(devices: Optional[List[torch.device]] = None):
"""Returns the real devices' name.

Args:
devices: The list of torch devices such as ['xla:0', 'xla:1'].

Returns:
A list of real devices' name such as ['CUDA:0', 'CUDA:1'].
"""
if not devices:
devices = get_xla_supported_devices()

Expand Down Expand Up @@ -257,6 +265,7 @@ def xla_replication_devices(local_devices):
format(len(local_devices), len(kind_devices)))
replication_devices = []
for device in torch_xla._XLAC._xla_get_all_devices():
# device is like 'CUDA:0'
xdev = parse_xla_device(device)
if not xdev:
raise RuntimeError('Invalid device format: {}'.format(device))
Expand Down Expand Up @@ -284,6 +293,7 @@ def set_replication(device, devices):
devctx = _get_device_context(device=device)
devices = [str(x) for x in devices]
if devices:
# sample replication_devices: ['CUDA:0', 'CUDA:1', 'CUDA:2', 'CUDA:3']
replication_devices = xla_replication_devices(devices)
torch_xla._XLAC._xla_set_replication_devices(replication_devices)
devctx.device_index = devices.index(device)
Expand Down
Loading