1+ #include < chrono>
2+ #include < memory>
3+
4+ #include < torch/torch.h>
5+
6+ #include " cpprl/algorithms/a2c.h"
7+ #include " cpprl/algorithms/algorithm.h"
8+ #include " cpprl/model/mlp_base.h"
9+ #include " cpprl/model/policy.h"
10+ #include " cpprl/storage.h"
11+ #include " cpprl/spaces.h"
12+ #include " third_party/doctest.h"
13+
14+ namespace cpprl
15+ {
16+ A2C::A2C (Policy &policy,
17+ float value_loss_coef,
18+ float entropy_coef,
19+ float learning_rate,
20+ float epsilon,
21+ float alpha,
22+ float max_grad_norm)
23+ : policy(policy),
24+ value_loss_coef (value_loss_coef),
25+ entropy_coef(entropy_coef),
26+ max_grad_norm(max_grad_norm),
27+ optimizer(std::make_unique<torch::optim::RMSprop>(
28+ policy->parameters (),
29+ torch::optim::RMSpropOptions(learning_rate)
30+ .eps(epsilon)
31+ .alpha(alpha))) {}
32+
33+ std::vector<UpdateDatum> A2C::update (RolloutStorage &rollouts)
34+ {
35+ // Prep work
36+ auto full_obs_shape = rollouts.get_observations ().sizes ();
37+ std::vector<int64_t > obs_shape (full_obs_shape.begin () + 2 ,
38+ full_obs_shape.end ());
39+ obs_shape.insert (obs_shape.begin (), -1 );
40+ auto action_shape = rollouts.get_actions ().size (-1 );
41+ auto rewards_shape = rollouts.get_rewards ().sizes ();
42+ int num_steps = rewards_shape[0 ];
43+ int num_processes = rewards_shape[1 ];
44+
45+ // Run evaluation on rollouts
46+ auto evaluate_result = policy->evaluate_actions (
47+ rollouts.get_observations ().slice (0 , 0 , -1 ).view (obs_shape),
48+ rollouts.get_hidden_states ()[0 ].view ({-1 , policy->get_hidden_size ()}),
49+ rollouts.get_masks ().slice (0 , 0 , -1 ).view ({-1 , 1 }),
50+ rollouts.get_actions ().view ({-1 , action_shape}));
51+ auto values = evaluate_result[0 ].view ({num_steps, num_processes, 1 });
52+ auto action_log_probs = evaluate_result[1 ].view (
53+ {num_steps, num_processes, 1 });
54+
55+ // Calculate advantages
56+ // Advantages aren't normalized (they are in PPO)
57+ auto advantages = rollouts.get_returns ().slice (0 , 0 , -1 ) - values;
58+
59+ // Value loss
60+ auto value_loss = advantages.pow (2 ).mean ();
61+
62+ // Action loss
63+ auto action_loss = -(advantages.detach () * action_log_probs).mean ();
64+
65+ // Total loss
66+ auto loss = (value_loss * value_loss_coef +
67+ action_loss -
68+ evaluate_result[2 ] * entropy_coef);
69+
70+ // Step optimizer
71+ optimizer->zero_grad ();
72+ loss.backward ();
73+ optimizer->step ();
74+
75+ return {{" Value loss" , value_loss.item ().toFloat ()},
76+ {" Action loss" , action_loss.item ().toFloat ()},
77+ {" Entropy" , evaluate_result[2 ].item ().toFloat ()}};
78+ }
79+
80+ TEST_CASE (" A2C" )
81+ {
82+ torch::manual_seed (0 );
83+ SUBCASE (" update() learns basic pattern" )
84+ {
85+ auto base = std::make_shared<MlpBase>(1 , false , 5 );
86+ ActionSpace space{" Discrete" , {2 }};
87+ Policy policy (space, base);
88+ RolloutStorage storage (5 , 2 , {1 }, space, 5 , torch::kCPU );
89+ A2C a2c (policy, 0.5 , 1e-3 , 0.001 );
90+
91+ // The reward is the action
92+ auto pre_game_probs = policy->get_probs (
93+ torch::ones ({2 , 1 }),
94+ torch::zeros ({2 , 5 }),
95+ torch::ones ({2 , 1 }));
96+
97+ for (int i = 0 ; i < 10 ; ++i)
98+ {
99+ for (int j = 0 ; j < 5 ; ++j)
100+ {
101+ auto observation = torch::randint (0 , 2 , {2 , 1 });
102+
103+ std::vector<torch::Tensor> act_result;
104+ {
105+ torch::NoGradGuard no_grad;
106+ act_result = policy->act (observation,
107+ torch::Tensor (),
108+ torch::ones ({2 , 1 }));
109+ }
110+ auto actions = act_result[1 ];
111+
112+ auto rewards = actions;
113+ storage.insert (observation,
114+ torch::zeros ({2 , 5 }),
115+ actions,
116+ act_result[2 ],
117+ act_result[0 ],
118+ rewards,
119+ torch::ones ({2 , 1 }));
120+ }
121+
122+ torch::Tensor next_value;
123+ {
124+ torch::NoGradGuard no_grad;
125+ next_value = policy->get_values (
126+ storage.get_observations ()[-1 ],
127+ storage.get_hidden_states ()[-1 ],
128+ storage.get_masks ()[-1 ])
129+ .detach ();
130+ }
131+ storage.compute_returns (next_value, false , 0 ., 0.9 );
132+
133+ a2c.update (storage);
134+ storage.after_update ();
135+ }
136+
137+ auto post_game_probs = policy->get_probs (
138+ torch::ones ({2 , 1 }),
139+ torch::zeros ({2 , 5 }),
140+ torch::ones ({2 , 1 }));
141+
142+ INFO (" Pre-training probabilities: \n "
143+ << pre_game_probs << " \n " );
144+ INFO (" Post-training probabilities: \n "
145+ << post_game_probs << " \n " );
146+ CHECK (post_game_probs[0 ][0 ].item ().toDouble () <
147+ pre_game_probs[0 ][0 ].item ().toDouble ());
148+ CHECK (post_game_probs[0 ][1 ].item ().toDouble () >
149+ pre_game_probs[0 ][1 ].item ().toDouble ());
150+ }
151+
152+ SUBCASE (" update() learns basic game" )
153+ {
154+ auto base = std::make_shared<MlpBase>(1 , false , 5 );
155+ ActionSpace space{" Discrete" , {2 }};
156+ Policy policy (space, base);
157+ RolloutStorage storage (5 , 2 , {1 }, space, 5 , torch::kCPU );
158+ A2C a2c (policy, 0.5 , 1e-7 , 0.0001 );
159+
160+ // The game is: If the action matches the input, give a reward of 1, otherwise -1
161+ auto pre_game_probs = policy->get_probs (
162+ torch::ones ({2 , 1 }),
163+ torch::zeros ({2 , 5 }),
164+ torch::ones ({2 , 1 }));
165+
166+ auto observation = torch::randint (0 , 2 , {2 , 1 });
167+ storage.set_first_observation (observation);
168+
169+ for (int i = 0 ; i < 10 ; ++i)
170+ {
171+ for (int j = 0 ; j < 5 ; ++j)
172+ {
173+ std::vector<torch::Tensor> act_result;
174+ {
175+ torch::NoGradGuard no_grad;
176+ act_result = policy->act (observation,
177+ torch::Tensor (),
178+ torch::ones ({2 , 1 }));
179+ }
180+ auto actions = act_result[1 ];
181+
182+ auto rewards = ((actions == observation.to (torch::kLong )).to (torch::kFloat ) * 2 ) - 1 ;
183+ observation = torch::randint (0 , 2 , {2 , 1 });
184+ storage.insert (observation,
185+ torch::zeros ({2 , 5 }),
186+ actions,
187+ act_result[2 ],
188+ act_result[0 ],
189+ rewards,
190+ torch::ones ({2 , 1 }));
191+ }
192+
193+ torch::Tensor next_value;
194+ {
195+ torch::NoGradGuard no_grad;
196+ next_value = policy->get_values (
197+ storage.get_observations ()[-1 ],
198+ storage.get_hidden_states ()[-1 ],
199+ storage.get_masks ()[-1 ])
200+ .detach ();
201+ }
202+ storage.compute_returns (next_value, false , 0.1 , 0.9 );
203+
204+ a2c.update (storage);
205+ storage.after_update ();
206+ }
207+
208+ auto post_game_probs = policy->get_probs (
209+ torch::ones ({2 , 1 }),
210+ torch::zeros ({2 , 5 }),
211+ torch::ones ({2 , 1 }));
212+
213+ INFO (" Pre-training probabilities: \n "
214+ << pre_game_probs << " \n " );
215+ INFO (" Post-training probabilities: \n "
216+ << post_game_probs << " \n " );
217+ CHECK (post_game_probs[0 ][0 ].item ().toDouble () <
218+ pre_game_probs[0 ][0 ].item ().toDouble ());
219+ CHECK (post_game_probs[0 ][1 ].item ().toDouble () >
220+ pre_game_probs[0 ][1 ].item ().toDouble ());
221+ }
222+ }
223+ }
0 commit comments