|
16 | 16 | from ray.rllib.models.catalog import ModelCatalog |
17 | 17 | from ray.rllib.utils.error import UnsupportedSpaceException |
18 | 18 | from ray.rllib.utils.explained_variance import explained_variance |
| 19 | +from ray.rllib.models.action_dist import Categorical |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class VTraceLoss(object): |
@@ -184,6 +185,14 @@ def to_batches(tensor): |
184 | 185 | clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], |
185 | 186 | clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) |
186 | 187 |
|
| 188 | + # KL divergence between worker and learner logits for debugging |
| 189 | + model_dist = Categorical(self.model.outputs) |
| 190 | + behaviour_dist = Categorical(behaviour_logits) |
| 191 | + self.KLs = model_dist.kl(behaviour_dist) |
| 192 | + self.mean_KL = tf.reduce_mean(self.KLs) |
| 193 | + self.max_KL = tf.reduce_max(self.KLs) |
| 194 | + self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0) |
| 195 | + |
187 | 196 | # Initialize TFPolicyGraph |
188 | 197 | loss_in = [ |
189 | 198 | ("actions", actions), |
@@ -225,6 +234,9 @@ def to_batches(tensor): |
225 | 234 | "vf_explained_var": explained_variance( |
226 | 235 | tf.reshape(self.loss.vtrace_returns.vs, [-1]), |
227 | 236 | tf.reshape(to_batches(values)[:-1], [-1])), |
| 237 | + "mean_KL": self.mean_KL, |
| 238 | + "max_KL": self.max_KL, |
| 239 | + "median_KL": self.median_KL, |
228 | 240 | }, |
229 | 241 | } |
230 | 242 |
|
|
0 commit comments