diff --git a/score.yaml b/score.yaml index 14f7e44..c0bb004 100644 --- a/score.yaml +++ b/score.yaml @@ -38,6 +38,8 @@ run: default: {conda: pims} tqdm: default: {conda: tqdm} + scikit-learn: + default: {conda: scikit-learn} test: pytest: default: {conda: pytest} @@ -57,4 +59,4 @@ docs: sphinx: default: {conda: sphinx} sphinx_rtd_theme: - default: {conda: sphinx_rtd_theme} \ No newline at end of file + default: {conda: sphinx_rtd_theme} diff --git a/xpdtools/tools.py b/xpdtools/tools.py index 7f8a6f5..ab7f8cc 100644 --- a/xpdtools/tools.py +++ b/xpdtools/tools.py @@ -24,6 +24,9 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from functools import wraps +from sklearn.cluster import DBSCAN +from scipy.stats import spearmanr + try: from diffpy.pdfgetx import PDFGetter except ImportError: @@ -581,3 +584,89 @@ def inner(x, *args, **kwargs): return func(*args, **kwargs) return inner + + + +def find_sample_from_2dscan(I_arr, xy_arr, Q_arr=None, + eps=0.05, min_samples=20, n_jobs=1, + b_ratio_thres=0.5, qrange=(1,5), use_unclassified=True): + """Find sample positions from xy-scan + + Parameters + ---------- + xy_arr : x,y of scan points + I_arr : Intensities for each scan point + Q_arr : Q-points (optional) + + See DBSCAN documentation: + http://scikit-learn.org/stable/modules/generated/sklearn.cluster.DBSCAN.html + + eps: The maximum distance between two samples for them to + be considered as in the same neighborhood. + + min_samples: The number of samples (or total weight) + in a neighborhood for a point to be considered as a core point. + This includes the point itself. + + n_jobs: The number of parallel jobs to run. None means 1 + unless in a joblib.parallel_backend context. + -1 means using all processors. + + b_ratio_thres: Clusters more than this threshold will be + condidered as background (not belonging to sample). + + use_unclassified: Sometimes DBSCAN is unable to classify + points around sample boundary. It that case, it gives -1. + If this keyword is True, that point is considered in the + sample positions (pts). + + + Returns + ------- + center : ndarray + xy coordinates of the center of the sample + pts : ndarray + xy coordinates of points considered within the sample + + """ + + if isinstance(Q_arr,np.ndarray): + # Trim to selected Q range. Because we do not want mess around + # beam stopper and high q. This also speedups DBSCAN calculation + sel = (Q_arr > qrange[0]) & (Q_arr < qrange[1]) + I_arr = np.array([i[sel] for i in I_arr]) + else: + print('Q array is not provided. Using all points') + + + # Use DBSCAN package to cluster I_arr + dbs = DBSCAN(eps, min_samples=min_samples, + metric=lambda i, j: 1 - spearmanr(i, j)[0], n_jobs=n_jobs) + preds = dbs.fit_predict(np.array(I_arr)) + uniques, counts = np.unique(preds, return_counts=True) + ratios = counts / sum(counts) + + # Collect x,y data for determining points which should correspond to the sample. + pts = [] + + for j,u in enumerate(uniques): + + mask = (preds == u) + masked = [] + for i,tf in enumerate(mask): + if tf: + masked.append(xy_arr[i]) + + if u == -1: + if use_unclassified: + pts.extend(masked) + else: + if (ratios[j] <= b_ratio_thres): + pts.extend(masked) + pts = np.array(pts) + + center = np.mean(pts, axis=0) + + # TODO: Get rid of s_ratio + + return center, pts