Skip to content

Commit bff7d44

Browse files
committed
🐛 ⬆️ -model-name will use from_pretrained; add docstring
1 parent 330577c commit bff7d44

File tree

1 file changed

+58
-18
lines changed

1 file changed

+58
-18
lines changed

src/transition_amr_parser/parse.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ def argument_parsing():
139139
parser.add_argument(
140140
'-m', '--model-name',
141141
type=str,
142-
help='name of model config (instead of checkpoint) and optionally'
143-
'seed separated by : e.g. amr2.0-structured-bart-large:42'
142+
help="pretrained checkpoint name; will first check if it is already in \
143+
cache, if not will automatically download and save to cache.; \
144+
eg: AMR2-structbart-L \
145+
for complete list of available checkpoint names, see README"
144146
)
145147
parser.add_argument(
146148
'-o', '--out-amr',
@@ -241,13 +243,23 @@ def argument_parsing():
241243

242244
return args
243245

244-
246+
############### HELPER FUNCTIONS ##########################
245247
def ordered_exit(signum, frame):
246248
print("\nStopped by user\n")
247249
exit(0)
248250

249251

250252
def load_models_and_task(args, use_cuda, task=None):
253+
"""Fairseq load from task method
254+
255+
Args:
256+
args (args): args from argparser
257+
use_cuda (bool): _description_
258+
task (FaiseqTask, optional): _description_. Defaults to None.
259+
260+
Returns:
261+
_type_: _description_
262+
"""
251263
# if `task` is not provided, it will be from the saved model args
252264
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
253265
args.path.split(':'),
@@ -274,7 +286,14 @@ def load_models_and_task(args, use_cuda, task=None):
274286

275287

276288
def load_args_from_config(config_path):
277-
"""Load args from bash configuration scripts"""
289+
"""Load args from bash configuration scripts
290+
291+
Args:
292+
config_path (str): the path to a training configuration file
293+
294+
Returns:
295+
dict: dictionary containing config info
296+
"""
278297
# TODO there might be better ways; e.g. source the bash script in python
279298
# and use $BERT_LAYERS directly
280299
config_dict = {}
@@ -489,6 +508,7 @@ def get_sliding_output(tok_sentences, parser, gold_amrs=None,
489508
all_tokens, all_actions, parser.machine_config)
490509
return annotations, machines
491510

511+
############### END HELPER FUNCTIONS ##########################
492512

493513
class AMRParser:
494514
def __init__(
@@ -703,9 +723,26 @@ def from_checkpoint(cls, checkpoint, dict_dir=None,
703723
roberta_cache_path=None, fp16=False,
704724
inspector=None, beam=1, nbest=1, num_samples=None,
705725
sampling_topp=-1, temperature=1.0):
706-
'''
707-
Initialize model from checkpoint
708-
'''
726+
""" Load a checkpoint from model path
727+
728+
Args:
729+
checkpoint (str): path to the model checkpoint
730+
dict_dir (_type_, optional): _description_. Defaults to None.
731+
roberta_cache_path (_type_, optional): _description_. Defaults to None.
732+
fp16 (bool, optional): _description_. Defaults to False.
733+
inspector (_type_, optional): _description_. Defaults to None.
734+
beam (int, optional): _description_. Defaults to 1.
735+
nbest (int, optional): _description_. Defaults to 1.
736+
num_samples (_type_, optional): _description_. Defaults to None.
737+
sampling_topp (int, optional): _description_. Defaults to -1.
738+
temperature (float, optional): _description_. Defaults to 1.0.
739+
740+
Raises:
741+
ValueError: _description_
742+
743+
Returns:
744+
_type_: _description_
745+
"""
709746
# load default args: some are dummy
710747
parser = options.get_interactive_generation_parser()
711748
# model path set here
@@ -1104,19 +1141,21 @@ def save_multiple_files(args, num_sentences, out_path, string_list):
11041141

11051142

11061143
def load_parser(args, inspector):
1144+
""" A meta load to check for loading from model name, or loading
1145+
from checkpoint path
1146+
1147+
Args:
1148+
args (_type_): arguments from arg parser
1149+
inspector (_type_): function to call after each step
1150+
1151+
Returns:
1152+
AMRParser: AMRParser class object
1153+
1154+
"""
11071155

11081156
if args.model_name:
1109-
# load from name and optionally seed
1110-
items = args.model_name.split(':')
1111-
model_name = items[0]
1112-
if len(items) > 1:
1113-
seed = items[1]
1114-
else:
1115-
seed = None
1116-
# load from model/config name
1117-
return AMRParser.load(
1118-
model_name,
1119-
seed=seed,
1157+
return AMRParser.from_pretrained(
1158+
args.model_name,
11201159
roberta_cache_path=args.roberta_cache_path,
11211160
inspector=inspector,
11221161
# selected fairseq decoder arguments
@@ -1127,6 +1166,7 @@ def load_parser(args, inspector):
11271166
# this is not, but implies --sampling
11281167
num_samples=args.num_samples
11291168
)
1169+
11301170
else:
11311171
# load from checkpoint and files in its folder
11321172
return AMRParser.from_checkpoint(

0 commit comments

Comments
 (0)