From cdf458a8f9828a1281a2dc4c0c4b7bc7941aceac Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Fri, 7 Mar 2025 11:15:38 -0800 Subject: [PATCH 1/6] Updated pytorch and disabled sparse tests --- .github/workflows/pr-gpu.yaml | 8 ++++---- .pre-commit-config.yaml | 2 +- megablocks/layers/arguments.py | 6 ++++++ pyproject.toml | 2 +- setup.py | 4 ++-- tests/layers/dmoe_test.py | 6 ++++++ tests/layers/glu_test.py | 6 ++++++ tests/layers/moe_test.py | 6 ++++++ 8 files changed, 32 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index dda6e68d..1f8de746 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -21,14 +21,14 @@ jobs: fail-fast: false matrix: include: - - name: "python3.11-pytorch2.5.1-gpus1" + - name: "python3.11-pytorch2.6.0-gpus1" gpu_num: 1 python_version: 3.11 - container: mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04 - - name: "python3.11-pytorch2.5.1-gpus2" + container: mosaicml/pytorch:2.6.0_cu124-python3.11-ubuntu22.04 + - name: "python3.11-pytorch2.6.0-gpus2" gpu_num: 2 python_version: 3.11 - container: mosaicml/pytorch:2.5.1_cu124-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.6.0_cu124-python3.11-ubuntu22.04 steps: - name: Run PR GPU tests uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c754b29d..2cf68c06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: additional_dependencies: - toml - repo: https://github.com/hadialqattan/pycln - rev: v2.1.2 + rev: v2.5.0 hooks: - id: pycln args: [. --all] diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index ec0eb4a1..5f811df9 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import triton import megablocks.grouped_gemm_util as grouped_gemm @@ -73,6 +74,11 @@ class Arguments: moe_zloss_in_fp32: bool = False def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse' and triton.__version__ >= '3.2.0': + raise ValueError('Sparse MLP is not supported with triton >=3.2.0') + if self.__getattribute__('mlp_impl') == 'grouped': grouped_gemm.assert_grouped_gemm_is_available() diff --git a/pyproject.toml b/pyproject.toml index 28b9135c..7fcc670b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # build requirements [build-system] -requires = ["setuptools < 70.0.0", "torch >= 2.5.1, < 2.5.2"] +requires = ["setuptools < 70.0.0", "torch >= 2.6.0, < 2.6.1"] build-backend = "setuptools.build_meta" # Pytest diff --git a/setup.py b/setup.py index 5338d6b3..26cecf4c 100644 --- a/setup.py +++ b/setup.py @@ -62,8 +62,8 @@ install_requires = [ 'numpy>=1.21.5,<2.1.0', 'packaging>=21.3.0,<24.2', - 'torch>=2.5.1,<2.5.2', - 'triton>=2.1.0', + 'torch>=2.6.0,<2.6.1', + 'triton>=3.2.0,<3.3.0', 'stanford-stk==0.7.1', ] diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index d8f34ba8..bc3b88dd 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -5,6 +5,7 @@ import pytest import torch +import triton from megablocks import grouped_gemm_util as gg from megablocks.layers.arguments import Arguments @@ -53,6 +54,11 @@ def construct_moes( mlp_impl: str = 'sparse', moe_zloss_weight: float = 0, ): + # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if mlp_impl == 'sparse' and triton.__version__ >= '3.2.0': + pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 1e031ded..960159d0 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -6,6 +6,7 @@ import pytest import stk import torch +import triton from megablocks.layers import dmlp_registry from megablocks.layers.arguments import Arguments @@ -23,6 +24,11 @@ def construct_dmoe_glu( mlp_impl: str = 'sparse', memory_optimized_mlp: bool = False, ): + # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if mlp_impl == 'sparse' and triton.__version__ >= '3.2.0': + pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index 24d42c9f..b58cffcf 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -5,6 +5,7 @@ import pytest import torch +import triton from megablocks.layers.arguments import Arguments from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss @@ -41,6 +42,11 @@ def construct_moe( moe_top_k: int = 1, moe_zloss_weight: float = 0, ): + # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if triton.__version__ >= '3.2.0': + pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, From c8f95713cc593d82e54321526735835348980491 Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Mon, 10 Mar 2025 13:59:43 -0700 Subject: [PATCH 2/6] added pull request target --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 1f8de746..f34c490b 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -4,7 +4,7 @@ on: branches: - main - release/* - pull_request_target: + pull_request: branches: - main - release/** From 6b012e2d02d2115071ef7bea536083536090034e Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Tue, 11 Mar 2025 12:42:26 -0700 Subject: [PATCH 3/6] udpated pull_request_target --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index f34c490b..1f8de746 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -4,7 +4,7 @@ on: branches: - main - release/* - pull_request: + pull_request_target: branches: - main - release/** From ebca28ad2aaf67ae2c3601ebbe655bbcb447844b Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Tue, 11 Mar 2025 16:30:14 -0700 Subject: [PATCH 4/6] made triton import inline and added try/catch --- megablocks/layers/arguments.py | 10 +++++++--- tests/layers/dmoe_test.py | 10 +++++++--- tests/layers/glu_test.py | 10 +++++++--- tests/layers/moe_test.py | 9 ++++++--- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 5f811df9..64830758 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -8,7 +8,6 @@ import torch import torch.distributed as dist import torch.nn.functional as F -import triton import megablocks.grouped_gemm_util as grouped_gemm @@ -76,8 +75,13 @@ class Arguments: def __post_init__(self): # Sparse MLP is not supported with triton >=3.2.0 # TODO: Remove this once sparse is supported with triton >=3.2.0 - if self.__getattribute__('mlp_impl') == 'sparse' and triton.__version__ >= '3.2.0': - raise ValueError('Sparse MLP is not supported with triton >=3.2.0') + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError('Sparse MLP is not supported with triton >=3.2.0') + except ImportError: + raise ImportError("Triton is required for sparse MLP implementation") if self.__getattribute__('mlp_impl') == 'grouped': grouped_gemm.assert_grouped_gemm_is_available() diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index bc3b88dd..1f66999e 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -5,7 +5,6 @@ import pytest import torch -import triton from megablocks import grouped_gemm_util as gg from megablocks.layers.arguments import Arguments @@ -56,8 +55,13 @@ def construct_moes( ): # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported # TODO: Remove this once sparse is supported with triton >=3.2.0 - if mlp_impl == 'sparse' and triton.__version__ >= '3.2.0': - pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + if mlp_impl == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + except ImportError: + pass init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 960159d0..0ce13397 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -6,7 +6,6 @@ import pytest import stk import torch -import triton from megablocks.layers import dmlp_registry from megablocks.layers.arguments import Arguments @@ -26,8 +25,13 @@ def construct_dmoe_glu( ): # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported # TODO: Remove this once sparse is supported with triton >=3.2.0 - if mlp_impl == 'sparse' and triton.__version__ >= '3.2.0': - pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + if mlp_impl == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + except ImportError: + pass init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index b58cffcf..46af3679 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -5,7 +5,6 @@ import pytest import torch -import triton from megablocks.layers.arguments import Arguments from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss @@ -44,8 +43,12 @@ def construct_moe( ): # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported # TODO: Remove this once sparse is supported with triton >=3.2.0 - if triton.__version__ >= '3.2.0': - pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + try: + import triton + if triton.__version__ >= '3.2.0': + pytest.skip('Sparse MLP is not supported with triton >=3.2.0') + except ImportError: + pass init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( From 93f837f32d0233efb9979be47d92f48ac115027d Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Tue, 11 Mar 2025 16:39:10 -0700 Subject: [PATCH 5/6] formatted --- megablocks/layers/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 64830758..f92f6b56 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -81,7 +81,7 @@ def __post_init__(self): if triton.__version__ >= '3.2.0': raise ValueError('Sparse MLP is not supported with triton >=3.2.0') except ImportError: - raise ImportError("Triton is required for sparse MLP implementation") + raise ImportError('Triton is required for sparse MLP implementation') if self.__getattribute__('mlp_impl') == 'grouped': grouped_gemm.assert_grouped_gemm_is_available() From 129335b22c027f64b84d52fe17829c09d00be63a Mon Sep 17 00:00:00 2001 From: Rithwik Ediga Lakhamsani Date: Wed, 12 Mar 2025 10:46:33 -0700 Subject: [PATCH 6/6] added note to use grouped instead --- megablocks/layers/arguments.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index f92f6b56..3962c771 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -79,7 +79,9 @@ def __post_init__(self): try: import triton if triton.__version__ >= '3.2.0': - raise ValueError('Sparse MLP is not supported with triton >=3.2.0') + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) except ImportError: raise ImportError('Triton is required for sparse MLP implementation')