Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
uses: conda-incubator/setup-miniconda@v2
with:
python-version: ${{ matrix.python-version }}
mamba-version: "*"
miniforge-variant: Mambaforge
channels: conda-forge,defaults
channel-priority: true
auto-activate-base: false
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ repos:
- id: double-quote-string-fixer

- repo: https://github.com/ambv/black
rev: 20.8b1
rev: 22.3.0 # https://github.com/psf/black/issues/2964
hooks:
- id: black
args: ["--line-length", "100", "--skip-string-normalization"]

- repo: https://gitlab.com/PyCQA/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.8.4
hooks:
- id: flake8
Expand Down
6 changes: 5 additions & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Required
version: 2

build:
os: "ubuntu-20.04"
tools:
python: "mambaforge-4.10"

# Build documentation in the doc/ directory with Sphinx
sphinx:
configuration: doc/conf.py
Expand All @@ -9,7 +14,6 @@ conda:
environment: environment_doc.yml

python:
version: 3.8
install:
- method: pip
path: .
234 changes: 209 additions & 25 deletions doc/examples/introduction.ipynb

Large diffs are not rendered by default.

44 changes: 40 additions & 4 deletions src/xoak/accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Hashable, Iterable, List, Mapping, Tuple, Type, Union
from typing import Any, Hashable, Iterable, List, Mapping, Optional, Tuple, Type, Union

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -134,12 +134,14 @@ def index(self) -> Union[None, Index, Iterable[Index]]:
return [wrp.index for wrp in index_wrappers]

def _query(self, indexers):
"""Find the distance(s) and indices of nearest point(s)."""
X = coords_to_point_array([indexers[c] for c in self._index_coords])

if isinstance(X, np.ndarray) and isinstance(self._index, XoakIndexWrapper):
# directly call index wrapper's query method
res = self._index.query(X)
results = res['indices'][:, 0]
distances = res['distances'][:, 0]

else:
# Two-stage lazy query with dask
Expand Down Expand Up @@ -195,7 +197,7 @@ def _query(self, indexers):
concatenate=True,
)

return results
return results, distances

def _get_pos_indexers(self, indices, indexers):
"""Returns positional indexers based on the query results and the
Expand Down Expand Up @@ -225,7 +227,10 @@ def _get_pos_indexers(self, indices, indexers):
return pos_indexers

def sel(
self, indexers: Mapping[Hashable, Any] = None, **indexers_kwargs: Any
self,
indexers: Mapping[Hashable, Any] = None,
distances_name: Optional[str] = None,
**indexers_kwargs: Any,
) -> Union[xr.Dataset, xr.DataArray]:
"""Selection based on a ball tree index.

Expand All @@ -243,14 +248,26 @@ def sel(
This triggers :func:`dask.compute` if the given indexers and/or the index
coordinates are chunked.

Parameters
----------
distances_name: str, optional
If a string is input, it is used to save the distances into the xarray
object. Distances are in km.

Returns
-------
xr.Dataset, xr.DataArray
Normally, the type input is the type output. However, if you input a
str for `distances_name`, the return type will be Dataset to
accommodate the additional variable.
"""
if not getattr(self, '_index', False):
raise ValueError(
'The index(es) has/have not been built yet. Call `.xoak.set_index()` first'
)

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'xoak.sel')
indices = self._query(indexers)
indices, distances = self._query(indexers)

if not isinstance(indices, np.ndarray):
# TODO: remove (see todo below)
Expand All @@ -263,4 +280,23 @@ def sel(
# This would also allow lazy selection
result = self._xarray_obj.isel(indexers=pos_indexers)

# save distances as a new variable in xarray object if name is input
if distances_name is not None or (
isinstance(distances_name, str) and len(distances_name) == 0
):
# need to have a Dataset instead of DataArray to add a new variable
# otherwise goes in as a coordinate
if not isinstance(result, xr.Dataset):
result = result.to_dataset()
# use same dimensions as indexers
attrs = {
'long_name': 'Distance from location to nearest comparison point.',
}

indexer_dim = list(indexers.values())[0].dims
indexer_shape = indexers[list(indexers.keys())[0]].shape
result[distances_name] = xr.Variable(
indexer_dim, distances.reshape(indexer_shape), attrs
)

return result
1 change: 1 addition & 0 deletions src/xoak/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dask
import dask.array
import numpy as np
import pytest
import xarray as xr
Expand Down
22 changes: 22 additions & 0 deletions src/xoak/tests/test_accessor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import xarray as xr
from scipy.spatial import cKDTree
Expand Down Expand Up @@ -80,3 +81,24 @@ def test_index_property():
ds_chunk = ds.chunk(2)
ds_chunk.xoak.set_index(['x', 'y'], 'scipy_kdtree')
assert isinstance(ds_chunk.xoak.index, list)


def test_distances():

ds = xr.Dataset(
coords={
'x': ('a', [0, 1, 2, 3]),
'y': ('a', [0, 1, 2, 3]),
}
)

ds_to_find = xr.Dataset({'lat_to_find': ('a', [0, 0]), 'lon_to_find': ('a', [0, 0.5])})
ds.xoak.set_index(['y', 'x'], 'sklearn_geo_balltree')

output = ds.xoak.sel(
{'y': ds_to_find.lat_to_find, 'x': ds_to_find.lon_to_find}, distances_name='distances'
)

assert isinstance(output, xr.Dataset)
# this distance is in radians
np.testing.assert_allclose(output['distances'], [0, 0.008726646259971648])