Skip to content

Commit 33318e8

Browse files
committed
commit
1 parent c72e6a3 commit 33318e8

File tree

20 files changed

+8366
-0
lines changed

20 files changed

+8366
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
target_sources(cpprl
2+
PRIVATE
3+
${CMAKE_CURRENT_LIST_DIR}/storage.cpp
4+
)
5+
6+
if (CPPRL_BUILD_TESTS)
7+
target_sources(cpprl_tests
8+
PRIVATE
9+
${CMAKE_CURRENT_LIST_DIR}/storage.cpp
10+
)
11+
endif (CPPRL_BUILD_TESTS)
12+
13+
add_subdirectory(algorithms)
14+
add_subdirectory(distributions)
15+
add_subdirectory(generators)
16+
add_subdirectory(model)
17+
add_subdirectory(third_party)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
target_sources(cpprl
2+
PRIVATE
3+
${CMAKE_CURRENT_LIST_DIR}/a2c.cpp
4+
${CMAKE_CURRENT_LIST_DIR}/ppo.cpp
5+
)
6+
7+
if (CPPRL_BUILD_TESTS)
8+
target_sources(cpprl_tests
9+
PRIVATE
10+
${CMAKE_CURRENT_LIST_DIR}/a2c.cpp
11+
${CMAKE_CURRENT_LIST_DIR}/ppo.cpp
12+
)
13+
endif (CPPRL_BUILD_TESTS)
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
#include <chrono>
2+
#include <memory>
3+
4+
#include <torch/torch.h>
5+
6+
#include "cpprl/algorithms/a2c.h"
7+
#include "cpprl/algorithms/algorithm.h"
8+
#include "cpprl/model/mlp_base.h"
9+
#include "cpprl/model/policy.h"
10+
#include "cpprl/storage.h"
11+
#include "cpprl/spaces.h"
12+
#include "third_party/doctest.h"
13+
14+
namespace cpprl
15+
{
16+
A2C::A2C(Policy &policy,
17+
float value_loss_coef,
18+
float entropy_coef,
19+
float learning_rate,
20+
float epsilon,
21+
float alpha,
22+
float max_grad_norm)
23+
: policy(policy),
24+
value_loss_coef(value_loss_coef),
25+
entropy_coef(entropy_coef),
26+
max_grad_norm(max_grad_norm),
27+
optimizer(std::make_unique<torch::optim::RMSprop>(
28+
policy->parameters(),
29+
torch::optim::RMSpropOptions(learning_rate)
30+
.eps(epsilon)
31+
.alpha(alpha))) {}
32+
33+
std::vector<UpdateDatum> A2C::update(RolloutStorage &rollouts)
34+
{
35+
// Prep work
36+
auto full_obs_shape = rollouts.get_observations().sizes();
37+
std::vector<int64_t> obs_shape(full_obs_shape.begin() + 2,
38+
full_obs_shape.end());
39+
obs_shape.insert(obs_shape.begin(), -1);
40+
auto action_shape = rollouts.get_actions().size(-1);
41+
auto rewards_shape = rollouts.get_rewards().sizes();
42+
int num_steps = rewards_shape[0];
43+
int num_processes = rewards_shape[1];
44+
45+
// Run evaluation on rollouts
46+
auto evaluate_result = policy->evaluate_actions(
47+
rollouts.get_observations().slice(0, 0, -1).view(obs_shape),
48+
rollouts.get_hidden_states()[0].view({-1, policy->get_hidden_size()}),
49+
rollouts.get_masks().slice(0, 0, -1).view({-1, 1}),
50+
rollouts.get_actions().view({-1, action_shape}));
51+
auto values = evaluate_result[0].view({num_steps, num_processes, 1});
52+
auto action_log_probs = evaluate_result[1].view(
53+
{num_steps, num_processes, 1});
54+
55+
// Calculate advantages
56+
// Advantages aren't normalized (they are in PPO)
57+
auto advantages = rollouts.get_returns().slice(0, 0, -1) - values;
58+
59+
// Value loss
60+
auto value_loss = advantages.pow(2).mean();
61+
62+
// Action loss
63+
auto action_loss = -(advantages.detach() * action_log_probs).mean();
64+
65+
// Total loss
66+
auto loss = (value_loss * value_loss_coef +
67+
action_loss -
68+
evaluate_result[2] * entropy_coef);
69+
70+
// Step optimizer
71+
optimizer->zero_grad();
72+
loss.backward();
73+
optimizer->step();
74+
75+
return {{"Value loss", value_loss.item().toFloat()},
76+
{"Action loss", action_loss.item().toFloat()},
77+
{"Entropy", evaluate_result[2].item().toFloat()}};
78+
}
79+
80+
TEST_CASE("A2C")
81+
{
82+
torch::manual_seed(0);
83+
SUBCASE("update() learns basic pattern")
84+
{
85+
auto base = std::make_shared<MlpBase>(1, false, 5);
86+
ActionSpace space{"Discrete", {2}};
87+
Policy policy(space, base);
88+
RolloutStorage storage(5, 2, {1}, space, 5, torch::kCPU);
89+
A2C a2c(policy, 0.5, 1e-3, 0.001);
90+
91+
// The reward is the action
92+
auto pre_game_probs = policy->get_probs(
93+
torch::ones({2, 1}),
94+
torch::zeros({2, 5}),
95+
torch::ones({2, 1}));
96+
97+
for (int i = 0; i < 10; ++i)
98+
{
99+
for (int j = 0; j < 5; ++j)
100+
{
101+
auto observation = torch::randint(0, 2, {2, 1});
102+
103+
std::vector<torch::Tensor> act_result;
104+
{
105+
torch::NoGradGuard no_grad;
106+
act_result = policy->act(observation,
107+
torch::Tensor(),
108+
torch::ones({2, 1}));
109+
}
110+
auto actions = act_result[1];
111+
112+
auto rewards = actions;
113+
storage.insert(observation,
114+
torch::zeros({2, 5}),
115+
actions,
116+
act_result[2],
117+
act_result[0],
118+
rewards,
119+
torch::ones({2, 1}));
120+
}
121+
122+
torch::Tensor next_value;
123+
{
124+
torch::NoGradGuard no_grad;
125+
next_value = policy->get_values(
126+
storage.get_observations()[-1],
127+
storage.get_hidden_states()[-1],
128+
storage.get_masks()[-1])
129+
.detach();
130+
}
131+
storage.compute_returns(next_value, false, 0., 0.9);
132+
133+
a2c.update(storage);
134+
storage.after_update();
135+
}
136+
137+
auto post_game_probs = policy->get_probs(
138+
torch::ones({2, 1}),
139+
torch::zeros({2, 5}),
140+
torch::ones({2, 1}));
141+
142+
INFO("Pre-training probabilities: \n"
143+
<< pre_game_probs << "\n");
144+
INFO("Post-training probabilities: \n"
145+
<< post_game_probs << "\n");
146+
CHECK(post_game_probs[0][0].item().toDouble() <
147+
pre_game_probs[0][0].item().toDouble());
148+
CHECK(post_game_probs[0][1].item().toDouble() >
149+
pre_game_probs[0][1].item().toDouble());
150+
}
151+
152+
SUBCASE("update() learns basic game")
153+
{
154+
auto base = std::make_shared<MlpBase>(1, false, 5);
155+
ActionSpace space{"Discrete", {2}};
156+
Policy policy(space, base);
157+
RolloutStorage storage(5, 2, {1}, space, 5, torch::kCPU);
158+
A2C a2c(policy, 0.5, 1e-7, 0.0001);
159+
160+
// The game is: If the action matches the input, give a reward of 1, otherwise -1
161+
auto pre_game_probs = policy->get_probs(
162+
torch::ones({2, 1}),
163+
torch::zeros({2, 5}),
164+
torch::ones({2, 1}));
165+
166+
auto observation = torch::randint(0, 2, {2, 1});
167+
storage.set_first_observation(observation);
168+
169+
for (int i = 0; i < 10; ++i)
170+
{
171+
for (int j = 0; j < 5; ++j)
172+
{
173+
std::vector<torch::Tensor> act_result;
174+
{
175+
torch::NoGradGuard no_grad;
176+
act_result = policy->act(observation,
177+
torch::Tensor(),
178+
torch::ones({2, 1}));
179+
}
180+
auto actions = act_result[1];
181+
182+
auto rewards = ((actions == observation.to(torch::kLong)).to(torch::kFloat) * 2) - 1;
183+
observation = torch::randint(0, 2, {2, 1});
184+
storage.insert(observation,
185+
torch::zeros({2, 5}),
186+
actions,
187+
act_result[2],
188+
act_result[0],
189+
rewards,
190+
torch::ones({2, 1}));
191+
}
192+
193+
torch::Tensor next_value;
194+
{
195+
torch::NoGradGuard no_grad;
196+
next_value = policy->get_values(
197+
storage.get_observations()[-1],
198+
storage.get_hidden_states()[-1],
199+
storage.get_masks()[-1])
200+
.detach();
201+
}
202+
storage.compute_returns(next_value, false, 0.1, 0.9);
203+
204+
a2c.update(storage);
205+
storage.after_update();
206+
}
207+
208+
auto post_game_probs = policy->get_probs(
209+
torch::ones({2, 1}),
210+
torch::zeros({2, 5}),
211+
torch::ones({2, 1}));
212+
213+
INFO("Pre-training probabilities: \n"
214+
<< pre_game_probs << "\n");
215+
INFO("Post-training probabilities: \n"
216+
<< post_game_probs << "\n");
217+
CHECK(post_game_probs[0][0].item().toDouble() <
218+
pre_game_probs[0][0].item().toDouble());
219+
CHECK(post_game_probs[0][1].item().toDouble() >
220+
pre_game_probs[0][1].item().toDouble());
221+
}
222+
}
223+
}

0 commit comments

Comments
 (0)