7676LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
7777LOG_PREFIX_NODE_IP = "Node IP: "
7878LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
79- LOG_PREFIX_NODE_HOST_KEY = "NODE HOST KEY: "
79+ LOG_PREFIX_HOST_KEY_RSA = "NODE HOST KEY RSA: "
80+ LOG_PREFIX_HOST_KEY_ECDSA = "NODE HOST KEY ECDSA: "
8081# Other constants used within this script
81- NODE_HOST_KEY_PATH = "/etc/ssh/ssh_host_rsa_key.pub"
82+ HOST_KEY_PATH_RSA = "/etc/ssh/ssh_host_rsa_key.pub"
83+ HOST_KEY_PATH_ECDSA = "/etc/ssh/ssh_host_ecdsa_key.pub"
8284USER_HOME = os .environ .get ("HOME" , f"/home/{ getpass .getuser ()} " )
8385SSH_DIR = os .environ .get ("OCI__SSH_DIR" , os .path .join (USER_HOME , ".ssh" ))
8486DEFAULT_LAUNCHER = "torchrun"
@@ -155,7 +157,8 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
155157
156158 # IP address of other nodes as a list
157159 self .node_ip_list = []
158- self .node_runs = []
160+ # For DTv2, node_runs should not be used.
161+ self .node_runs = None
159162 self .host_ocid = None
160163 self .host_job_run = None
161164
@@ -204,6 +207,9 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
204207
205208 logger .debug ("Runner initialized." )
206209
210+ def is_dtv2 (self ):
211+ return CONST_ENV_META_FILE in os .environ
212+
207213 def launch_cmd_contains (self , arg ) -> bool :
208214 """Checks if the cmd for launching the training contains specific keyword argument."""
209215 return f"--{ arg } " in self .launch_cmd
@@ -250,7 +256,7 @@ def wait_for_ip_address(self, job_run, timeout=15 * 60) -> str:
250256 logger .info ("IP of %s: %s" , job_run .id [- 6 :], ip_address )
251257 return ip_address
252258
253- def wait_for_log (self , job_run , log_prefix , timeout = 15 * 60 ) -> str :
259+ def wait_for_log (self , job_run , log_prefix , timeout = 15 * 60 , limit = 1 ) -> str :
254260 """Waits until a log message with specific prefix is found in the logs of a job run.
255261
256262 Parameters
@@ -276,10 +282,11 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
276282 "Waiting for logs with prefix '%s' from %s." , log_prefix , job_run .id
277283 )
278284 second_started = time .time ()
279- log = None
280- while not log :
281- log = self .check_job_run_logs (job_run = job_run , log_prefix = log_prefix )
282- if log :
285+ logs = None
286+ while True :
287+ logs = self .check_job_run_logs (job_run = job_run , log_prefix = log_prefix )
288+ if logs and len (logs ) >= limit :
289+ logs = logs [:limit ]
283290 break
284291 if time .time () - second_started > timeout :
285292 logs = job_run .logs ()
@@ -289,10 +296,12 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
289296 f"Last log obtained: { last_log } "
290297 )
291298 time .sleep (60 )
292- return log
299+ if limit == 1 :
300+ return logs [0 ]
301+ return logs
293302
294303 @staticmethod
295- def check_job_run_logs (job_run , log_prefix : str ) -> str :
304+ def check_job_run_logs (job_run , log_prefix : str ) -> list :
296305 """Checks the logs of a specific job run and find the log message with specific prefix.
297306
298307 Parameters
@@ -309,10 +318,12 @@ def check_job_run_logs(job_run, log_prefix: str) -> str:
309318 """
310319 logger .debug ("Checking logs for job run %s" , job_run .id )
311320 logs = job_run .logs ()
312- for log in logs :
313- if log ["message" ].startswith (log_prefix ):
314- return log ["message" ][len (log_prefix ) :]
315- return None
321+ logs = [
322+ log ["message" ][len (log_prefix ) :]
323+ for log in logs
324+ if log ["message" ].startswith (log_prefix )
325+ ]
326+ return logs
316327
317328 def find_self_ip (self ):
318329 """
@@ -408,6 +419,8 @@ def read_metadata(self):
408419 )
409420 time .sleep (20 )
410421 continue
422+ logger .debug ("All nodes are found in metadata file." )
423+ logger .debug (node_list )
411424 return {int (meta [CONST_RANK ]): meta [CONST_IP_ADDRESS ] for meta in node_list }
412425
413426 def fetch_code (self ):
@@ -513,6 +526,7 @@ class TorchRunner(Runner):
513526
514527 def __init__ (self , code_dir : str = driver_utils .DEFAULT_CODE_DIR ) -> None :
515528 super ().__init__ (code_dir )
529+ logger .debug ("Initializing Torch Runner..." )
516530 self .build_c_library ()
517531
518532 def build_c_library (self ):
@@ -588,20 +602,30 @@ class DeepSpeedRunner(Runner):
588602
589603 def __init__ (self , code_dir : str = driver_utils .DEFAULT_CODE_DIR ) -> None :
590604 super ().__init__ (code_dir )
591- self .host_key = None
592- self .deepspeed_setup ()
605+ logger .debug ("Initializing DeepSpeed Runner..." )
606+ # Setup DeepSpeed if it used.
607+ if self .use_deepspeed ():
608+ self .host_key = None
609+ self .deepspeed_setup ()
610+
611+ def use_deepspeed (self ):
612+ """Indicate if DeepSpeed is used."""
613+ # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
614+ return bool (
615+ os .environ .get (CONST_ENV_DEEPSPEED )
616+ or self .launch_cmd_contains ("use_deepspeed" )
617+ or self .launch_cmd_contains ("deepspeed" )
618+ )
593619
594620 def deepspeed_setup (self ):
595621 """Setup for DeepSpeed."""
596- self .host_key = (
597- NODE_HOST_KEY_PATH if os .path .exists (NODE_HOST_KEY_PATH ) else None
598- )
622+ self .host_key = HOST_KEY_PATH_RSA if os .path .exists (HOST_KEY_PATH_RSA ) else None
599623 # Create the temp dir if one does not exist.
600624 # This is needed for JIT
601625 if self .TMPDIR and not os .path .isdir (self .TMPDIR ):
602626 logger .info ("Creating temp directory: %s" , self .TMPDIR )
603627 os .makedirs (self .TMPDIR , exist_ok = True )
604- self .install_dependencies ()
628+ self .install_deepspeed_dependencies ()
605629 # host_job_run is needed for DeepSpeed to fetch the public SSH key from the logs.
606630 if self .host_ocid and self .node_count > 1 :
607631 self .host_job_run = DataScienceJobRun .from_ocid (self .host_ocid )
@@ -621,7 +645,35 @@ def install_epel(self):
621645 )
622646 break
623647
624- def install_dependencies (self ):
648+ def _print_host_key (self , host_key_path , prefix ):
649+ with open (host_key_path , encoding = CONST_ENCODING ) as f :
650+ public_key = f .read ()
651+ print (f"{ prefix } { self .ip } -{ public_key } " )
652+
653+ def _add_known_hosts_from_file (self , ip_addr , key_file ):
654+ if not os .path .exists (key_file ):
655+ logger .warning (
656+ "Unable to add host key %s to known_hosts: key file not found." ,
657+ key_file ,
658+ )
659+ return
660+ self .run_command (
661+ f"echo -n '{ ip_addr } ' | " f"cat - { key_file } >> { SSH_DIR } /known_hosts" ,
662+ level = logging .DEBUG ,
663+ check = True ,
664+ )
665+
666+ def _add_known_hosts_from_log (self , job_run , prefix , ip_address = None ):
667+ ip_key = self .wait_for_log (job_run , f"{ prefix } " )
668+ ip_addr , public_key = ip_key .split ("-" , 1 )
669+ if ip_address :
670+ ip_addr = ip_address
671+ with open (f"{ SSH_DIR } /known_hosts" , "a+" , encoding = CONST_ENCODING ) as f :
672+ line = f"{ ip_addr } { public_key } "
673+ f .write (f"{ line } \n " )
674+ logger .debug ("Added host key: %s" , line )
675+
676+ def install_deepspeed_dependencies (self ):
625677 """Installs extra dependencies and start SSH service."""
626678 if self .node_count == 1 :
627679 logger .debug (
@@ -637,9 +689,8 @@ def install_dependencies(self):
637689 else :
638690 # Generate SSH host keys for SSH server
639691 self .run_command ("sudo ssh-keygen -A" , level = logging .DEBUG , check = True )
640- with open (NODE_HOST_KEY_PATH , encoding = CONST_ENCODING ) as f :
641- public_key = f .read ()
642- print (f"{ LOG_PREFIX_NODE_HOST_KEY } { self .ip } -{ public_key } " )
692+ self ._print_host_key (HOST_KEY_PATH_RSA , LOG_PREFIX_HOST_KEY_RSA )
693+ self ._print_host_key (HOST_KEY_PATH_ECDSA , LOG_PREFIX_HOST_KEY_ECDSA )
643694
644695 if self .run_command ("which pdsh" , level = logging .DEBUG ) != 0 :
645696 # Install "openssh-server" to accept SSH connections
@@ -670,35 +721,29 @@ def generate_key_pair(self):
670721 with open (os .path .join (SSH_DIR , "id_rsa.pub" ), encoding = CONST_ENCODING ) as f :
671722 public_key = f .read ()
672723 print (f"{ LOG_PREFIX_PUBLIC_KEY } { public_key } " , flush = True )
673- self .add_authoried_key (public_key )
674- # Public
675- self .run_command (
676- f"echo -n '{ self .host_ip } ' | "
677- f"cat - { NODE_HOST_KEY_PATH } >> { SSH_DIR } /known_hosts" ,
678- level = logging .DEBUG ,
679- check = True ,
680- )
724+ self ._add_authoried_key (public_key )
725+ # Add host key to known hosts
726+ self ._add_known_hosts_from_file (self .host_ip , HOST_KEY_PATH_RSA )
727+ self ._add_known_hosts_from_file (self .host_ip , HOST_KEY_PATH_ECDSA )
681728 self .test_ssh_connection (self .host_ip )
682729 # Check DeepSpeed compatibility
683730 self .run_command (
684731 "ds_report" , conda_prefix = self .conda_prefix , level = logging .DEBUG
685732 )
686733 return self
687734
688- @staticmethod
689- def add_authoried_key (public_key ):
735+ def _add_authoried_key (self , public_key ):
690736 auth_keys_file = os .path .join (SSH_DIR , "authorized_keys" )
691737 os .makedirs (SSH_DIR , exist_ok = True )
692738 with open (auth_keys_file , "a+" , encoding = CONST_ENCODING ) as f :
693739 f .write (public_key )
694740 f .write ("\n " )
695- logger .debug ("Public key saved to %s" , auth_keys_file )
741+ logger .debug ("Public key saved to %s:%s" , self . ip , auth_keys_file )
696742
697743 def fetch_host_public_key (self ):
698744 public_key = self .wait_for_log (self .host_job_run , LOG_PREFIX_PUBLIC_KEY )
699745 print (f"{ LOG_PREFIX_PUBLIC_KEY } { public_key } " , flush = True )
700- # logger.debug("%s", LOG_PREFIX_PUBLIC_KEY + public_key)
701- self .add_authoried_key (public_key )
746+ self ._add_authoried_key (public_key )
702747
703748 def generate_hostfile (self ):
704749 if not self .node_ip_list :
@@ -716,9 +761,12 @@ def generate_hostfile(self):
716761 # Hostfile
717762 logger .debug ("Writing hostfile to %s" , self .HOST_FILE )
718763 os .makedirs (os .path .dirname (self .HOST_FILE ), exist_ok = True )
719- host_file_content = [f"{ ip } slots={ self .gpu_count } " for ip in self .node_ip_list ]
764+ host_file_content = [
765+ f"{ ip } slots={ self .gpu_count } \n " for ip in self .node_ip_list
766+ ]
720767 with open (self .HOST_FILE , "w" , encoding = CONST_ENCODING ) as f :
721- f .write (f"{ self .host_ip } slots={ self .gpu_count } \n " )
768+ if self .host_ip not in self .node_ip_list :
769+ f .write (f"{ self .host_ip } slots={ self .gpu_count } \n " )
722770 f .writelines (host_file_content )
723771 self .run_command (f"cat { self .HOST_FILE } " , level = logging .DEBUG )
724772 # SSH config
@@ -727,16 +775,18 @@ def generate_hostfile(self):
727775 with open (ssh_config_path , "w" , encoding = CONST_ENCODING ) as f :
728776 f .writelines (
729777 [
730- "" ,
731- f"Host { self .host_ip } " ,
778+ "\n " ,
779+ f"Host { self .host_ip } \n " ,
732780 "KexAlgorithms diffie-hellman-group-exchange-sha256\n " ,
733781 ]
734782 )
735783 for node_ip in self .node_ip_list :
784+ if node_ip == self .host_ip :
785+ continue
736786 f .writelines (
737787 [
738- "" ,
739- f"Host { node_ip } " ,
788+ "\n " ,
789+ f"Host { node_ip } \n " ,
740790 "KexAlgorithms diffie-hellman-group-exchange-sha256\n " ,
741791 ]
742792 )
@@ -758,9 +808,8 @@ def touch_file(self, filename):
758808 for node_ip in self .node_ip_list :
759809 logger .debug ("Sending stop file to %s" , node_ip )
760810 self .run_command (
761- f"ssh -v { node_ip } 'touch { filename } '" ,
811+ f"ssh -v -o PasswordAuthentication=no { node_ip } 'touch { filename } '" ,
762812 level = logging .DEBUG ,
763- check = True ,
764813 )
765814
766815 def save_deepspeed_env (self ):
@@ -815,6 +864,9 @@ def save_deepspeed_env(self):
815864 logger .debug ("Environment variables saved to %s" , self .ENV_FILE )
816865 self .run_command (f"cat { self .ENV_FILE } " )
817866
867+ def wait_for_nodes (self ):
868+ pass
869+
818870 def run_deepspeed_host (self , launch_args = None ):
819871 """Prepares the host and launch the deepspeed training.
820872
@@ -830,30 +882,41 @@ def run_deepspeed_host(self, launch_args=None):
830882 self .generate_key_pair ().generate_hostfile ()
831883 self .save_deepspeed_env ()
832884 # Wait for nodes to be ready
833- for run in self .node_runs :
834- self .wait_for_log (run , LOG_PREFIX_PUBLIC_KEY )
885+ # For DTv2, self.node_runs will be None
886+ if self .is_dtv2 ():
887+ self .wait_for_log (
888+ self .host_job_run , LOG_PREFIX_PUBLIC_KEY , limit = self .node_count
889+ )
890+ else :
891+ for run in self .node_runs :
892+ self .wait_for_log (run , LOG_PREFIX_PUBLIC_KEY )
835893
836894 if self .host_key :
837895 # If host key exists, it should be the same for all nodes.
838896 for node_ip in self .node_ip_list :
839- self .run_command (
840- f"echo -n '{ node_ip } ' | "
841- f"cat - /etc/ssh/ssh_host_rsa_key.pub >> { SSH_DIR } /known_hosts" ,
842- level = logging .DEBUG ,
843- check = True ,
844- )
845- else :
897+ self ._add_known_hosts_from_file (node_ip , HOST_KEY_PATH_RSA )
898+ self ._add_known_hosts_from_file (node_ip , HOST_KEY_PATH_ECDSA )
899+ elif self .is_dtv2 ():
846900 # If host key did not exist, it it generated on the fly,
847901 # Each node will have a different key.
848902 # We will need to check the logs for the public key.
903+ logger .debug ("Adding node host keys to known_hosts..." )
904+ for node_ip in self .node_ip_list :
905+ self ._add_known_hosts_from_log (
906+ self .host_job_run ,
907+ LOG_PREFIX_HOST_KEY_RSA + node_ip ,
908+ ip_address = node_ip ,
909+ )
910+ self ._add_known_hosts_from_log (
911+ self .host_job_run ,
912+ LOG_PREFIX_HOST_KEY_ECDSA + node_ip ,
913+ ip_address = node_ip ,
914+ )
915+ else :
916+ logger .debug ("Adding job run host keys to known_hosts..." )
849917 for run in self .node_runs :
850- ip_key = self .wait_for_log (run , f"{ LOG_PREFIX_NODE_HOST_KEY } " )
851- ip_addr , public_key = ip_key .split ("-" , 1 )
852- with open (
853- f"{ SSH_DIR } /known_hosts" , "a+" , encoding = CONST_ENCODING
854- ) as f :
855- f .write (f"{ ip_addr } { public_key } \n " )
856- logger .debug ("Added host key for %s" , ip_addr )
918+ self ._add_known_hosts_from_log (run , LOG_PREFIX_HOST_KEY_RSA )
919+ self ._add_known_hosts_from_log (run , LOG_PREFIX_HOST_KEY_ECDSA )
857920
858921 cmd = self .prepare_cmd (launch_args )
859922 # For DeepSpeed, we only need to run the cmd on the host
@@ -913,18 +976,8 @@ class GenericRunner(TorchRunner, DeepSpeedRunner):
913976 LAUNCHER = ""
914977
915978 def __init__ (self , code_dir : str = driver_utils .DEFAULT_CODE_DIR ) -> None :
916- TorchRunner .__init__ (self , code_dir )
917- if self .use_deepspeed ():
918- self .deepspeed_setup ()
919-
920- def use_deepspeed (self ):
921- """Indicate if DeepSpeed is used."""
922- # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
923- return bool (
924- os .environ .get (CONST_ENV_DEEPSPEED )
925- or self .launch_cmd_contains ("use_deepspeed" )
926- or self .launch_cmd_contains ("deepspeed" )
927- )
979+ super ().__init__ (code_dir )
980+ logger .debug ("Initializing Generic Runner..." )
928981
929982 def set_env_var (self ):
930983 """Set default environment variables."""
@@ -975,7 +1028,10 @@ class AccelerateRunner(GenericRunner):
9751028 LAUNCHER = "accelerate launch"
9761029
9771030 def __init__ (self , code_dir : str = driver_utils .DEFAULT_CODE_DIR ) -> None :
978- super ().__init__ (self , code_dir )
1031+ # Here we need to call GenericRunner.__init__() explicitly
1032+ # to avoid calling the DeepSpeedRunner.__init__().
1033+ super ().__init__ (code_dir )
1034+ logger .debug ("Initializing Accelerate Runner..." )
9791035 # For "accelerate launch", only one of the following options can be used at one time
9801036 # `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`.
9811037 # When a config file is not provided,
@@ -1069,7 +1125,7 @@ def main():
10691125 runner_class = AccelerateRunner
10701126 else :
10711127 runner_class = GenericRunner
1072-
1128+ logger . debug ( "Using %s" , str ( runner_class ))
10731129 runner = runner_class ()
10741130
10751131 runner : Runner
0 commit comments