@@ -1640,6 +1640,55 @@ def _top_2_gating(
16401640 position_in_expert_2 = mtf .reduce_sum (
16411641 position_in_expert_2 , reduced_dim = experts_dim )
16421642
1643+ if train :
1644+ # Gate entropy.
1645+ if importance is not None :
1646+ raw_gates *= mtf .to_float (mtf .greater (importance , 0.0 ))
1647+ entropy = mtf .reduce_sum (- raw_gates * mtf .log (raw_gates + 1e-9 ),
1648+ reduced_dim = experts_dim )
1649+ batch_entropy = mtf .reduce_mean (entropy )
1650+ mtf .scalar_summary (name + "/entropy" , batch_entropy )
1651+
1652+ # Mean top-1 and top-2 normalized gate probabilities.
1653+ if importance is not None :
1654+ gate_2 *= mtf .to_float (mtf .greater (importance , 0.0 ))
1655+ mtf .scalar_summary ("top1_gate_normalized" , mtf .reduce_mean (gate_1 ))
1656+ mtf .scalar_summary ("top2_gate_normalized" , mtf .reduce_mean (gate_2 ))
1657+ top1_routed = mtf .reduce_sum (mask_1_flat )
1658+ top2_routed = mtf .reduce_sum (mask_2_flat )
1659+ importance = mtf .cast (importance , dtype = top1_routed .dtype )
1660+
1661+ # What fraction of the top-1 and top-2 tokens are being routed to any
1662+ # expert.
1663+ mtf .scalar_summary ("top1_fraction_routed" ,
1664+ top1_routed / mtf .reduce_sum (importance ))
1665+ mtf .scalar_summary ("top2_fraction_routed" ,
1666+ top2_routed / mtf .reduce_sum (importance ))
1667+ # One or zero if that token got routed anywhere.
1668+ total_routed = mtf .reduce_sum (mtf .minimum (
1669+ mask_1_flat + mask_2_flat , mtf .ones_like (top1_routed )))
1670+ mtf .scalar_summary ("all_fraction_routed" ,
1671+ total_routed / mtf .reduce_sum (importance ))
1672+ mtf .scalar_summary ("aux_loss" , mtf .reduce_mean (loss ))
1673+
1674+ # Log what fraction of tokens are going to each expert.
1675+ def _log_per_expert_fraction (mask , name ):
1676+ # mask: [batch, group, experts]
1677+ tokens_per_expert = mtf .reduce_sum (mask , output_shape = [experts_dim ])
1678+ total_routed = mtf .reduce_sum (tokens_per_expert )
1679+ expert_fraction = mtf .to_float (tokens_per_expert / total_routed )
1680+ split_fractions = mtf .split (
1681+ expert_fraction ,
1682+ split_dim = experts_dim ,
1683+ num_or_size_splits = experts_dim .size )
1684+ for fraction in split_fractions :
1685+ mtf .scalar_summary (name + "_experts/" + fraction .name .replace (":" , "/" ),
1686+ mtf .reduce_mean (fraction ))
1687+
1688+ _log_per_expert_fraction (mask_1 , "top1" )
1689+ _log_per_expert_fraction (mask_2 , "top2" )
1690+ _log_per_expert_fraction (mask_1 + mask_2 , "all" )
1691+
16431692 # [batch, group, experts, expert_capacity]
16441693 combine_tensor = (
16451694 gate_1 * mask_1_flat
0 commit comments