Skip to content

Commit 8837332

Browse files
committed
add file
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 448628f commit 8837332

File tree

1 file changed

+353
-0
lines changed

1 file changed

+353
-0
lines changed
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Callable
4+
5+
import rose
6+
import torch
7+
8+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
11+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
12+
TopKWeightAndReduceDelegate,
13+
)
14+
from vllm.model_executor.layers.fused_moe.utils import (
15+
_validate_scale_shape,
16+
moe_kernel_quantize_input,
17+
)
18+
from vllm.utils import cdiv, round_up
19+
20+
logger = init_logger(__name__)
21+
22+
23+
def rose_hidden_dim_scale(
24+
hidden_dim: int,
25+
quant_dtype: torch.dtype | str | None,
26+
per_act_token_quant: bool,
27+
block_shape: list[int] | None,
28+
) -> int:
29+
# For blocked per token: set to
30+
# ceil_div(hidden_dim, block_size) * sizeof(float32)
31+
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
32+
if quant_dtype is not None:
33+
assert isinstance(quant_dtype, torch.dtype)
34+
assert quant_dtype.itemsize == 1
35+
hidden_dim = hidden_dim
36+
37+
if per_act_token_quant:
38+
# per-token (M x 1)
39+
assert block_shape is None
40+
hidden_dim_scale = 1
41+
elif block_shape is not None:
42+
# per-group (M x K_tiles)
43+
block_size = block_shape[1]
44+
hidden_dim_scale = cdiv(hidden_dim, block_size)
45+
else:
46+
# per-tensor (1 x 1)
47+
hidden_dim_scale = 1
48+
else:
49+
hidden_dim_scale = 0
50+
51+
return hidden_dim_scale
52+
53+
54+
class RosePrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
55+
def __init__(
56+
self,
57+
a2a: rose.AllToAllKernel,
58+
max_num_tokens: int,
59+
num_local_experts: int,
60+
num_dispatchers: int,
61+
):
62+
super().__init__()
63+
assert max_num_tokens > 0
64+
assert num_local_experts > 0
65+
self.a2a = a2a
66+
self.max_num_tokens = max_num_tokens
67+
self.num_local_experts = num_local_experts
68+
self.num_dispatchers_ = num_dispatchers
69+
70+
@property
71+
def activation_format(self) -> mk.FusedMoEActivationFormat:
72+
return mk.FusedMoEActivationFormat.BatchedExperts
73+
74+
def max_num_tokens_per_rank(self) -> int | None:
75+
return self.max_num_tokens
76+
77+
def topk_indices_dtype(self) -> torch.dtype | None:
78+
return torch.uint32
79+
80+
def num_dispatchers(self) -> int:
81+
return self.num_dispatchers_
82+
83+
def output_is_reduced(self) -> bool:
84+
return True
85+
86+
def supports_async(self) -> bool:
87+
return True
88+
89+
def prepare_async(
90+
self,
91+
a1: torch.Tensor,
92+
topk_weights: torch.Tensor,
93+
topk_ids: torch.Tensor,
94+
num_experts: int,
95+
expert_map: torch.Tensor | None,
96+
apply_router_weight_on_input: bool,
97+
quant_config: FusedMoEQuantConfig,
98+
) -> tuple[Callable, mk.ReceiverType]:
99+
num_tokens = a1.size(0) # M
100+
hidden_dim = a1.size(-1) # K
101+
102+
assert topk_ids.size(0) == num_tokens
103+
# expert_map should be None because with expert map, -1 id is used for
104+
# non-local token; this causes error when casting ids to the
105+
# topk_indices_dtype() int32
106+
#
107+
if expert_map is not None:
108+
logger.warning_once(
109+
"The PPLX Rose backend does not support expert mapping. "
110+
"The provided `expert_map` will be ignored."
111+
)
112+
expert_map = None # noqa: F841
113+
114+
# Is this always going to be a1.device?
115+
device = a1.device
116+
117+
if apply_router_weight_on_input:
118+
topk = topk_ids.size(1)
119+
# TODO: this only works for topK=1, will need to update for topK>1
120+
assert topk == 1, (
121+
"apply_router_weight_on_input is only implemented for topk=1"
122+
)
123+
a1 = a1 * topk_weights.to(a1.dtype)
124+
125+
repeat_cols = 4
126+
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
127+
# TODO(bnell): always pass quant_config.a1_scale?
128+
a1q, a1q_scale = moe_kernel_quantize_input(
129+
a1,
130+
(None if quant_config.per_act_token_quant else quant_config.a1_scale),
131+
quant_dtype=quant_config.quant_dtype,
132+
per_act_token_quant=quant_config.per_act_token_quant,
133+
block_shape=quant_config.block_shape,
134+
)
135+
136+
_validate_scale_shape(
137+
a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
138+
)
139+
140+
orig_a_scale_block_shape: int | None = None
141+
142+
if a1q_scale is not None:
143+
scalar_scales = a1q_scale.numel() == 1
144+
145+
# Rose requires 2-d scales even for scalar scales
146+
if a1q_scale.dim() <= 1:
147+
assert scalar_scales
148+
a1q_scale = a1q_scale.view(1, 1)
149+
150+
orig_a_scale_block_shape = a1q_scale.shape[-1]
151+
152+
if not quant_config.is_block_quantized:
153+
# TODO (bnell): use group_broadcast instead?
154+
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
155+
156+
assert a1q_scale is None or a1q_scale.ndim == 2, (
157+
f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
158+
)
159+
160+
expert_num_tokens = torch.empty(
161+
self.num_local_experts,
162+
dtype=torch.int32,
163+
device=device,
164+
)
165+
166+
expert_x = torch.empty(
167+
(
168+
self.num_local_experts,
169+
self.max_num_tokens * self.num_dispatchers(),
170+
hidden_dim,
171+
),
172+
dtype=a1q.dtype,
173+
device=device,
174+
)
175+
176+
expert_x_scale: torch.Tensor | None = None
177+
if a1q.dtype.itemsize == 1:
178+
if quant_config.is_per_act_token:
179+
# (M x 1) -> (E x M x K)
180+
final_dim = expert_x.size(2)
181+
elif quant_config.is_per_tensor:
182+
# (1 x 1) -> (E x 1 x 1)
183+
final_dim = 1
184+
else:
185+
# (M x K_tiles) -> (E x M x K_tiles)
186+
assert quant_config.block_shape is not None
187+
num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
188+
final_dim = num_blocks
189+
190+
expert_x_scale_shape = (
191+
self.num_local_experts,
192+
expert_x.size(1),
193+
round_up(final_dim, 4), # round up for alignment
194+
)
195+
196+
expert_x_scale = torch.empty(
197+
expert_x_scale_shape,
198+
dtype=torch.float32,
199+
device=expert_x.device,
200+
)
201+
202+
# This argument is optional, defaults to indices.size(0)
203+
# There's not much point setting this unless it is != indices.size(0)
204+
bound_m: torch.Tensor | None = None
205+
206+
self.a2a.dispatch(
207+
out_expert_num_tokens=expert_num_tokens,
208+
out_expert_x=expert_x,
209+
out_expert_x_scale=expert_x_scale,
210+
dp_x=a1q,
211+
dp_x_scale=a1q_scale,
212+
indices=topk_ids,
213+
bound_m=bound_m,
214+
do_send=True,
215+
do_recv=False,
216+
)
217+
218+
hook = lambda: self.a2a.dispatch(
219+
out_expert_num_tokens=expert_num_tokens,
220+
out_expert_x=expert_x,
221+
out_expert_x_scale=expert_x_scale,
222+
dp_x=a1q,
223+
dp_x_scale=a1q_scale,
224+
indices=topk_ids,
225+
bound_m=bound_m,
226+
do_send=False,
227+
do_recv=True,
228+
)
229+
230+
return (
231+
hook,
232+
lambda: self._receiver(
233+
expert_num_tokens,
234+
expert_x,
235+
expert_x_scale,
236+
orig_a_scale_block_shape,
237+
),
238+
)
239+
240+
def _receiver(
241+
self,
242+
expert_num_tokens: torch.Tensor,
243+
expert_x: torch.Tensor,
244+
expert_x_scale: torch.Tensor | None,
245+
orig_a_scale_block_shape: int | None,
246+
) -> mk.PrepareResultType:
247+
if expert_x_scale is not None:
248+
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
249+
assert expert_x_scale.ndim == 3
250+
251+
expert_tokens_meta = mk.ExpertTokensMetadata(
252+
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
253+
)
254+
255+
return expert_x, expert_x_scale, expert_tokens_meta, None, None
256+
257+
def prepare(
258+
self,
259+
a1: torch.Tensor,
260+
topk_weights: torch.Tensor,
261+
topk_ids: torch.Tensor,
262+
num_experts: int,
263+
expert_map: torch.Tensor | None,
264+
apply_router_weight_on_input: bool,
265+
quant_config: FusedMoEQuantConfig,
266+
) -> mk.PrepareResultType:
267+
hook, receiver = self.prepare_async(
268+
a1,
269+
topk_weights,
270+
topk_ids,
271+
num_experts,
272+
expert_map,
273+
apply_router_weight_on_input,
274+
quant_config,
275+
)
276+
hook()
277+
return receiver()
278+
279+
def finalize_async(
280+
self,
281+
output: torch.Tensor,
282+
fused_expert_output: torch.Tensor,
283+
topk_weights: torch.Tensor,
284+
topk_ids: torch.Tensor,
285+
apply_router_weight_on_input: bool,
286+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
287+
) -> Callable:
288+
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
289+
"Weight application and reduction happens in the combine kernel."
290+
)
291+
292+
# This argument is optional
293+
# There's not much point setting this unless it is != topk_ids.size(0)
294+
bound_m: torch.Tensor | None = None
295+
296+
# TODO (bnell): fails in test_rose_moe.py, figure out what's going on
297+
# num_tokens = output.size(0) # M
298+
# assert topk_ids.size(0) == num_tokens, (
299+
# f"{topk_ids.size(0)} == {num_tokens}")
300+
assert topk_ids.size() == topk_weights.size(), (
301+
f"{topk_ids.size()} == {topk_weights.size()}"
302+
)
303+
assert output.size(0) <= self.max_num_tokens, (
304+
f"{output.size(0)} <= {self.max_num_tokens}"
305+
)
306+
assert output.size(1) == fused_expert_output.size(-1)
307+
308+
# Set weights to 1 if we did them in dispatch. This is hacky.
309+
if apply_router_weight_on_input:
310+
topk_weights = torch.ones_like(topk_weights)
311+
312+
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
313+
314+
self.a2a.combine(
315+
out_tokens=output,
316+
indices=topk_ids_u32,
317+
weights=topk_weights,
318+
expert_y=fused_expert_output,
319+
bound_m=bound_m,
320+
do_send=True,
321+
do_recv=False,
322+
# Note: new kernels allow accumulate.
323+
)
324+
325+
return lambda: self.a2a.combine(
326+
out_tokens=output,
327+
indices=topk_ids_u32,
328+
weights=topk_weights,
329+
expert_y=fused_expert_output,
330+
bound_m=bound_m,
331+
do_send=False,
332+
do_recv=True,
333+
# Note: new kernels allow accumulate.
334+
)
335+
336+
def finalize(
337+
self,
338+
output: torch.Tensor,
339+
fused_expert_output: torch.Tensor,
340+
topk_weights: torch.Tensor,
341+
topk_ids: torch.Tensor,
342+
apply_router_weight_on_input: bool,
343+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
344+
) -> None:
345+
receiver = self.finalize_async(
346+
output,
347+
fused_expert_output,
348+
topk_weights,
349+
topk_ids,
350+
apply_router_weight_on_input,
351+
weight_and_reduce_impl,
352+
)
353+
receiver()

0 commit comments

Comments
 (0)