diff --git a/sidpy/sid/__init__.py b/sidpy/sid/__init__.py index 970f8e6e..abc487dc 100644 --- a/sidpy/sid/__init__.py +++ b/sidpy/sid/__init__.py @@ -4,8 +4,9 @@ from .dimension import Dimension, DimensionType from .translator import Translator -from .dataset import Dataset, DataType, convert_hyperspy +from .dataset import Dataset, convert_hyperspy from .reader import Reader +from .datatype import DataType __all__ = ['Dimension', 'DimensionType', 'Dataset', 'DataType', 'Reader', 'Translator', 'convert_hyperspy'] diff --git a/sidpy/sid/dataset.py b/sidpy/sid/dataset.py index 301b0685..1cb07dde 100644 --- a/sidpy/sid/dataset.py +++ b/sidpy/sid/dataset.py @@ -38,7 +38,7 @@ from ..base.dict_utils import print_nested_dict from ..viz.dataset_viz import CurveVisualizer, ImageVisualizer, ImageStackVisualizer from ..viz.dataset_viz import SpectralImageVisualizer, FourDimImageVisualizer, ComplexSpectralImageVisualizer -from ..viz.dataset_viz import PointCloudVisualizer, DictionaryVisualizer +from ..viz.dataset_viz import PointCloudVisualizer, DictionaryVisualizer, DP_PointCloudVisualizer # from ..hdf.hdf_utils import is_editable_h5 from .dimension import DimensionType from copy import deepcopy, copy @@ -46,23 +46,15 @@ import logging from ..__version__ import version +from .datatype import DataType + def is_simple_list(lst): if isinstance(lst, list): return any(hasattr(item, '__getitem__') for item in lst) return False -class DataType(Enum): - UNKNOWN = -1 - SPECTRUM = 1 - LINE_PLOT = 2 - LINE_PLOT_FAMILY = 3 - IMAGE = 4 - IMAGE_MAP = 5 - IMAGE_STACK = 6 # 3d - SPECTRAL_IMAGE = 7 - IMAGE_4D = 8 - POINT_CLOUD = 9 + def view_subclass(dask_array, cls): @@ -281,7 +273,7 @@ def from_array(cls, x, title='generic', chunks='auto', lock=False, sid_dataset._axes = {} - if sid_dataset.data_type == DataType.POINT_CLOUD and coordinates is None: + if sid_dataset.data_type in [DataType.POINT_CLOUD, DataType.DP_POINT_CLOUD] and coordinates is None: raise ValueError("coordinates must be specified for a point cloud dataset") for dim in range(sid_dataset.ndim): @@ -293,6 +285,13 @@ def from_array(cls, x, title='generic', chunks='auto', lock=False, 2: 'channel' } dimension_type = dimension_map.get(dim, None) + elif datatype == "dp_point_cloud": + dimension_map = { + 0: 'point_cloud', + 1: 'reciprocal', + 2: 'reciprocal' + } + dimension_type = dimension_map.get(dim, None) else: dimension_type = 'unknown' sid_dataset.set_dimension(dim, Dimension(np.arange(sid_dataset.shape[dim]), @@ -309,6 +308,8 @@ def from_array(cls, x, title='generic', chunks='auto', lock=False, sid_dataset.point_cloud = {'coordinates': coordinates} else: sid_dataset.point_cloud = None + + return sid_dataset def like_data(self, data, title=None, chunks='auto', lock=False, @@ -667,7 +668,7 @@ def plot(self, verbose=False, figure=None, dict_data=None, **kwargs): self.view.fig: matplotlib figure reference """ - + if verbose: print('Shape of dataset is: ', self.shape) @@ -692,8 +693,9 @@ def plot(self, verbose=False, figure=None, dict_data=None, **kwargs): elif self.data_type.value <= DataType['LINE_PLOT'].value: # self.data_type in ['spectrum_family', 'line_family', 'line_plot_family', 'spectra']: self.view = CurveVisualizer(self, figure=figure, **kwargs) - elif self.data_type == DataType.POINT_CLOUD: + elif self.data_type in [DataType.POINT_CLOUD, DataType.DP_POINT_CLOUD]: self.view = PointCloudVisualizer(self, figure=figure, **kwargs) + elif self.data_type == DataType.SPECTRAL_IMAGE: print('sp') self.view = SpectralImageVisualizer(self, figure=figure, **kwargs) @@ -721,6 +723,10 @@ def plot(self, verbose=False, figure=None, dict_data=None, **kwargs): self.view = SpectralImageVisualizer(self, figure=figure, **kwargs) elif self.data_type == DataType.POINT_CLOUD: self.view = PointCloudVisualizer(self, figure=figure, **kwargs) + elif self.data_type == DataType.DP_POINT_CLOUD: + + self.view = DP_PointCloudVisualizer(self, figure=figure, **kwargs) + else: raise NotImplementedError('Datasets with data_type {} cannot be plotted, yet.'.format(self.data_type)) elif len(self.shape) == 4: diff --git a/sidpy/sid/datatype.py b/sidpy/sid/datatype.py new file mode 100644 index 00000000..d5931f09 --- /dev/null +++ b/sidpy/sid/datatype.py @@ -0,0 +1,13 @@ +from enum import Enum +class DataType(Enum): + UNKNOWN = -1 + SPECTRUM = 1 + LINE_PLOT = 2 + LINE_PLOT_FAMILY = 3 + IMAGE = 4 + IMAGE_MAP = 5 + IMAGE_STACK = 6 # 3d + SPECTRAL_IMAGE = 7 + IMAGE_4D = 8 + POINT_CLOUD = 9 + DP_POINT_CLOUD = 10 \ No newline at end of file diff --git a/sidpy/viz/dataset_viz.py b/sidpy/viz/dataset_viz.py index 28d8bb79..b8a5bc4a 100644 --- a/sidpy/viz/dataset_viz.py +++ b/sidpy/viz/dataset_viz.py @@ -27,6 +27,9 @@ import dill import base64 import dask.array as da +from ..sid.datatype import DataType + +# from ..sid.dataset import DataType # import matplotlib.animation as animation @@ -1088,7 +1091,7 @@ def update_image(self, event_value): class PointCloudVisualizerBase(object): def __init__(self, dset, base_image=None, figure=None, horizontal=True, **kwargs): - if not isinstance(dset, sidpy.Dataset): + if not isinstance(dset, sidpy.sid.Dataset): raise TypeError('dset should be a sidpy.Dataset object') self.dset = dset @@ -1097,10 +1100,19 @@ def __init__(self, dset, base_image=None, figure=None, horizontal=True, **kwargs raise ValueError('Variance array must have the same dimensionality as the dataset') #kwargs parsing + self.data_type = None + if dset.data_type == DataType.POINT_CLOUD: + self.data_type = 'spectrum' + elif dset.data_type == DataType.DP_POINT_CLOUD: + self.data_type = 'dp' + if self.data_type not in ['spectrum', 'dp']: + raise ValueError("data_type should be 'spectrum' or 'dp'") + + temp = kwargs.pop('figsize', None) + self.set_title = kwargs.pop('set_title', True) scale_bar = kwargs.pop('scale_bar', False) amp_phase = kwargs.pop('amp_phase', False) - self.set_title = kwargs.pop('set_title', True) - temp = kwargs.pop('figsize', None) + self.verify_dataset() @@ -1118,10 +1130,13 @@ def __init__(self, dset, base_image=None, figure=None, horizontal=True, **kwargs if base_image is not None: self.image, self.px_coord = self._base_image(base_image) else: - if len(self.channel_dim) > 0: - self.cloud = dset.mean(axis=(self.spectral_dim[0], self.channel_dim[0])) - else: - self.cloud = dset.mean(axis=(self.spectral_dim[0],)) + if self.data_type == 'spectrum': + if len(self.channel_dim) > 0: + self.cloud = dset.mean(axis=(self.spectral_dim[0], self.channel_dim[0])) + else: + self.cloud = dset.mean(axis=(self.spectral_dim[0],)) + elif self.data_type == 'dp': + self.cloud = dset.mean(axis=(self.reciprocal_dim[0], self.reciprocal_dim[1])) self.image, self.px_coord = self._mask_image() if self.dset.dtype == 'complex': @@ -1168,27 +1183,33 @@ def __init__(self, dset, base_image=None, figure=None, horizontal=True, **kwargs self._scale_bar() self.spectrum, self.variance = self.get_spectrum(_point_number) - self.energy_axis = self.spectral_dim[0] - if len(self.channel_dim)>0: self.channel_axis = self.channel_dim - self.energy_scale = self.dset._axes[self.energy_axis].values - #spectrum self.spectrum_plot = [] - if self.ri_ap is not None: - if self.ri_ap == 'Real and Imaginary': - _complex = ['real', 'imag'] - elif self.ri_ap == 'Amplitude and Phase': - _complex = ['amp', 'phase'] - self.spectrum_plot, self.fill_between = self.set_spectrum(_point_number, complex=_complex[0], axis=self.axes[1], **kwargs) - #add colorbar for multi chanell figures - if len(self.spectrum_plot) > 1: - self.add_colorbar(self.spectrum_plot) - _spectrum_plots, _fill_between = self.set_spectrum(_point_number, complex=_complex[1], axis=self.axes[2], **kwargs) - for sp in _spectrum_plots: self.spectrum_plot.append(sp) - for fb in _fill_between: self.fill_between.append(fb) - else: + if self.data_type == 'spectrum': + self.energy_axis = self.spectral_dim[0] + if len(self.channel_dim)>0: self.channel_axis = self.channel_dim + self.energy_scale = self.dset._axes[self.energy_axis].values + #spectrum + + if self.ri_ap is not None: + if self.ri_ap == 'Real and Imaginary': + _complex = ['real', 'imag'] + elif self.ri_ap == 'Amplitude and Phase': + _complex = ['amp', 'phase'] + self.spectrum_plot, self.fill_between = self.set_spectrum(_point_number, complex=_complex[0], axis=self.axes[1], **kwargs) + #add colorbar for multi chanell figures + if len(self.spectrum_plot) > 1: + self.add_colorbar(self.spectrum_plot) + _spectrum_plots, _fill_between = self.set_spectrum(_point_number, complex=_complex[1], axis=self.axes[2], **kwargs) + for sp in _spectrum_plots: self.spectrum_plot.append(sp) + for fb in _fill_between: self.fill_between.append(fb) + else: + self.spectrum_plot, self.fill_between = self.set_spectrum(_point_number, **kwargs) + if len(self.spectrum_plot) > 1: + self.add_colorbar(self.spectrum_plot) + + if self.data_type == 'dp': self.spectrum_plot, self.fill_between = self.set_spectrum(_point_number, **kwargs) - if len(self.spectrum_plot) > 1: - self.add_colorbar(self.spectrum_plot) + if not isinstance(self.fig, SubFigure): # for sparse array vis #TODO self.fig.tight_layout() @@ -1220,6 +1241,7 @@ def verify_dataset(self): point_dims = [] spectral_dim = [] channel_dim = [] + reciprocal_dim = [] for dim, axis in dset._axes.items(): if axis.dimension_type == sidpy.DimensionType.POINT_CLOUD: @@ -1230,27 +1252,38 @@ def verify_dataset(self): spectral_dim.append(dim) elif axis.dimension_type == sidpy.DimensionType.CHANNEL: channel_dim.append(dim) + elif axis.dimension_type == sidpy.DimensionType.RECIPROCAL: + selection.append(slice(None)) # what does this do? + reciprocal_dim.append(dim) + else: selection.append(slice(0, 1)) # checking dimension types - if len(channel_dim) > 1: - raise ValueError("We have more than one Channel Dimension, this won't work for the visualizer") - if len(spectral_dim) > 1: - raise ValueError("We have more than one Spectral Dimension, this won't work for the visualizer...") - if len(dset.shape) == 3: - if len(channel_dim) != 1: - raise TypeError("We need one dimension with type CHANNEL \ - for a spectral image plot for a 4D dataset") - elif len(dset.shape) == 2: - if len(spectral_dim) != 1: - raise TypeError("We need one dimension with dimension_type SPECTRAL \ - to plot a spectra for a 3D dataset") + if self.data_type == 'spectrum': + if len(channel_dim) > 1: + raise ValueError("We have more than one Channel Dimension, this won't work for the visualizer") + if len(spectral_dim) > 1: + raise ValueError("We have more than one Spectral Dimension, this won't work for the visualizer...") + if len(dset.shape) == 3: + if len(channel_dim) != 1: + raise TypeError("We need one dimension with type CHANNEL \ + for a spectral image plot for a 4D dataset") + elif len(dset.shape) == 2: + if len(spectral_dim) != 1: + raise TypeError("We need one dimension with dimension_type SPECTRAL \ + to plot a spectra for a 3D dataset") + + if self.data_type == 'dp': + if len(reciprocal_dim) != 2: + raise ValueError("We need two dimensions with dimension_type RECIPROCAL to plot a diffraction pattern") + self.selection = selection self.point_dims = point_dims self.spectral_dim = spectral_dim self.channel_dim = channel_dim + self.reciprocal_dim = reciprocal_dim return True def set_image(self, quantity, **kwargs): @@ -1295,50 +1328,59 @@ def set_spectrum(self, point_number, complex = None, axis = None, **kwargs): axis = self.axes[1] spectrum_plot = [] # list is required for the case of several channels - if len(self.spectrum.shape) > 1: - for i in range(len(self.spectrum)): - _spectrum_plot, = axis.plot(self.energy_scale, _spectrums[i]) - spectrum_plot.append(_spectrum_plot) - else: - _spectrum_plot, = axis.plot(self.energy_scale, _spectrums) - spectrum_plot.append(_spectrum_plot) + print('visualizer data_type', self.data_type) + + if self.data_type == 'spectrum': + if len(self.spectrum.shape) > 1: + for i in range(len(self.spectrum)): + _spectrum_plot, = axis.plot(self.energy_scale, _spectrums[i]) + spectrum_plot.append(_spectrum_plot) + else: + _spectrum_plot, = axis.plot(self.energy_scale, _spectrums) + elif self.data_type == 'dp': + _spectrum_plot = axis.imshow(_spectrums) + spectrum_plot.append(_spectrum_plot) fill_between = [] - if self.variance is not None: - if complex == 'real': - _variance = self.variance.real - elif complex == 'imag': - _variance = self.variance.imag - elif complex == 'amp': - _variance = da.abs(self.variance) - elif complex == 'phase': - _variance = da.angle(self.variance) - else: - _variance = self.variance + + if self.data_type == 'spectrum': + if self.variance is not None: + if complex == 'real': + _variance = self.variance.real + elif complex == 'imag': + _variance = self.variance.imag + elif complex == 'amp': + _variance = da.abs(self.variance) + elif complex == 'phase': + _variance = da.angle(self.variance) + else: + _variance = self.variance - # 3d - many curves - if len(self.variance.shape) > 1: - for i in range(len(self.variance)): + # 3d - many curves + if len(self.variance.shape) > 1: + for i in range(len(self.variance)): + _fill_between = axis.fill_between(self.energy_scale, + _spectrums[i] - _variance[i], + _spectrums[i] + _variance[i], + alpha=0.3, **kwargs) + fill_between.append(_fill_between) + # 2d - one curve at each point + else: _fill_between = axis.fill_between(self.energy_scale, - _spectrums[i] - _variance[i], - _spectrums[i] + _variance[i], - alpha=0.3, **kwargs) + _spectrums - _variance, + _spectrums + _variance, + alpha=0.3, **kwargs) fill_between.append(_fill_between) - # 2d - one curve at each point - else: - _fill_between = axis.fill_between(self.energy_scale, - _spectrums - _variance, - _spectrums + _variance, - alpha=0.3, **kwargs) - fill_between.append(_fill_between) if complex is None: axis.set_title('Point {}'.format(point_number)) else: axis.set_title('Point {}, {}'.format(point_number, complex)) - axis.set_xlabel(self.dset.labels[self.energy_axis]) - axis.set_ylabel(self.dset.data_descriptor) - axis.ticklabel_format(style='sci', scilimits=(-2, 3)) + + if self.data_type == 'spectrum': + axis.set_xlabel(self.dset.labels[self.energy_axis]) + axis.set_ylabel(self.dset.data_descriptor) + axis.ticklabel_format(style='sci', scilimits=(-2, 3)) return spectrum_plot, fill_between def get_spectrum(self, point_number): @@ -1360,6 +1402,8 @@ def get_spectrum(self, point_number): selection.append(slice(None)) elif axis.dimension_type == sidpy.DimensionType.CHANNEL: selection.append(slice(None)) + elif axis.dimension_type == sidpy.DimensionType.RECIPROCAL: + selection.append(slice(None)) else: selection.append(slice(0, 1)) @@ -1464,32 +1508,34 @@ def _onclick(self, event): #self.axes[1].set_title('Point {}'.format(_point_number)) - if self.variance is not None: - # 3d - many curves - if len(self.variance.shape) > 1: - for i in range(len(self.variance)): - _c = self.fill_between[i].get_facecolor()[0] - self.fill_between[i].remove() - self.fill_between[i] = self.axes[1].fill_between(self.energy_scale, - self.spectrum[i] - self.variance[i], - self.spectrum[i] + self.variance[i], - color= _c) - else: - _c = self.fill_between[0].get_facecolor()[0] - self.fill_between[0].remove() - self.fill_between[0] = self.axes[1].fill_between(self.energy_scale, - self.spectrum - self.variance, - self.spectrum + self.variance, - color=_c) + if self.data_type == 'spectrum': + if self.variance is not None: + # 3d - many curves + if len(self.variance.shape) > 1: + for i in range(len(self.variance)): + _c = self.fill_between[i].get_facecolor()[0] + self.fill_between[i].remove() + self.fill_between[i] = self.axes[1].fill_between(self.energy_scale, + self.spectrum[i] - self.variance[i], + self.spectrum[i] + self.variance[i], + color= _c) + else: + _c = self.fill_between[0].get_facecolor()[0] + self.fill_between[0].remove() + self.fill_between[0] = self.axes[1].fill_between(self.energy_scale, + self.spectrum - self.variance, + self.spectrum + self.variance, + color=_c) - self.axes[1].set_title('point {}'.format(_point_number)) + - _sp_min, _sp_max = np.min(self.spectrum.compute()), np.max(self.spectrum.compute()) - if self.variance is not None: - _sp_min, _sp_max = _sp_min - np.max(self.variance.compute()), _sp_max + np.max(self.variance.compute()) - _sp_d = _sp_max - _sp_min - self.axes[1].set_ylim(_sp_min-0.2*_sp_d, _sp_max+0.2*_sp_d) + _sp_min, _sp_max = np.min(self.spectrum.compute()), np.max(self.spectrum.compute()) + if self.variance is not None: + _sp_min, _sp_max = _sp_min - np.max(self.variance.compute()), _sp_max + np.max(self.variance.compute()) + _sp_d = _sp_max - _sp_min + self.axes[1].set_ylim(_sp_min-0.2*_sp_d, _sp_max+0.2*_sp_d) + self.axes[1].set_title('point {}'.format(_point_number)) self.sel_point.set_offsets(np.column_stack((self.px_coord[_point_number, 0], self.px_coord[_point_number, 1]))) @@ -1512,79 +1558,84 @@ def _onclick(self, event): def _update_spectrum(self): _point_number = self.tree.query(np.array([self.x, self.y]))[1] self.spectrum, self.variance = self.get_spectrum(_point_number) - if self.ri_ap == 'Real and Imaginary': - _s = self.real_imag(self.spectrum) - _v = self.real_imag(self.variance) - self.axes[1].set_title('Point {}, real'.format(_point_number)) - self.axes[2].set_title('Point {}, imag'.format(_point_number)) - elif self.ri_ap == 'Amplitude and Phase': - _s = self.amp_ph(self.spectrum) - _v = self.amp_ph(self.variance) - self.axes[1].set_title('Point {}, amp'.format(_point_number)) - self.axes[2].set_title('Point {}, phase'.format(_point_number)) - else: - _s = self.spectrum - _v = self.variance - self.axes[1].set_title('Point {}'.format(_point_number)) - - self.real = self.real_imag(self.spectrum) - if len(self.spectrum.shape) > 1: - if self.ri_ap is None: - for i in range(len(self.spectrum)): - self.spectrum_plot[i].set_data(self.energy_scale, _s.compute()[i]) - else: - for i in range(len(self.spectrum)): - k = i + len(self.spectrum) - self.spectrum_plot[i].set_data(self.energy_scale, _s[0].compute()[i]) - self.spectrum_plot[k].set_data(self.energy_scale, _s[1].compute()[i]) - else: - if self.ri_ap is None: - self.spectrum_plot[0].set_data(self.energy_scale, _s.compute()) + + if self.data_type == 'spectrum': + if self.ri_ap == 'Real and Imaginary': + _s = self.real_imag(self.spectrum) + _v = self.real_imag(self.variance) + self.axes[1].set_title('Point {}, real'.format(_point_number)) + self.axes[2].set_title('Point {}, imag'.format(_point_number)) + elif self.ri_ap == 'Amplitude and Phase': + _s = self.amp_ph(self.spectrum) + _v = self.amp_ph(self.variance) + self.axes[1].set_title('Point {}, amp'.format(_point_number)) + self.axes[2].set_title('Point {}, phase'.format(_point_number)) else: - self.spectrum_plot[0].set_data(self.energy_scale, _s[0].compute()) - self.spectrum_plot[1].set_data(self.energy_scale, _s[1].compute()) + _s = self.spectrum + _v = self.variance + self.axes[1].set_title('Point {}'.format(_point_number)) - if self.variance is not None: - # 3d - many curves - if len(self.variance.shape) > 1: + self.real = self.real_imag(self.spectrum) + if len(self.spectrum.shape) > 1: if self.ri_ap is None: - for i in range(len(self.variance)): - _c = self.fill_between[i].get_facecolor()[0] - self.fill_between[i].remove() - self.fill_between[i] = self.axes[1].fill_between(self.energy_scale, - _s[i] - _v[i], - _s[i] + _v[i], - color=_c) + for i in range(len(self.spectrum)): + self.spectrum_plot[i].set_data(self.energy_scale, _s.compute()[i]) else: - for i in range(len(self.variance)): - k = i + len(self.variance) - _c = self.fill_between[i].get_facecolor()[0] - self.fill_between[i].remove() - self.fill_between[k].remove() - self.fill_between[i] = self.axes[1].fill_between(self.energy_scale, - _s[0].compute()[i] - _v[0].compute()[i], - _s[0].compute()[i] + _v[0].compute()[i], - color=_c) - self.fill_between[k] = self.axes[2].fill_between(self.energy_scale, - _s[1].compute()[i] - _v[1].compute()[i], - _s[1].compute()[i] + _v[1].compute()[i], - color=_c) + for i in range(len(self.spectrum)): + k = i + len(self.spectrum) + self.spectrum_plot[i].set_data(self.energy_scale, _s[0].compute()[i]) + self.spectrum_plot[k].set_data(self.energy_scale, _s[1].compute()[i]) else: if self.ri_ap is None: - _c = self.fill_between[0].get_facecolor()[0] - self.fill_between[0].remove() - self.fill_between[0] = self.axes[1].fill_between(self.energy_scale, - _s - _v, - _s + _v, - color=_c) + self.spectrum_plot[0].set_data(self.energy_scale, _s.compute()) + else: + self.spectrum_plot[0].set_data(self.energy_scale, _s[0].compute()) + self.spectrum_plot[1].set_data(self.energy_scale, _s[1].compute()) + + if self.variance is not None: + # 3d - many curves + if len(self.variance.shape) > 1: + if self.ri_ap is None: + for i in range(len(self.variance)): + _c = self.fill_between[i].get_facecolor()[0] + self.fill_between[i].remove() + self.fill_between[i] = self.axes[1].fill_between(self.energy_scale, + _s[i] - _v[i], + _s[i] + _v[i], + color=_c) + else: + for i in range(len(self.variance)): + k = i + len(self.variance) + _c = self.fill_between[i].get_facecolor()[0] + self.fill_between[i].remove() + self.fill_between[k].remove() + self.fill_between[i] = self.axes[1].fill_between(self.energy_scale, + _s[0].compute()[i] - _v[0].compute()[i], + _s[0].compute()[i] + _v[0].compute()[i], + color=_c) + self.fill_between[k] = self.axes[2].fill_between(self.energy_scale, + _s[1].compute()[i] - _v[1].compute()[i], + _s[1].compute()[i] + _v[1].compute()[i], + color=_c) else: - for i in range(len(self.fill_between)): - _c = self.fill_between[i].get_facecolor()[0] - self.fill_between[i].remove() - self.fill_between[i] = self.axes[i + 1].fill_between(self.energy_scale, - _s[i] - _v[i], - _s[i] + _v[i], - color=_c) + if self.ri_ap is None: + _c = self.fill_between[0].get_facecolor()[0] + self.fill_between[0].remove() + self.fill_between[0] = self.axes[1].fill_between(self.energy_scale, + _s - _v, + _s + _v, + color=_c) + else: + for i in range(len(self.fill_between)): + _c = self.fill_between[i].get_facecolor()[0] + self.fill_between[i].remove() + self.fill_between[i] = self.axes[i + 1].fill_between(self.energy_scale, + _s[i] - _v[i], + _s[i] + _v[i], + color=_c) + + elif self.data_type == 'dp': + self.spectrum_plot[0].set_data(self.spectrum.compute()) @staticmethod def real_imag(array): @@ -1733,6 +1784,99 @@ def _update_image(self, event_value): self.axes[0].set_xlabel('{}'.format(_quantity[0])) self.axes[0].set_ylabel('{}'.format(_quantity[1])) + + +class DP_PointCloudVisualizer(PointCloudVisualizerBase): + """ + Interactive DP point cloud visualization + """ + def __init__(self, dset, base_image=None, figure=None, horizontal=True, **kwargs): + super().__init__(dset, base_image, figure, horizontal, **kwargs) + + scale_bar = kwargs.pop('scale_bar', False) + amp_phase = kwargs.pop('amp_phase', False) + if amp_phase: + amp_phase_val = 'Amplitude and Phase' + else: + amp_phase_val = 'Real and Imaginary' + + + #from here + buttons = [] + if self.ri_ap is not None: + #self.button0 = ipywidgets.Dropdown(options=['Real and Imaginary', 'Amplitude and Phase'], + # value=amp_phase_val, + # tooltip='How to plot complex data') + #self.button0.observe(self._ri_ap, 'value') # real/imag or amp/phase + #buttons.append(self.button0) + + pass + + if scale_bar == False: + self.button = ipywidgets.widgets.Dropdown( options=[('Pixel Wise', 1), ('Units Wise', 2)], + value=1, + descrption='Image', + tooltip='How to plot spatial data: Pixel Wise (by px), Units wise (in given units)') + self.button.observe(self._pw_uw, 'value') #pixel or unit wise + buttons.append(self.button) + #self.fig.canvas.draw_idle() + + if len(buttons) > 0: + widg = ipywidgets.HBox(buttons) + display(widg) + + def _pw_uw(self, event): + pw_uw = event.new + self._update_image(pw_uw) + + def _ri_ap(self, event): + self.ri_ap = event.new + self._update_spectrum() + + def _update_image(self, event_value): + # pixel wise or unit wise listener + if 'spacial_units' in self.dset.point_cloud: + _sp_units = self.dset.point_cloud['spacial_units'] + if isinstance(_sp_units, str): + _sp_units = (_sp_units, _sp_units) + elif not (isinstance(_sp_units, list) or isinstance(_sp_units, tuple)): + raise ValueError('Spacial units in Dataset.point_cloud should be str or list, or tuple.') + + if 'quantity' in self.dset.point_cloud: + _quantity = self.dset.point_cloud['quantity'] + if isinstance(_quantity, str): + _quantity = (_quantity, _quantity) + elif not (isinstance(_quantity, list) or isinstance(_quantity, tuple)): + raise ValueError('Quantity in Dataset.point_cloud should be str or list, or tuple.') + else: + _quantity = ('distance', 'distance') + + if event_value == 1: + self.axes[0].set_xticks(np.linspace(self.extent[0], self.extent[1], 5),) + self.axes[0].set_xticklabels(np.round(np.linspace(self.extent[0], self.extent[1], 5), 1)) + + self.axes[0].set_yticks(np.linspace(self.extent[2], self.extent[3], 5),) + self.axes[0].set_yticklabels(np.round(np.linspace(self.extent[2], self.extent[3], 5), 1)) + + self.axes[0].set_xlabel('{} [{}]'.format(_quantity[0], 'px')) + self.axes[0].set_ylabel('{} [{}]'.format(_quantity[1], 'px')) + else: + self.axes[0].set_xticks(np.linspace(self.extent[0], self.extent[1], 5),) + self.axes[0].set_xticklabels(np.round(np.linspace(self.real_extent[0], self.real_extent[1], 5), 2)) + + self.axes[0].set_yticks(np.linspace(self.extent[2], self.extent[3], 5),) + self.axes[0].set_yticklabels(np.round(np.linspace(self.real_extent[2], self.real_extent[3], 5), 2)) + + if 'spacial_units' in self.dset.point_cloud: + self.axes[0].set_xlabel('{} [{}]'.format(_quantity[0], _sp_units[0])) + self.axes[0].set_ylabel('{} [{}]'.format(_quantity[1], _sp_units[1])) + else: + self.axes[0].set_xlabel('{}'.format(_quantity[0])) + self.axes[0].set_ylabel('{}'.format(_quantity[1])) + + + + class FourDimImageVisualizer(object): """