Skip to content

Commit 33eea54

Browse files
committed
Update PyTorch driver to fix SSH error.
1 parent 921c582 commit 33eea54

File tree

1 file changed

+131
-75
lines changed

1 file changed

+131
-75
lines changed

ads/jobs/templates/driver_pytorch.py

Lines changed: 131 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
7777
LOG_PREFIX_NODE_IP = "Node IP: "
7878
LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
79-
LOG_PREFIX_NODE_HOST_KEY = "NODE HOST KEY: "
79+
LOG_PREFIX_HOST_KEY_RSA = "NODE HOST KEY RSA: "
80+
LOG_PREFIX_HOST_KEY_ECDSA = "NODE HOST KEY ECDSA: "
8081
# Other constants used within this script
81-
NODE_HOST_KEY_PATH = "/etc/ssh/ssh_host_rsa_key.pub"
82+
HOST_KEY_PATH_RSA = "/etc/ssh/ssh_host_rsa_key.pub"
83+
HOST_KEY_PATH_ECDSA = "/etc/ssh/ssh_host_ecdsa_key.pub"
8284
USER_HOME = os.environ.get("HOME", f"/home/{getpass.getuser()}")
8385
SSH_DIR = os.environ.get("OCI__SSH_DIR", os.path.join(USER_HOME, ".ssh"))
8486
DEFAULT_LAUNCHER = "torchrun"
@@ -155,7 +157,8 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
155157

156158
# IP address of other nodes as a list
157159
self.node_ip_list = []
158-
self.node_runs = []
160+
# For DTv2, node_runs should not be used.
161+
self.node_runs = None
159162
self.host_ocid = None
160163
self.host_job_run = None
161164

@@ -204,6 +207,9 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
204207

205208
logger.debug("Runner initialized.")
206209

210+
def is_dtv2(self):
211+
return CONST_ENV_META_FILE in os.environ
212+
207213
def launch_cmd_contains(self, arg) -> bool:
208214
"""Checks if the cmd for launching the training contains specific keyword argument."""
209215
return f"--{arg}" in self.launch_cmd
@@ -250,7 +256,7 @@ def wait_for_ip_address(self, job_run, timeout=15 * 60) -> str:
250256
logger.info("IP of %s: %s", job_run.id[-6:], ip_address)
251257
return ip_address
252258

253-
def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
259+
def wait_for_log(self, job_run, log_prefix, timeout=15 * 60, limit=1) -> str:
254260
"""Waits until a log message with specific prefix is found in the logs of a job run.
255261
256262
Parameters
@@ -276,10 +282,11 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
276282
"Waiting for logs with prefix '%s' from %s.", log_prefix, job_run.id
277283
)
278284
second_started = time.time()
279-
log = None
280-
while not log:
281-
log = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix)
282-
if log:
285+
logs = None
286+
while True:
287+
logs = self.check_job_run_logs(job_run=job_run, log_prefix=log_prefix)
288+
if logs and len(logs) >= limit:
289+
logs = logs[:limit]
283290
break
284291
if time.time() - second_started > timeout:
285292
logs = job_run.logs()
@@ -289,10 +296,12 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
289296
f"Last log obtained: {last_log}"
290297
)
291298
time.sleep(60)
292-
return log
299+
if limit == 1:
300+
return logs[0]
301+
return logs
293302

294303
@staticmethod
295-
def check_job_run_logs(job_run, log_prefix: str) -> str:
304+
def check_job_run_logs(job_run, log_prefix: str) -> list:
296305
"""Checks the logs of a specific job run and find the log message with specific prefix.
297306
298307
Parameters
@@ -309,10 +318,12 @@ def check_job_run_logs(job_run, log_prefix: str) -> str:
309318
"""
310319
logger.debug("Checking logs for job run %s", job_run.id)
311320
logs = job_run.logs()
312-
for log in logs:
313-
if log["message"].startswith(log_prefix):
314-
return log["message"][len(log_prefix) :]
315-
return None
321+
logs = [
322+
log["message"][len(log_prefix) :]
323+
for log in logs
324+
if log["message"].startswith(log_prefix)
325+
]
326+
return logs
316327

317328
def find_self_ip(self):
318329
"""
@@ -408,6 +419,8 @@ def read_metadata(self):
408419
)
409420
time.sleep(20)
410421
continue
422+
logger.debug("All nodes are found in metadata file.")
423+
logger.debug(node_list)
411424
return {int(meta[CONST_RANK]): meta[CONST_IP_ADDRESS] for meta in node_list}
412425

413426
def fetch_code(self):
@@ -513,6 +526,7 @@ class TorchRunner(Runner):
513526

514527
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
515528
super().__init__(code_dir)
529+
logger.debug("Initializing Torch Runner...")
516530
self.build_c_library()
517531

518532
def build_c_library(self):
@@ -588,20 +602,30 @@ class DeepSpeedRunner(Runner):
588602

589603
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
590604
super().__init__(code_dir)
591-
self.host_key = None
592-
self.deepspeed_setup()
605+
logger.debug("Initializing DeepSpeed Runner...")
606+
# Setup DeepSpeed if it used.
607+
if self.use_deepspeed():
608+
self.host_key = None
609+
self.deepspeed_setup()
610+
611+
def use_deepspeed(self):
612+
"""Indicate if DeepSpeed is used."""
613+
# Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
614+
return bool(
615+
os.environ.get(CONST_ENV_DEEPSPEED)
616+
or self.launch_cmd_contains("use_deepspeed")
617+
or self.launch_cmd_contains("deepspeed")
618+
)
593619

594620
def deepspeed_setup(self):
595621
"""Setup for DeepSpeed."""
596-
self.host_key = (
597-
NODE_HOST_KEY_PATH if os.path.exists(NODE_HOST_KEY_PATH) else None
598-
)
622+
self.host_key = HOST_KEY_PATH_RSA if os.path.exists(HOST_KEY_PATH_RSA) else None
599623
# Create the temp dir if one does not exist.
600624
# This is needed for JIT
601625
if self.TMPDIR and not os.path.isdir(self.TMPDIR):
602626
logger.info("Creating temp directory: %s", self.TMPDIR)
603627
os.makedirs(self.TMPDIR, exist_ok=True)
604-
self.install_dependencies()
628+
self.install_deepspeed_dependencies()
605629
# host_job_run is needed for DeepSpeed to fetch the public SSH key from the logs.
606630
if self.host_ocid and self.node_count > 1:
607631
self.host_job_run = DataScienceJobRun.from_ocid(self.host_ocid)
@@ -621,7 +645,35 @@ def install_epel(self):
621645
)
622646
break
623647

624-
def install_dependencies(self):
648+
def _print_host_key(self, host_key_path, prefix):
649+
with open(host_key_path, encoding=CONST_ENCODING) as f:
650+
public_key = f.read()
651+
print(f"{prefix}{self.ip}-{public_key}")
652+
653+
def _add_known_hosts_from_file(self, ip_addr, key_file):
654+
if not os.path.exists(key_file):
655+
logger.warning(
656+
"Unable to add host key %s to known_hosts: key file not found.",
657+
key_file,
658+
)
659+
return
660+
self.run_command(
661+
f"echo -n '{ip_addr} ' | " f"cat - {key_file} >> {SSH_DIR}/known_hosts",
662+
level=logging.DEBUG,
663+
check=True,
664+
)
665+
666+
def _add_known_hosts_from_log(self, job_run, prefix, ip_address=None):
667+
ip_key = self.wait_for_log(job_run, f"{prefix}")
668+
ip_addr, public_key = ip_key.split("-", 1)
669+
if ip_address:
670+
ip_addr = ip_address
671+
with open(f"{SSH_DIR}/known_hosts", "a+", encoding=CONST_ENCODING) as f:
672+
line = f"{ip_addr} {public_key}"
673+
f.write(f"{line}\n")
674+
logger.debug("Added host key: %s", line)
675+
676+
def install_deepspeed_dependencies(self):
625677
"""Installs extra dependencies and start SSH service."""
626678
if self.node_count == 1:
627679
logger.debug(
@@ -637,9 +689,8 @@ def install_dependencies(self):
637689
else:
638690
# Generate SSH host keys for SSH server
639691
self.run_command("sudo ssh-keygen -A", level=logging.DEBUG, check=True)
640-
with open(NODE_HOST_KEY_PATH, encoding=CONST_ENCODING) as f:
641-
public_key = f.read()
642-
print(f"{LOG_PREFIX_NODE_HOST_KEY}{self.ip}-{public_key}")
692+
self._print_host_key(HOST_KEY_PATH_RSA, LOG_PREFIX_HOST_KEY_RSA)
693+
self._print_host_key(HOST_KEY_PATH_ECDSA, LOG_PREFIX_HOST_KEY_ECDSA)
643694

644695
if self.run_command("which pdsh", level=logging.DEBUG) != 0:
645696
# Install "openssh-server" to accept SSH connections
@@ -670,35 +721,29 @@ def generate_key_pair(self):
670721
with open(os.path.join(SSH_DIR, "id_rsa.pub"), encoding=CONST_ENCODING) as f:
671722
public_key = f.read()
672723
print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}", flush=True)
673-
self.add_authoried_key(public_key)
674-
# Public
675-
self.run_command(
676-
f"echo -n '{self.host_ip} ' | "
677-
f"cat - {NODE_HOST_KEY_PATH} >> {SSH_DIR}/known_hosts",
678-
level=logging.DEBUG,
679-
check=True,
680-
)
724+
self._add_authoried_key(public_key)
725+
# Add host key to known hosts
726+
self._add_known_hosts_from_file(self.host_ip, HOST_KEY_PATH_RSA)
727+
self._add_known_hosts_from_file(self.host_ip, HOST_KEY_PATH_ECDSA)
681728
self.test_ssh_connection(self.host_ip)
682729
# Check DeepSpeed compatibility
683730
self.run_command(
684731
"ds_report", conda_prefix=self.conda_prefix, level=logging.DEBUG
685732
)
686733
return self
687734

688-
@staticmethod
689-
def add_authoried_key(public_key):
735+
def _add_authoried_key(self, public_key):
690736
auth_keys_file = os.path.join(SSH_DIR, "authorized_keys")
691737
os.makedirs(SSH_DIR, exist_ok=True)
692738
with open(auth_keys_file, "a+", encoding=CONST_ENCODING) as f:
693739
f.write(public_key)
694740
f.write("\n")
695-
logger.debug("Public key saved to %s", auth_keys_file)
741+
logger.debug("Public key saved to %s:%s", self.ip, auth_keys_file)
696742

697743
def fetch_host_public_key(self):
698744
public_key = self.wait_for_log(self.host_job_run, LOG_PREFIX_PUBLIC_KEY)
699745
print(f"{LOG_PREFIX_PUBLIC_KEY}{public_key}", flush=True)
700-
# logger.debug("%s", LOG_PREFIX_PUBLIC_KEY + public_key)
701-
self.add_authoried_key(public_key)
746+
self._add_authoried_key(public_key)
702747

703748
def generate_hostfile(self):
704749
if not self.node_ip_list:
@@ -716,9 +761,12 @@ def generate_hostfile(self):
716761
# Hostfile
717762
logger.debug("Writing hostfile to %s", self.HOST_FILE)
718763
os.makedirs(os.path.dirname(self.HOST_FILE), exist_ok=True)
719-
host_file_content = [f"{ip} slots={self.gpu_count}" for ip in self.node_ip_list]
764+
host_file_content = [
765+
f"{ip} slots={self.gpu_count}\n" for ip in self.node_ip_list
766+
]
720767
with open(self.HOST_FILE, "w", encoding=CONST_ENCODING) as f:
721-
f.write(f"{self.host_ip} slots={self.gpu_count}\n")
768+
if self.host_ip not in self.node_ip_list:
769+
f.write(f"{self.host_ip} slots={self.gpu_count}\n")
722770
f.writelines(host_file_content)
723771
self.run_command(f"cat {self.HOST_FILE}", level=logging.DEBUG)
724772
# SSH config
@@ -727,16 +775,18 @@ def generate_hostfile(self):
727775
with open(ssh_config_path, "w", encoding=CONST_ENCODING) as f:
728776
f.writelines(
729777
[
730-
"",
731-
f"Host {self.host_ip}",
778+
"\n",
779+
f"Host {self.host_ip}\n",
732780
"KexAlgorithms diffie-hellman-group-exchange-sha256\n",
733781
]
734782
)
735783
for node_ip in self.node_ip_list:
784+
if node_ip == self.host_ip:
785+
continue
736786
f.writelines(
737787
[
738-
"",
739-
f"Host {node_ip}",
788+
"\n",
789+
f"Host {node_ip}\n",
740790
"KexAlgorithms diffie-hellman-group-exchange-sha256\n",
741791
]
742792
)
@@ -758,9 +808,8 @@ def touch_file(self, filename):
758808
for node_ip in self.node_ip_list:
759809
logger.debug("Sending stop file to %s", node_ip)
760810
self.run_command(
761-
f"ssh -v {node_ip} 'touch {filename}'",
811+
f"ssh -v -o PasswordAuthentication=no {node_ip} 'touch {filename}'",
762812
level=logging.DEBUG,
763-
check=True,
764813
)
765814

766815
def save_deepspeed_env(self):
@@ -815,6 +864,9 @@ def save_deepspeed_env(self):
815864
logger.debug("Environment variables saved to %s", self.ENV_FILE)
816865
self.run_command(f"cat {self.ENV_FILE}")
817866

867+
def wait_for_nodes(self):
868+
pass
869+
818870
def run_deepspeed_host(self, launch_args=None):
819871
"""Prepares the host and launch the deepspeed training.
820872
@@ -830,30 +882,41 @@ def run_deepspeed_host(self, launch_args=None):
830882
self.generate_key_pair().generate_hostfile()
831883
self.save_deepspeed_env()
832884
# Wait for nodes to be ready
833-
for run in self.node_runs:
834-
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
885+
# For DTv2, self.node_runs will be None
886+
if self.is_dtv2():
887+
self.wait_for_log(
888+
self.host_job_run, LOG_PREFIX_PUBLIC_KEY, limit=self.node_count
889+
)
890+
else:
891+
for run in self.node_runs:
892+
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
835893

836894
if self.host_key:
837895
# If host key exists, it should be the same for all nodes.
838896
for node_ip in self.node_ip_list:
839-
self.run_command(
840-
f"echo -n '{node_ip} ' | "
841-
f"cat - /etc/ssh/ssh_host_rsa_key.pub >> {SSH_DIR}/known_hosts",
842-
level=logging.DEBUG,
843-
check=True,
844-
)
845-
else:
897+
self._add_known_hosts_from_file(node_ip, HOST_KEY_PATH_RSA)
898+
self._add_known_hosts_from_file(node_ip, HOST_KEY_PATH_ECDSA)
899+
elif self.is_dtv2():
846900
# If host key did not exist, it it generated on the fly,
847901
# Each node will have a different key.
848902
# We will need to check the logs for the public key.
903+
logger.debug("Adding node host keys to known_hosts...")
904+
for node_ip in self.node_ip_list:
905+
self._add_known_hosts_from_log(
906+
self.host_job_run,
907+
LOG_PREFIX_HOST_KEY_RSA + node_ip,
908+
ip_address=node_ip,
909+
)
910+
self._add_known_hosts_from_log(
911+
self.host_job_run,
912+
LOG_PREFIX_HOST_KEY_ECDSA + node_ip,
913+
ip_address=node_ip,
914+
)
915+
else:
916+
logger.debug("Adding job run host keys to known_hosts...")
849917
for run in self.node_runs:
850-
ip_key = self.wait_for_log(run, f"{LOG_PREFIX_NODE_HOST_KEY}")
851-
ip_addr, public_key = ip_key.split("-", 1)
852-
with open(
853-
f"{SSH_DIR}/known_hosts", "a+", encoding=CONST_ENCODING
854-
) as f:
855-
f.write(f"{ip_addr} {public_key}\n")
856-
logger.debug("Added host key for %s", ip_addr)
918+
self._add_known_hosts_from_log(run, LOG_PREFIX_HOST_KEY_RSA)
919+
self._add_known_hosts_from_log(run, LOG_PREFIX_HOST_KEY_ECDSA)
857920

858921
cmd = self.prepare_cmd(launch_args)
859922
# For DeepSpeed, we only need to run the cmd on the host
@@ -913,18 +976,8 @@ class GenericRunner(TorchRunner, DeepSpeedRunner):
913976
LAUNCHER = ""
914977

915978
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
916-
TorchRunner.__init__(self, code_dir)
917-
if self.use_deepspeed():
918-
self.deepspeed_setup()
919-
920-
def use_deepspeed(self):
921-
"""Indicate if DeepSpeed is used."""
922-
# Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
923-
return bool(
924-
os.environ.get(CONST_ENV_DEEPSPEED)
925-
or self.launch_cmd_contains("use_deepspeed")
926-
or self.launch_cmd_contains("deepspeed")
927-
)
979+
super().__init__(code_dir)
980+
logger.debug("Initializing Generic Runner...")
928981

929982
def set_env_var(self):
930983
"""Set default environment variables."""
@@ -975,7 +1028,10 @@ class AccelerateRunner(GenericRunner):
9751028
LAUNCHER = "accelerate launch"
9761029

9771030
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
978-
super().__init__(self, code_dir)
1031+
# Here we need to call GenericRunner.__init__() explicitly
1032+
# to avoid calling the DeepSpeedRunner.__init__().
1033+
super().__init__(code_dir)
1034+
logger.debug("Initializing Accelerate Runner...")
9791035
# For "accelerate launch", only one of the following options can be used at one time
9801036
# `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`.
9811037
# When a config file is not provided,
@@ -1069,7 +1125,7 @@ def main():
10691125
runner_class = AccelerateRunner
10701126
else:
10711127
runner_class = GenericRunner
1072-
1128+
logger.debug("Using %s", str(runner_class))
10731129
runner = runner_class()
10741130

10751131
runner: Runner

0 commit comments

Comments
 (0)