Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 7c09bb9

Browse files
author
Mesh TensorFlow Team
committed
Change second d_model_split dim's size to be the output shape, instead of input shape. This allows it to work for layers where the input size is different than the output size.
PiperOrigin-RevId: 391048566
1 parent 3922a8f commit 7c09bb9

File tree

1 file changed

+2
-0
lines changed
  • mesh_tensorflow/transformer

1 file changed

+2
-0
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,8 @@ def _compute_output(hidden, layer_name):
526526
# Extra reshape reduces communication cost for model-parallel versions.
527527
# For model-parallel versions, this reshape causes an mtf.slice and for non-
528528
# model-parallel versions, this has no effect.
529+
d_model_split_dim = mtf.Dimension(
530+
"d_model_split", expert_output.shape[-1].size)
529531
expert_output = mtf.reshape(
530532
expert_output,
531533
mtf.Shape([

0 commit comments

Comments
 (0)