Skip to content

Commit 901c3a3

Browse files
authored
Update Neuron initializations (#7952)
1 parent 900296a commit 901c3a3

File tree

7 files changed

+250
-18
lines changed

7 files changed

+250
-18
lines changed

test/neuron/run_tests.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
set -xue
3+
4+
python3 test/neuron/test_neuron_utils.py

test/neuron/test_neuron_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import pytest
3+
import unittest
4+
from torch_xla._internal.neuron_utils import *
5+
6+
7+
class NeuronUtilsTest(unittest.TestCase):
8+
9+
def test_get_visible_cores_list(self):
10+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1"
11+
assert (get_visible_cores_list() == [1])
12+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,2,3"
13+
assert (get_visible_cores_list() == [1, 2, 3])
14+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3"
15+
assert (get_visible_cores_list() == [1, 2, 3])
16+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8"
17+
assert (get_visible_cores_list() == [1, 2, 3, 5, 6, 7, 8])
18+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,3,5-8"
19+
assert (get_visible_cores_list() == [1, 3, 5, 6, 7, 8])
20+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8,3-5"
21+
with pytest.raises(ValueError):
22+
get_visible_cores_list()
23+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8-5"
24+
with pytest.raises(ValueError):
25+
get_visible_cores_list()
26+
os.environ["NEURON_RT_VISIBLE_CORES"] = "a-b,5-8-5"
27+
with pytest.raises(Exception):
28+
get_visible_cores_list()
29+
os.environ["NEURON_RT_VISIBLE_CORES"] = "a"
30+
with pytest.raises(Exception):
31+
get_visible_cores_list()
32+
33+
def test_remap_visible_cores(self):
34+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1"
35+
remap_visible_cores(0, 1)
36+
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "1")
37+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,2,3"
38+
remap_visible_cores(1, 3)
39+
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "2")
40+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3"
41+
remap_visible_cores(2, 3)
42+
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "3")
43+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1-3,5-8"
44+
remap_visible_cores(5, 7)
45+
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "7")
46+
os.environ["NEURON_RT_VISIBLE_CORES"] = "1,3,5-8"
47+
remap_visible_cores(5, 6)
48+
assert (os.environ['NEURON_RT_VISIBLE_CORES'] == "8")
49+
with pytest.raises(ValueError):
50+
remap_visible_cores(5, 9)
51+
with pytest.raises(ValueError):
52+
remap_visible_cores(6, 6)
53+
54+
55+
if __name__ == "__main__":
56+
test = unittest.main()
57+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/pjrt/test_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
8686
reload(torch_xla)
8787
logs_context = contextlib.nullcontext()
8888
if expect_using_pjrt:
89-
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
89+
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU', 'NEURON'])
9090
else:
9191
self.assertIsNone(xr.device_type())
9292

torch_xla/__init__.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,31 @@ def _summarize_fn_tracker():
116116

117117
def _aws_ec2_inf_trn_init():
118118
try:
119-
from torch_neuronx import xla
119+
from libneuronxla.libneuronpjrt_path import libneuronpjrt_path
120120
except ImportError:
121-
return
121+
pass
122122
else:
123-
xla.init()
123+
# Need to set NEURON_LIBRARY_PATH here for proper Neuron Cache behavior
124+
os.environ.setdefault('NEURON_LIBRARY_PATH', libneuronpjrt_path())
125+
# Enable addition features and overrides
126+
try:
127+
from torch_neuronx import xla
128+
except ImportError:
129+
pass
130+
else:
131+
xla.init()
132+
133+
# Basic initializations if torch-neuronx is not available
134+
from ._internal import neuron
135+
if os.path.basename(sys.argv[0]) != 'neuron_parallel_compile':
136+
import libneuronxla
137+
libneuronxla.configure_environment()
138+
neuron.set_envvar_defaults()
139+
neuron.configure_pjrt_environment()
140+
# Found libneuronxla
141+
return True
142+
# Did not find libneuronxla
143+
return False
124144

125145

126146
def _setup_tpu_vm_library_path() -> bool:
@@ -179,7 +199,7 @@ def _check_deprecated_env_var():
179199
_found_libtpu = _setup_tpu_vm_library_path()
180200

181201
# Setup Neuron library for AWS EC2 inf/trn instances.
182-
_aws_ec2_inf_trn_init()
202+
_found_libneuronxla = _aws_ec2_inf_trn_init()
183203

184204

185205
def _prepare_to_exit():

torch_xla/_internal/neuron.py

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,115 @@
33

44
from torch_xla.experimental import plugins
55

6+
import sys
7+
import torch.distributed as dist
8+
9+
from .neuron_utils import get_visible_cores_list, remap_visible_cores
10+
11+
logging.basicConfig()
12+
logger = logging.getLogger(__name__)
13+
14+
15+
# Set root communication address/port
16+
def set_rt_root_comm_id():
17+
if os.environ.get('NEURON_RT_ROOT_COMM_ID', None) is None:
18+
if 'MASTER_ADDR' not in os.environ:
19+
logging.warning(
20+
"MASTER_ADDR environment variable is not set, defaulting to localhost"
21+
)
22+
root_port = 62182
23+
root_addr = os.environ.get('MASTER_ADDR', 'localhost')
24+
is_ipv6 = len(root_addr.split(":")) >= 3
25+
if is_ipv6:
26+
modified = False
27+
if not root_addr.startswith("["):
28+
root_addr = "[" + root_addr
29+
modified = True
30+
if not root_addr.endswith("]"):
31+
root_addr = root_addr + "]"
32+
modified = True
33+
if modified:
34+
logger.warning(
35+
"IPv6 address detected for MASTER_ADDR and missing brackets added: {}"
36+
.format(root_addr))
37+
os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(root_addr, root_port)
38+
39+
40+
def set_envvar_defaults():
41+
os.environ.setdefault('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', '50')
42+
43+
44+
def configure_pjrt_environment():
45+
"""
46+
Setting all necessary PJRT default environment variables.
47+
"""
48+
from torch.distributed import is_torchelastic_launched
49+
50+
# Set root communication address/port
51+
set_rt_root_comm_id()
52+
53+
# Set env variables if we don't use GSPMD, using PJRT, and using torchrun
54+
if os.environ.get('XLA_USE_SPMD', '0') != '1' \
55+
and is_torchelastic_launched():
56+
# Env variables that only need to be set once
57+
# NEURON_PJRT_PROCESSES_NUM_DEVICES is a list of core counts and is too long for very large cluster,
58+
# so use NEURON_PJRT_WORLD_SIZE to pass world size and use core count of 1 per process in PJRT client.
59+
if 'NEURON_PJRT_PROCESSES_NUM_DEVICES' not in os.environ and 'NEURON_PJRT_WORLD_SIZE' not in os.environ:
60+
if 'WORLD_SIZE' not in os.environ:
61+
logger.warning(
62+
'WORLD_SIZE environment variable not set, defaulting to 1.')
63+
os.environ["NEURON_PJRT_WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1")
64+
if 'LOCAL_WORLD_SIZE' not in os.environ:
65+
logger.warning(
66+
'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.')
67+
os.environ['PJRT_LOCAL_PROCESS_COUNT'] = os.environ.get(
68+
'LOCAL_WORLD_SIZE', '1')
69+
70+
# Env variables that need to be set once per process
71+
if not os.environ.get('NEURON_RT_VISIBLE_CORES', None):
72+
os.environ['NEURON_RT_VISIBLE_CORES'] = os.environ.get('LOCAL_RANK', '0')
73+
else:
74+
local_rank = int(os.environ.get('LOCAL_RANK', '0'))
75+
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', '1'))
76+
remap_visible_cores(local_rank, local_world_size)
77+
78+
if 'RANK' not in os.environ:
79+
logger.warning('RANK environment variable is not set, defaulting to 0.')
80+
os.environ['NEURON_PJRT_PROCESS_INDEX'] = os.environ.get('RANK', '0')
81+
if 'LOCAL_RANK' not in os.environ:
82+
logger.warning(
83+
'LOCAL RANK environment variable is not set, defaulting to 0.')
84+
os.environ['PJRT_LOCAL_PROCESS_RANK'] = os.environ.get('LOCAL_RANK', '0')
85+
686

787
def num_local_processes() -> int:
8-
if 'MASTER_ADDR' not in os.environ:
9-
logging.warning("MASTER_ADDR not setting, defaulting to localhost")
10-
os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(
11-
os.environ.get('MASTER_ADDR', 'localhost'), '62182')
12-
if "NEURONCORE_NUM_DEVICES" not in os.environ:
13-
logging.warning("NEURONCORE_NUM_DEVICES not set, defaulting to 1")
88+
set_rt_root_comm_id()
1489
num_processes = int(os.environ.get("NEURONCORE_NUM_DEVICES", "1"))
1590
os.environ['NEURON_PJRT_PROCESSES_NUM_DEVICES'] = ','.join(
1691
['1' for _ in range(num_processes)])
17-
1892
return num_processes
1993

2094

95+
# When torchrun is used, setting these environments causes the
96+
# second instance in 2-node cluster to think it is node 0 instead of node 1.
97+
# Need to skip these settings and let configure_pjrt_environment to
98+
# set the distributed PJRT environment variables.
99+
# If NEURONCORE_NUM_DEVICES is used, then go ahead and set the environments.
21100
def initialize_env(local_rank, local_world_size):
22-
os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank)
23-
assert (
24-
local_rank < local_world_size
25-
), "ERROR in initialize_env: PJRT_LOCAL_PROCESS_RANK is not less than PJRT_LOCAL_PROCESS_COUNT"
26-
os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank)
101+
from torch.distributed import is_torchelastic_launched
102+
if not is_torchelastic_launched():
103+
os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank)
104+
if not os.environ.get('NEURON_RT_VISIBLE_CORES', None):
105+
os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank)
106+
else:
107+
remap_visible_cores(local_rank, local_world_size)
27108

28109

29110
class NeuronPlugin(plugins.DevicePlugin):
30111

31112
def library_path(self):
32-
return os.environ.get("NEURON_LIBRARY_PATH", "libneuronpjrt.so")
113+
from libneuronxla.libneuronpjrt_path import libneuronpjrt_path
114+
return os.environ.get("NEURON_LIBRARY_PATH", libneuronpjrt_path())
33115

34116
def configure_multiprocess(self, local_rank, local_world_size):
35117
initialize_env(local_rank, local_world_size)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import logging
3+
logging.basicConfig()
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def convert_range(range_spec):
8+
try:
9+
lowerupper = list(map(int, range_spec.split("-")))
10+
except Exception as e:
11+
print(f"ERROR: Malformed range specs in NEURON_RT_VISIBLE_CORES;" +
12+
f"expecting <int> or <lower int>-<upper int> (got {range_spec})")
13+
raise e
14+
if len(lowerupper) > 2:
15+
raise ValueError(
16+
f"ERROR: Range specs in NEURON_RT_VISIBLE_CORES should be of " +
17+
f"the form <int> or <lower int>-<upper int> (got {range_spec})")
18+
if len(lowerupper) == 2:
19+
if lowerupper[0] > lowerupper[1]:
20+
raise ValueError(
21+
f"ERROR: Range specs in NEURON_RT_VISIBLE_CORES should " +
22+
f"be of the form <int> or <lower int>-<upper int> (got {range_spec})")
23+
lowerupper = range(lowerupper[0], lowerupper[1] + 1)
24+
return lowerupper
25+
26+
27+
def get_visible_cores_list():
28+
import os
29+
30+
range_list = os.environ.get("NEURON_RT_VISIBLE_CORES")
31+
cores_list = None
32+
if range_list:
33+
range_list = range_list.split(",")
34+
cores_list = []
35+
for i in range_list:
36+
new = convert_range(i)
37+
if (set(cores_list) & set(new)) != set():
38+
raise ValueError(
39+
"ERROR: Please ensure the ranges in NEURON_RT_VISIBLE_CORES are mutually exclusive."
40+
)
41+
cores_list.extend(new)
42+
return cores_list
43+
44+
45+
def remap_visible_cores(local_rank, local_world_size):
46+
cores_list = get_visible_cores_list()
47+
count = len(cores_list)
48+
assert (local_world_size > 0), "Local world size should be non-zero"
49+
if count <= 1 and local_world_size == 1:
50+
# Allow user to pass NEURON_RT_VISIBLE_CORES for sinlge-core workload
51+
pass
52+
elif local_world_size != count:
53+
raise ValueError(
54+
f"LOCAL_WORLD_SIZE (torchrun) or PJRT_LOCAL_PROCESS_COUNT (xmp.spawn) value of {local_world_size} "
55+
+
56+
f"is not equal to count {count} from NEURON_RT_VISIBLE_CORES {cores_list}"
57+
)
58+
elif local_rank >= count:
59+
raise ValueError(
60+
f"LOCAL_RANK (torchrun) or PJRT_LOCAL_PROCESS_RANK (xmp.spawn) value of {local_rank} is higher than "
61+
+ f"count {count} from NEURON_RT_VISIBLE_CORES {cores_list}")
62+
else:
63+
remapped_core = cores_list[local_rank]
64+
logger.warning(f"Remapping NEURON_RT_VISIBLE_CORES {cores_list} to " +
65+
f"NEURON_RT_VISIBLE_CORES[LOCAL_RANK]={remapped_core}")
66+
os.environ['NEURON_RT_VISIBLE_CORES'] = str(remapped_core)

torch_xla/runtime.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def _maybe_select_default_device():
7272
+ num_devices_str)
7373
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
7474
os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str
75+
elif torch_xla._found_libneuronxla:
76+
logging.warning('Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.')
77+
os.environ[xenv.PJRT_DEVICE] = 'NEURON'
7578
else:
7679
logging.warning('Defaulting to PJRT_DEVICE=CPU')
7780
os.environ[xenv.PJRT_DEVICE] = 'CPU'

0 commit comments

Comments
 (0)