diff --git a/esutil/numpy_util.py b/esutil/numpy_util.py index e5e3cbf..56c8ec1 100644 --- a/esutil/numpy_util.py +++ b/esutil/numpy_util.py @@ -1507,7 +1507,7 @@ def rem_dup(arr, flag, values=False): return s -def match(arr1input, arr2input, presorted=False): +def match(arr1input, arr2input, presorted=False, isunique=False): """ Match two arrays, returning the indicies of matches for each array, or empty arrays if no matches are found. This means arr1[ind1] == arr2[ind2] @@ -1529,6 +1529,8 @@ def match(arr1input, arr2input, presorted=False): The second array. presorted: bool, optional If set to True, the first array is assumed to be sorted. + isunique: bool, optional + Input array is guaranteed to be unique (skip test). Returns ------- @@ -1544,21 +1546,15 @@ def match(arr1input, arr2input, presorted=False): arr1 = np.atleast_1d(arr1input) arr2 = np.atleast_1d(arr2input) - el = arr1[0] - - if isinstance(el, str) or isinstance(el, bytes): - is_string = True - else: - is_string = False - if (arr1.size == 0) or (arr2.size == 0): mess = "Error: arr1 and arr2 must each be non-zero length" raise ValueError(mess) - # make sure that arr1 has unique values... - test = np.unique(arr1) - if test.size != arr1.size: - raise ValueError("Error: the arr1input must be unique") + if not isunique: + # make sure that arr1 has unique values... + test = np.unique(arr1) + if test.size != arr1.size: + raise ValueError("Error: the arr1input must be unique") # sort arr1 if not presorted if not presorted: @@ -1566,13 +1562,9 @@ def match(arr1input, arr2input, presorted=False): else: st1 = None - # search the sorted array - sub1 = np.searchsorted(arr1, arr2, sorter=st1) - - # check for out-of-bounds at the high end if necessary - if is_string or arr2.max() > arr1.max(): - (bad,) = np.where(sub1 == arr1.size) - sub1[bad] = arr1.size - 1 + # search the sorted array; + # clip to bounds. + sub1 = np.clip(np.searchsorted(arr1, arr2, sorter=st1), 0, arr1.size - 1) if not presorted: (sub2,) = np.where(arr1[st1[sub1]] == arr2)