Skip to content

Commit 888ec84

Browse files
committed
misc
1 parent b516a80 commit 888ec84

File tree

8 files changed

+10
-29
lines changed

8 files changed

+10
-29
lines changed

dqn/args.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,14 @@
33

44
import argparse
55

6+
67
parser = argparse.ArgumentParser()
78

89
parser.add_argument("-b", "--batch_size", type=int, default=128)
9-
1010
parser.add_argument("-r", "--replay_buffer_size", type=int, default=10 ** 4)
11-
12-
1311
parser.add_argument("-i", "--iterations", type=int, default=10 ** 3)
14-
1512
parser.add_argument("-d", "--discount_rate", "--gamma", type=float, default=0.999)
16-
1713
parser.add_argument("-e", "--exploration_rate", "--epsilon", type=float, default=0.9)
18-
1914
parser.add_argument("-l", "--lr", "--learning_rate", type=float, default=1e-7)
2015

2116
args = parser.parse_args()
22-
23-
if __name__ == "__main__":
24-
pass

dqn/dqn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# -*- coding: utf-8 -*-
33

44
import random
5-
from copy import deepcopy
65

76
import gym
87
import numpy as np
@@ -50,10 +49,11 @@ def __init__(self, env):
5049

5150
S = self.state_size = int(np.product(env.observation_space.shape))
5251
A = self.action_size = env.action_space.n
52+
H = 50
5353

54-
self.fc1 = nn.Linear(S, 50)
55-
self.fc2 = nn.Linear(50, 50)
56-
self.fc3 = nn.Linear(50, A)
54+
self.fc1 = nn.Linear(S, H)
55+
self.fc2 = nn.Linear(H, H)
56+
self.fc3 = nn.Linear(H, A)
5757

5858
self.loss = nn.functional.mse_loss
5959
self.opt = torch.optim.Adam(self.parameters())

dqn/exploration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,4 @@ def decay_exploration(i, epsilon=epsilon):
3434

3535

3636
if __name__ == "__main__":
37-
pass
3837
print(epsilon_greedy(env.observation_space.sample()))

dqn/main.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
import torchvision.datasets as dset
99
import torchvision.transforms as transforms
1010
from torch import nn
11-
from torch.autograd import Variable as V
12-
from torch.nn import Parameter as P
1311
from torch.utils.data import DataLoader
14-
from torch.autograd import Variable
1512

1613
from args import args
1714
from env import env

dqn/now

Lines changed: 0 additions & 1 deletion
This file was deleted.

dqn/replay_buffer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,3 @@
55

66
REPLAY_SIZE = 10 ** 6
77
replay_buffer = deque(maxlen=REPLAY_SIZE)
8-
9-
if __name__ == "__main__":
10-
pass

dqn/train.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,3 @@ def train(buffer, Q):
6767

6868
states.volatile, td_estimates.volatile = False, False
6969
return states, td_estimates
70-
71-
72-
if __name__ == "__main__":
73-
pass

xor_lstm/xor_lstm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import torch
1010
import torch.nn.functional as F
1111
import torch.optim as optim
12-
from torch import Tensor, nn
12+
from torch import nn
1313
from torch.autograd import Variable
1414
from torch.nn import Linear, ReLU, Softmax
1515

16+
1617
CUDA_AVAILABLE = torch.cuda.is_available()
1718

1819
N = NUM_SAMPLES = 100_000
@@ -25,7 +26,7 @@
2526

2627

2728
def foldr(arr: np.ndarray, op) -> np.ndarray:
28-
"""Specific version of foldr that's only for Numpy arrays"""
29+
"""Specific version of foldr that's only for Numpy arrays."""
2930

3031
return np.fromiter(itertools.accumulate(arr, op), dtype=np.float32, count=len(arr))
3132

@@ -75,8 +76,8 @@ def argmax(tensor, dim=1):
7576
# Hack to check if we've already trained a model (assumed to be a good one.
7677
model_path = Path(f"model-{OP.__name__}" + ("cuda" if CUDA_AVAILABLE else "") + ".pth")
7778

78-
test_mode = model_path.exists()
79-
train_mode = not test_mode
79+
test_mode: bool = model_path.exists()
80+
train_mode: bool = not test_mode
8081

8182
if test_mode:
8283
model = torch.load(model_path)

0 commit comments

Comments
 (0)