Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions xarray_array_testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ class DuckArrayTestMixin(ABC):
def xp() -> ModuleType:
pass

@property
@abc.abstractmethod
def array_type(self) -> type[duckarray]:
@staticmethod
def array_type(op: str) -> type[duckarray]:
pass

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion xarray_array_testing/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class CreationTests(DuckArrayTestMixin):
def test_create_variable(self, data):
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))

assert isinstance(variable.data, self.array_type)
assert isinstance(variable.data, self.array_type("__init__"))
8 changes: 6 additions & 2 deletions xarray_array_testing/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def test_variable_isel_orthogonal(self, data):
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
expected = variable.data[*raw_indexers.values()]

assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
assert isinstance(
actual, self.array_type("orthogonal_indexing")
), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)

@given(st.data())
Expand All @@ -109,5 +111,7 @@ def test_variable_isel_vectorized(self, data):
raw_indexers = {dim: idx.get(dim, slice(None)) for dim in variable.dims}
expected = variable.data[*raw_indexers.values()]

assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
assert isinstance(
actual, self.array_type("vectorized_indexing")
), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)
8 changes: 4 additions & 4 deletions xarray_array_testing/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_variable_numerical_reduce(self, op, data):
# compute using xp.<OP>(array)
expected = getattr(self.xp, op)(variable.data)

assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)

@pytest.mark.parametrize("op", ["all", "any"])
Expand All @@ -39,7 +39,7 @@ def test_variable_boolean_reduce(self, op, data):
# compute using xp.<OP>(array)
expected = getattr(self.xp, op)(variable.data)

assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)

@pytest.mark.parametrize("op", ["max", "min"])
Expand All @@ -53,7 +53,7 @@ def test_variable_order_reduce(self, op, data):
# compute using xp.<OP>(array)
expected = getattr(self.xp, op)(variable.data)

assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)

@pytest.mark.parametrize("op", ["argmax", "argmin"])
Expand Down Expand Up @@ -96,5 +96,5 @@ def test_variable_cumulative_reduce(self, op, data):
for axis in range(variable.ndim):
expected = getattr(self.xp, array_api_names[op])(expected, axis=axis)

assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
assert isinstance(actual, self.array_type(op)), f"wrong type: {type(actual)}"
self.assert_equal(actual, expected)
4 changes: 2 additions & 2 deletions xarray_array_testing/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class NumpyTestMixin(DuckArrayTestMixin):
def xp(self) -> ModuleType:
return np

@property
def array_type(self) -> type[np.ndarray]:
@staticmethod
def array_type(op: str) -> type[np.ndarray]:
return np.ndarray

@staticmethod
Expand Down
Loading