From a550b8ba149bab614a3b860add8c61e7e9832c05 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 19 Apr 2023 01:42:06 -0400 Subject: [PATCH] support numpy fancy indexing for colors and data features --- fastplotlib/graphics/features/_base.py | 58 +++++++++++++++++++++++- fastplotlib/graphics/features/_colors.py | 26 +++++++---- fastplotlib/graphics/features/_data.py | 12 +++-- 3 files changed, 83 insertions(+), 13 deletions(-) diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 1119bed6b..ed08a9008 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -170,6 +170,12 @@ def cleanup_slice(key: Union[int, slice], upper_bound) -> Union[slice, int]: if isinstance(key, int): return key + if isinstance(key, np.ndarray): + return cleanup_array_slice(key, upper_bound) + + # if isinstance(key, np.integer): + # return int(key) + if isinstance(key, tuple): # if tuple of slice we only need the first obj # since the first obj is the datapoint indices @@ -197,13 +203,54 @@ def cleanup_slice(key: Union[int, slice], upper_bound) -> Union[slice, int]: stop = upper_bound elif stop > upper_bound: - raise IndexError("Index out of bounds") + raise IndexError(f"Index: `{stop}` out of bounds for feature array of size: `{upper_bound}`") step = key.step if step is None: step = 1 return slice(start, stop, step) + # return slice(int(start), int(stop), int(step)) + + +def cleanup_array_slice(key: np.ndarray, upper_bound) -> np.ndarray: + """ + Cleanup numpy array used for fancy indexing, make sure key[-1] <= upper_bound. + + Parameters + ---------- + key: np.ndarray + integer or boolean array + + upper_bound + + Returns + ------- + np.ndarray + integer indexing array + + """ + + if key.ndim > 1: + raise TypeError( + f"Can only use 1D boolean or integer arrays for fancy indexing" + ) + + # if boolean array convert to integer array of indices + if key.dtype == bool: + key = np.nonzero(key)[0] + + # make sure indices within bounds of feature buffer range + if key[-1] > upper_bound: + raise IndexError(f"Index: `{key[-1]}` out of bounds for feature array of size: `{upper_bound}`") + + # make sure indices are integers + if np.issubdtype(key.dtype, np.integer): + return key + + raise TypeError( + f"Can only use 1D boolean or integer arrays for fancy indexing" + ) class GraphicFeatureIndexable(GraphicFeature): @@ -236,7 +283,8 @@ def _upper_bound(self) -> int: def _update_range_indices(self, key): """Currently used by colors and positions data""" - key = cleanup_slice(key, self._upper_bound) + if not isinstance(key, np.ndarray): + key = cleanup_slice(key, self._upper_bound) if isinstance(key, int): self.buffer.update_range(key, size=1) @@ -254,6 +302,12 @@ def _update_range_indices(self, key): ixs = range(key.start, key.stop, step) for ix in ixs: self.buffer.update_range(ix, size=1) + + # TODO: See how efficient this is with large indexing + elif isinstance(key, np.ndarray): + for ix in key: + self.buffer.update_range(int(ix), size=1) + else: raise TypeError("must pass int or slice to update range") diff --git a/fastplotlib/graphics/features/_colors.py b/fastplotlib/graphics/features/_colors.py index 7813df61f..5ff82ca72 100644 --- a/fastplotlib/graphics/features/_colors.py +++ b/fastplotlib/graphics/features/_colors.py @@ -1,6 +1,6 @@ import numpy as np -from ._base import GraphicFeature, GraphicFeatureIndexable, cleanup_slice, FeatureEvent +from ._base import GraphicFeature, GraphicFeatureIndexable, cleanup_slice, FeatureEvent, cleanup_array_slice from ...utils import make_colors, get_cmap_texture, make_pygfx_colors from pygfx import Color @@ -102,9 +102,8 @@ def __setitem__(self, key, value): indices = range(_key.start, _key.stop, _key.step) # or single numerical index - elif isinstance(key, int): - if key > self._upper_bound: - raise IndexError("Index out of bounds") + elif isinstance(key, (int, np.integer)): + key = cleanup_slice(key, self._upper_bound) indices = [key] elif isinstance(key, tuple): @@ -128,6 +127,10 @@ def __setitem__(self, key, value): self._feature_changed(key, value) return + elif isinstance(key, np.ndarray): + key = cleanup_array_slice(key, self._upper_bound) + indices = key + else: raise TypeError("Graphic features only support integer and numerical fancy indexing") @@ -181,6 +184,8 @@ def _feature_changed(self, key, new_data): indices = [key] elif isinstance(key, slice): indices = range(key.start, key.stop, key.step) + elif isinstance(key, np.ndarray): + indices = key else: raise TypeError("feature changed key must be slice or int") @@ -205,11 +210,16 @@ def __init__(self, parent, colors): def __setitem__(self, key, value): key = cleanup_slice(key, self._upper_bound) - if not isinstance(key, slice): - raise TypeError("Cannot set cmap on single indices, must pass a slice object or " - "set it on the entire data.") + if not isinstance(key, (slice, np.ndarray)): + raise TypeError("Cannot set cmap on single indices, must pass a slice object, " + "numpy.ndarray or set it on the entire data.") + + if isinstance(key, slice): + n_colors = len(range(key.start, key.stop, key.step)) - n_colors = len(range(key.start, key.stop, key.step)) + else: + # numpy array + n_colors = key.size colors = make_colors(n_colors, cmap=value).astype(self._data.dtype) super(CmapFeature, self).__setitem__(key, colors) diff --git a/fastplotlib/graphics/features/_data.py b/fastplotlib/graphics/features/_data.py index 5063b4200..6c7dbfa75 100644 --- a/fastplotlib/graphics/features/_data.py +++ b/fastplotlib/graphics/features/_data.py @@ -3,7 +3,7 @@ import numpy as np from pygfx import Buffer, Texture -from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype +from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype, cleanup_array_slice class PointsDataFeature(GraphicFeatureIndexable): @@ -48,8 +48,12 @@ def _fix_data(self, data, parent): return data def __setitem__(self, key, value): + if isinstance(key, np.ndarray): + # make sure 1D array of int or boolean + key = cleanup_array_slice(key, self._upper_bound) + # put data into right shape if they're only indexing datapoints - if isinstance(key, (slice, int)): + if isinstance(key, (slice, int, np.ndarray, np.integer)): value = self._fix_data(value, self._parent) # otherwise assume that they have the right shape # numpy will throw errors if it can't broadcast @@ -66,10 +70,12 @@ def _update_range(self, key): def _feature_changed(self, key, new_data): if key is not None: key = cleanup_slice(key, self._upper_bound) - if isinstance(key, int): + if isinstance(key, (int, np.integer)): indices = [key] elif isinstance(key, slice): indices = range(key.start, key.stop, key.step) + elif isinstance(key, np.ndarray): + indices = key elif key is None: indices = None