Skip to content

Commit b7b94e0

Browse files
Merge branch 'master' into change-dash-m-method
2 parents bff7d44 + 797f37b commit b7b94e0

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

src/transition_amr_parser/parse.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def argument_parsing():
6161
parser.add_argument(
6262
'--in-doc',
6363
type=str,
64-
help='File with one __tokenized__ sentence per line and one newline separating each doc (and the end)'
64+
help='File with one __tokenized__ sentence per line and one newline '
65+
'separating each doc (and the end)'
6566
)
6667
parser.add_argument(
6768
'--in-amr',
@@ -219,7 +220,8 @@ def argument_parsing():
219220
assert (
220221
bool(args.in_tokenized_sentences) or bool(args.in_amr)
221222
) or bool(args.service) or bool(args.in_doc), \
222-
"Must either specify --in-tokenized-sentences or --in-doc or set --service"
223+
"Must either specify --in-tokenized-sentences or --in-doc or set " \
224+
"--service"
223225

224226
if not (bool(args.model_name) ^ bool(args.in_checkpoint)):
225227
raise Exception("Use either --model-name or --in-checkpoint")
@@ -231,7 +233,10 @@ def argument_parsing():
231233
)
232234
if bool(args.in_actions) and bool(args.in_doc):
233235
print(yellow_font(
234-
'WARNING: Given force actions will be used superseding the force actions generated by the inhouse force action generator. Make sure it is in the correct format '))
236+
'WARNING: Given force actions will be used superseding the force '
237+
'actions generated by the inhouse force action generator. Make '
238+
'sure it is in the correct format '
239+
))
235240

236241
# num samples replaces beam search
237242
if args.num_samples:
@@ -605,14 +610,14 @@ def from_pretrained(cls, model_name, dict_dir=None,
605610
roberta_cache_path=None, fp16=False,
606611
inspector=None, beam=1, nbest=1, num_samples=None,
607612
sampling_topp=-1, temperature=1.0):
608-
""" Load model checkpoints from available model names;
609-
Will check if the model is downloaded to cache, if not, download from cloud storage;
613+
""" Load model checkpoints from available model names;
614+
Will check if the model is downloaded to cache, if not, download from cloud storage;
610615
Below is a list of available modelnames to trun
611616
{
612617
'AMR3.0':'https://s3.us-east.cloud-object-storage.appdomain.cloud/cloud-object-storage-xc-cos-standard-htg/amr3.0-structured-bart-large-neur-al-sampling5.zip'
613618
}
614619
Args:
615-
modelname (str): a model name within our pretrained model library.
620+
modelname (str): a model name within our pretrained model library.
616621
dict_dir (_type_, optional): _description_. Defaults to None.
617622
roberta_cache_path (_type_, optional): _description_. Defaults to None.
618623
fp16 (bool, optional): _description_. Defaults to False.
@@ -1263,7 +1268,6 @@ def run_service(args, parser):
12631268

12641269
def prepare_data(args, parser):
12651270

1266-
force_actions = None
12671271
if args.in_amr:
12681272

12691273
# align mode: read input AMR to be aligned
@@ -1398,8 +1402,8 @@ def main():
13981402
if args.in_actions:
13991403
with open(args.in_actions) as fact:
14001404
force_actions = [eval(line.strip()) + [[]] for line in fact]
1401-
assert len(tok_sentences) == len(
1402-
force_actions), "Number of force actions doesn't match the number of sentences"
1405+
assert len(tok_sentences) == len(force_actions), \
1406+
"Number of force actions doesn't match the number of sentences"
14031407

14041408
# sampling needs copy of force actions N times
14051409
if args.num_samples is not None:
@@ -1416,15 +1420,17 @@ def main():
14161420
start = time.time()
14171421

14181422
if args.in_doc:
1419-
annotations, machines = parser.parse_docs(tok_sentences,
1420-
gold_amrs=gold_amrs,
1421-
window_size=args.window_size,
1422-
window_overlap=args.window_overlap,
1423-
batch_size=args.batch_size,
1424-
roberta_batch_size=args.roberta_batch_size,
1425-
beam=args.beam,
1426-
jamr=args.jamr,
1427-
no_isi=args.no_isi)
1423+
annotations, machines = parser.parse_docs(
1424+
tok_sentences,
1425+
gold_amrs=gold_amrs,
1426+
window_size=args.window_size,
1427+
window_overlap=args.window_overlap,
1428+
batch_size=args.batch_size,
1429+
roberta_batch_size=args.roberta_batch_size,
1430+
beam=args.beam,
1431+
jamr=args.jamr,
1432+
no_isi=args.no_isi
1433+
)
14281434

14291435
else:
14301436

0 commit comments

Comments
 (0)