Skip to content

Commit 419eb3e

Browse files
authored
FIX: Gracefully handle numpy arrays as input to check_in_list() (matplotlib#30714)
Closes matplotlib#30706
1 parent 677a2ea commit 419eb3e

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

lib/matplotlib/_api/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def check_in_list(values, /, *, _print_supported_values=True, **kwargs):
116116
----------
117117
values : iterable
118118
Sequence of values to check on.
119+
120+
Note: All values must support == comparisons.
121+
This means in particular the entries must not be numpy arrays.
119122
_print_supported_values : bool, default: True
120123
Whether to print *values* when raising ValueError.
121124
**kwargs : dict
@@ -133,7 +136,18 @@ def check_in_list(values, /, *, _print_supported_values=True, **kwargs):
133136
if not kwargs:
134137
raise TypeError("No argument to check!")
135138
for key, val in kwargs.items():
136-
if val not in values:
139+
try:
140+
exists = val in values
141+
except ValueError:
142+
# `in` internally uses `val == values[i]`. There are some objects
143+
# that do not support == to arbitrary other objects, in particular
144+
# numpy arrays.
145+
# Since such objects are not allowed in values, we can gracefully
146+
# handle the case that val (typically provided by users) is of such
147+
# type and directly state it's not in the list instead of letting
148+
# the individual `val == values[i]` ValueError surface.
149+
exists = False
150+
if not exists:
137151
msg = f"{val!r} is not a valid value for {key}"
138152
if _print_supported_values:
139153
msg += f"; supported values are {', '.join(map(repr, values))}"

lib/matplotlib/tests/test_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,8 @@ def f() -> None:
150150
def test_empty_check_in_list() -> None:
151151
with pytest.raises(TypeError, match="No argument to check!"):
152152
_api.check_in_list(["a"])
153+
154+
155+
def test_check_in_list_numpy() -> None:
156+
with pytest.raises(ValueError, match=r"array\(5\) is not a valid value"):
157+
_api.check_in_list(['a', 'b'], value=np.array(5))

0 commit comments

Comments
 (0)