Skip to content

Commit 1d0cd70

Browse files
committed
Cleanup code
1 parent 77ffd09 commit 1d0cd70

File tree

2 files changed

+13
-8816
lines changed

2 files changed

+13
-8816
lines changed

pg/main.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
env = gym.make(args.env)
2626

27-
state_size, action_size = int(np.prod(env.observation_space.shape)), int(env.action_space.n)
27+
state_size = int(np.prod(env.observation_space.shape))
28+
action_size = int(env.action_space.n)
29+
2830
hidden_size = 50
2931

3032
S, A, H = state_size, action_size, hidden_size
@@ -79,8 +81,12 @@ def G(rewards, start=0, end=None):
7981

8082
s = succ
8183

82-
discounted_rewards = [pow(DISCOUNT, t) * r for t, r in enumerate(rewards)]
83-
cumulative_returns = [G(discounted_rewards, t) for t in range(len(discounted_rewards))]
84+
discounted_rewards = [
85+
pow(DISCOUNT, t) * r for t, r in enumerate(rewards)
86+
]
87+
cumulative_returns = [
88+
G(discounted_rewards, t) for t in range(len(discounted_rewards))
89+
]
8490

8591
states = torch.stack(states).cuda()
8692
state_values = critic(states).view(-1)
@@ -106,4 +112,7 @@ def G(rewards, start=0, end=None):
106112

107113
# turn into list of lists
108114
stats = [list(x) for x in zip(*stats)]
109-
print(DISCOUNT, len([r for r in stats[1] if r >= env.spec.reward_threshold]))
115+
print(
116+
DISCOUNT,
117+
len([r for r in stats[1] if r >= env.spec.reward_threshold]),
118+
)

0 commit comments

Comments
 (0)