From 463a84fb7f61e06c319a9a878aae4d90611e4eab Mon Sep 17 00:00:00 2001 From: zpcore Date: Mon, 17 Nov 2025 22:54:24 -0800 Subject: [PATCH 1/2] Fix sharding prop cache clear --- .github/workflows/test_cuda.yml | 2 +- autoparallel/dtensor_util/utils.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 1e91703a..dcd2cf6b 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -40,7 +40,7 @@ jobs: pip uninstall -y torch pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 pip install --quiet . - pytest tests --deselect=tests/test_dtensor.py::ImplicitRegistrationTest::test_implicit_registration + pytest tests python examples/example_autoparallel.py python examples/example_llama3.py python examples/example_dcp.py diff --git a/autoparallel/dtensor_util/utils.py b/autoparallel/dtensor_util/utils.py index 3341e2e9..fec580fa 100644 --- a/autoparallel/dtensor_util/utils.py +++ b/autoparallel/dtensor_util/utils.py @@ -24,6 +24,11 @@ ) from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.distributed.tensor.debug import ( + _clear_fast_path_sharding_prop_cache, + _clear_python_sharding_prop_cache, +) + try: from torch.utils._cxx_pytree import tree_leaves except ImportError: @@ -82,7 +87,8 @@ def op_strategy_context(op_overload, strategy_func, schema_info=None): del propagator.op_to_schema_info[op_overload] else: propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema - propagator.propagate_op_sharding.cache.cache_clear() + _clear_fast_path_sharding_prop_cache() + _clear_python_sharding_prop_cache() # -------------define universal op strategy------------- From a973def9ab82ca543ad7970bd62c7eea4f7ddbb9 Mon Sep 17 00:00:00 2001 From: zpcore Date: Mon, 17 Nov 2025 22:59:51 -0800 Subject: [PATCH 2/2] lint --- autoparallel/dtensor_util/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/autoparallel/dtensor_util/utils.py b/autoparallel/dtensor_util/utils.py index fec580fa..6b11ca58 100644 --- a/autoparallel/dtensor_util/utils.py +++ b/autoparallel/dtensor_util/utils.py @@ -22,12 +22,11 @@ is_tensor_shardable, register_op_strategy, ) -from torch.distributed.tensor.placement_types import Placement, Replicate, Shard - from torch.distributed.tensor.debug import ( _clear_fast_path_sharding_prop_cache, _clear_python_sharding_prop_cache, ) +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard try: from torch.utils._cxx_pytree import tree_leaves