Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
| # Copyright 2024 Databricks | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import os | |
| from typing import List, Optional | |
| import pytest | |
| # from composer.utils import reproducibility | |
| # Allowed options for pytest.mark.world_size() | |
| WORLD_SIZE_OPTIONS = (1, 2) | |
| # Enforce deterministic mode before any tests start. | |
| # reproducibility.configure_deterministic_mode() | |
| # TODO: allow plugind when deps resolved | |
| # Add the path of any pytest fixture files you want to make global | |
| pytest_plugins = [ | |
| # 'tests.fixtures.autouse', | |
| 'tests.fixtures.fixtures', | |
| ] | |
| def _get_world_size(item: pytest.Item): | |
| """Returns the world_size of a test, defaults to 1.""" | |
| _default = pytest.mark.world_size(1).mark | |
| return item.get_closest_marker('world_size', default=_default).args[0] | |
| def _get_option( | |
| config: pytest.Config, | |
| name: str, | |
| default: Optional[str] = None, | |
| ) -> str: # type: ignore | |
| val = config.getoption(name) | |
| if val is not None: | |
| assert isinstance(val, str) | |
| return val | |
| val = config.getini(name) | |
| if val == []: | |
| val = None | |
| if val is None: | |
| if default is None: | |
| pytest.fail(f'Config option {name} is not specified but is required',) | |
| val = default | |
| assert isinstance(val, str) | |
| return val | |
| def _add_option( | |
| parser: pytest.Parser, | |
| name: str, | |
| help: str, | |
| choices: Optional[list[str]] = None, | |
| ): | |
| parser.addoption( | |
| f'--{name}', | |
| default=None, | |
| type=str, | |
| choices=choices, | |
| help=help, | |
| ) | |
| parser.addini( | |
| name=name, | |
| help=help, | |
| type='string', | |
| default=None, | |
| ) | |
| def pytest_collection_modifyitems( | |
| config: pytest.Config, | |
| items: List[pytest.Item], | |
| ) -> None: | |
| """Filter tests by world_size (for multi-GPU tests)""" | |
| world_size = int(os.environ.get('WORLD_SIZE', '1')) | |
| print(f'world_size={world_size}') | |
| conditions = [ | |
| lambda item: _get_world_size(item) == world_size, | |
| ] | |
| # keep items that satisfy all conditions | |
| remaining = [] | |
| deselected = [] | |
| for item in items: | |
| if all(condition(item) for condition in conditions): | |
| remaining.append(item) | |
| else: | |
| deselected.append(item) | |
| if deselected: | |
| config.hook.pytest_deselected(items=deselected) | |
| items[:] = remaining | |
| def pytest_addoption(parser: pytest.Parser) -> None: | |
| _add_option( | |
| parser, | |
| 'seed', | |
| help="""\ | |
| Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked | |
| before each test.""", | |
| ) | |
| def pytest_sessionfinish(session: pytest.Session, exitstatus: int): | |
| if exitstatus == 5: | |
| session.exitstatus = 0 # Ignore no-test-ran errors | |