Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 21c4ef3

Browse files
author
Mesh TensorFlow Team
committed
Add more extensive top-2 logging.
PiperOrigin-RevId: 392699230
1 parent 1f381ce commit 21c4ef3

File tree

1 file changed

+49
-0
lines changed
  • mesh_tensorflow/transformer

1 file changed

+49
-0
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,55 @@ def _top_2_gating(
16401640
position_in_expert_2 = mtf.reduce_sum(
16411641
position_in_expert_2, reduced_dim=experts_dim)
16421642

1643+
if train:
1644+
# Gate entropy.
1645+
if importance is not None:
1646+
raw_gates *= mtf.to_float(mtf.greater(importance, 0.0))
1647+
entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
1648+
reduced_dim=experts_dim)
1649+
batch_entropy = mtf.reduce_mean(entropy)
1650+
mtf.scalar_summary(name + "/entropy", batch_entropy)
1651+
1652+
# Mean top-1 and top-2 normalized gate probabilities.
1653+
if importance is not None:
1654+
gate_2 *= mtf.to_float(mtf.greater(importance, 0.0))
1655+
mtf.scalar_summary("top1_gate_normalized", mtf.reduce_mean(gate_1))
1656+
mtf.scalar_summary("top2_gate_normalized", mtf.reduce_mean(gate_2))
1657+
top1_routed = mtf.reduce_sum(mask_1_flat)
1658+
top2_routed = mtf.reduce_sum(mask_2_flat)
1659+
importance = mtf.cast(importance, dtype=top1_routed.dtype)
1660+
1661+
# What fraction of the top-1 and top-2 tokens are being routed to any
1662+
# expert.
1663+
mtf.scalar_summary("top1_fraction_routed",
1664+
top1_routed / mtf.reduce_sum(importance))
1665+
mtf.scalar_summary("top2_fraction_routed",
1666+
top2_routed / mtf.reduce_sum(importance))
1667+
# One or zero if that token got routed anywhere.
1668+
total_routed = mtf.reduce_sum(mtf.minimum(
1669+
mask_1_flat + mask_2_flat, mtf.ones_like(top1_routed)))
1670+
mtf.scalar_summary("all_fraction_routed",
1671+
total_routed / mtf.reduce_sum(importance))
1672+
mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))
1673+
1674+
# Log what fraction of tokens are going to each expert.
1675+
def _log_per_expert_fraction(mask, name):
1676+
# mask: [batch, group, experts]
1677+
tokens_per_expert = mtf.reduce_sum(mask, output_shape=[experts_dim])
1678+
total_routed = mtf.reduce_sum(tokens_per_expert)
1679+
expert_fraction = mtf.to_float(tokens_per_expert / total_routed)
1680+
split_fractions = mtf.split(
1681+
expert_fraction,
1682+
split_dim=experts_dim,
1683+
num_or_size_splits=experts_dim.size)
1684+
for fraction in split_fractions:
1685+
mtf.scalar_summary(name + "_experts/" + fraction.name.replace(":", "/"),
1686+
mtf.reduce_mean(fraction))
1687+
1688+
_log_per_expert_fraction(mask_1, "top1")
1689+
_log_per_expert_fraction(mask_2, "top2")
1690+
_log_per_expert_fraction(mask_1 + mask_2, "all")
1691+
16431692
# [batch, group, experts, expert_capacity]
16441693
combine_tensor = (
16451694
gate_1 * mask_1_flat

0 commit comments

Comments
 (0)