Skip to content

Commit a4e6972

Browse files
PeterChe1990soumith
authored andcommitted
Fix test data in time_sequence_prediction (#186)
Previously the training data input[:3] is incorrectly taken as the test data. Change the input for prediction to data[:3].
1 parent 10b22dc commit a4e6972

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

time_sequence_prediction/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def forward(self, input, future = 0):
4242
data = torch.load('traindata.pt')
4343
input = Variable(torch.from_numpy(data[3:, :-1]), requires_grad=False)
4444
target = Variable(torch.from_numpy(data[3:, 1:]), requires_grad=False)
45+
test_input = Variable(torch.from_numpy(data[:3, :-1]), requires_grad=False)
46+
test_target = Variable(torch.from_numpy(data[:3, 1:]), requires_grad=False)
4547
# build the model
4648
seq = Sequence()
4749
seq.double()
@@ -61,7 +63,9 @@ def closure():
6163
optimizer.step(closure)
6264
# begin to predict
6365
future = 1000
64-
pred = seq(input[:3], future = future)
66+
pred = seq(test_input, future = future)
67+
loss = criterion(pred[:, :-future], test_target)
68+
print('test loss:', loss.data.numpy()[0])
6569
y = pred.data.numpy()
6670
# draw the result
6771
plt.figure(figsize=(30,10))

0 commit comments

Comments
 (0)