Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit ef0cb24

Browse files
author
Mesh TensorFlow Team
committed
Add more options to Experts Attention. These options remove 1/3 of the all2all communication costs:
- Compute q only - Compute kv only PiperOrigin-RevId: 389986864
1 parent 3aaa765 commit ef0cb24

File tree

2 files changed

+68
-14
lines changed

2 files changed

+68
-14
lines changed

mesh_tensorflow/transformer/attention.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,8 @@ def __init__(self,
633633
keep_query_heads_dims=False,
634634
fold_scaling_into_initializer=True,
635635
context=None,
636-
experts_hparams=None):
636+
experts_hparams=None,
637+
expert_computation="qkv"):
637638
super(ExpertsAttentionParams, self).__init__(
638639
mesh=mesh,
639640
query_input_dim=query_input_dim,
@@ -653,19 +654,48 @@ def __init__(self,
653654
make_attention_vars=False)
654655

655656
self.context = context
657+
self.expert_computation = expert_computation
658+
659+
# Unless we want to compute both q and kv, we can use the normal MoE
660+
# settings.
661+
if expert_computation == "qkv":
662+
experts_attention_compute_qkv = True
663+
elif expert_computation in ["q", "kv"]:
664+
experts_attention_compute_qkv = False
665+
if expert_computation == "q":
666+
# Always assume shared_kv.
667+
self.wkv = mtf.get_variable(
668+
self.mesh,
669+
"kv",
670+
self.k_shape,
671+
initializer=tf.random_normal_initializer(
672+
stddev=self.memory_input_dim.size ** -0.5),
673+
dtype=self.variable_dtype)
674+
else: # Computing kv with experts.
675+
self.wq = mtf.get_variable(
676+
self.mesh,
677+
"q",
678+
self.q_shape,
679+
initializer=tf.random_normal_initializer(
680+
stddev=self.query_input_dim.size ** -0.5),
681+
dtype=self.variable_dtype)
682+
else:
683+
raise ValueError("Invalid expert computation mode: {}".format(
684+
expert_computation))
656685

657686
# ExpertsAttention, for simplicitly, asserts that combine_dims is True, and
658687
# for efficiency, that shared_kv is True.
659688
if not self.combine_dims:
660-
raise ValueError("self.combine_dims must be True for ExpertsAttention")
689+
raise ValueError("combine_dims must be True for ExpertsAttention.")
661690
if not self.shared_kv:
662-
raise ValueError("self.shared_kv must be True for ExpertsAttention")
691+
raise ValueError("shared_kv must be True for ExpertsAttention.")
663692
if mtf.layers.unit_scaling_convention():
664693
raise NotImplementedError
665694

666-
# TODO(barretzoph): Make this work for model parallelism by not outputing
667-
# a tensor with `heads` dim.
668-
moe_output_dims = self.q_shape[-1]
695+
# Now replace "heads" dim with the "d_model" name to avoid conflicts when
696+
# we want to partition both "experts_hidden" and "heads".
697+
moe_output_dims = mtf.Dimension("d_model", self.q_shape[-1].size)
698+
669699
tf.logging.info("ExpertsAttention moe_hidden_size: {}".format(
670700
experts_hparams.hidden_size))
671701
tf.logging.info("moe_output_dims: {}".format(moe_output_dims))
@@ -685,16 +715,39 @@ def __init__(self,
685715
ntlb_top_k=experts_hparams.ntlb_top_k,
686716
hidden_size=experts_hparams.hidden_size,
687717
output_dim=moe_output_dims,
688-
use_experts_attention=experts_hparams.use_experts_attention,
718+
use_experts_attention=experts_attention_compute_qkv,
689719
activation=experts_hparams.activation,
690720
z_loss=experts_hparams.z_loss)
691721

692722
def _compute_merge_qkv(self, antecedent):
693723
"""Computes qkv all in one call using MoE layer."""
694-
# NOTE: This assumes querty and memory antecedent are the same
695-
qk = self.moe_layer.call(self.context, antecedent)
696-
# Split qk here since they went through experts-layers
697-
q, k = qk
724+
def _replace_d_model_dim(t):
725+
"""Used to replace the `d_model` dim with `heads`."""
726+
new_last_dim = mtf.Dimension(self.q_shape[-1].name, t.shape[-1].size)
727+
return mtf.reshape(
728+
t, new_shape=mtf.Shape(t.shape[:-1] + [new_last_dim]))
729+
if self.expert_computation == "qkv":
730+
# NOTE: This assumes querty and memory antecedent are the same
731+
qk = self.moe_layer.call(self.context, antecedent)
732+
# Split qk here since they went through experts-layers
733+
q, k = qk
734+
q = _replace_d_model_dim(q)
735+
k = _replace_d_model_dim(k)
736+
elif self.expert_computation == "q":
737+
q = self.moe_layer.call(self.context, antecedent)
738+
q = _replace_d_model_dim(q)
739+
# Compute key/value normally
740+
k = mtf.layers.us_einsum(
741+
[antecedent, self.wkv], reduced_dims=[self.memory_input_dim])
742+
elif self.expert_computation == "kv":
743+
k = self.moe_layer.call(self.context, antecedent)
744+
k = _replace_d_model_dim(k)
745+
# Compute query normally
746+
q = mtf.layers.us_einsum(
747+
[antecedent, self.wq], reduced_dims=[self.query_input_dim])
748+
else:
749+
raise ValueError("Invalid expert computation mode: {}".format(
750+
self.expert_computation))
698751

699752
# Scale query
700753
q *= self.key_dim.size ** -0.5

mesh_tensorflow/transformer/transformer_layers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,12 @@ def __init__(self,
404404
switch_jitter=1e-2,
405405
ntlb_top_k=4,
406406
hidden_size=3072,
407-
use_experts_attention=True,
408407
activation="relu",
409408
z_loss=None,
409+
expert_computation="qkv",
410410
**kwargs):
411411
super(ExpertsSelfAttention, self).__init__(**kwargs)
412+
self.expert_computation = expert_computation
412413
self._hparams = mtf.transformer.moe.HParams(
413414
moe_gating=moe_gating,
414415
num_experts=num_experts,
@@ -424,7 +425,6 @@ def __init__(self,
424425
switch_jitter=switch_jitter,
425426
ntlb_top_k=ntlb_top_k,
426427
hidden_size=hidden_size,
427-
use_experts_attention=use_experts_attention,
428428
activation=activation,
429429
z_loss=z_loss)
430430

@@ -464,7 +464,8 @@ def make_params(self, context):
464464
keep_query_heads_dims=self.keep_query_heads_dims,
465465
fold_scaling_into_initializer=self.fold_scaling_into_initializer,
466466
context=context,
467-
experts_hparams=self._hparams)
467+
experts_hparams=self._hparams,
468+
expert_computation=self.expert_computation)
468469

469470

470471
@gin.configurable

0 commit comments

Comments
 (0)