3232import torch ._dynamo .testing
3333import torch ._inductor .test_case
3434import torch .onnx .operators
35- import torch .utils ._pytree as pytree
35+ import torch .utils ._pytree as python_pytree
3636import torch .utils .cpp_extension
3737from torch import Tensor
3838from torch ._C import FileCheck
8989from torch .testing ._internal .logging_utils import logs_to_string
9090
9191
92- HAS_OPTREE = importlib . util . find_spec ( "optree" )
92+ HAS_OPTREE = python_pytree . _cxx_pytree_exists
9393if HAS_OPTREE :
94- import optree
94+ import torch .utils ._cxx_pytree as cxx_pytree
95+ else :
96+ cxx_pytree = None
9597
9698MyTuple = collections .namedtuple ("MyTuple" , ["a" , "b" , "ab" ])
9799T = typing .TypeVar ("T" )
@@ -293,9 +295,9 @@ def fn(x):
293295
294296 @unittest .skipIf (not HAS_OPTREE , "missing optree package" )
295297 def test_optree_graph_break_message (self ):
296- @ torch . compile (
297- backend = "eager" ,
298- )
298+ import optree
299+
300+ @ torch . compile ( backend = "eager" )
299301 def fn (x ):
300302 d = {"a" : 1 }
301303 optree .tree_flatten (d )
@@ -8722,9 +8724,9 @@ def fn():
87228724
87238725 def test_tracing_py_tree (self ):
87248726 def fn (xs ):
8725- flat_xs , spec = pytree .tree_flatten (xs )
8727+ flat_xs , spec = python_pytree .tree_flatten (xs )
87268728 res = [x .clone () for x in flat_xs ]
8727- return pytree .tree_unflatten (res , spec )
8729+ return python_pytree .tree_unflatten (res , spec )
87288730
87298731 xs = [torch .tensor (i ) for i in range (3 )]
87308732
@@ -8734,12 +8736,10 @@ def fn(xs):
87348736 self .assertEqual (counter .op_count , 3 )
87358737
87368738 def test_tracing_nested_py_tree (self ):
8737- import torch .utils ._pytree as pytree
8738-
87398739 def fn (xs ):
8740- flat_xs , spec = pytree .tree_flatten (xs )
8740+ flat_xs , spec = python_pytree .tree_flatten (xs )
87418741 res = [x .clone () for x in flat_xs ]
8742- return pytree .tree_unflatten (res , spec )
8742+ return python_pytree .tree_unflatten (res , spec )
87438743
87448744 xs = [torch .tensor (i ) for i in range (3 )]
87458745 xsl = [xs , xs , xs , xs ]
@@ -8752,12 +8752,10 @@ def fn(xs):
87528752 self .assertEqual (counter .op_count , 12 )
87538753
87548754 def test_tracing_nested_py_tree_tuples (self ):
8755- import torch .utils ._pytree as pytree
8756-
87578755 def fn (xs ):
8758- flat_xs , spec = pytree .tree_flatten (xs )
8756+ flat_xs , spec = python_pytree .tree_flatten (xs )
87598757 res = [x .clone () for x in flat_xs ]
8760- return pytree .tree_unflatten (res , spec )
8758+ return python_pytree .tree_unflatten (res , spec )
87618759
87628760 xs = [torch .tensor (i ) for i in range (3 )]
87638761 xsl = (xs , xs , xs , xs )
@@ -8770,12 +8768,10 @@ def fn(xs):
87708768 self .assertEqual (counter .op_count , 12 )
87718769
87728770 def test_tracing_nested_py_tree_dicts (self ):
8773- import torch .utils ._pytree as pytree
8774-
87758771 def fn (xs ):
8776- flat_xs , spec = pytree .tree_flatten (xs )
8772+ flat_xs , spec = python_pytree .tree_flatten (xs )
87778773 res = [x .clone () for x in flat_xs ]
8778- return pytree .tree_unflatten (res , spec )
8774+ return python_pytree .tree_unflatten (res , spec )
87798775
87808776 xs = [torch .tensor (i ) for i in range (3 )]
87818777 xsl = {
@@ -8808,12 +8804,10 @@ def fn(x):
88088804 self .assertEqual (counter .op_count , 2 )
88098805
88108806 def test_tracing_nested_py_tree_mixed_all (self ):
8811- import torch .utils ._pytree as pytree
8812-
88138807 def fn (xs ):
8814- flat_xs , spec = pytree .tree_flatten (xs )
8808+ flat_xs , spec = python_pytree .tree_flatten (xs )
88158809 res = [x .clone () for x in flat_xs ]
8816- return pytree .tree_unflatten (res , spec )
8810+ return python_pytree .tree_unflatten (res , spec )
88178811
88188812 xs = [torch .tensor (i ) for i in range (3 )]
88198813 xsa = (xs , xs )
@@ -8858,13 +8852,12 @@ def fn(x):
88588852 self .assertEqual (cnt .frame_count , 2 )
88598853
88608854 def test_tracing_py_tree_tensor_subclass (self ):
8861- import torch .utils ._pytree as pytree
88628855 from torch .testing ._internal .two_tensor import TwoTensor
88638856 from torch .utils .checkpoint import checkpoint
88648857
88658858 def fn (xs ):
88668859 nested_xs = [[xs ]]
8867- flat_xs , spec = pytree .tree_flatten (xs )
8860+ flat_xs , spec = python_pytree .tree_flatten (xs )
88688861 return flat_xs [0 ].clone ()
88698862
88708863 # use checkpoint to trigger a "sourceless" tensor subclass
@@ -8879,13 +8872,11 @@ def checkpoint_fn(xs):
88798872 self .assertEqual (counter .op_count , 2 )
88808873
88818874 def test_tracing_tree_map_only (self ):
8882- import torch .utils ._pytree as pytree
8883-
88848875 def fn (xs ):
88858876 def mapper (x ):
88868877 return x .clone ()
88878878
8888- y = pytree .tree_map_only (torch .Tensor , mapper , xs )
8879+ y = python_pytree .tree_map_only (torch .Tensor , mapper , xs )
88898880 return y
88908881
88918882 xs = [torch .tensor (i ) for i in range (3 )] + ["hi" ]
@@ -10235,7 +10226,9 @@ def fn(x, y):
1023510226 self .assertEqual (actual , expected )
1023610227
1023710228 def test_pytree_tree_leaves (self ):
10238- implemtations = [("python" , pytree )]
10229+ implemtations = [("python" , python_pytree )]
10230+ if cxx_pytree is not None :
10231+ implemtations .append (("cxx" , cxx_pytree ))
1023910232
1024010233 for name , module in implemtations :
1024110234 with self .subTest (f"pytree implement: { name } " ):
@@ -10267,7 +10260,7 @@ def fn(x):
1026710260 self .assertEqual (actual , expected )
1026810261
1026910262 def test_pytree_tree_flatten_unflatten (self ):
10270- implemtations = [("python" , pytree )]
10263+ implemtations = [("python" , python_pytree )]
1027110264
1027210265 for name , module in implemtations :
1027310266 with self .subTest (f"pytree implement: { name } " ):
@@ -10316,7 +10309,7 @@ def fn(x, y):
1031610309 self .assertEqual (actual , expected )
1031710310
1031810311 def test_pytree_tree_map (self ):
10319- implemtations = [("python" , pytree )]
10312+ implemtations = [("python" , python_pytree )]
1032010313
1032110314 for name , module in implemtations :
1032210315 with self .subTest (f"pytree implement: { name } " ):
0 commit comments