Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,15 +709,20 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
return res


def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
group_sizes: torch.Tensor) -> torch.Tensor:
def gmm(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: tuple[int, int, int] = (512, 512, 512)
) -> torch.Tensor:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.

Args:
lhs: A 2d, jnp.ndarray with shape [m, k].
rhs: A 3d, jnp.ndarray with shape [num_groups, k, n].
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.

Returns:
A 2d, jnp.ndarray with shape [m, n].
Expand All @@ -727,17 +732,24 @@ def gmm(lhs: torch.Tensor, rhs: torch.Tensor,
jax_import_guard()
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm

payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes)
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2]
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
payload, _ = trace_pallas(
gmm,
lhs,
rhs,
group_sizes,
static_argnames=["tiling"],
tiling=(tm, tk, tn))

m, n = lhs.shape[0], rhs.shape[2]
# Create the metadata we need for computation.
# TODO (alanwaketan): The following assuumes groups_sizes is a cpu tensor.
# That means we need to materialize this input in order to use this gmm
# kernel, and that will introduce graph breaks in the computation.
group_offsets, group_ids, m_tile_ids, num_tiles = _make_group_metadata(
group_sizes=group_sizes,
m=lhs.shape[0],
tm=128 # TODO (alanwaketan): Tune this later.
m=m,
tm=tm,
)
group_offset_torch = torch.tensor([0], dtype=torch.int32).to("xla")

Expand Down