Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions esutil/numpy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ def rem_dup(arr, flag, values=False):
return s


def match(arr1input, arr2input, presorted=False):
def match(arr1input, arr2input, presorted=False, isunique=False):
"""
Match two arrays, returning the indicies of matches for each array, or
empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2]
Expand All @@ -1529,6 +1529,8 @@ def match(arr1input, arr2input, presorted=False):
The second array.
presorted: bool, optional
If set to True, the first array is assumed to be sorted.
isunique: bool, optional
Input array is guaranteed to be unique (skip test).

Returns
-------
Expand All @@ -1544,35 +1546,25 @@ def match(arr1input, arr2input, presorted=False):
arr1 = np.atleast_1d(arr1input)
arr2 = np.atleast_1d(arr2input)

el = arr1[0]

if isinstance(el, str) or isinstance(el, bytes):
is_string = True
else:
is_string = False

if (arr1.size == 0) or (arr2.size == 0):
mess = "Error: arr1 and arr2 must each be non-zero length"
raise ValueError(mess)

# make sure that arr1 has unique values...
test = np.unique(arr1)
if test.size != arr1.size:
raise ValueError("Error: the arr1input must be unique")
if not isunique:
# make sure that arr1 has unique values...
test = np.unique(arr1)
if test.size != arr1.size:
raise ValueError("Error: the arr1input must be unique")

# sort arr1 if not presorted
if not presorted:
st1 = np.argsort(arr1)
else:
st1 = None

# search the sorted array
sub1 = np.searchsorted(arr1, arr2, sorter=st1)

# check for out-of-bounds at the high end if necessary
if is_string or arr2.max() > arr1.max():
(bad,) = np.where(sub1 == arr1.size)
sub1[bad] = arr1.size - 1
# search the sorted array;
# clip to bounds.
sub1 = np.clip(np.searchsorted(arr1, arr2, sorter=st1), 0, arr1.size - 1)

if not presorted:
(sub2,) = np.where(arr1[st1[sub1]] == arr2)
Expand Down