Skip to content

Commit 2d0f1c4

Browse files
andreh7soumith
authored andcommitted
added a function makedirs() which works both for python 2 and 3 (pytorch#176)
1 parent 1b26501 commit 2d0f1c4

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

snli/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchtext import datasets
1111

1212
from model import SNLIClassifier
13-
from util import get_args
13+
from util import get_args, makedirs
1414

1515

1616
args = get_args()
@@ -27,7 +27,7 @@
2727
inputs.vocab.vectors = torch.load(args.vector_cache)
2828
else:
2929
inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed)
30-
os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True)
30+
makedirs(os.path.dirname(args.vector_cache))
3131
torch.save(inputs.vocab.vectors, args.vector_cache)
3232
answers.build_vocab(train)
3333

@@ -61,7 +61,7 @@
6161
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
6262
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
6363
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
64-
os.makedirs(args.save_path, exist_ok=True)
64+
makedirs(args.save_path)
6565
print(header)
6666

6767
for epoch in range(args.epochs):

snli/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
import os
22
from argparse import ArgumentParser
33

4+
def makedirs(name):
5+
"""helper function for python 2 and 3 to call os.makedirs()
6+
avoiding an error if the directory to be created already exists"""
7+
8+
import os, errno
9+
10+
try:
11+
os.makedirs(name)
12+
except OSError as ex:
13+
if ex.errno == errno.EEXIST and os.path.isdir(name):
14+
# ignore existing directory
15+
pass
16+
else:
17+
# a different error happened
18+
raise
19+
20+
421
def get_args():
522
parser = ArgumentParser(description='PyTorch/torchtext SNLI example')
623
parser.add_argument('--epochs', type=int, default=50)

0 commit comments

Comments
 (0)