|
3 | 3 |
|
4 | 4 | from torch_xla.experimental import plugins |
5 | 5 |
|
| 6 | +import sys |
| 7 | +import torch.distributed as dist |
| 8 | + |
| 9 | +from .neuron_utils import get_visible_cores_list, remap_visible_cores |
| 10 | + |
| 11 | +logging.basicConfig() |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +# Set root communication address/port |
| 16 | +def set_rt_root_comm_id(): |
| 17 | + if os.environ.get('NEURON_RT_ROOT_COMM_ID', None) is None: |
| 18 | + if 'MASTER_ADDR' not in os.environ: |
| 19 | + logging.warning( |
| 20 | + "MASTER_ADDR environment variable is not set, defaulting to localhost" |
| 21 | + ) |
| 22 | + root_port = 62182 |
| 23 | + root_addr = os.environ.get('MASTER_ADDR', 'localhost') |
| 24 | + is_ipv6 = len(root_addr.split(":")) >= 3 |
| 25 | + if is_ipv6: |
| 26 | + modified = False |
| 27 | + if not root_addr.startswith("["): |
| 28 | + root_addr = "[" + root_addr |
| 29 | + modified = True |
| 30 | + if not root_addr.endswith("]"): |
| 31 | + root_addr = root_addr + "]" |
| 32 | + modified = True |
| 33 | + if modified: |
| 34 | + logger.warning( |
| 35 | + "IPv6 address detected for MASTER_ADDR and missing brackets added: {}" |
| 36 | + .format(root_addr)) |
| 37 | + os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format(root_addr, root_port) |
| 38 | + |
| 39 | + |
| 40 | +def set_envvar_defaults(): |
| 41 | + os.environ.setdefault('ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB', '50') |
| 42 | + |
| 43 | + |
| 44 | +def configure_pjrt_environment(): |
| 45 | + """ |
| 46 | + Setting all necessary PJRT default environment variables. |
| 47 | + """ |
| 48 | + from torch.distributed import is_torchelastic_launched |
| 49 | + |
| 50 | + # Set root communication address/port |
| 51 | + set_rt_root_comm_id() |
| 52 | + |
| 53 | + # Set env variables if we don't use GSPMD, using PJRT, and using torchrun |
| 54 | + if os.environ.get('XLA_USE_SPMD', '0') != '1' \ |
| 55 | + and is_torchelastic_launched(): |
| 56 | + # Env variables that only need to be set once |
| 57 | + # NEURON_PJRT_PROCESSES_NUM_DEVICES is a list of core counts and is too long for very large cluster, |
| 58 | + # so use NEURON_PJRT_WORLD_SIZE to pass world size and use core count of 1 per process in PJRT client. |
| 59 | + if 'NEURON_PJRT_PROCESSES_NUM_DEVICES' not in os.environ and 'NEURON_PJRT_WORLD_SIZE' not in os.environ: |
| 60 | + if 'WORLD_SIZE' not in os.environ: |
| 61 | + logger.warning( |
| 62 | + 'WORLD_SIZE environment variable not set, defaulting to 1.') |
| 63 | + os.environ["NEURON_PJRT_WORLD_SIZE"] = os.environ.get("WORLD_SIZE", "1") |
| 64 | + if 'LOCAL_WORLD_SIZE' not in os.environ: |
| 65 | + logger.warning( |
| 66 | + 'LOCAL_WORLD_SIZE environment variable not set, defaulting to 1.') |
| 67 | + os.environ['PJRT_LOCAL_PROCESS_COUNT'] = os.environ.get( |
| 68 | + 'LOCAL_WORLD_SIZE', '1') |
| 69 | + |
| 70 | + # Env variables that need to be set once per process |
| 71 | + if not os.environ.get('NEURON_RT_VISIBLE_CORES', None): |
| 72 | + os.environ['NEURON_RT_VISIBLE_CORES'] = os.environ.get('LOCAL_RANK', '0') |
| 73 | + else: |
| 74 | + local_rank = int(os.environ.get('LOCAL_RANK', '0')) |
| 75 | + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', '1')) |
| 76 | + remap_visible_cores(local_rank, local_world_size) |
| 77 | + |
| 78 | + if 'RANK' not in os.environ: |
| 79 | + logger.warning('RANK environment variable is not set, defaulting to 0.') |
| 80 | + os.environ['NEURON_PJRT_PROCESS_INDEX'] = os.environ.get('RANK', '0') |
| 81 | + if 'LOCAL_RANK' not in os.environ: |
| 82 | + logger.warning( |
| 83 | + 'LOCAL RANK environment variable is not set, defaulting to 0.') |
| 84 | + os.environ['PJRT_LOCAL_PROCESS_RANK'] = os.environ.get('LOCAL_RANK', '0') |
| 85 | + |
6 | 86 |
|
7 | 87 | def num_local_processes() -> int: |
8 | | - if 'MASTER_ADDR' not in os.environ: |
9 | | - logging.warning("MASTER_ADDR not setting, defaulting to localhost") |
10 | | - os.environ['NEURON_RT_ROOT_COMM_ID'] = '{}:{}'.format( |
11 | | - os.environ.get('MASTER_ADDR', 'localhost'), '62182') |
12 | | - if "NEURONCORE_NUM_DEVICES" not in os.environ: |
13 | | - logging.warning("NEURONCORE_NUM_DEVICES not set, defaulting to 1") |
| 88 | + set_rt_root_comm_id() |
14 | 89 | num_processes = int(os.environ.get("NEURONCORE_NUM_DEVICES", "1")) |
15 | 90 | os.environ['NEURON_PJRT_PROCESSES_NUM_DEVICES'] = ','.join( |
16 | 91 | ['1' for _ in range(num_processes)]) |
17 | | - |
18 | 92 | return num_processes |
19 | 93 |
|
20 | 94 |
|
| 95 | +# When torchrun is used, setting these environments causes the |
| 96 | +# second instance in 2-node cluster to think it is node 0 instead of node 1. |
| 97 | +# Need to skip these settings and let configure_pjrt_environment to |
| 98 | +# set the distributed PJRT environment variables. |
| 99 | +# If NEURONCORE_NUM_DEVICES is used, then go ahead and set the environments. |
21 | 100 | def initialize_env(local_rank, local_world_size): |
22 | | - os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank) |
23 | | - assert ( |
24 | | - local_rank < local_world_size |
25 | | - ), "ERROR in initialize_env: PJRT_LOCAL_PROCESS_RANK is not less than PJRT_LOCAL_PROCESS_COUNT" |
26 | | - os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank) |
| 101 | + from torch.distributed import is_torchelastic_launched |
| 102 | + if not is_torchelastic_launched(): |
| 103 | + os.environ["NEURON_PJRT_PROCESS_INDEX"] = str(local_rank) |
| 104 | + if not os.environ.get('NEURON_RT_VISIBLE_CORES', None): |
| 105 | + os.environ["NEURON_RT_VISIBLE_CORES"] = str(local_rank) |
| 106 | + else: |
| 107 | + remap_visible_cores(local_rank, local_world_size) |
27 | 108 |
|
28 | 109 |
|
29 | 110 | class NeuronPlugin(plugins.DevicePlugin): |
30 | 111 |
|
31 | 112 | def library_path(self): |
32 | | - return os.environ.get("NEURON_LIBRARY_PATH", "libneuronpjrt.so") |
| 113 | + from libneuronxla.libneuronpjrt_path import libneuronpjrt_path |
| 114 | + return os.environ.get("NEURON_LIBRARY_PATH", libneuronpjrt_path()) |
33 | 115 |
|
34 | 116 | def configure_multiprocess(self, local_rank, local_world_size): |
35 | 117 | initialize_env(local_rank, local_world_size) |
|
0 commit comments