Skip to content

Commit a723598

Browse files
boathitsoumith
authored andcommitted
change lr to 0.8 to fix the issue for 0.2
1 parent 407bd3e commit a723598

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

time_sequence_prediction/train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import print_function
22
import torch
3-
import torch.nn as nn
3+
import torch.nn as nn
44
from torch.autograd import Variable
55
import torch.optim as optim
66
import numpy as np
@@ -49,7 +49,7 @@ def forward(self, input, future = 0):
4949
seq.double()
5050
criterion = nn.MSELoss()
5151
# use LBFGS as optimizer since we can load the whole data to train
52-
optimizer = optim.LBFGS(seq.parameters())
52+
optimizer = optim.LBFGS(seq.parameters(), lr=0.8)
5353
#begin to train
5454
for i in range(15):
5555
print('STEP: ', i)
@@ -69,7 +69,7 @@ def closure():
6969
y = pred.data.numpy()
7070
# draw the result
7171
plt.figure(figsize=(30,10))
72-
plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
72+
plt.title('Predict future values for time sequences\n(Dashlines are predicted values)', fontsize=30)
7373
plt.xlabel('x', fontsize=20)
7474
plt.ylabel('y', fontsize=20)
7575
plt.xticks(fontsize=20)
@@ -82,4 +82,3 @@ def draw(yi, color):
8282
draw(y[2], 'b')
8383
plt.savefig('predict%d.pdf'%i)
8484
plt.close()
85-

0 commit comments

Comments
 (0)