Skip to content

Commit 322ffad

Browse files
Fix spot tpu bug (skypilot-org#1717)
* debug * shorten * debug * fix many bugs * undo debug * comment * bug.. * fix * fix comments * comments * fix * fix multinode * fix * refactor * type * Update sky/spot/recovery_strategy.py Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com> * update --------- Co-authored-by: Zhanghao Wu <zhanghao.wu@outlook.com>
1 parent 3654d72 commit 322ffad

File tree

3 files changed

+94
-58
lines changed

3 files changed

+94
-58
lines changed

sky/backends/backend_utils.py

Lines changed: 89 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,37 @@ def get_timestamp_from_run_timestamp(run_timestamp: str) -> float:
996996
run_timestamp.partition('-')[2], '%Y-%m-%d-%H-%M-%S-%f').timestamp()
997997

998998

999+
def _count_healthy_nodes_from_ray(output: str,
1000+
is_local_cloud: bool = False
1001+
) -> Tuple[int, int]:
1002+
"""Count the number of healthy nodes from the output of `ray status`."""
1003+
1004+
def get_ready_nodes(pattern, output):
1005+
result = pattern.findall(output)
1006+
# On-prem/local case is handled differently.
1007+
# `ray status` produces different output for local case, and
1008+
# we poll for number of nodes launched instead of counting for
1009+
# head and number of worker nodes separately (it is impossible
1010+
# to distinguish between head and worker node for local case).
1011+
if is_local_cloud:
1012+
# In the local case, ready_workers mean the total number
1013+
# of nodes launched, including head.
1014+
return len(result)
1015+
if len(result) == 0:
1016+
return 0
1017+
assert len(result) == 1, result
1018+
return int(result[0])
1019+
1020+
if is_local_cloud:
1021+
ready_head = 0
1022+
ready_workers = get_ready_nodes(_LAUNCHED_LOCAL_WORKER_PATTERN, output)
1023+
else:
1024+
ready_head = get_ready_nodes(_LAUNCHED_HEAD_PATTERN, output)
1025+
ready_workers = get_ready_nodes(_LAUNCHED_WORKER_PATTERN, output)
1026+
assert ready_head <= 1, f'#head node should be <=1 (Got {ready_head}).'
1027+
return ready_head, ready_workers
1028+
1029+
9991030
@timeline.event
10001031
def wait_until_ray_cluster_ready(
10011032
cluster_config_file: str,
@@ -1035,32 +1066,8 @@ def wait_until_ray_cluster_ready(
10351066
stderr)
10361067
logger.debug(output)
10371068

1038-
# Workers that are ready
1039-
ready_workers = 0
1040-
# On-prem/local case is handled differently.
1041-
# `ray status` produces different output for local case, and
1042-
# we poll for number of nodes launched instead of counting for
1043-
# head and number of worker nodes separately (it is impossible
1044-
# to distinguish between head and worker node for local case).
1045-
if is_local_cloud:
1046-
result = _LAUNCHED_LOCAL_WORKER_PATTERN.findall(output)
1047-
# In the local case, ready_workers mean the total number
1048-
# of nodes launched, including head.
1049-
ready_workers = len(result)
1050-
else:
1051-
result = _LAUNCHED_WORKER_PATTERN.findall(output)
1052-
if len(result) == 0:
1053-
ready_workers = 0
1054-
else:
1055-
assert len(result) == 1, result
1056-
ready_workers = int(result[0])
1057-
1058-
result = _LAUNCHED_HEAD_PATTERN.findall(output)
1059-
ready_head = 0
1060-
if result:
1061-
assert len(result) == 1, result
1062-
ready_head = int(result[0])
1063-
assert ready_head <= 1, ready_head
1069+
ready_head, ready_workers = _count_healthy_nodes_from_ray(
1070+
output, is_local_cloud=is_local_cloud)
10641071

10651072
worker_status.update('[bold cyan]'
10661073
f'{ready_workers} out of {num_nodes - 1} '
@@ -1445,16 +1452,16 @@ def _get_tpu_vm_pod_ips(ray_config: Dict[str, Any],
14451452
'Hint: make sure it is not leaked.')
14461453
continue
14471454

1448-
if not get_internal_ips:
1449-
ips = [
1450-
endpoint['accessConfig']['externalIp']
1451-
for endpoint in tpuvm_json['networkEndpoints']
1452-
]
1453-
else:
1454-
ips = [
1455-
endpoint['ipAddress']
1456-
for endpoint in tpuvm_json['networkEndpoints']
1457-
]
1455+
ips = []
1456+
for endpoint in tpuvm_json['networkEndpoints']:
1457+
# Note: if TPU VM is being preempted, its IP field may not exist.
1458+
# We use get() to avoid KeyError.
1459+
if get_internal_ips:
1460+
ip = endpoint.get('ipAddress', None)
1461+
else:
1462+
ip = endpoint['accessConfig'].get('externalIp', None)
1463+
if ip is not None:
1464+
ips.append(ip)
14581465
all_ips.extend(ips)
14591466

14601467
return all_ips
@@ -1762,6 +1769,8 @@ def _update_cluster_status_no_lock(
17621769
return record
17631770

17641771
cluster_name = handle.cluster_name
1772+
use_spot = handle.launched_resources.use_spot
1773+
ray_cluster_up = False
17651774
try:
17661775
# TODO(zhwu): This function cannot distinguish transient network error
17671776
# in ray's get IPs vs. ray runtime failing.
@@ -1770,32 +1779,57 @@ def _update_cluster_status_no_lock(
17701779
if external_ips is None or len(external_ips) == 0:
17711780
raise exceptions.FetchIPError(
17721781
reason=exceptions.FetchIPError.Reason.HEAD)
1773-
if handle.launched_nodes == 1:
1774-
# Check the ray cluster status. We have to check it for single node
1775-
# case, since the get_node_ips() does not require ray cluster to be
1776-
# running.
1777-
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml)
1778-
runner = command_runner.SSHCommandRunner(external_ips[0],
1779-
**ssh_credentials)
1780-
returncode = runner.run('ray status', stream_logs=False)
1781-
if returncode:
1782-
raise exceptions.FetchIPError(
1783-
reason=exceptions.FetchIPError.Reason.HEAD)
1784-
# If we get node ips correctly, the cluster is UP. It is safe to
1785-
# set the status to UP, as the `handle.external_ips` function uses ray
1786-
# to fetch IPs and starting ray is the final step of sky launch.
1782+
# Check if ray cluster status is healthy.
1783+
ssh_credentials = ssh_credential_from_yaml(handle.cluster_yaml)
1784+
runner = command_runner.SSHCommandRunner(external_ips[0],
1785+
**ssh_credentials)
1786+
rc, output, _ = runner.run('ray status',
1787+
stream_logs=False,
1788+
require_outputs=True,
1789+
separate_stderr=True)
1790+
if rc:
1791+
raise exceptions.FetchIPError(
1792+
reason=exceptions.FetchIPError.Reason.HEAD)
1793+
1794+
ready_head, ready_workers = _count_healthy_nodes_from_ray(output)
1795+
1796+
if ready_head + ready_workers == handle.launched_nodes:
1797+
ray_cluster_up = True
1798+
1799+
# For non-spot clusters:
1800+
# If ray status shows all nodes are healthy, it is safe to set
1801+
# the status to UP as starting ray is the final step of sky launch.
1802+
# For spot clusters, the above can be unsafe because the Ray cluster
1803+
# may remain healthy for a while before the cloud completely
1804+
# preempts the VMs.
1805+
# Additionally, we query the VM state from the cloud provider.
1806+
if ray_cluster_up and not use_spot:
1807+
record['status'] = global_user_state.ClusterStatus.UP
1808+
global_user_state.add_or_update_cluster(cluster_name,
1809+
handle,
1810+
requested_resources=None,
1811+
ready=True,
1812+
is_launch=False)
1813+
return record
1814+
except exceptions.FetchIPError:
1815+
logger.debug('Refreshing status: Failed to get IPs from cluster '
1816+
f'{cluster_name!r}, trying to fetch from provider.')
1817+
# For all code below, we query cluster status by cloud CLI for two cases:
1818+
# 1) ray fails to get IPs for the cluster.
1819+
# 2) the cluster is a spot cluster.
1820+
node_statuses = _get_cluster_status_via_cloud_cli(handle)
1821+
1822+
all_nodes_up = (all(status == global_user_state.ClusterStatus.UP
1823+
for status in node_statuses) and
1824+
len(node_statuses) == handle.launched_nodes)
1825+
if ray_cluster_up and all_nodes_up:
17871826
record['status'] = global_user_state.ClusterStatus.UP
17881827
global_user_state.add_or_update_cluster(cluster_name,
17891828
handle,
17901829
requested_resources=None,
17911830
ready=True,
17921831
is_launch=False)
17931832
return record
1794-
except exceptions.FetchIPError:
1795-
logger.debug('Refreshing status: Failed to get IPs from cluster '
1796-
f'{cluster_name!r}, trying to fetch from provider.')
1797-
# For all code below, ray fails to get IPs for the cluster.
1798-
node_statuses = _get_cluster_status_via_cloud_cli(handle)
17991833

18001834
if len(node_statuses) > handle.launched_nodes:
18011835
# Unexpected: in the queried region more than 1 cluster with the same

sky/spot/recovery_strategy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The strategy to handle launching/recovery/termination of spot clusters."""
22
import time
3+
import traceback
34
import typing
45
from typing import Optional, Tuple
56

@@ -117,7 +118,9 @@ def terminate_cluster(self, max_retry: int = 3) -> None:
117118
raise RuntimeError('Failed to terminate the spot cluster '
118119
f'{self.cluster_name}.') from e
119120
logger.error('Failed to terminate the spot cluster '
120-
f'{self.cluster_name}. Retrying.')
121+
f'{self.cluster_name}. Retrying.'
122+
f'Details: {common_utils.format_exception(e)}')
123+
logger.error(f' Traceback: {traceback.format_exc()}')
121124

122125
def _try_cancel_all_jobs(self):
123126
handle = global_user_state.get_handle_from_cluster_name(
@@ -288,7 +291,6 @@ def _launch(self, max_retry=3, raise_on_failure=True) -> Optional[float]:
288291
# code.
289292
logger.info('Failed to launch the spot cluster with error: '
290293
f'{common_utils.format_exception(e)})')
291-
import traceback # pylint: disable=import-outside-toplevel
292294
logger.info(f' Traceback: {traceback.format_exc()}')
293295
else: # No exception, the launch succeeds.
294296
# At this point, a sky.launch() has succeeded. Cluster may be

sky/utils/tpu_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool:
2727
if resources is None or not is_tpu_vm(resources):
2828
return False
2929
acc, _ = list(resources.accelerators.items())[0]
30-
return acc not in ['tpu-v2-8', 'tpu-v3-8']
30+
return acc not in ['tpu-v2-8', 'tpu-v3-8', 'tpu-v4-8']
3131

3232

3333
def get_num_tpu_devices(

0 commit comments

Comments
 (0)