Skip to content
9 changes: 5 additions & 4 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6128,10 +6128,11 @@ void MoePermuteInferMeta(const MetaTensor& X,
2,
common::errors::InvalidArgument("Input X's dims should be 2, but got %u.",
X.dims().size()));
PADDLE_ENFORCE_EQ(
X.dtype() == phi::DataType::BFLOAT16,
true,
common::errors::InvalidArgument("Input X's dtype should be BFLOAT16"));
PADDLE_ENFORCE_EQ(X.dtype() == phi::DataType::BFLOAT16 ||
X.dtype() == phi::DataType::FLOAT8_E4M3FN,
true,
common::errors::InvalidArgument(
"Input X's dtype should be BFLOAT16 or FLOAT8_E4M3FN"));
PADDLE_ENFORCE_EQ(expert_routemap_topk.dtype() == phi::DataType::INT32,
true,
common::errors::InvalidArgument(
Expand Down
9 changes: 6 additions & 3 deletions paddle/phi/kernels/gpu/moe_permute_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ void MoePermuteKernel(const Context &dev_ctx,
#undef MAX_NUM_EXPERTS
} // namespace phi

PD_REGISTER_KERNEL(
moe_permute, GPU, ALL_LAYOUT, phi::MoePermuteKernel, phi::dtype::bfloat16) {
}
PD_REGISTER_KERNEL(moe_permute,
GPU,
ALL_LAYOUT,
phi::MoePermuteKernel,
phi::dtype::float8_e4m3fn,
phi::dtype::bfloat16) {}
99 changes: 78 additions & 21 deletions python/paddle/nn/functional/moe_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,90 @@ def moe_permute(
name: str | None = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
r"""
Permute tokens for Mixture of Experts (MoE) computation.
Permute tokens for Mixture of Experts (MoE) computation in distributed training scenarios.

Note:
This function reorganizes input tokens based on expert assignments to prepare for expert computation.
It handles both bfloat16 and float8_e4m3fn data types with proper scaling for float8 inputs.

1. This function is typically used in pair of moe_unpermute to provide complete MoE functionality.
2. For float8 inputs, proper scaling must be provided via the scale parameter.
3. The padding_alignment parameter affects memory efficiency but not correctness.
4. Any output tokens can find an exact-match in the original input tokens.
5. This permute function has overcomed the aadiff issue, is deterministic.

Args:
hidden_states (Tensor): Input tensor storing tokens in row-major layout.
Shape: [seq_len, token_len], dtype: bfloat16 or float8_e4m3fn.
scale (Tensor|None): Input tensor required when hidden_states is fp8 type.
Shape: [seq_len, (token_len + 127) // 128], dtype: float32.
expert_routemap_topk (Tensor): Tensor recording which expert each token is dispatched to.
Shape: [seq_len, topk], dtype: int32, value range: [-1, num_experts).
expert_prob_topk (Tensor): Tensor storing expert probabilities.
Shape: [seq_len, topk], dtype: float32.
num_experts (int): Number of experts.
tokens_per_expert (list[int]): List indicating how many tokens each expert receives.
padding_alignment (int): Alignment requirement for expert buffers (must be multiple of this value).
name (str|None, optional): Name for the operation. Defaults to None.
hidden_states (Tensor): The input tensor containing tokens to be permuted, stored in row-major layout.
Supported data types: bfloat16 or float8_e4m3fn.
Shape: [sequence_length, token_dimension]
scale (Tensor|None): Scaling factors required when hidden_states is of float8 type.
For float8 inputs, this tensor provides the scaling factors for dequantization.
Shape: [sequence_length, ceil(token_dimension / 128)]
Data type: float32
expert_routemap_topk (Tensor): Tensor indicating expert assignments for each token (top-k experts).
Each value represents the expert index the token is assigned to (-1 indicates not assigned).
Shape: [sequence_length, top_k_experts]
Data type: int32
Value range: [-1, num_experts)
expert_prob_topk (Tensor): Tensor containing routing probabilities for top-k experts.
Shape: [sequence_length, top_k_experts]
Data type: float32
num_experts (int): Total number of experts in the MoE layer, limited between 1 and 64.
tokens_per_expert (list[int]): List where each element indicates the number of tokens
assigned to the corresponding expert.
padding_alignment (int): Tokens alignment requirement for expert buffers (in bytes).
Must be a power of 2. Typical values are 16, 32 or 64 for optimal memory access.
name (str|None, optional): Name prefix for the operation (optional).
Default: None

Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
- hidden_states_unzipped: Permuted and broadcasted tensor.
Shape: [seqlen_broadcasted, token_len], dtype same as input.
- zipped_expertwise_rowmap: Mapping tensor for unpermute operation.
Shape: [seqlen, num_experts], dtype: int32.
- token_prob_unzipped: Flattened expert probabilities aligned with hidden_states_unzipped.
Shape: [seqlen_broadcasted, 1], dtype: float32.
- scale_unzipped: Scaled tensor (only valid when hidden_states is fp8).
Shape: [seqlen_broadcasted, (token_len + 127) // 128], dtype: float32.
- hidden_states_unzipped (Tensor): The permuted and broadcasted input tensor.
Shape: [total_tokens_after_broadcast, token_dimension]
Data type: same as input hidden_states
- zipped_expertwise_rowmap (Tensor): Mapping tensor used to restore original order (unpermute).
Shape: [sequence_length, num_experts]
Data type: int32
- token_prob_unzipped (Tensor): Flattened expert probabilities aligned with permuted tokens.
Shape: [total_tokens_after_broadcast, 1]
Data type: float32
- scale_unzipped (Tensor): Broadcasted scale tensor (only valid for float8 inputs).
Shape: [total_tokens_after_broadcast, ceil(token_dimension / 128)]
Data type: float32

Examples:
.. code-block:: python

>>> # doctest: +REQUIRES(env:GPU)
>>> # doctest: +SKIP('This is only support in cuda 12.0+')
>>> import paddle
>>> import numpy as np
>>> import paddle.nn.functional as F
>>> hidden_states = paddle.randn([3, 128], dtype='bfloat16')
>>> expert_routemap_topk = paddle.to_tensor([[-1, 0, -1, -1, 2, -1, -1, -1],
... [1, -1, -1, -1, -1, -1, -1, -1],
... [-1, -1, -1, -1, -1, -1, 1, -1]],
... dtype='int32')
>>> expert_prob_topk= paddle.to_tensor([[0.0, 0.6, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0],
... [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]],
... dtype='float32')
>>> num_experts = 3
>>> tokens_per_expert = [1, 2, 1]
>>> padding_alignment = 2
>>> hidden_states_unzipped, zipped_expertwise_rowmap, token_prob_unzipped, scale_unzipped = F.moe_permute(
... hidden_states,
... None,
... expert_routemap_topk,
... expert_prob_topk,
... num_experts,
... tokens_per_expert,
... padding_alignment,
... )
>>> # weighted by probs.
>>> hidden_states_unzipped = (hidden_states_unzipped.astype("float32") * token_prob_unzipped.astype("float32").unsqueeze(-1)).astype("bfloat16")
>>> zipped_tokens, zipped_probs = F.moe_unpermute(hidden_states_unzipped, zipped_expertwise_rowmap, expert_routemap_topk, token_prob_unzipped,3,3)
>>> np.testing.assert_allclose(zipped_tokens.numpy(), hidden_states.numpy(), rtol=1e-05, atol=1e-06)
"""
if in_dynamic_or_pir_mode():
(
Expand Down
74 changes: 51 additions & 23 deletions python/paddle/nn/functional/moe_unpermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,61 @@ def moe_unpermute(
name: str | None = None,
) -> tuple[Tensor, Tensor]:
r"""
Permute tokens for Mixture of Experts (MoE) computation.

Args:
hidden_states (Tensor): Input tensor storing tokens in row-major layout.
Shape: [seq_len, token_len], dtype: bfloat16 or float8_e4m3fn.
scale (Tensor|None): Input tensor required when hidden_states is fp8 type.
Shape: [seq_len, (token_len + 127) // 128], dtype: float32.
expert_routemap_topk (Tensor): Tensor recording which expert each token is dispatched to.
Shape: [seq_len, topk], dtype: int32, value range: [-1, num_experts).
expert_prob_topk (Tensor): Tensor storing expert probabilities.
Shape: [seq_len, topk], dtype: float32.
num_experts (int): Number of experts.
tokens_per_expert (list[int]): List indicating how many tokens each expert receives.
padding_alignment (int): Alignment requirement for expert buffers (must be multiple of this value).
name (str|None, optional): Name for the operation. Defaults to None.
hidden_states_unzipped (Tensor): The input Tensor containing broadcasted and permuted hidden states.
Shape: (seqlen_broadcasted, token_len). Dtype: bfloat16.
zipped_expertwise_rowmap (Tensor): The input Tensor recording the mapping relationship for unpermute operation.
Shape: (seqlen, num_experts). Dtype: int32.
expert_routemap_topk (Tensor): The input Tensor indicating which expert each token is assigned to.
Shape: (seqlen, 8). Value range: [-1, num_experts]. Dtype: int32.
token_prob_unzipped (Tensor): The input Tensor containing flattened expert probabilities corresponding to hidden_states_unzipped.
Shape: (seqlen_broadcasted, 1). Dtype: float32.
total_zipped_tokens_num (int): The total number of tokens before permutation for output buffer allocation. Dtype: int32.
num_experts (int): The number of experts. Dtype: int32.
use_mix_precision (bool, optional): Whether to use mixed precision during accumulation.
This option significantly improves precision when number of experts > 4. Default: True.
name (str|None, optional): Name for the operation. Default: None.

Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
- hidden_states_unzipped: Permuted and broadcasted tensor.
Shape: [seqlen_broadcasted, token_len], dtype same as input.
- zipped_expertwise_rowmap: Mapping tensor for unpermute operation.
Shape: [seqlen, num_experts], dtype: int32.
- token_prob_unzipped: Flattened expert probabilities aligned with hidden_states_unzipped.
Shape: [seqlen_broadcasted, 1], dtype: float32.
- scale_unzipped: Scaled tensor (only valid when hidden_states is fp8).
Shape: [seqlen_broadcasted, (token_len + 127) // 128], dtype: float32.
tuple[Tensor, Tensor]: A tuple containing:
- hidden_states (Tensor): The output Tensor with unpermuted tokens.
Shape: (seqlen, token_len). Dtype: bfloat16.
- expert_prob_topk (Tensor): The output Tensor with unpermuted probabilities.
Shape: (seqlen, topk). Dtype: float32.

Examples:
.. code-block:: python

>>> # doctest: +REQUIRES(env:GPU)
>>> # doctest: +SKIP('This is only support in cuda 12.0+')
>>> import paddle
>>> import numpy as np
>>> import paddle.nn.functional as F
>>> hidden_states = paddle.randn([3, 128], dtype='bfloat16')
>>> expert_routemap_topk = paddle.to_tensor([[-1, 0, -1, -1, 2, -1, -1, -1],
... [1, -1, -1, -1, -1, -1, -1, -1],
... [-1, -1, -1, -1, -1, -1, 1, -1]],
... dtype='int32')
>>> expert_prob_topk= paddle.to_tensor([[0.0, 0.6, 0.0, 0.0, 0.4, 0.0, 0.0, 0.0],
... [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
... [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]],
... dtype='float32')
>>> num_experts = 3
>>> tokens_per_expert = [1, 2, 1]
>>> padding_alignment = 2
>>> hidden_states_unzipped, zipped_expertwise_rowmap, token_prob_unzipped, scale_unzipped = F.moe_permute(
... hidden_states,
... None,
... expert_routemap_topk,
... expert_prob_topk,
... num_experts,
... tokens_per_expert,
... padding_alignment,
... )
>>> # weighted by probs.
>>> hidden_states_unzipped = (hidden_states_unzipped.astype("float32") * token_prob_unzipped.astype("float32").unsqueeze(-1)).astype("bfloat16")
>>> zipped_tokens, zipped_probs = F.moe_unpermute(hidden_states_unzipped, zipped_expertwise_rowmap, expert_routemap_topk, token_prob_unzipped,3,3)
>>> np.testing.assert_allclose(zipped_tokens.numpy(), hidden_states.numpy(), rtol=1e-05, atol=1e-06)
"""
if in_dynamic_or_pir_mode():
zipped_tokens, zipped_probs_topk = _C_ops.moe_unpermute(
Expand Down
12 changes: 7 additions & 5 deletions test/legacy_test/test_moe_permute_unpermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def fabricate_dispatch_result(
broadcast_ratio=0.5,
):
"""Helper function to generate test data."""
hidden_states = paddle.randn([seqlen, token_length], dtype=data_type)
hidden_states = paddle.randn([seqlen, token_length]).astype(data_type)

scale = paddle.empty([0])
if data_type == "float8_e4m3fn":
Expand Down Expand Up @@ -93,7 +93,7 @@ class TestFusedMoePermuteUnpermute(unittest.TestCase):

SEQLEN = 16384
TOKEN_LEN = 7168
DTYPES = ["bfloat16"]
DTYPES = ["float8_e4m3fn", "bfloat16"]
EXPERT_NUMS = [4, 8, 16, 32, 64]
TOPKS = [4, 8, 16]

Expand Down Expand Up @@ -141,7 +141,8 @@ def test_permute_unpermute_consistency(self):
)

unpermute_input = (
unzipped_tokens * unzipped_probs.unsqueeze(-1)
unzipped_tokens.astype("float32")
* unzipped_probs.unsqueeze(-1)
).astype("bfloat16")

unzipped_tokens_recovered, expert_prob_topk_recovered = (
Expand All @@ -157,12 +158,13 @@ def test_permute_unpermute_consistency(self):

# Check tensor recovery
max_abs_err, max_rel_err = tensor_max_abs_rel_err(
hidden_states, unzipped_tokens_recovered
hidden_states.astype("float32"),
unzipped_tokens_recovered.astype("float32"),
)

self.assertLess(
max_rel_err,
1e-2,
1e-1 if dt == "float8_e4m3fn" else 1e-2,
f"Tokens relative error too large, permute-unpermute tokens max relative error: {max_rel_err}",
)

Expand Down
Loading