diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py index 24f6b2f19df..96a75038056 100644 --- a/examples/multimodal/layer_specs.py +++ b/examples/multimodal/layer_specs.py @@ -17,6 +17,9 @@ from megatron.core.typed_torch import not_none try: + import transformer_engine # pylint: disable=unused-import + + HAVE_TE = True from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, TEDotProductAttention, @@ -24,9 +27,8 @@ TENorm, TERowParallelLinear, ) - - HAVE_TE = True except ImportError: + HAVE_TE = False ( TEColumnParallelLinear, TEDotProductAttention, @@ -34,7 +36,6 @@ TENorm, TERowParallelLinear, ) = (None, None, None, None, None) - HAVE_TE = False try: import apex diff --git a/examples/multimodal/radio/radio_g.py b/examples/multimodal/radio/radio_g.py index d39a9083722..2458464fc7d 100644 --- a/examples/multimodal/radio/radio_g.py +++ b/examples/multimodal/radio/radio_g.py @@ -18,6 +18,9 @@ from megatron.core.typed_torch import not_none try: + import transformer_engine # pylint: disable=unused-import + + HAVE_TE = True from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, TEDotProductAttention, @@ -25,9 +28,8 @@ TENorm, TERowParallelLinear, ) - - HAVE_TE = True except ImportError: + HAVE_TE = False ( TEColumnParallelLinear, TEDotProductAttention, @@ -35,7 +37,6 @@ TENorm, TERowParallelLinear, ) = (None, None, None, None, None) - HAVE_TE = False try: import apex diff --git a/tests/unit_tests/transformer/moe/test_upcycling.py b/tests/unit_tests/transformer/moe/test_upcycling.py index b4f544b8930..152a4772255 100644 --- a/tests/unit_tests/transformer/moe/test_upcycling.py +++ b/tests/unit_tests/transformer/moe/test_upcycling.py @@ -34,11 +34,13 @@ from tests.unit_tests.test_utilities import Utils try: - from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear + import transformer_engine # pylint: disable=unused-import HAVE_TE = True + from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear except ImportError: HAVE_TE = False + TEColumnParallelGroupedLinear = None _SEED = 42 diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index 5a9cad7f5f7..520f980d034 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -41,11 +41,13 @@ from tests.unit_tests.test_utilities import Utils try: - from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear + import transformer_engine # pylint: disable=unused-import HAVE_TE = True + from megatron.core.extensions.transformer_engine import TEColumnParallelGroupedLinear except ImportError: HAVE_TE = False + TEColumnParallelGroupedLinear = None _SEED = 42