@@ -139,8 +139,10 @@ def argument_parsing():
139139 parser .add_argument (
140140 '-m' , '--model-name' ,
141141 type = str ,
142- help = 'name of model config (instead of checkpoint) and optionally'
143- 'seed separated by : e.g. amr2.0-structured-bart-large:42'
142+ help = "pretrained checkpoint name; will first check if it is already in \
143+ cache, if not will automatically download and save to cache.; \
144+ eg: AMR2-structbart-L \
145+ for complete list of available checkpoint names, see README"
144146 )
145147 parser .add_argument (
146148 '-o' , '--out-amr' ,
@@ -241,13 +243,23 @@ def argument_parsing():
241243
242244 return args
243245
244-
246+ ############### HELPER FUNCTIONS ##########################
245247def ordered_exit (signum , frame ):
246248 print ("\n Stopped by user\n " )
247249 exit (0 )
248250
249251
250252def load_models_and_task (args , use_cuda , task = None ):
253+ """Fairseq load from task method
254+
255+ Args:
256+ args (args): args from argparser
257+ use_cuda (bool): _description_
258+ task (FaiseqTask, optional): _description_. Defaults to None.
259+
260+ Returns:
261+ _type_: _description_
262+ """
251263 # if `task` is not provided, it will be from the saved model args
252264 models , model_args , task = checkpoint_utils .load_model_ensemble_and_task (
253265 args .path .split (':' ),
@@ -274,7 +286,14 @@ def load_models_and_task(args, use_cuda, task=None):
274286
275287
276288def load_args_from_config (config_path ):
277- """Load args from bash configuration scripts"""
289+ """Load args from bash configuration scripts
290+
291+ Args:
292+ config_path (str): the path to a training configuration file
293+
294+ Returns:
295+ dict: dictionary containing config info
296+ """
278297 # TODO there might be better ways; e.g. source the bash script in python
279298 # and use $BERT_LAYERS directly
280299 config_dict = {}
@@ -489,6 +508,7 @@ def get_sliding_output(tok_sentences, parser, gold_amrs=None,
489508 all_tokens , all_actions , parser .machine_config )
490509 return annotations , machines
491510
511+ ############### END HELPER FUNCTIONS ##########################
492512
493513class AMRParser :
494514 def __init__ (
@@ -703,9 +723,26 @@ def from_checkpoint(cls, checkpoint, dict_dir=None,
703723 roberta_cache_path = None , fp16 = False ,
704724 inspector = None , beam = 1 , nbest = 1 , num_samples = None ,
705725 sampling_topp = - 1 , temperature = 1.0 ):
706- '''
707- Initialize model from checkpoint
708- '''
726+ """ Load a checkpoint from model path
727+
728+ Args:
729+ checkpoint (str): path to the model checkpoint
730+ dict_dir (_type_, optional): _description_. Defaults to None.
731+ roberta_cache_path (_type_, optional): _description_. Defaults to None.
732+ fp16 (bool, optional): _description_. Defaults to False.
733+ inspector (_type_, optional): _description_. Defaults to None.
734+ beam (int, optional): _description_. Defaults to 1.
735+ nbest (int, optional): _description_. Defaults to 1.
736+ num_samples (_type_, optional): _description_. Defaults to None.
737+ sampling_topp (int, optional): _description_. Defaults to -1.
738+ temperature (float, optional): _description_. Defaults to 1.0.
739+
740+ Raises:
741+ ValueError: _description_
742+
743+ Returns:
744+ _type_: _description_
745+ """
709746 # load default args: some are dummy
710747 parser = options .get_interactive_generation_parser ()
711748 # model path set here
@@ -1104,19 +1141,21 @@ def save_multiple_files(args, num_sentences, out_path, string_list):
11041141
11051142
11061143def load_parser (args , inspector ):
1144+ """ A meta load to check for loading from model name, or loading
1145+ from checkpoint path
1146+
1147+ Args:
1148+ args (_type_): arguments from arg parser
1149+ inspector (_type_): function to call after each step
1150+
1151+ Returns:
1152+ AMRParser: AMRParser class object
1153+
1154+ """
11071155
11081156 if args .model_name :
1109- # load from name and optionally seed
1110- items = args .model_name .split (':' )
1111- model_name = items [0 ]
1112- if len (items ) > 1 :
1113- seed = items [1 ]
1114- else :
1115- seed = None
1116- # load from model/config name
1117- return AMRParser .load (
1118- model_name ,
1119- seed = seed ,
1157+ return AMRParser .from_pretrained (
1158+ args .model_name ,
11201159 roberta_cache_path = args .roberta_cache_path ,
11211160 inspector = inspector ,
11221161 # selected fairseq decoder arguments
@@ -1127,6 +1166,7 @@ def load_parser(args, inspector):
11271166 # this is not, but implies --sampling
11281167 num_samples = args .num_samples
11291168 )
1169+
11301170 else :
11311171 # load from checkpoint and files in its folder
11321172 return AMRParser .from_checkpoint (
0 commit comments