Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 28 additions & 17 deletions python/mlc_llm/model/phi3/phi3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class Phi3Config(ConfigBase): # pylint: disable=too-many-instance-attributes
intermediate_size: int
rms_norm_eps: float
num_key_value_heads: int
max_position_embeddings: int
position_embedding_base: int = 0
rope_scaling: Optional[Dict[str, Any]] = None
original_max_position_embeddings: int = 0
context_window_size: int = 0
prefill_chunk_size: int = 0
head_dim: int = 0
Expand All @@ -46,23 +49,21 @@ def __post_init__(self):
self.position_embedding_base = self.kwargs.pop("rope_theta")
else:
self.position_embedding_base = 10000
if self.context_window_size == 0:
for name in ["max_position_embeddings", "max_sequence_length"]:
if name in self.kwargs:
self.context_window_size = self.kwargs.pop(name)
logger.info(
"%s not found in config.json. Falling back to %s (%d)",
bold("context_window_size"),
bold(name),
self.context_window_size,
)
break
if self.rope_scaling is not None:
if "type" not in self.rope_scaling:
self.rope_scaling = None
else:
raise ValueError(
"Unable to determine the maximum sequence length, because none of "
"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is "
"provided in `config.json`."
)
assert (
self.rope_scaling["type"] == "longrope"
), f'Unsupported RoPE scaling type {self.rope_scaling["rope_type"]} for Phi3'
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
(
self.rope_scaling["max_position_embeddings"],
self.rope_scaling["original_max_position_embeddings"],
) = (self.max_position_embeddings, self.original_max_position_embeddings)

if self.context_window_size == 0:
self.context_window_size = self.max_position_embeddings

if self.prefill_chunk_size == 0:
logger.info(
Expand Down Expand Up @@ -123,6 +124,9 @@ def __init__(self, config: Phi3Config):
"must be divisible by tensor_parallel_shards"
)
self.head_dim = config.head_dim
self.rope_ext_factors = (
config.rope_scaling["long_factor"] if config.rope_scaling is not None else None
)

self.qkv_proj = nn.Linear(
in_features=config.hidden_size,
Expand All @@ -139,7 +143,12 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))
# Attention
output = op.reshape(
paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads),
paged_kv_cache.attention_with_fused_qkv(
layer_id,
qkv,
self.num_q_heads,
rope_ext_factors=self.rope_ext_factors,
),
(b, s, h_q * d),
)
return self.out_proj(output)
Expand Down Expand Up @@ -215,6 +224,7 @@ def __init__(self, config: Phi3Config) -> None:
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.vocab_size = config.vocab_size
self.rope_scaling = config.rope_scaling
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.dtype = "float32"
Expand Down Expand Up @@ -306,6 +316,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,
head_dim=self.head_dim,
rope_mode=RopeMode.NORMAL,
rope_scaling=self.rope_scaling,
rope_scale=1,
rope_theta=self.rope_theta,
dtype=self.dtype,
Expand Down
17 changes: 11 additions & 6 deletions python/mlc_llm/nn/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from tvm import relax as rx
from tvm import tir
from tvm.relax.frontend.nn import Object, Tensor
Expand Down Expand Up @@ -103,6 +104,7 @@ def attention_with_fused_qkv( # pylint: disable=invalid-name
qkv: Tensor,
num_qo_heads: int,
attn_score_scaling_factor: float = 1.0,
rope_ext_factors: Optional[List] = None,
) -> Tensor:
"""Compute attention with the given fused q/k/v data and in-cache k/v data
on the specified layer. Rotary position embeddings are applied to k/v
Expand All @@ -119,16 +121,19 @@ def attention_with_fused_qkv( # pylint: disable=invalid-name
# pylint: disable=protected-access
b, s, _, d = qkv._expr.struct_info.shape
qkv = qkv.reshape(b * s, qkv.shape[2], d)
args = [
self._expr,
rx.PrimValue(layer_id), # type: ignore[arg-type]
rx.PrimValue(attn_score_scaling_factor),
qkv._expr,
]
if rope_ext_factors is not None:
args.append(rx.const(np.array(rope_ext_factors, "float32")))
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.call_dps_packed(
"vm.builtin.attention_kv_cache_attention_with_fused_qkv",
[
self._expr,
rx.PrimValue(layer_id), # type: ignore[arg-type]
rx.PrimValue(attn_score_scaling_factor),
qkv._expr,
],
args,
out_sinfo=rx.TensorStructInfo((b * s, num_qo_heads, d), qkv.dtype),
)
)
Expand Down
105 changes: 101 additions & 4 deletions python/mlc_llm/op/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype:
The common expression map.
"""
freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32"))
cos_freq = tir.cos(freq).astype(dtype)
sin_freq = tir.sin(freq).astype(dtype)
return cos_freq, sin_freq, {}
freq_var = tir.Var("freq", "float32")
cos_freq = tir.cos(freq_var).astype(dtype)
sin_freq = tir.sin(freq_var).astype(dtype)
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
Expand Down Expand Up @@ -76,6 +77,33 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals
return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}


def rope_freq_longrope( # pylint: disable=too-many-arguments
s: tir.Var,
d: tir.Var,
d_range: int,
theta: float,
dtype: str,
max_position_embeddings: int,
original_max_position_embeddings: int,
ext_factors: Optional[T.Buffer] = None,
):
"""Compute the inverse frequency of RoPE for longrope scaling."""
scale = max_position_embeddings / original_max_position_embeddings
scaling_factor = (
math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))
if scale > 1.0
else 1.0
)
divisor = tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32"))
if ext_factors is not None:
divisor = ext_factors[d % (d_range // 2)] * divisor
freq = s / divisor
freq_var = tir.Var("freq", "float32")
cos_freq = (tir.cos(freq_var) * scaling_factor).astype(dtype)
sin_freq = (tir.sin(freq_var) * scaling_factor).astype(dtype)
return cos_freq, sin_freq, {freq_var: freq}


def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
"""Return the RoPE inverse frequency computation function based
on the given RoPE scaling.
Expand All @@ -90,6 +118,12 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
high_freq_factor=rope_scaling["high_freq_factor"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
if rope_scaling["rope_type"] == "longrope":
return partial(
rope_freq_longrope,
max_position_embeddings=rope_scaling["max_position_embeddings"],
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
)
raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}')


Expand Down Expand Up @@ -265,16 +299,21 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments
if rotary_dim is None:
rotary_dim = head_dim
scale = tir.const(scale, "float32")
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"

def _rope( # pylint: disable=too-many-arguments
x: T.Buffer,
s: tir.Var,
h: tir.Var,
d: tir.Var,
pos: tir.Var,
ext_factors: Optional[T.Buffer] = None,
):
kwargs = {}
if ext_factors:
kwargs["ext_factors"] = ext_factors
cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)(
pos * scale, d, rotary_dim, theta, "float32"
pos * scale, d, rotary_dim, theta, "float32", **kwargs
)
cos = cos_freq * x[s, h, d].astype("float32")
sin = sin_freq * tir.if_then_else(
Expand Down Expand Up @@ -329,4 +368,62 @@ def fused_rope( # pylint: disable=too-many-locals
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

@T.prim_func
def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
ext_factors: T.Buffer((head_dim // 2,), "float32"), # type: ignore
):
T.func_attr(
{
"op_pattern": 8, # 2 means injective, 8 means opaque
"tir.noalias": T.bool(True),
}
)
seq_len = T.int64()
position_map_elem_offset = T.int64()
qkv = T.match_buffer(var_qkv, (seq_len, fused_heads, head_dim), dtype)
q = T.match_buffer(var_q, (seq_len, num_q_heads, head_dim), dtype)
k = T.match_buffer(var_k, (seq_len, num_kv_heads, head_dim), dtype)
v = T.match_buffer(var_v, (seq_len, num_kv_heads, head_dim), dtype)
position_map = T.match_buffer(
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
)
for iters in T.grid(seq_len, fused_heads, head_dim):
with T.block("llama_fused_rope"):
s, h, d = T.axis.remap("SSS", iters)
if h < num_q_heads:
q[s, h, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
elif h < num_q_heads + num_kv_heads:
k[s, h - num_q_heads, d] = T.if_then_else(
d < rotary_dim,
_rope(
qkv,
s,
h,
d,
position_map[s],
ext_factors if is_longrope_scaling else None,
),
qkv[s, h, d],
)
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

if is_longrope_scaling:
return fused_rope_longrope_scaling
return fused_rope