Skip to content

Commit 9108041

Browse files
bmccannsoumith
authored andcommitted
updates for torchtext and loading from snapshot
1 parent 832141b commit 9108041

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

snli/train.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
args = get_args()
1717

18-
inputs = data.Field()
18+
inputs = data.Field(lower=args.lower)
1919
answers = data.Field(sequential=False)
2020

2121
train, dev, test = datasets.SNLI.splits(inputs, answers)
2222

23-
inputs.build_vocab(train, dev, test, lower=args.lower)
23+
inputs.build_vocab(train, dev, test)
2424
if args.word_vectors:
2525
if os.path.isfile(args.vector_cache):
2626
inputs.vocab.vectors = torch.load(args.vector_cache)
@@ -40,10 +40,14 @@
4040
if config.birnn:
4141
config.n_cells *= 2
4242

43-
model = SNLIClassifier(config)
44-
if args.word_vectors:
45-
model.embed.weight.data = inputs.vocab.vectors
46-
model.cuda()
43+
if args.resume_snapshot:
44+
model = torch.load(args.resume_snapshot, map_location=lambda storage, locatoin: storage.cuda(args.gpu))
45+
else:
46+
model = SNLIClassifier(config)
47+
if args.word_vectors:
48+
model.embed.weight.data = inputs.vocab.vectors
49+
model.cuda(args.gpu)
50+
4751
criterion = nn.CrossEntropyLoss()
4852
opt = O.Adam(model.parameters(), lr=args.lr)
4953

snli/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ def get_args():
2323
parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache'))
2424
parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt'))
2525
parser.add_argument('--word_vectors', type=str, default='glove.42B')
26+
parser.add_argument('--resume_snapshot', type=str, default='results/snapshot.pt')
2627
args = parser.parse_args()
2728
return args

0 commit comments

Comments
 (0)