Skip to content

Commit 4f7812b

Browse files
PEP8 changes
1 parent 573d717 commit 4f7812b

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

src/transition_amr_parser/parse.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def argument_parsing():
6161
parser.add_argument(
6262
'--in-doc',
6363
type=str,
64-
help='File with one __tokenized__ sentence per line and one newline separating each doc (and the end)'
64+
help='File with one __tokenized__ sentence per line and one newline '
65+
'separating each doc (and the end)'
6566
)
6667
parser.add_argument(
6768
'--in-amr',
@@ -217,7 +218,8 @@ def argument_parsing():
217218
assert (
218219
bool(args.in_tokenized_sentences) or bool(args.in_amr)
219220
) or bool(args.service) or bool(args.in_doc), \
220-
"Must either specify --in-tokenized-sentences or --in-doc or set --service"
221+
"Must either specify --in-tokenized-sentences or --in-doc or set " \
222+
"--service"
221223

222224
if not (bool(args.model_name) ^ bool(args.in_checkpoint)):
223225
raise Exception("Use either --model-name or --in-checkpoint")
@@ -229,7 +231,10 @@ def argument_parsing():
229231
)
230232
if bool(args.in_actions) and bool(args.in_doc):
231233
print(yellow_font(
232-
'WARNING: Given force actions will be used superseding the force actions generated by the inhouse force action generator. Make sure it is in the correct format '))
234+
'WARNING: Given force actions will be used superseding the force '
235+
'actions generated by the inhouse force action generator. Make '
236+
'sure it is in the correct format '
237+
))
233238

234239
# num samples replaces beam search
235240
if args.num_samples:
@@ -585,14 +590,14 @@ def from_pretrained(cls, model_name, dict_dir=None,
585590
roberta_cache_path=None, fp16=False,
586591
inspector=None, beam=1, nbest=1, num_samples=None,
587592
sampling_topp=-1, temperature=1.0):
588-
""" Load model checkpoints from available model names;
589-
Will check if the model is downloaded to cache, if not, download from cloud storage;
593+
""" Load model checkpoints from available model names;
594+
Will check if the model is downloaded to cache, if not, download from cloud storage;
590595
Below is a list of available modelnames to trun
591596
{
592597
'AMR3.0':'https://s3.us-east.cloud-object-storage.appdomain.cloud/cloud-object-storage-xc-cos-standard-htg/amr3.0-structured-bart-large-neur-al-sampling5.zip'
593598
}
594599
Args:
595-
modelname (str): a model name within our pretrained model library.
600+
modelname (str): a model name within our pretrained model library.
596601
dict_dir (_type_, optional): _description_. Defaults to None.
597602
roberta_cache_path (_type_, optional): _description_. Defaults to None.
598603
fp16 (bool, optional): _description_. Defaults to False.
@@ -1223,7 +1228,6 @@ def run_service(args, parser):
12231228

12241229
def prepare_data(args, parser):
12251230

1226-
force_actions = None
12271231
if args.in_amr:
12281232

12291233
# align mode: read input AMR to be aligned
@@ -1358,8 +1362,8 @@ def main():
13581362
if args.in_actions:
13591363
with open(args.in_actions) as fact:
13601364
force_actions = [eval(line.strip()) + [[]] for line in fact]
1361-
assert len(tok_sentences) == len(
1362-
force_actions), "Number of force actions doesn't match the number of sentences"
1365+
assert len(tok_sentences) == len(force_actions), \
1366+
"Number of force actions doesn't match the number of sentences"
13631367

13641368
# sampling needs copy of force actions N times
13651369
if args.num_samples is not None:
@@ -1376,15 +1380,17 @@ def main():
13761380
start = time.time()
13771381

13781382
if args.in_doc:
1379-
annotations, machines = parser.parse_docs(tok_sentences,
1380-
gold_amrs=gold_amrs,
1381-
window_size=args.window_size,
1382-
window_overlap=args.window_overlap,
1383-
batch_size=args.batch_size,
1384-
roberta_batch_size=args.roberta_batch_size,
1385-
beam=args.beam,
1386-
jamr=args.jamr,
1387-
no_isi=args.no_isi)
1383+
annotations, machines = parser.parse_docs(
1384+
tok_sentences,
1385+
gold_amrs=gold_amrs,
1386+
window_size=args.window_size,
1387+
window_overlap=args.window_overlap,
1388+
batch_size=args.batch_size,
1389+
roberta_batch_size=args.roberta_batch_size,
1390+
beam=args.beam,
1391+
jamr=args.jamr,
1392+
no_isi=args.no_isi
1393+
)
13881394

13891395
else:
13901396

0 commit comments

Comments
 (0)