Skip to content

Commit 57c7b42

Browse files
andrewztanericl
authored andcommitted
KL Divergence Metrics (ray-project#3300)
* added KL divergence metrics * fix
1 parent 1660c9d commit 57c7b42

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

python/ray/rllib/agents/impala/vtrace_policy_graph.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ray.rllib.models.catalog import ModelCatalog
1717
from ray.rllib.utils.error import UnsupportedSpaceException
1818
from ray.rllib.utils.explained_variance import explained_variance
19+
from ray.rllib.models.action_dist import Categorical
1920

2021

2122
class VTraceLoss(object):
@@ -184,6 +185,14 @@ def to_batches(tensor):
184185
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
185186
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
186187

188+
# KL divergence between worker and learner logits for debugging
189+
model_dist = Categorical(self.model.outputs)
190+
behaviour_dist = Categorical(behaviour_logits)
191+
self.KLs = model_dist.kl(behaviour_dist)
192+
self.mean_KL = tf.reduce_mean(self.KLs)
193+
self.max_KL = tf.reduce_max(self.KLs)
194+
self.median_KL = tf.contrib.distributions.percentile(self.KLs, 50.0)
195+
187196
# Initialize TFPolicyGraph
188197
loss_in = [
189198
("actions", actions),
@@ -225,6 +234,9 @@ def to_batches(tensor):
225234
"vf_explained_var": explained_variance(
226235
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
227236
tf.reshape(to_batches(values)[:-1], [-1])),
237+
"mean_KL": self.mean_KL,
238+
"max_KL": self.max_KL,
239+
"median_KL": self.median_KL,
228240
},
229241
}
230242

0 commit comments

Comments
 (0)