@@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat):
2323 if library == "ndonnx" and api_version in ("2021.12" , "2022.12" ):
2424 pytest .skip ("Unsupported API version" )
2525
26- namespace = array_namespace (array , api_version = api_version , use_compat = use_compat )
26+ with warnings .catch_warnings ():
27+ warnings .simplefilter ('ignore' , UserWarning )
28+ namespace = array_namespace (array , api_version = api_version , use_compat = use_compat )
2729
2830 if use_compat is False or use_compat is None and library not in wrapped_libraries :
2931 if library == "jax.numpy" and use_compat is None :
@@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat):
4547
4648 if library == "numpy" :
4749 # check that the same namespace is returned for NumPy scalars
48- scalar_namespace = array_namespace (
49- xp .float64 (0.0 ), api_version = api_version , use_compat = use_compat
50- )
51- assert scalar_namespace == namespace
50+ with warnings .catch_warnings ():
51+ warnings .simplefilter ('ignore' , UserWarning )
52+
53+ scalar_namespace = array_namespace (
54+ xp .float64 (0.0 ), api_version = api_version , use_compat = use_compat
55+ )
56+ assert scalar_namespace == namespace
5257
5358 # Check that array_namespace works even if jax.experimental.array_api
5459 # hasn't been imported yet (it monkeypatches __array_namespace__
@@ -97,7 +102,9 @@ def test_api_version_torch():
97102 torch = import_ ("torch" )
98103 x = torch .asarray ([1 , 2 ])
99104 torch_ = import_ ("torch" , wrapper = True )
100- assert array_namespace (x , api_version = "2023.12" ) == torch_
105+ with warnings .catch_warnings ():
106+ warnings .simplefilter ('ignore' , UserWarning )
107+ assert array_namespace (x , api_version = "2023.12" ) == torch_
101108 assert array_namespace (x , api_version = None ) == torch_
102109 assert array_namespace (x ) == torch_
103110 # Should issue a warning
0 commit comments