-
Couldn't load subscription status.
- Fork 560
Adding megablox gmm standalone #6940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5abcb18 a9c8653 6250bef 816b717 2112d7e 6cd4bca 6a7e786 c5ba27b 90c850a 36b7ac1 e416985 7b50136 351097e b50a371 3cc89d3 b06f06e 3cbc14f da821f5 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| """Grouped matrix multiplication kernels for TPU written in Pallas.""" | ||
| | ||
| import logging | ||
| import unittest | ||
| | ||
| from typing import Optional, Union, Callable | ||
| | ||
| import torch | ||
| import torch_xla | ||
| import torch_xla.core.xla_model as xm | ||
| import torch_xla.experimental.megablox as megablox | ||
| from torch_xla import runtime as xr | ||
| from torch_xla._internal import tpu | ||
| | ||
| import numpy as np | ||
| | ||
| if xr.device_type() == 'TPU': | ||
| from torch_xla.experimental.custom_kernel import jax_import_guard | ||
| jax_import_guard() | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax.experimental import pallas as pl | ||
| | ||
| | ||
| class MegabloxTest(unittest.TestCase): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why can't we merge this to test_pallas.py? | ||
| | ||
| def _reference_gmm( | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can just do it in torch instead of np? | ||
| self, | ||
| lhs: np.array, | ||
| rhs: np.array, | ||
| group_sizes: np.array, | ||
| preferred_element_type: np.dtype = np.float32, | ||
| ) -> np.array: | ||
| | ||
| start = 0 | ||
| out = [] | ||
| for i, size in enumerate(group_sizes): | ||
| result = np.dot(lhs[start:start + size, :], rhs[i, :, :]) | ||
| | ||
| result = result.astype(preferred_element_type) | ||
| out.append(result) | ||
| start += group_sizes[i] | ||
| return np.array(np.concatenate(out, axis=0)) | ||
| | ||
| def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I can tell, for us, we just need to make sure our piping is correct and we don't need to ensure gmm itself is correct. That's JAX's job. So, let's remove this and pick one or two cases that are tuned to our wrapper. | ||
| # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer | ||
| # sample with replacement so that it's possible to get zero-sized groups. Get | ||
| # 'num_groups - 1' run ends. The final group will end at 'm'. | ||
| ends_no_final = np.sort( | ||
| np.array( | ||
| [np.random.randint(low=0, high=m) for _ in range(num_groups - 1)], | ||
| dtype=np.int32, | ||
| ),) | ||
| ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) | ||
| | ||
| # Calculate the run starts by shifting ends 1 to the right. The first run | ||
| # starts at zero. | ||
| starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) | ||
| return torch.from_numpy(ends - starts).to(torch.int32) | ||
| | ||
| def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype, | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use torch, and we don't need this. We can just torch.allclose. | ||
| out_dtype: torch.dtype) -> tuple[float, float]: | ||
| if (lhs_dtype == torch.bfloat16 or rhs_dtype == torch.bfloat16 or | ||
| out_dtype == torch.bfloat16): | ||
| return 1e-3, 1e-2 # atol, rtol | ||
| return 1e-4, 1e-2 # atol, rtol | ||
| | ||
| LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's this? | ||
| | ||
| def _init_test_cases(self): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might not need all of these. | ||
| self.tests_cases = [] | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.float32, | ||
| 'm': 128, | ||
| 'k': 128, | ||
| 'n': 128, | ||
| 'num_groups': 1 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.float32, | ||
| 'm': 256, | ||
| 'k': 128, | ||
| 'n': 128, | ||
| 'num_groups': 1 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.float32, | ||
| 'm': 128, | ||
| 'k': 256, | ||
| 'n': 128, | ||
| 'num_groups': 8 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.float32, | ||
| 'm': 512, | ||
| 'k': 128, | ||
| 'n': 256, | ||
| 'num_groups': 2 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.bfloat16, | ||
| 'm': 128, | ||
| 'k': 128, | ||
| 'n': 128, | ||
| 'num_groups': 1 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.bfloat16, | ||
| 'm': 256, | ||
| 'k': 128, | ||
| 'n': 128, | ||
| 'num_groups': 1 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.bfloat16, | ||
| 'm': 128, | ||
| 'k': 256, | ||
| 'n': 128, | ||
| 'num_groups': 8 | ||
| }) | ||
| self.tests_cases.append({ | ||
| 'dtype': torch.bfloat16, | ||
| 'm': 512, | ||
| 'k': 128, | ||
| 'n': 256, | ||
| 'num_groups': 2 | ||
| }) | ||
| | ||
| @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") | ||
| def test_gmm(self): | ||
| self._init_test_cases() | ||
| for test_case in self.tests_cases: | ||
| num_groups = test_case['num_groups'] | ||
| k = test_case['k'] | ||
| m = test_case['m'] | ||
| n = test_case['n'] | ||
| lhs_dtype = rhs_dtype = test_case['dtype'] | ||
| out_dtype = torch.float32 | ||
| | ||
| lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') | ||
| rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') | ||
| group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a CPU tensor!!!!!!!!!!!!!!!!!!!!!!!! | ||
| out = megablox.gmm(lhs, rhs, group_sizes) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We always output fp32 in this test case regardless of the input dtypes.... | ||
| | ||
| ref_out = self._reference_gmm(lhs.cpu().float().numpy(), | ||
| rhs.cpu().float().numpy(), | ||
| group_sizes.numpy()) | ||
| | ||
| atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype) | ||
| np.testing.assert_allclose( | ||
| ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol) | ||
| | ||
| | ||
| if __name__ == '__main__': | ||
| logging.getLogger().setLevel(logging.INFO) | ||
| torch.set_default_dtype(torch.float32) | ||
| torch.manual_seed(42) | ||
| torch_xla._XLAC._xla_set_use_full_mat_mul_precision( | ||
| use_full_mat_mul_precision=True) | ||
| test = unittest.main() | ||
| sys.exit(0 if test.result.wasSuccessful() else 1) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .gmm import gmm | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once we remove all the duplicated code. We can move this method back to custom_kernel.py. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| """Common utilities for Pallas kernels.""" | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file can be deleted if we directly use the helper from JAX. | ||
| | ||
| from typing import Union | ||
| import torch | ||
| from torch_xla._internal import tpu | ||
| | ||
| | ||
| def assert_is_supported_dtype(dtype: torch.dtype) -> None: | ||
| if dtype != torch.bfloat16 and dtype != torch.float32: | ||
| raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.") | ||
| | ||
| | ||
| def select_input_dtype(lhs: torch.Tensor, rhs: torch.Tensor) -> torch.dtype: | ||
| """A type to which both input should be adapted to before dot product.""" | ||
| # bf16xbf16 matmul is only supported since TPU v4 generation. In | ||
| # case of mixed input precision, we need to convert bf16 argument to fp32 | ||
| # beforehand. | ||
| if (tpu.version() >= 4 and lhs.dtype == torch.bfloat16 and | ||
| rhs.dtype == torch.bfloat16): | ||
| return torch.bfloat16 | ||
| else: | ||
| return torch.float32 | ||
Uh oh!
There was an error while loading. Please reload this page.