Skip to content

Commit bf475ff

Browse files
committed
Use clearer function
1 parent 7521f07 commit bf475ff

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pg/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@ def G(rewards, start=0, end=None):
6969
]
7070

7171
states = torch.stack(states)
72-
state_values = critic(states).reshape(-1)
72+
state_values = critic(states).flatten()
7373

7474
cumulative_returns = tensor(cumulative_returns)
7575
Adv = cumulative_returns - state_values
7676

77-
log_probs = torch.stack(log_probs).reshape(-1)
77+
log_probs = torch.stack(log_probs).flatten()
7878

7979
loss = -(Adv @ log_probs) / len(rewards)
8080
if episode > 500 and loss.item() < -1000:

ppo/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def train(model, old_model, data) -> float:
122122
[sum(discounted_rewards[t:]) for t, _ in enumerate(discounted_rewards)]
123123
)
124124

125-
state_values = model.vf(states).reshape(-1)
125+
state_values = model.vf(states).flatten()
126126

127127
adv = cumulative_returns - state_values
128128
vf_loss = F.mse_loss(state_values, cumulative_returns)

0 commit comments

Comments
 (0)