Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import megablocks | |
| def randn(bs, x, y): | |
| out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) | |
| return out.cuda().to(torch.bfloat16) | |
| def gmm(a, b, batch_sizes, trans_b=False): | |
| batch_sizes = batch_sizes.cpu().numpy() | |
| out = [] | |
| start = 0 | |
| for i, size in enumerate(batch_sizes): | |
| rhs = b[i, :, :].t() if trans_b else b[i, :, :] | |
| out.append(a[start : start + size, :] @ rhs) | |
| start += size | |
| return torch.cat(out) | |
| def test_gmm(): | |
| z = 1 | |
| m = 128 | |
| n = 128 | |
| k = 128 | |
| trans_b = False | |
| batch_sizes_on_device = False | |
| # TODO: fix to enable batch_sizes_on_device | |
| # batch_sizes_on_device = True | |
| torch.manual_seed(0) | |
| a = randn(z, m, k).view(-1, k) | |
| b = randn(z, n, k) if trans_b else randn(z, k, n) | |
| batch_sizes = torch.tensor([m] * z) | |
| if batch_sizes_on_device: | |
| batch_sizes = batch_sizes.cuda() | |
| a.requires_grad_(True) | |
| b.requires_grad_(True) | |
| a_ref = a.detach().clone().requires_grad_(True) | |
| b_ref = b.detach().clone().requires_grad_(True) | |
| # out = ops.gmm(a, b, batch_sizes, trans_b) | |
| out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) | |
| print("out", out) | |
| expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) | |
| assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}" | |
| out.sum().backward() | |
| expected_out.sum().backward() | |
| assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}" | |
| assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}" | |
| print("Test passed successfully!") |