diff --git a/chaco/abstract_data_source.py b/chaco/abstract_data_source.py index 0cfb7c21d..07f8aa94e 100644 --- a/chaco/abstract_data_source.py +++ b/chaco/abstract_data_source.py @@ -57,7 +57,7 @@ class AbstractDataSource(HasTraits): # Abstract methods # ------------------------------------------------------------------------ - def get_data(self): + def get_data(self, lod=None): """get_data() -> data_array Returns a data array of the dimensions of the data source. This data @@ -66,6 +66,12 @@ def get_data(self): In the case of structured (gridded) 2-D data, this method may return two 1-D ArrayDataSources as an optimization. + + Parameters + ---------- + lod : int + Level of detail for data to retrieve. if None, then return the + orignal data without downsampling. """ raise NotImplementedError diff --git a/chaco/image_data.py b/chaco/image_data.py index dbf8842d0..edf2b5ad1 100644 --- a/chaco/image_data.py +++ b/chaco/image_data.py @@ -4,11 +4,12 @@ from numpy import fmax, fmin, swapaxes # Enthought library imports -from traits.api import Bool, Int, Property, ReadOnly, Tuple +from traits.api import Any, Bool, Int, Property, ReadOnly, Tuple, Unicode # Local relative imports -from .base import DimensionTrait, ImageTrait from .abstract_data_source import AbstractDataSource +from .base import DimensionTrait, ImageTrait + class ImageData(AbstractDataSource): @@ -62,6 +63,16 @@ class ImageData(AbstractDataSource): #: A read-only attribute that exposes the underlying array. raw_value = Property(ImageTrait) + #: Flag that data source support retrieving data with specified + #: level of details (LOD) + support_downsampling = Bool(False) + + #: An entry point to the LOD data which maps LOD to corresponding data + lod_data_entry = Any + + #: Key pattern for lod data stored in the **lod_data_entry** + lod_key_pattern = Unicode + # ------------------------------------------------------------------------ # Private traits # ------------------------------------------------------------------------ @@ -102,38 +113,61 @@ def fromfile(cls, filename): ) return imgdata - def get_width(self): - """Returns the shape of the x-axis.""" + def get_width(self, lod=None): + """ Returns the shape of the x-axis.""" + data = self.get_data(lod, transpose_inplace=False) if self.transposed: - return self._data.shape[0] + return data.shape[0] else: - return self._data.shape[1] + return data.shape[1] - def get_height(self): - """Returns the shape of the y-axis.""" + def get_height(self, lod=None): + """ Returns the shape of the y-axis.""" + data = self.get_data(lod, transpose_inplace=False) if self.transposed: - return self._data.shape[1] + return data.shape[1] else: - return self._data.shape[0] + return data.shape[0] - def get_array_bounds(self): - """Always returns ((0, width), (0, height)) for x-bounds and y-bounds.""" + def get_array_bounds(self, lod=None): + """ Always returns ((0, width), (0, height)) for x-bounds and y-bounds.""" + data = self.get_data(lod, transpose_inplace=False) if self.transposed: - b = ((0, self._data.shape[0]), (0, self._data.shape[1])) + b = ((0, data.shape[0]), (0, data.shape[1])) else: - b = ((0, self._data.shape[1]), (0, self._data.shape[0])) + b = ((0, data.shape[1]), (0, data.shape[0])) return b # ------------------------------------------------------------------------ # Datasource interface # ------------------------------------------------------------------------ - def get_data(self): - """Returns the data for this data source. + def get_data(self, lod=None, transpose_inplace=True): + """ Returns the data for this data source. Implements AbstractDataSource. + + Parameters + ---------- + lod : int + Level of detail for data to retrieve. If None, use the in-memory + `self._data` + transpose_inplace : bool + Whether to transpose the data before returning it when the raw data + stored is transposed. + + Returns + ------- + data : array-like + Requested image data """ - return self.data + if lod is None: + data = self._data + else: + data = self.get_lod_data(lod) + if self.transposed and transpose_inplace: + data = swapaxes(data, 0, 1) + return data def is_masked(self): """is_masked() -> False @@ -161,13 +195,15 @@ def get_bounds(self): self._bounds_cache_valid = True return self._cached_bounds - def get_size(self): + def get_size(self, lod=None): """get_size() -> int Implements AbstractDataSource. """ - if self._data is not None and self._data.shape[0] != 0: - return self._data.shape[0] * self._data.shape[1] + image = self.get_data(lod) + + if image is not None and image.shape[0] != 0: + return image.shape[0] * image.shape[1] else: return 0 @@ -181,6 +217,13 @@ def set_data(self, data): """ self._set_data(data) + def get_lod_data(self, lod): + if not self.lod_key_pattern: + key = str(lod) + else: + key = self.lod_key_pattern.format(lod) + return self.lod_data_entry[key] + # ------------------------------------------------------------------------ # Private methods # ------------------------------------------------------------------------ diff --git a/chaco/image_plot.py b/chaco/image_plot.py index 7131d1843..d294dc8ce 100644 --- a/chaco/image_plot.py +++ b/chaco/image_plot.py @@ -14,6 +14,7 @@ from math import ceil, floor, pi from contextlib import contextmanager +# Enthought library imports. import numpy as np # Enthought library imports. @@ -28,7 +29,9 @@ Tuple, Property, cached_property, + on_trait_change, ) +from traits_futures.api import CallFuture, TraitsExecutor from kiva.agg import GraphicsContextArray # Local relative imports @@ -71,6 +74,13 @@ class ImagePlot(Base2DPlot): #: Bool indicating whether y-axis is flipped. y_axis_is_flipped = Property(observe=["orientation", "origin"]) + #: Does the plot use downsampling? + use_downsampling = Bool(False) + + #: The Traits executor for the background jobs. + #: Required if **use_downsampling** is True. + traits_executor = Either(None, TraitsExecutor) + # ------------------------------------------------------------------------ # Private traits # ------------------------------------------------------------------------ @@ -89,6 +99,9 @@ class ImagePlot(Base2DPlot): # The name "principal diagonal" is borrowed from linear algebra. _origin_on_principal_diagonal = Property(observe="origin") + #: Submitted job. Only keeping track of the last submitted one. + _future = Instance(CallFuture) + # ------------------------------------------------------------------------ # Properties # ------------------------------------------------------------------------ @@ -121,6 +134,30 @@ def _value_data_changed_fired(self): self._image_cache_valid = False self.request_redraw() + @on_trait_change("index_mapper:updated, bounds[]") + def _update_lod_cache_image(self): + if not self.use_downsampling: + return + if self.traits_executor is None: + msg = "A traits_futures.TraitsExecutor is required to update" \ + "the plot at higher resolutions as a background job." + raise RuntimeError(msg) + lod = self._calculate_necessary_lod() + # Only keep the most recent job as bounds have been changed + # FIXME: call a public method of TraitExecutor to clean previous jobs + for future in self.traits_executor._futures: + if future.cancellable: + future.cancel() + self._future = self.traits_executor.submit_call( + self._compute_cached_image, lod=lod + ) + + @on_trait_change("_future:done", dispatch='ui') + def _handle_lod_cached_image(self): + self._cached_image, self._cached_dest_rect = self._future.result + self._image_cache_valid = True + self.request_redraw() + # ------------------------------------------------------------------------ # Base2DPlot interface # ------------------------------------------------------------------------ @@ -131,7 +168,9 @@ def _render(self, gc): Implements the Base2DPlot interface. """ if not self._image_cache_valid: - self._compute_cached_image() + self._cached_image, self._cached_dest_rect = \ + self._compute_cached_image() + self._image_cache_valid = True scale_x = -1 if self.x_axis_is_flipped else 1 scale_y = 1 if self.y_axis_is_flipped else -1 @@ -250,31 +289,46 @@ def _calc_virtual_screen_bbox(self): y_min += 0.5 return [x_min, y_min, virtual_x_size, virtual_y_size] - def _compute_cached_image(self, data=None, mapper=None): - """Computes the correct screen coordinates and renders an image into - `self._cached_image`. + def _compute_cached_image(self, mapper=None, lod=None): + """ Computes the correct screen coordinates and and renders an image + into `self._cached_image`. Parameters ---------- - data : array - Image data. If None, image is derived from the `value` attribute. mapper : function Allows subclasses to transform the displayed values for the visible region. This may be used to adapt grayscale images to RGB(A) images. + lod : int + Level of detail for cached image. If None, use the in-memory part + `self.value._data`. + + Returns + ------- + cache_image : `kiva.agg.GraphicsContextArray` + Computed cache image. + cache_dest_rect : 4-tuple + (x, y, width, height) rectangle describing the pixels bounds where + the image will be rendered in the plot """ - if data is None: - data = self.value.data + # Not to transpose the full matrix ahead in case it is too large + data = self.value.get_data(lod=lod, transpose_inplace=False) virtual_rect = self._calc_virtual_screen_bbox() - index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect) + index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect, + lod=lod) col_min, col_max, row_min, row_max = index_bounds view_rect = self.position + self.bounds sub_array_size = (col_max - col_min, row_max - row_min) screen_rect = trim_screen_rect(screen_rect, view_rect, sub_array_size) - data = data[row_min:row_max, col_min:col_max] + if self.value.transposed: + # Swap after slicing to avoid transposing the whole matrix + data = data[col_min:col_max, row_min:row_max] + data = data.swapaxes(0, 1) + else: + data = data[row_min:row_max, col_min:col_max] if mapper is not None: data = mapper(data) @@ -282,10 +336,9 @@ def _compute_cached_image(self, data=None, mapper=None): if len(data.shape) != 3: raise RuntimeError("`ImagePlot` requires color images.") - # Update cached image and rectangle. - self._cached_image = self._kiva_array_from_numpy_array(data) - self._cached_dest_rect = screen_rect - self._image_cache_valid = True + cached_image = self._kiva_array_from_numpy_array(data) + cached_dest_rect = screen_rect + return cached_image, cached_dest_rect def _kiva_array_from_numpy_array(self, data): if data.shape[2] not in KIVA_DEPTH_MAP: @@ -297,8 +350,8 @@ def _kiva_array_from_numpy_array(self, data): data = np.ascontiguousarray(data) return GraphicsContextArray(data, pix_format=kiva_depth) - def _calc_zoom_coords(self, image_rect): - """Calculates the coordinates of a zoomed sub-image. + def _calc_zoom_coords(self, image_rect, lod=None): + """ Calculates the coordinates of a zoomed sub-image. Because of floating point limitations, it is not advisable to request a extreme level of zoom, e.g., idx or idy > 10^10. @@ -323,12 +376,12 @@ def _calc_zoom_coords(self, image_rect): if 0 in (image_width, image_height) or 0 in self.bounds: return ((0, 0, 0, 0), (0, 0, 0, 0)) - array_bounds = self._array_bounds_from_screen_rect(image_rect) + array_bounds = self._array_bounds_from_screen_rect(image_rect, lod=lod) col_min, col_max, row_min, row_max = array_bounds # Convert array indices back into screen coordinates after its been # clipped to fit within the bounds. - array_width = self.value.get_width() - array_height = self.value.get_height() + array_width = self.value.get_width(lod=lod) + array_height = self.value.get_height(lod=lod) x_min = float(col_min) / array_width * image_width + ix x_max = float(col_max) / array_width * image_width + ix y_min = float(row_min) / array_height * image_height + iy @@ -349,8 +402,8 @@ def _calc_zoom_coords(self, image_rect): screen_rect = [x_min, y_min, x_max - x_min, y_max - y_min] return index_bounds, screen_rect - def _array_bounds_from_screen_rect(self, image_rect): - """Transform virtual-image rectangle into array indices. + def _array_bounds_from_screen_rect(self, image_rect, lod=None): + """ Transform virtual-image rectangle into array indices. The virtual-image rectangle is in screen coordinates and can be outside the plot bounds. This method converts the rectangle into array indices @@ -373,8 +426,8 @@ def _array_bounds_from_screen_rect(self, image_rect): x_max = x_min + plot_width y_max = y_min + plot_height - array_width = self.value.get_width() - array_height = self.value.get_height() + array_width = self.value.get_width(lod=lod) + array_height = self.value.get_height(lod=lod) # Convert screen coordinates to array indexes col_min = floor(float(x_min) / image_width * array_width) col_max = ceil(float(x_max) / image_width * array_width) @@ -388,3 +441,18 @@ def _array_bounds_from_screen_rect(self, image_rect): row_max = min(row_max, array_height) return col_min, col_max, row_min, row_max + + def _calculate_necessary_lod(self): + """ Computes the necessary lod so that array has more pixels than + the screen rectangle. + """ + virtual_rect = self._calc_virtual_screen_bbox() + # NOTE: LOD numbers are assumed to be continuous integers + # starting from 0 + for lod in range(len(self.value.lod_data_entry))[::-1]: + index_bounds, screen_rect = self._calc_zoom_coords(virtual_rect, lod=lod) + array_width = index_bounds[1] - index_bounds[0] + array_height = index_bounds[3] - index_bounds[2] + if (array_width >= screen_rect[2]) and (array_height >= screen_rect[3]): + break + return lod diff --git a/ci/edmtool.py b/ci/edmtool.py index ec9b96f46..7b8ca0fcb 100644 --- a/ci/edmtool.py +++ b/ci/edmtool.py @@ -94,6 +94,7 @@ "enable", # Needed to install enable from source "swig", + "traits_futures" } # Dependencies we install from source for cron tests diff --git a/examples/demo/advanced/lod_image_viewer.py b/examples/demo/advanced/lod_image_viewer.py new file mode 100644 index 000000000..d9823cac4 --- /dev/null +++ b/examples/demo/advanced/lod_image_viewer.py @@ -0,0 +1,116 @@ +""" +Renders high resolution image based on user interactions while keeping the GUI +responsive. + +Move the scrollbar to move around the image. Note the scrollbar stays +responsive even though the high resolution image may take longer to load. +""" +import numpy as np + +from enable.api import ComponentEditor, Container +from traits.api import HasTraits, Instance +from traitsui.api import Item, View +from traits_futures.api import TraitsExecutor + +from chaco.api import ( + DataRange2D, GridDataSource, GridMapper, HPlotContainer, + ImageData, ImagePlot +) +from chaco.tools.api import PanTool, ZoomTool + + +LOD_PATH = "LOD_{}" + + +def mandelbrot_set(xmin, xmax, ymin, ymax, xn, yn, maxiter, horizon): + """ Generates Mandelbrot dataset. """ + X = np.linspace(xmin, xmax, xn).astype(np.float32) + Y = np.linspace(ymin, ymax, yn).astype(np.float32) + C = X + Y[:, None] * 1j + N = np.zeros_like(C, dtype=int) + Z = np.zeros_like(C) + for n in range(maxiter): + I = abs(Z) < horizon + N[I] = n + Z[I] = Z[I]**2 + C[I] + N[N == maxiter-1] = 0 + return Z, N + + +def sample_big_data(): + """ Generates the Mandelbrot fractal with different resolutions stored as + multiple LOD images + Ref: https://matplotlib.org/examples/showcase/mandelbrot.html + """ + xmin, xmax = -2.25, +0.75 + ymin, ymax = -1.25, +1.25 + maxiter = 200 + horizon = 2.0 ** 40 + log_horizon = np.log2(np.log(horizon)) + + xn = 3000 + yn = 2500 + sample = {} + + for lod in range(10): + Z, N = mandelbrot_set(xmin, xmax, ymin, ymax, + xn // (2 ** lod), yn // (2 ** lod), + maxiter, horizon) + with np.errstate(invalid='ignore'): + M = np.nan_to_num(N + 1 - np.log2(np.log(abs(Z))) + log_horizon) + + sample[LOD_PATH.format(lod)] = np.stack( + [M/2, M, M/4], axis=2).astype('uint8') + + return sample + + +def _create_lod_plot(executor): + sample = sample_big_data() + sample_image_data = ImageData(data=sample[LOD_PATH.format(5)], + support_downsampling=True, + lod_data_entry=sample, + lod_key_pattern=LOD_PATH, + transposed=False) + + h = sample_image_data.get_height(lod=0) + w = sample_image_data.get_width(lod=0) + index = GridDataSource(np.arange(h), np.arange(w)) + index_mapper = GridMapper( + range=DataRange2D(low=(0, 0), high=(h-1, w-1)) + ) + renderer = ImagePlot( + value=sample_image_data, + index=index, + index_mapper=index_mapper, + use_downsampling=True, + traits_executor=executor + ) + + container = HPlotContainer(bounds=(1200, 1000)) + container.add(renderer) + renderer.tools.append(PanTool(renderer, constrain_key="shift")) + renderer.overlays.append(ZoomTool(component=renderer, + tool_mode="box", always_on=False)) + return container + + +class LODImageDemo(HasTraits): + + plot_container = Instance(Container) + + traits_view = View( + Item( + 'plot_container', + editor=ComponentEditor(size=(600, 500)), + show_label=False, + ), + resizable=True, + ) + + +if __name__ == "__main__": + executor = TraitsExecutor() + lod_demo = LODImageDemo(plot_container=_create_lod_plot(executor)) + lod_demo.configure_traits() + executor.stop()