From b53940b441fe0b0edae57f8e6cd89bf46cf0943b Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 26 Oct 2023 22:45:38 -0400 Subject: [PATCH 1/6] move render and animation methods from Subplot to PlotBase --- fastplotlib/layouts/_base.py | 85 +++++++++++++++++++++++++++++++- fastplotlib/layouts/_subplot.py | 86 --------------------------------- 2 files changed, 84 insertions(+), 87 deletions(-) diff --git a/fastplotlib/layouts/_base.py b/fastplotlib/layouts/_base.py index ecd68dc3f..3cfdbbd41 100644 --- a/fastplotlib/layouts/_base.py +++ b/fastplotlib/layouts/_base.py @@ -1,5 +1,7 @@ +from inspect import getfullargspec from typing import * import weakref +from warnings import warn import numpy as np @@ -92,6 +94,9 @@ def __init__( self.viewport, ) + self._animate_funcs_pre = list() + self._animate_funcs_post = list() + self.renderer.add_event_handler(self.set_viewport_rect, "resize") # list of hex id strings for all graphics managed by this PlotArea @@ -224,12 +229,90 @@ def set_viewport_rect(self, *args): self.viewport.rect = self.get_rect() def render(self): - # does not flush + self._call_animate_functions(self._animate_funcs_pre) + + # does not flush, flush must be implemented in user-facing Plot objects self.viewport.render(self.scene, self.camera) for child in self.children: child.render() + self._call_animate_functions(self._animate_funcs_post) + + def _call_animate_functions(self, funcs: Iterable[callable]): + for fn in funcs: + try: + args = getfullargspec(fn).args + + if len(args) > 0: + if args[0] == "self" and not len(args) > 1: + fn() + else: + fn(self) + else: + fn() + except (ValueError, TypeError): + warn( + f"Could not resolve argspec of {self.__class__.__name__} animation function: {fn}, " + f"calling it without arguments." + ) + fn() + + def add_animations( + self, + *funcs: Iterable[callable], + pre_render: bool = True, + post_render: bool = False, + ): + """ + Add function(s) that are called on every render cycle. + These are called at the Subplot level. + + Parameters + ---------- + *funcs: callable or iterable of callable + function(s) that are called on each render cycle + + pre_render: bool, default ``True``, optional keyword-only argument + if true, these function(s) are called before a render cycle + + post_render: bool, default ``False``, optional keyword-only argument + if true, these function(s) are called after a render cycle + + """ + for f in funcs: + if not callable(f): + raise TypeError( + f"all positional arguments to add_animations() must be callable types, you have passed a: {type(f)}" + ) + if pre_render: + self._animate_funcs_pre += funcs + if post_render: + self._animate_funcs_post += funcs + + def remove_animation(self, func): + """ + Removes the passed animation function from both pre and post render. + + Parameters + ---------- + func: callable + The function to remove, raises a error if it's not registered as a pre or post animation function. + + """ + if func not in self._animate_funcs_pre and func not in self._animate_funcs_post: + raise KeyError( + f"The passed function: {func} is not registered as an animation function. These are the animation " + f" functions that are currently registered:\n" + f"pre: {self._animate_funcs_pre}\n\npost: {self._animate_funcs_post}" + ) + + if func in self._animate_funcs_pre: + self._animate_funcs_pre.remove(func) + + if func in self._animate_funcs_post: + self._animate_funcs_post.remove(func) + def add_graphic(self, graphic: Graphic, center: bool = True): """ Add a Graphic to the scene diff --git a/fastplotlib/layouts/_subplot.py b/fastplotlib/layouts/_subplot.py index a8cd4852b..c32737b51 100644 --- a/fastplotlib/layouts/_subplot.py +++ b/fastplotlib/layouts/_subplot.py @@ -1,6 +1,4 @@ from typing import * -from inspect import getfullargspec -from warnings import warn import numpy as np @@ -97,9 +95,6 @@ def __init__( self._grid: GridHelper = GridHelper(size=100, thickness=1) - self._animate_funcs_pre = list() - self._animate_funcs_post = list() - super(Subplot, self).__init__( parent=parent, position=position, @@ -192,87 +187,6 @@ def get_rect(self): return rect - def render(self): - self._call_animate_functions(self._animate_funcs_pre) - - super(Subplot, self).render() - - self._call_animate_functions(self._animate_funcs_post) - - def _call_animate_functions(self, funcs: Iterable[callable]): - for fn in funcs: - try: - args = getfullargspec(fn).args - - if len(args) > 0: - if args[0] == "self" and not len(args) > 1: - fn() - else: - fn(self) - else: - fn() - except (ValueError, TypeError): - warn( - f"Could not resolve argspec of {self.__class__.__name__} animation function: {fn}, " - f"calling it without arguments." - ) - fn() - - def add_animations( - self, - *funcs: Iterable[callable], - pre_render: bool = True, - post_render: bool = False, - ): - """ - Add function(s) that are called on every render cycle. - These are called at the Subplot level. - - Parameters - ---------- - *funcs: callable or iterable of callable - function(s) that are called on each render cycle - - pre_render: bool, default ``True``, optional keyword-only argument - if true, these function(s) are called before a render cycle - - post_render: bool, default ``False``, optional keyword-only argument - if true, these function(s) are called after a render cycle - - """ - for f in funcs: - if not callable(f): - raise TypeError( - f"all positional arguments to add_animations() must be callable types, you have passed a: {type(f)}" - ) - if pre_render: - self._animate_funcs_pre += funcs - if post_render: - self._animate_funcs_post += funcs - - def remove_animation(self, func): - """ - Removes the passed animation function from both pre and post render. - - Parameters - ---------- - func: callable - The function to remove, raises a error if it's not registered as a pre or post animation function. - - """ - if func not in self._animate_funcs_pre and func not in self._animate_funcs_post: - raise KeyError( - f"The passed function: {func} is not registered as an animation function. These are the animation " - f" functions that are currently registered:\n" - f"pre: {self._animate_funcs_pre}\n\npost: {self._animate_funcs_post}" - ) - - if func in self._animate_funcs_pre: - self._animate_funcs_pre.remove(func) - - if func in self._animate_funcs_post: - self._animate_funcs_post.remove(func) - def set_axes_visibility(self, visible: bool): """Toggles axes visibility.""" if visible: From a4428f9c6d5607068f736aa36aa5586de7c66b6f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 26 Oct 2023 22:46:04 -0400 Subject: [PATCH 2/6] edge_thickness param for linear_region --- fastplotlib/graphics/selectors/_linear_region.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fastplotlib/graphics/selectors/_linear_region.py b/fastplotlib/graphics/selectors/_linear_region.py index 602215467..c59265fe8 100644 --- a/fastplotlib/graphics/selectors/_linear_region.py +++ b/fastplotlib/graphics/selectors/_linear_region.py @@ -44,6 +44,7 @@ def __init__( resizable: bool = True, fill_color=(0, 0, 0.35), edge_color=(0.8, 0.8, 0), + edge_thickness: int = 3, arrow_keys_modifier: str = "Shift", name: str = None, ): @@ -168,7 +169,7 @@ def __init__( left_line = pygfx.Line( pygfx.Geometry(positions=left_line_data), - pygfx.LineMaterial(thickness=3, color=edge_color), + pygfx.LineMaterial(thickness=edge_thickness, color=edge_color), ) # position data for the right edge line @@ -181,7 +182,7 @@ def __init__( right_line = pygfx.Line( pygfx.Geometry(positions=right_line_data), - pygfx.LineMaterial(thickness=3, color=edge_color), + pygfx.LineMaterial(thickness=edge_thickness, color=edge_color), ) self.edges: Tuple[pygfx.Line, pygfx.Line] = (left_line, right_line) @@ -197,7 +198,7 @@ def __init__( bottom_line = pygfx.Line( pygfx.Geometry(positions=bottom_line_data), - pygfx.LineMaterial(thickness=3, color=edge_color), + pygfx.LineMaterial(thickness=edge_thickness, color=edge_color), ) # position data for the right edge line @@ -210,7 +211,7 @@ def __init__( top_line = pygfx.Line( pygfx.Geometry(positions=top_line_data), - pygfx.LineMaterial(thickness=3, color=edge_color), + pygfx.LineMaterial(thickness=edge_thickness, color=edge_color), ) self.edges: Tuple[pygfx.Line, pygfx.Line] = (bottom_line, top_line) From 83d28975c2e66da985612cb4518936bdf4b7e9d3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 26 Oct 2023 22:46:20 -0400 Subject: [PATCH 3/6] histogram LUT widget basically works! --- fastplotlib/widgets/histogram_lut.py | 131 +++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 fastplotlib/widgets/histogram_lut.py diff --git a/fastplotlib/widgets/histogram_lut.py b/fastplotlib/widgets/histogram_lut.py new file mode 100644 index 000000000..d93d22075 --- /dev/null +++ b/fastplotlib/widgets/histogram_lut.py @@ -0,0 +1,131 @@ +import numpy as np + +from pygfx import Group + +from ..graphics import LineGraphic, ImageGraphic +from ..graphics._base import Graphic +from ..graphics.selectors import LinearRegionSelector + + +# TODO: This is a widget, we can think about a BaseWidget class later if necessary +class HistogramLUT(Graphic): + def __init__( + self, + data: np.ndarray, + image_graphic: ImageGraphic, + nbins: int = 100, + **kwargs + ): + super().__init__(**kwargs) + + self.nbins = nbins + self._image_graphic = image_graphic + + hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) + + line_data = np.column_stack([hist_scaled, edges_flanked]) + + self.line = LineGraphic(line_data) + + bounds = (edges[0], edges[-1]) + limits = (edges_flanked[0], edges_flanked[-1]) + size = 120 # since it's scaled to 100 + origin = (hist_scaled.max() / 2, 0) + + self.linear_region = LinearRegionSelector( + bounds=bounds, + limits=limits, + size=size, + origin=origin, + axis="y", + ) + + widget_wo = Group() + widget_wo.add(self.line.world_object, self.linear_region.world_object) + + self._set_world_object(widget_wo) + + self.world_object.local.scale_x *= -1 + + self.linear_region.selection.add_event_handler( + self._set_vmin_vmax + ) + + def _add_plot_area_hook(self, plot_area): + self._plot_area = plot_area + self.linear_region._add_plot_area_hook(plot_area) + self.line._add_plot_area_hook(plot_area) + + def _calculate_histogram(self, data): + if data.ndim > 2: + # subsample to max of 500 x 100 x 100, + # np.histogram takes ~30ms with this size on a 8 core Ryzen laptop + # dim0 is usually time, allow max of 500 timepoints + ss0 = int(data.shape[0] / 500) + # allow max of 100 for x and y if ndim > 2 + ss1 = int(data.shape[1] / 100) + ss2 = int(data.shape[2] / 100) + + hist, edges = np.histogram(data[::ss0, ::ss1, ::ss2], bins=self.nbins) + + else: + # allow max of 1000 x 1000 + # this takes ~4ms on a 8 core Ryzen laptop + ss0 = int(data.shape[0] / 1_000) + ss1 = int(data.shape[1] / 1_000) + + hist, edges = np.histogram(data[::ss0, ::ss1], bins=self.nbins) + + bin_width = edges[1] - edges[0] + + flank_nbins = int(self.nbins / 3) + flank_size = flank_nbins * bin_width + + flank_left = np.arange(edges[0] - flank_size, edges[0], bin_width) + flank_right = np.arange(edges[-1] + bin_width, edges[-1] + flank_size, bin_width) + + edges_flanked = np.concatenate((flank_left, edges, flank_right)) + np.unique(np.diff(edges_flanked)) + + hist_flanked = np.concatenate((np.zeros(flank_nbins), hist, np.zeros(flank_nbins))) + + # scale 0-100 to make it easier to see + # float32 data can produce unnecessarily high values + hist_scaled = hist_flanked / (hist_flanked.max() / 100) + + return hist, edges, hist_scaled, edges_flanked + + def _set_vmin_vmax(self, ev): + selected = self.linear_region.get_selected_data(self.line)[:, 1] + self.image_graphic.cmap.vmin = selected[0] + self.image_graphic.cmap.vmax = selected[-1] + + def set_data(self, data): + hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) + + line_data = np.column_stack([hist_scaled, edges_flanked]) + + self.line.data = line_data + + bounds = (edges[0], edges[-1]) + limits = (edges_flanked[0], edges_flanked[-11]) + origin = (hist_scaled.max() / 2, 0) + + self.linear_region.limits = limits + self.linear_region.selection = bounds + # self.linear_region.fill.world.position = (*origin, -2) + + # def nbins(self): + + @property + def image_graphic(self) -> ImageGraphic: + return self._image_graphic + + @image_graphic.setter + def image_graphic(self, graphic): + if not isinstance(graphic, ImageGraphic): + raise TypeError( + f"HistogramLUT can only use ImageGraphic types, you have passed: {type(graphic)}" + ) + + self._image_graphic = graphic From 20a84eebf26f840905b51b10b5733d639c1f3e9e Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 26 Oct 2023 23:04:39 -0400 Subject: [PATCH 4/6] bidirectional vmin vmax with linear region and image graphic, start imagewidget integration --- fastplotlib/widgets/histogram_lut.py | 17 +++++++++++++++++ fastplotlib/widgets/image.py | 19 +++---------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/fastplotlib/widgets/histogram_lut.py b/fastplotlib/widgets/histogram_lut.py index d93d22075..2a112538c 100644 --- a/fastplotlib/widgets/histogram_lut.py +++ b/fastplotlib/widgets/histogram_lut.py @@ -38,6 +38,7 @@ def __init__( size=size, origin=origin, axis="y", + edge_thickness=8 ) widget_wo = Group() @@ -51,6 +52,8 @@ def __init__( self._set_vmin_vmax ) + self.image_graphic.cmap.add_event_handler(self._set_selection_from_cmap) + def _add_plot_area_hook(self, plot_area): self._plot_area = plot_area self.linear_region._add_plot_area_hook(plot_area) @@ -97,9 +100,23 @@ def _calculate_histogram(self, data): def _set_vmin_vmax(self, ev): selected = self.linear_region.get_selected_data(self.line)[:, 1] + + self.image_graphic.cmap.block_events(True) + self.image_graphic.cmap.vmin = selected[0] self.image_graphic.cmap.vmax = selected[-1] + self.image_graphic.cmap.block_events(False) + + def _set_selection_from_cmap(self, ev): + vmin, vmax = ev.pick_info["vmin"], ev.pick_info["vmax"] + + self.linear_region.selection.block_events(True) + + self.linear_region.selection = (vmin, vmax) + + self.linear_region.selection.block_events(False) + def set_data(self, data): hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) diff --git a/fastplotlib/widgets/image.py b/fastplotlib/widgets/image.py index 589e75f83..cc857a07a 100644 --- a/fastplotlib/widgets/image.py +++ b/fastplotlib/widgets/image.py @@ -838,20 +838,8 @@ def reset_vmin_vmax(self): """ Reset the vmin and vmax w.r.t. the currently displayed image(s) """ - for i, ig in enumerate(self.managed_graphics): - mm = self._get_vmin_vmax_range(ig.data()) - - if len(self.vmin_vmax_sliders) != 0: - state = { - "value": mm[0], - "step": mm[1] / 150, - "min": mm[2], - "max": mm[3], - } - - self.vmin_vmax_sliders[i].set_state(state) - else: - ig.cmap.vmin, ig.cmap.vmax = mm[0] + for ig in self.managed_graphics: + ig.cmap.reset_vmin_vmax() def set_data( self, @@ -1068,8 +1056,7 @@ def __init__(self, iw: ImageWidget): self.reset_vminvmax_button.on_click(self._reset_vminvmax) def _reset_vminvmax(self, obj): - if len(self.iw.vmin_vmax_sliders) != 0: - self.iw.reset_vmin_vmax() + self.iw.reset_vmin_vmax() def _change_stepsize(self, obj): self.iw.sliders["t"].step = self.step_size_setter.value From 9c7eb633e749609b0e3e44a2eb9f93f969da44d4 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 27 Oct 2023 21:28:54 -0400 Subject: [PATCH 5/6] more hlut functionality, integrate in imagewidget --- .../graphics/selectors/_base_selector.py | 13 +- .../graphics/selectors/_linear_region.py | 3 + fastplotlib/widgets/histogram_lut.py | 141 ++++++++++++++---- fastplotlib/widgets/image.py | 83 +++-------- 4 files changed, 146 insertions(+), 94 deletions(-) diff --git a/fastplotlib/graphics/selectors/_base_selector.py b/fastplotlib/graphics/selectors/_base_selector.py index a84c02c6c..e892ca32d 100644 --- a/fastplotlib/graphics/selectors/_base_selector.py +++ b/fastplotlib/graphics/selectors/_base_selector.py @@ -86,6 +86,8 @@ def __init__( # sets to `True` on "pointer_down", sets to `False` on "pointer_up" self._moving = False #: indicates if the selector is currently being moved + self._initial_controller_state: bool = None + # used to disable fill area events if the edge is being actively hovered # otherwise annoying and requires too much accuracy to move just an edge self._edge_hovered: bool = False @@ -201,6 +203,8 @@ def _move_start(self, event_source: WorldObject, ev): self._move_info = MoveInfo(last_position=last_position, source=event_source) self._moving = True + self._initial_controller_state = self._plot_area.controller.enabled + def _move(self, ev): """ Called on pointer move events @@ -235,7 +239,9 @@ def _move(self, ev): # update last position self._move_info.last_position = world_pos - self._plot_area.controller.enabled = True + # restore the initial controller state + # if it was disabled, keep it disabled + self._plot_area.controller.enabled = self._initial_controller_state def _move_graphic(self, delta: np.ndarray): raise NotImplementedError("Must be implemented in subclass") @@ -243,7 +249,10 @@ def _move_graphic(self, delta: np.ndarray): def _move_end(self, ev): self._move_info = None self._moving = False - self._plot_area.controller.enabled = True + + # restore the initial controller state + # if it was disabled, keep it disabled + self._plot_area.controller.enabled = self._initial_controller_state def _move_to_pointer(self, ev): """ diff --git a/fastplotlib/graphics/selectors/_linear_region.py b/fastplotlib/graphics/selectors/_linear_region.py index c59265fe8..2a7547d5b 100644 --- a/fastplotlib/graphics/selectors/_linear_region.py +++ b/fastplotlib/graphics/selectors/_linear_region.py @@ -57,6 +57,9 @@ def __init__( Holding the right mouse button while dragging an edge will force the entire region selector to move. This is a when using transparent fill areas due to ``pygfx`` picking limitations. + **Note:** Events get very weird if the values of bounds, limits and origin are close to zero. If you need + a linear selector with small data, we recommend scaling the data and then using the selector. + Parameters ---------- bounds: (int, int) diff --git a/fastplotlib/widgets/histogram_lut.py b/fastplotlib/widgets/histogram_lut.py index 2a112538c..6b4088033 100644 --- a/fastplotlib/widgets/histogram_lut.py +++ b/fastplotlib/widgets/histogram_lut.py @@ -1,3 +1,5 @@ +import weakref + import numpy as np from pygfx import Group @@ -14,13 +16,30 @@ def __init__( data: np.ndarray, image_graphic: ImageGraphic, nbins: int = 100, + flank_divisor: float = 5.0, **kwargs ): + """ + + Parameters + ---------- + data + image_graphic + nbins + flank_divisor: float, default 5.0 + set `np.inf` for no flanks + kwargs + """ super().__init__(**kwargs) - self.nbins = nbins + self._nbins = nbins + self._flank_divisor = flank_divisor self._image_graphic = image_graphic + self._data = weakref.proxy(data) + + self._scale_factor: float = 1.0 + hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) line_data = np.column_stack([hist_scaled, edges_flanked]) @@ -41,6 +60,12 @@ def __init__( edge_thickness=8 ) + # there will be a small difference with the histogram edges so this makes them both line up exactly + self.linear_region.selection = (image_graphic.cmap.vmin, image_graphic.cmap.vmax) + + self._vmin = self.image_graphic.cmap.vmin + self._vmax = self.image_graphic.cmap.vmax + widget_wo = Group() widget_wo.add(self.line.world_object, self.linear_region.world_object) @@ -49,10 +74,10 @@ def __init__( self.world_object.local.scale_x *= -1 self.linear_region.selection.add_event_handler( - self._set_vmin_vmax + self._linear_region_handler ) - self.image_graphic.cmap.add_event_handler(self._set_selection_from_cmap) + self.image_graphic.cmap.add_event_handler(self._image_cmap_handler) def _add_plot_area_hook(self, plot_area): self._plot_area = plot_area @@ -64,24 +89,35 @@ def _calculate_histogram(self, data): # subsample to max of 500 x 100 x 100, # np.histogram takes ~30ms with this size on a 8 core Ryzen laptop # dim0 is usually time, allow max of 500 timepoints - ss0 = int(data.shape[0] / 500) + ss0 = max(1, int(data.shape[0] / 500)) # max to prevent step = 0 # allow max of 100 for x and y if ndim > 2 - ss1 = int(data.shape[1] / 100) - ss2 = int(data.shape[2] / 100) + ss1 = max(1, int(data.shape[1] / 100)) + ss2 = max(1, int(data.shape[2] / 100)) + + data_ss = data[::ss0, ::ss1, ::ss2] - hist, edges = np.histogram(data[::ss0, ::ss1, ::ss2], bins=self.nbins) + hist, edges = np.histogram(data_ss, bins=self._nbins) else: # allow max of 1000 x 1000 # this takes ~4ms on a 8 core Ryzen laptop - ss0 = int(data.shape[0] / 1_000) - ss1 = int(data.shape[1] / 1_000) + ss0 = max(1, int(data.shape[0] / 1_000)) + ss1 = max(1, int(data.shape[1] / 1_000)) - hist, edges = np.histogram(data[::ss0, ::ss1], bins=self.nbins) + data_ss = data[::ss0, ::ss1] + + hist, edges = np.histogram(data_ss, bins=self._nbins) + + # used if data ptp <= 10 because event things get weird + # with tiny world objects due to floating point error + # so if ptp <= 10, scale up by a factor + self._scale_factor: int = max(1, 100 * int(10 / data_ss.ptp())) + + edges = edges * self._scale_factor bin_width = edges[1] - edges[0] - flank_nbins = int(self.nbins / 3) + flank_nbins = int(self._nbins / self._flank_divisor) flank_size = flank_nbins * bin_width flank_left = np.arange(edges[0] - flank_size, edges[0], bin_width) @@ -96,28 +132,60 @@ def _calculate_histogram(self, data): # float32 data can produce unnecessarily high values hist_scaled = hist_flanked / (hist_flanked.max() / 100) + if edges_flanked.size > hist_scaled.size: + edges_flanked = edges_flanked[:-1] + return hist, edges, hist_scaled, edges_flanked - def _set_vmin_vmax(self, ev): - selected = self.linear_region.get_selected_data(self.line)[:, 1] + def _linear_region_handler(self, ev): + # must use world coordinate values directly from selection() + # otherwise the linear region bounds jump to the closest bin edges + vmin, vmax = self.linear_region.selection() + vmin, vmax = vmin / self._scale_factor, vmax / self._scale_factor + self.vmin, self.vmax = vmin, vmax + + def _image_cmap_handler(self, ev): + self.vmin, self.vmax = ev.pick_info["vmin"], ev.pick_info["vmax"] + + def _block_events(self, b: bool): + self.image_graphic.cmap.block_events(b) + self.linear_region.selection.block_events(b) - self.image_graphic.cmap.block_events(True) + @property + def vmin(self) -> float: + return self._vmin + + @vmin.setter + def vmin(self, value: float): + self._block_events(True) + + # must use world coordinate values directly from selection() + # otherwise the linear region bounds jump to the closest bin edges + self.linear_region.selection = (value * self._scale_factor, self.linear_region.selection()[1]) + self.image_graphic.cmap.vmin = value - self.image_graphic.cmap.vmin = selected[0] - self.image_graphic.cmap.vmax = selected[-1] + self._block_events(False) - self.image_graphic.cmap.block_events(False) + self._vmin = value - def _set_selection_from_cmap(self, ev): - vmin, vmax = ev.pick_info["vmin"], ev.pick_info["vmax"] + @property + def vmax(self) -> float: + return self._vmax + + @vmax.setter + def vmax(self, value: float): + self._block_events(True) - self.linear_region.selection.block_events(True) + # must use world coordinate values directly from selection() + # otherwise the linear region bounds jump to the closest bin edges + self.linear_region.selection = (self.linear_region.selection()[0], value * self._scale_factor) + self.image_graphic.cmap.vmax = value - self.linear_region.selection = (vmin, vmax) + self._block_events(False) - self.linear_region.selection.block_events(False) + self._vmax = value - def set_data(self, data): + def set_data(self, data, reset_vmin_vmax: bool = True): hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) line_data = np.column_stack([hist_scaled, edges_flanked]) @@ -127,12 +195,19 @@ def set_data(self, data): bounds = (edges[0], edges[-1]) limits = (edges_flanked[0], edges_flanked[-11]) origin = (hist_scaled.max() / 2, 0) - - self.linear_region.limits = limits - self.linear_region.selection = bounds # self.linear_region.fill.world.position = (*origin, -2) - # def nbins(self): + if reset_vmin_vmax: + # reset according to the new data + self.linear_region.limits = limits + self.linear_region.selection = bounds + else: + # don't change the current selection + self._block_events(True) + self.linear_region.limits = limits + self._block_events(False) + + self._data = weakref.proxy(data) @property def image_graphic(self) -> ImageGraphic: @@ -145,4 +220,16 @@ def image_graphic(self, graphic): f"HistogramLUT can only use ImageGraphic types, you have passed: {type(graphic)}" ) + # cleanup events from current image graphic + self._image_graphic.cmap.remove_event_handler( + self._image_cmap_handler + ) + self._image_graphic = graphic + + self.image_graphic.cmap.add_event_handler(self._image_cmap_handler) + + def _cleanup(self): + self.linear_region._cleanup() + del self.line + del self.linear_region diff --git a/fastplotlib/widgets/image.py b/fastplotlib/widgets/image.py index cc857a07a..d421de28b 100644 --- a/fastplotlib/widgets/image.py +++ b/fastplotlib/widgets/image.py @@ -9,7 +9,6 @@ VBox, HBox, Layout, - FloatRangeSlider, Button, BoundedIntText, Play, @@ -21,6 +20,7 @@ from ..layouts import GridPlot from ..graphics import ImageGraphic from ..utils import quick_min_max, calculate_gridshape +from .histogram_lut import HistogramLUT DEFAULT_DIMS_ORDER = { @@ -220,7 +220,6 @@ def __init__( slider_dims: Union[str, int, List[Union[str, int]]] = None, window_funcs: Union[int, Dict[str, int]] = None, frame_apply: Union[callable, Dict[int, callable]] = None, - vmin_vmax_sliders: bool = False, grid_shape: Tuple[int, int] = None, names: List[str] = None, grid_plot_kwargs: dict = None, @@ -527,8 +526,6 @@ def __init__( # current_index stores {dimension_index: slice_index} for every dimension self._current_index: Dict[str, int] = {sax: 0 for sax in self.slider_dims} - self.vmin_vmax_sliders: List[FloatRangeSlider] = list() - # get max bound for all data arrays for all dimensions self._dims_max_bounds: Dict[str, int] = {k: np.inf for k in self.slider_dims} for _dim in list(self._dims_max_bounds.keys()): @@ -543,34 +540,10 @@ def __init__( self._gridplot: GridPlot = GridPlot(shape=grid_shape, **grid_plot_kwargs) for data_ix, (d, subplot) in enumerate(zip(self.data, self.gridplot)): - minmax = quick_min_max(self.data[data_ix]) - if self._names is not None: name = self._names[data_ix] - name_slider = name else: name = None - name_slider = "" - - if vmin_vmax_sliders: - data_range = np.ptp(minmax) - data_range_40p = np.ptp(minmax) * 0.4 - - minmax_slider = FloatRangeSlider( - value=minmax, - min=minmax[0] - data_range_40p, - max=minmax[1] + data_range_40p, - step=data_range / 150, - description=f"mm: {name_slider}", - readout=True, - readout_format=".3f", - ) - - minmax_slider.observe( - partial(self._vmin_vmax_slider_changed, data_ix), names="value" - ) - - self.vmin_vmax_sliders.append(minmax_slider) frame = self._process_indices(d, slice_indices=self._current_index) frame = self._process_frame_apply(frame, data_ix) @@ -579,6 +552,17 @@ def __init__( subplot.name = name subplot.set_title(name) + hlut = HistogramLUT( + data=d, + image_graphic=ig, + name="histogram_lut" + ) + + subplot.docks["right"].add_graphic(hlut) + subplot.docks["right"].size = 50 + subplot.docks["right"].auto_scale(maintain_aspect=False) + subplot.docks["right"].controller.enabled = False + self.gridplot.renderer.add_event_handler(self._set_slider_layout, "resize") for sdm in self.slider_dims: @@ -601,7 +585,7 @@ def __init__( # TODO: So just stack everything vertically for now self._vbox_sliders = VBox( - [*list(self._sliders.values()), *self.vmin_vmax_sliders] + [*list(self._sliders.values())] ) @property @@ -795,45 +779,11 @@ def _slider_value_changed(self, dimension: str, change: dict): return self.current_index = {dimension: change["new"]} - def _vmin_vmax_slider_changed(self, data_ix: int, change: dict): - vmin, vmax = change["new"] - self.managed_graphics[data_ix].cmap.vmin = vmin - self.managed_graphics[data_ix].cmap.vmax = vmax - def _set_slider_layout(self, *args): w, h = self.gridplot.renderer.logical_size for k, v in self.sliders.items(): v.layout = Layout(width=f"{w}px") - for mm in self.vmin_vmax_sliders: - mm.layout = Layout(width=f"{w}px") - - def _get_vmin_vmax_range(self, data: np.ndarray) -> tuple: - """ - Parameters - ---------- - data - - Returns - ------- - Tuple[Tuple[float, float], float, float, float] - (min, max), data_range, min - (data_range * 0.4), max + (data_range * 0.4) - """ - - minmax = quick_min_max(data) - - data_range = np.ptp(minmax) - data_range_40p = data_range * 0.4 - - _range = ( - minmax, - data_range, - minmax[0] - data_range_40p, - minmax[1] + data_range_40p, - ) - - return _range - def reset_vmin_vmax(self): """ Reset the vmin and vmax w.r.t. the currently displayed image(s) @@ -914,6 +864,9 @@ def set_data( if new_array.ndim > 3: # tzxy max_lengths["z"] = min(max_lengths["z"], new_array.shape[1] - 1) + # set histogram widget + subplot.docks["right"]["histogram_lut"].set_data(new_array, reset_vmin_vmax=reset_vmin_vmax) + # set slider maxes # TODO: maybe make this stuff a property, like ndims, n_frames etc. and have it set the sliders for key in self.sliders.keys(): @@ -923,8 +876,8 @@ def set_data( # force graphics to update self.current_index = self.current_index - if reset_vmin_vmax: - self.reset_vmin_vmax() + # if reset_vmin_vmax: + # self.reset_vmin_vmax() def show(self, toolbar: bool = True, sidecar: bool = True, sidecar_kwargs: dict = None): """ From 724da88591f5387f650d9d60ca219867f0dd89ec Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 27 Oct 2023 21:40:13 -0400 Subject: [PATCH 6/6] fix image widget example nb --- examples/notebooks/image_widget.ipynb | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/notebooks/image_widget.ipynb b/examples/notebooks/image_widget.ipynb index 5b7de6145..449f13229 100644 --- a/examples/notebooks/image_widget.ipynb +++ b/examples/notebooks/image_widget.ipynb @@ -44,7 +44,6 @@ "source": [ "iw = ImageWidget(\n", " data=a,\n", - " vmin_vmax_sliders=True,\n", " cmap=\"viridis\"\n", ")" ] @@ -113,7 +112,6 @@ "iw = ImageWidget(\n", " data=a, \n", " slider_dims=[\"t\"],\n", - " vmin_vmax_sliders=True,\n", " cmap=\"gnuplot2\"\n", ")" ] @@ -247,7 +245,6 @@ " data=data, \n", " slider_dims=[\"t\"], \n", " # dims_order=\"txy\", # you can set this manually if dim order is not the usual\n", - " vmin_vmax_sliders=True,\n", " names=[\"zero\", \"one\", \"two\", \"three\"],\n", " window_funcs={\"t\": (np.mean, 5)},\n", " cmap=\"gnuplot2\", \n", @@ -338,7 +335,6 @@ " data=data, \n", " slider_dims=[\"t\", \"z\"], \n", " dims_order=\"xyzt\", # example of how you can set this for non-standard orders\n", - " vmin_vmax_sliders=True,\n", " names=[\"zero\", \"one\", \"two\", \"three\"],\n", " # window_funcs={\"t\": (np.mean, 5)}, # window functions can be slow when indexing multiple dims\n", " cmap=\"gnuplot2\", \n", @@ -402,7 +398,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.2" } }, "nbformat": 4,