From ea4b65284ed25396080cdc95cdbf941f157c41e2 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 4 Mar 2023 03:59:09 -0500 Subject: [PATCH 1/2] HeatmapGraphic for dims larger than 8192, data, cmaps, vminmax are modifiable --- fastplotlib/graphics/__init__.py | 4 +- fastplotlib/graphics/features/__init__.py | 4 +- fastplotlib/graphics/features/_base.py | 27 +++- fastplotlib/graphics/features/_colors.py | 13 ++ fastplotlib/graphics/features/_data.py | 57 +++++-- fastplotlib/graphics/image.py | 176 +++++++++++++++++++++- fastplotlib/layouts/_subplot.py | 3 - 7 files changed, 265 insertions(+), 19 deletions(-) diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index 0c883e456..66ad5820f 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -1,8 +1,8 @@ from .histogram import HistogramGraphic from .line import LineGraphic from .scatter import ScatterGraphic -from .image import ImageGraphic -from .heatmap import HeatmapGraphic +from .image import ImageGraphic, HeatmapGraphic +# from .heatmap import HeatmapGraphic from .text import TextGraphic from .line_collection import LineCollection, LineStack diff --git a/fastplotlib/graphics/features/__init__.py b/fastplotlib/graphics/features/__init__.py index 3a5ea551c..0e1e5f512 100644 --- a/fastplotlib/graphics/features/__init__.py +++ b/fastplotlib/graphics/features/__init__.py @@ -1,5 +1,5 @@ -from ._colors import ColorFeature, CmapFeature, ImageCmapFeature -from ._data import PointsDataFeature, ImageDataFeature +from ._colors import ColorFeature, CmapFeature, ImageCmapFeature, HeatmapCmapFeature +from ._data import PointsDataFeature, ImageDataFeature, HeatmapDataFeature from ._present import PresentFeature from ._thickness import ThicknessFeature from ._base import GraphicFeature, GraphicFeatureIndexable diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index a9aafc2b7..65b7ff5b0 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -7,6 +7,30 @@ from pygfx import Buffer +supported_dtypes = [ + np.int8, + np.int16, + np.int32, + np.float16, + np.float32 +] + + +def to_gpu_supported_dtype(array): + if isinstance(array, np.ndarray): + if array.dtype not in supported_dtypes: + if np.issubdtype(array.dtype, np.integer): + warn(f"converting {array.dtype} array to int32") + return array.astype(np.int32) + elif np.issubdtype(array.dtype, np.floating): + warn(f"converting {array.dtype} array to float32") + return array.astype(np.float32, copy=False) + else: + raise TypeError("Unsupported type, supported array types must be int or float dtypes") + + return array + + class FeatureEvent: """ type: , example: "colors" @@ -43,7 +67,7 @@ def __init__(self, parent, data: Any, collection_index: int = None): """ self._parent = parent if isinstance(data, np.ndarray): - data = data.astype(np.float32) + data = to_gpu_supported_dtype(data) self._data = data @@ -227,3 +251,4 @@ def _update_range_indices(self, key): self._buffer.update_range(ix, size=1) else: raise TypeError("must pass int or slice to update range") + diff --git a/fastplotlib/graphics/features/_colors.py b/fastplotlib/graphics/features/_colors.py index 41f03e711..6f2bd6a92 100644 --- a/fastplotlib/graphics/features/_colors.py +++ b/fastplotlib/graphics/features/_colors.py @@ -238,3 +238,16 @@ def _feature_changed(self, key, new_data): event_data = FeatureEvent(type="cmap", pick_info=pick_info) self._call_event_handlers(event_data) + + +class HeatmapCmapFeature(ImageCmapFeature): + """ + Colormap for HeatmapGraphic + """ + + def _set(self, cmap_name: str): + self._parent._material.map.texture.data[:] = make_colors(256, cmap_name) + self._parent._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1)) + self.name = cmap_name + + self._feature_changed(key=None, new_data=self.name) diff --git a/fastplotlib/graphics/features/_data.py b/fastplotlib/graphics/features/_data.py index 0002d6697..2884723ea 100644 --- a/fastplotlib/graphics/features/_data.py +++ b/fastplotlib/graphics/features/_data.py @@ -3,14 +3,7 @@ import numpy as np from pygfx import Buffer, Texture -from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent - - -def to_float32(array): - if isinstance(array, np.ndarray): - return array.astype(np.float32, copy=False) - - return array +from ._base import GraphicFeatureIndexable, cleanup_slice, FeatureEvent, to_gpu_supported_dtype class PointsDataFeature(GraphicFeatureIndexable): @@ -102,7 +95,7 @@ def __init__(self, parent, data: Any): "``[x_dim, y_dim]`` or ``[x_dim, y_dim, rgb]``" ) - data = to_float32(data) + data = to_gpu_supported_dtype(data) super(ImageDataFeature, self).__init__(parent, data) @property @@ -114,7 +107,7 @@ def __getitem__(self, item): def __setitem__(self, key, value): # make sure float32 - value = to_float32(value) + value = to_gpu_supported_dtype(value) self._buffer.data[key] = value self._update_range(key) @@ -145,3 +138,47 @@ def _feature_changed(self, key, new_data): event_data = FeatureEvent(type="data", pick_info=pick_info) self._call_event_handlers(event_data) + + +class HeatmapDataFeature(ImageDataFeature): + @property + def _buffer(self) -> List[Texture]: + return [img.geometry.grid.texture for img in self._parent.world_object.children] + + def __getitem__(self, item): + return self._data[item] + + def __setitem__(self, key, value): + # make sure supported type, not float64 etc. + value = to_gpu_supported_dtype(value) + + self._data[key] = value + self._update_range(key) + + # avoid creating dicts constantly if there are no events to handle + if len(self._event_handlers) > 0: + self._feature_changed(key, value) + + def _update_range(self, key): + for buffer in self._buffer: + buffer.update_range((0, 0, 0), size=buffer.size) + + def _feature_changed(self, key, new_data): + if key is not None: + key = cleanup_slice(key, self._upper_bound) + if isinstance(key, int): + indices = [key] + elif isinstance(key, slice): + indices = range(key.start, key.stop, key.step) + elif key is None: + indices = None + + pick_info = { + "index": indices, + "world_object": self._parent.world_object, + "new_data": new_data + } + + event_data = FeatureEvent(type="data", pick_info=pick_info) + + self._call_event_handlers(event_data) diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index b9f15f5dc..854f757f2 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -1,9 +1,12 @@ from typing import * +from math import ceil +from itertools import product import pygfx +from pygfx.utils import unpack_bitfield from ._base import Graphic, Interaction, PreviouslyModifiedData -from .features import ImageCmapFeature, ImageDataFeature +from .features import ImageCmapFeature, ImageDataFeature, HeatmapDataFeature, HeatmapCmapFeature from ..utils import quick_min_max @@ -119,5 +122,176 @@ def _reset_feature(self, feature: str): pass +class _ImageTile(pygfx.Image): + """ + Similar to pygfx.Image, only difference is that it contains a few properties to keep track of + row chunk index, column chunk index + """ + def _wgpu_get_pick_info(self, pick_value): + tex = self.geometry.grid + if hasattr(tex, "texture"): + tex = tex.texture # tex was a view + # This should match with the shader + values = unpack_bitfield(pick_value, wobject_id=20, x=22, y=22) + x = values["x"] / 4194304 * tex.size[0] - 0.5 + y = values["y"] / 4194304 * tex.size[1] - 0.5 + ix, iy = int(x + 0.5), int(y + 0.5) + return { + "index": (ix, iy), + "pixel_coord": (x - ix, y - iy), + "row_chunk_index": self.row_chunk_index, + "col_chunk_index": self.col_chunk_index + } + + @property + def row_chunk_index(self) -> int: + return self._row_chunk_index + + @row_chunk_index.setter + def row_chunk_index(self, index: int): + self._row_chunk_index = index + + @property + def col_chunk_index(self) -> int: + return self._col_chunk_index + + @col_chunk_index.setter + def col_chunk_index(self, index: int): + self._col_chunk_index = index + + +class HeatmapGraphic(Graphic, Interaction): + feature_events = ( + "data", + "cmap", + ) + + def __init__( + self, + data: Any, + vmin: int = None, + vmax: int = None, + cmap: str = 'plasma', + filter: str = "nearest", + chunk_size: int = 8192, + *args, + **kwargs + ): + """ + Create an Image Graphic + + Parameters + ---------- + data: array-like + array-like, usually numpy.ndarray, must support ``memoryview()`` + Tensorflow Tensors also work **probably**, but not thoroughly tested + | shape must be ``[x_dim, y_dim]`` + vmin: int, optional + minimum value for color scaling, calculated from data if not provided + vmax: int, optional + maximum value for color scaling, calculated from data if not provided + cmap: str, optional, default "plasma" + colormap to use to display the data + filter: str, optional, default "nearest" + interpolation filter, one of "nearest" or "linear" + chunk_size: int, default 8192, max 8192 + chunk size for each tile used to make up the heatmap texture + args: + additional arguments passed to Graphic + kwargs: + additional keyword arguments passed to Graphic + + Examples + -------- + .. code-block:: python + + from fastplotlib import Plot + # create a `Plot` instance + plot = Plot() + # make some random 2D image data + data = np.random.rand(512, 512) + # plot the image data + plot.add_image(data=data) + # show the plot + plot.show() + """ + + super().__init__(*args, **kwargs) + + if chunk_size > 8192: + raise ValueError("Maximum chunk size is 8192") + + self.data = HeatmapDataFeature(self, data) + + row_chunks = range(ceil(data.shape[0] / chunk_size)) + col_chunks = range(ceil(data.shape[1] / chunk_size)) + + chunks = list(product(row_chunks, col_chunks)) + # chunks is the index position of each chunk + + start_ixs = [list(map(lambda c: c * chunk_size, chunk)) for chunk in chunks] + stop_ixs = [list(map(lambda c: c + chunk_size, chunk)) for chunk in start_ixs] + + self._world_object = pygfx.Group() + + if (vmin is None) or (vmax is None): + vmin, vmax = quick_min_max(data) + + self.cmap = HeatmapCmapFeature(self, cmap) + self._material = pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=self.cmap()) + + for start, stop, chunk in zip(start_ixs, stop_ixs, chunks): + row_start, col_start = start + row_stop, col_stop = stop + + # x and y positions of the Tile in world space coordinates + y_pos, x_pos = row_start, col_start + + tex_view = pygfx.Texture(data[row_start:row_stop, col_start:col_stop], dim=2).get_view(filter=filter) + geometry = pygfx.Geometry(grid=tex_view) + # material = pygfx.ImageBasicMaterial(clim=(0, 1), map=self.cmap()) + + img = _ImageTile(geometry, self._material) + + # row and column chunk index for this Tile + img.row_chunk_index = chunk[0] + img.col_chunk_index = chunk[1] + + img.position.set_x(x_pos) + img.position.set_y(y_pos) + + self.world_object.add(img) + + @property + def vmin(self) -> float: + """Minimum contrast limit.""" + return self._material.clim[0] + + @vmin.setter + def vmin(self, value: float): + """Minimum contrast limit.""" + self._material.clim = ( + value, + self._material.clim[1] + ) + + @property + def vmax(self) -> float: + """Maximum contrast limit.""" + return self._material.clim[1] + + @vmax.setter + def vmax(self, value: float): + """Maximum contrast limit.""" + self._material.clim = ( + self._material.clim[0], + value + ) + + def _set_feature(self, feature: str, new_data: Any, indices: Any): + pass + + def _reset_feature(self, feature: str): + pass diff --git a/fastplotlib/layouts/_subplot.py b/fastplotlib/layouts/_subplot.py index 7df118d1d..41d065648 100644 --- a/fastplotlib/layouts/_subplot.py +++ b/fastplotlib/layouts/_subplot.py @@ -255,9 +255,6 @@ def add_graphic(self, graphic, center: bool = True): graphic.world_object.position.z = len(self._graphics) super(Subplot, self).add_graphic(graphic, center) - if isinstance(graphic, graphics.HeatmapGraphic): - self.controller.scale.y = copysign(self.controller.scale.y, -1) - def set_axes_visibility(self, visible: bool): """Toggles axes visibility.""" if visible: From e4649253f559b9fb2918f285a22101d5baba128f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 5 Mar 2023 02:36:17 -0500 Subject: [PATCH 2/2] add unsigned int types to supported --- fastplotlib/graphics/features/_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 65b7ff5b0..58f9aca9f 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -8,6 +8,9 @@ supported_dtypes = [ + np.uint8, + np.uint16, + np.uint32, np.int8, np.int16, np.int32,