@@ -633,7 +633,8 @@ def __init__(self,
633633 keep_query_heads_dims = False ,
634634 fold_scaling_into_initializer = True ,
635635 context = None ,
636- experts_hparams = None ):
636+ experts_hparams = None ,
637+ expert_computation = "qkv" ):
637638 super (ExpertsAttentionParams , self ).__init__ (
638639 mesh = mesh ,
639640 query_input_dim = query_input_dim ,
@@ -653,19 +654,48 @@ def __init__(self,
653654 make_attention_vars = False )
654655
655656 self .context = context
657+ self .expert_computation = expert_computation
658+
659+ # Unless we want to compute both q and kv, we can use the normal MoE
660+ # settings.
661+ if expert_computation == "qkv" :
662+ experts_attention_compute_qkv = True
663+ elif expert_computation in ["q" , "kv" ]:
664+ experts_attention_compute_qkv = False
665+ if expert_computation == "q" :
666+ # Always assume shared_kv.
667+ self .wkv = mtf .get_variable (
668+ self .mesh ,
669+ "kv" ,
670+ self .k_shape ,
671+ initializer = tf .random_normal_initializer (
672+ stddev = self .memory_input_dim .size ** - 0.5 ),
673+ dtype = self .variable_dtype )
674+ else : # Computing kv with experts.
675+ self .wq = mtf .get_variable (
676+ self .mesh ,
677+ "q" ,
678+ self .q_shape ,
679+ initializer = tf .random_normal_initializer (
680+ stddev = self .query_input_dim .size ** - 0.5 ),
681+ dtype = self .variable_dtype )
682+ else :
683+ raise ValueError ("Invalid expert computation mode: {}" .format (
684+ expert_computation ))
656685
657686 # ExpertsAttention, for simplicitly, asserts that combine_dims is True, and
658687 # for efficiency, that shared_kv is True.
659688 if not self .combine_dims :
660- raise ValueError ("self. combine_dims must be True for ExpertsAttention" )
689+ raise ValueError ("combine_dims must be True for ExpertsAttention. " )
661690 if not self .shared_kv :
662- raise ValueError ("self. shared_kv must be True for ExpertsAttention" )
691+ raise ValueError ("shared_kv must be True for ExpertsAttention. " )
663692 if mtf .layers .unit_scaling_convention ():
664693 raise NotImplementedError
665694
666- # TODO(barretzoph): Make this work for model parallelism by not outputing
667- # a tensor with `heads` dim.
668- moe_output_dims = self .q_shape [- 1 ]
695+ # Now replace "heads" dim with the "d_model" name to avoid conflicts when
696+ # we want to partition both "experts_hidden" and "heads".
697+ moe_output_dims = mtf .Dimension ("d_model" , self .q_shape [- 1 ].size )
698+
669699 tf .logging .info ("ExpertsAttention moe_hidden_size: {}" .format (
670700 experts_hparams .hidden_size ))
671701 tf .logging .info ("moe_output_dims: {}" .format (moe_output_dims ))
@@ -685,16 +715,39 @@ def __init__(self,
685715 ntlb_top_k = experts_hparams .ntlb_top_k ,
686716 hidden_size = experts_hparams .hidden_size ,
687717 output_dim = moe_output_dims ,
688- use_experts_attention = experts_hparams . use_experts_attention ,
718+ use_experts_attention = experts_attention_compute_qkv ,
689719 activation = experts_hparams .activation ,
690720 z_loss = experts_hparams .z_loss )
691721
692722 def _compute_merge_qkv (self , antecedent ):
693723 """Computes qkv all in one call using MoE layer."""
694- # NOTE: This assumes querty and memory antecedent are the same
695- qk = self .moe_layer .call (self .context , antecedent )
696- # Split qk here since they went through experts-layers
697- q , k = qk
724+ def _replace_d_model_dim (t ):
725+ """Used to replace the `d_model` dim with `heads`."""
726+ new_last_dim = mtf .Dimension (self .q_shape [- 1 ].name , t .shape [- 1 ].size )
727+ return mtf .reshape (
728+ t , new_shape = mtf .Shape (t .shape [:- 1 ] + [new_last_dim ]))
729+ if self .expert_computation == "qkv" :
730+ # NOTE: This assumes querty and memory antecedent are the same
731+ qk = self .moe_layer .call (self .context , antecedent )
732+ # Split qk here since they went through experts-layers
733+ q , k = qk
734+ q = _replace_d_model_dim (q )
735+ k = _replace_d_model_dim (k )
736+ elif self .expert_computation == "q" :
737+ q = self .moe_layer .call (self .context , antecedent )
738+ q = _replace_d_model_dim (q )
739+ # Compute key/value normally
740+ k = mtf .layers .us_einsum (
741+ [antecedent , self .wkv ], reduced_dims = [self .memory_input_dim ])
742+ elif self .expert_computation == "kv" :
743+ k = self .moe_layer .call (self .context , antecedent )
744+ k = _replace_d_model_dim (k )
745+ # Compute query normally
746+ q = mtf .layers .us_einsum (
747+ [antecedent , self .wq ], reduced_dims = [self .query_input_dim ])
748+ else :
749+ raise ValueError ("Invalid expert computation mode: {}" .format (
750+ self .expert_computation ))
698751
699752 # Scale query
700753 q *= self .key_dim .size ** - 0.5
0 commit comments