diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 1e91703a..dcd2cf6b 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -40,7 +40,7 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 pip install --quiet . - pytest tests --deselect=tests/test_dtensor.py::ImplicitRegistrationTest::test_implicit_registration + pytest tests python examples/example_autoparallel.py python examples/example_llama3.py python examples/example_dcp.py diff --git a/autoparallel/dtensor_util/utils.py b/autoparallel/dtensor_util/utils.py index 3341e2e9..6b11ca58 100644 --- a/autoparallel/dtensor_util/utils.py +++ b/autoparallel/dtensor_util/utils.py @@ -22,6 +22,10 @@ is_tensor_shardable, register_op_strategy, ) +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, +) from torch.distributed.tensor.placement_types import Placement, Replicate, Shard try: @@ -82,7 +86,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - propagator.propagate_op_sharding.cache.cache_clear() + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() # -------------define universal op strategy-------------