diff --git a/docs/source/api/widgets/ImageWidget.rst b/docs/source/api/widgets/ImageWidget.rst index 08bce8d7a..2b4708007 100644 --- a/docs/source/api/widgets/ImageWidget.rst +++ b/docs/source/api/widgets/ImageWidget.rst @@ -23,9 +23,11 @@ Properties ImageWidget.cmap ImageWidget.current_index ImageWidget.data - ImageWidget.dims_order + ImageWidget.frame_apply ImageWidget.gridplot ImageWidget.managed_graphics + ImageWidget.n_img_dims + ImageWidget.n_scrollable_dims ImageWidget.ndim ImageWidget.slider_dims ImageWidget.sliders diff --git a/examples/notebooks/image_widget.ipynb b/examples/notebooks/image_widget.ipynb index 56d5c8a81..a7527601a 100644 --- a/examples/notebooks/image_widget.ipynb +++ b/examples/notebooks/image_widget.ipynb @@ -115,7 +115,6 @@ "source": [ "iw_movie = ImageWidget(\n", " data=gray_movie, \n", - " slider_dims=[\"t\"],\n", " cmap=\"gray\"\n", ")" ] diff --git a/examples/notebooks/image_widget_test.ipynb b/examples/notebooks/image_widget_test.ipynb index c236ce9b7..39cf0b887 100644 --- a/examples/notebooks/image_widget_test.ipynb +++ b/examples/notebooks/image_widget_test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "07019035-83f2-4753-9e7c-628ae439b441", "metadata": { "tags": [] @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "10b8ab40-944d-472c-9b7e-cae8a129e7ce", "metadata": {}, "outputs": [], @@ -130,7 +130,6 @@ "source": [ "iw_movie = ImageWidget(\n", " data=gray_movie, \n", - " slider_dims=[\"t\"],\n", " cmap=\"gray\",\n", " grid_plot_kwargs={\"size\": (900, 600)},\n", ")" @@ -275,9 +274,6 @@ "execution_count": null, "id": "76535d56-e514-4c16-aa48-a6359f8019d5", "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [] }, "outputs": [], @@ -444,23 +440,66 @@ "iw_z.close()" ] }, + { + "cell_type": "markdown", + "id": "6716f255-44c2-400d-a2bf-254683e4cd9d", + "metadata": {}, + "source": [ + "# Test Mixed Shapes, RGB (and set data)" + ] + }, { "cell_type": "code", - "execution_count": null, - "id": "870627ef-09d8-44e4-8952-aedb702d1526", + "execution_count": 30, + "id": "ed783360-992d-40f8-bb6f-152a59edff43", "metadata": {}, "outputs": [], "source": [ - "notebook_finished()" + "zfish_data = np.load(\"./zfish_test.npy\")\n", + "zfish_frame_1 = zfish_data[0, 0, :, :]\n", + "zfish_frame_2 = zfish_data[20, 3, :, :]\n", + "movie = iio.imread(\"imageio:cockatoo.mp4\")\n", + "\n", + "iw_mixed_shapes = ImageWidget(\n", + " data=[zfish_frame_1, movie], # you can also provide a list of tzxy arrays\n", + " rgb=[False, True],\n", + " histogram_widget=True,\n", + " cmap=\"gnuplot2\", \n", + " grid_plot_kwargs = {\"controller_ids\": None},\n", + ")\n", + "\n", + "iw_mixed_shapes.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "274c67b4-aa07-4fcf-a094-1b1e70d0378a", + "metadata": {}, + "outputs": [], + "source": [ + "iw_mixed_shapes.sliders[\"t\"].value = 50\n", + "plot_test(\"image-widget-zfish-mixed-rgb-cockatoo-frame-50\", iw_mixed_shapes.gridplot)\n", + "\n", + "#Set the data, changing the first array and also the size of the \"T\" slider\n", + "iw_mixed_shapes.set_data([zfish_frame_2, movie[:200, :, :, :]], reset_indices=True)\n", + "plot_test(\"image-widget-zfish-mixed-rgb-cockatoo-set-data\", iw_mixed_shapes.gridplot)\n", + "\n", + "#Check how a window function might work on the RGB data\n", + "iw_mixed_shapes.window_funcs = {\"t\": (np.mean, 4)}\n", + "iw_mixed_shapes.sliders[\"t\"].value = 20\n", + "plot_test(\"image-widget-zfish-mixed-rgb-cockatoo-windowrgb\", iw_mixed_shapes.gridplot)" ] }, { "cell_type": "code", "execution_count": null, - "id": "b8fff1a6-119e-4f03-ba3a-4c7b9e8c212b", + "id": "870627ef-09d8-44e4-8952-aedb702d1526", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "notebook_finished()" + ] } ], "metadata": { @@ -479,7 +518,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-frame-50.png b/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-frame-50.png new file mode 100644 index 000000000..5e0750ac8 --- /dev/null +++ b/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-frame-50.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f8f74a0a5fa24e10a88d3723836306913243fa5fc23f46f44bbdae4c0209075 +size 58878 diff --git a/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-set-data.png b/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-set-data.png new file mode 100644 index 000000000..8df83fe33 --- /dev/null +++ b/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-set-data.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0809b2dda0e773b7f100386f97144c40d36d51cd935c86ef1dcd4a938fce3981 +size 56319 diff --git a/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-windowrgb.png b/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-windowrgb.png new file mode 100644 index 000000000..5bbefc7ae --- /dev/null +++ b/examples/notebooks/screenshots/nb-image-widget-zfish-mixed-rgb-cockatoo-windowrgb.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a2e2e2cf7ac6be1a4fccec54494c3fd48af673765653675438fa2469c549e90c +size 55055 diff --git a/fastplotlib/widgets/image.py b/fastplotlib/widgets/image.py index acef26a7d..86671b1fc 100644 --- a/fastplotlib/widgets/image.py +++ b/fastplotlib/widgets/image.py @@ -3,20 +3,30 @@ import numpy as np - from ..layouts import GridPlot from ..graphics import ImageGraphic from ..utils import calculate_gridshape from .histogram_lut import HistogramLUT -DEFAULT_DIMS_ORDER = { - 2: "xy", - 3: "txy", - 4: "tzxy", - 5: "tzcxy", +# Number of dimensions that represent one image/one frame. For grayscale shape will be [x, y], i.e. 2 dims, for RGB(A) +# shape will be [x, y, c] where c is of size 3 (RGB) or 4 (RGBA) +IMAGE_DIM_COUNTS = {"gray": 2, "rgb": 3} + +# Map boolean (indicating whether we use RGB or grayscale) to the string. Used to index RGB_DIM_MAP +RGB_BOOL_MAP = {False: "gray", True: "rgb"} + +# Dimensions that can be scrolled from a given data array +SCROLLABLE_DIMS_ORDER = { + 0: "", + 1: "t", + 2: "tz", } +ALLOWED_SLIDER_DIMS = {0: "t", 1: "z"} + +ALLOWED_WINDOW_DIMS = {"t", "z"} + def _is_arraylike(obj) -> bool: """ @@ -149,13 +159,16 @@ def data(self) -> List[np.ndarray]: @property def ndim(self) -> int: - """number of dimensions in the image data displayed in the widget""" + """Number of dimensions of grayscale data displayed in the widget (it will be 1 more for RGB(A) data)""" return self._ndim @property - def dims_order(self) -> List[str]: - """dimension order of the data displayed in the widget""" - return self._dims_order + def n_scrollable_dims(self) -> List[int]: + """ + list indicating the number of dimenensions that are scrollable for each data array + All other dimensions are frame/image data, i.e. [x, y] or [x, y, c] + """ + return self._n_scrollable_dims @property def sliders(self) -> Dict[str, Any]: @@ -184,6 +197,56 @@ def current_index(self) -> Dict[str, int]: """ return self._current_index + @property + def n_img_dims(self) -> list[int]: + """ + list indicating the number of dimensions that contain image/single frame data for each data array. + if 2: data are grayscale, i.e. [x, y] dims, if 3: data are [x, y, c] where c is RGB or RGBA, + this is the complement of `n_scrollable_dims` + """ + return self._n_img_dims + + def _get_n_scrollable_dims(self, curr_arr: np.ndarray, rgb: bool) -> list[int]: + """ + For a given ``array`` displayed in the ImageWidget, this function infers how many of the dimensions are + supported by sliders (aka scrollable). Ex: "xy" data has 0 scrollable dims, "txy" has 1, "tzxy" has 2. + + Parameters + ---------- + curr_arr: np.ndarray + np.ndarray or a list of array-like + + rgb: bool + True if we view this as RGB(A) and False if grayscale + + Returns + ------- + int + Number of scrollable dimensions for each ``array`` in the dataset. + """ + + n_img_dims = IMAGE_DIM_COUNTS[RGB_BOOL_MAP[rgb]] + # Make sure each image stack at least ``n_img_dims`` dimensions + if len(curr_arr.shape) < n_img_dims: + raise ValueError( + f"Your array has shape {curr_arr.shape} " + f"but you specified that each image in your array is {n_img_dims}D " + ) + + # If RGB(A), last dim must be 3 or 4 + if n_img_dims == 3: + if not (curr_arr.shape[-1] == 3 or curr_arr.shape[-1] == 4): + raise ValueError( + f"Expected size 3 or 4 for last dimension of RGB(A) array, got: {curr_arr.shape[-1]}." + ) + + n_scrollable_dims = len(curr_arr.shape) - n_img_dims + + if n_scrollable_dims not in SCROLLABLE_DIMS_ORDER.keys(): + raise ValueError(f"Array had shape {curr_arr.shape} which is not supported") + + return n_scrollable_dims + @current_index.setter def current_index(self, index: Dict[str, int]): # ignore if output context has not been created yet @@ -223,30 +286,29 @@ def current_index(self, index: Dict[str, int]): def __init__( self, data: Union[np.ndarray, List[np.ndarray]], - dims_order: Union[str, Dict[int, str]] = None, - slider_dims: Union[str, int, List[Union[str, int]]] = None, window_funcs: Union[int, Dict[str, int]] = None, frame_apply: Union[callable, Dict[int, callable]] = None, grid_shape: Tuple[int, int] = None, names: List[str] = None, grid_plot_kwargs: dict = None, histogram_widget: bool = True, + rgb: list[bool] = None, **kwargs, ): """ - A high level widget for displaying n-dimensional image data in conjunction with automatically generated - sliders for navigating through 1-2 selected dimensions within image data. - - Can display a single n-dimensional image array or a grid of n-dimensional images. + This widget facilitates high-level navigation through image stacks, which are arrays containing one or more + images. It includes sliders for key dimensions such as "t" (time) and "z", enabling users to smoothly navigate + through one or multiple image stacks simultaneously. - Default dimension orders: + Allowed dimensions orders for each image stack: Note that each has a an optional (c) channel which refers to + RGB(A) a channel. So this channel should be either 3 or 4. ======= ========== n_dims dims order ======= ========== - 2 "xy" - 3 "txy" - 4 "tzxy" + 2 "xy(c)" + 3 "txy(c)" + 4 "tzxy(c)" ======= ========== Parameters @@ -254,31 +316,18 @@ def __init__( data: Union[np.ndarray, List[np.ndarray] array-like or a list of array-like - dims_order: Optional[Union[str, Dict[np.ndarray, str]]] - | ``str`` or a dict mapping to indicate dimension order - | a single ``str`` if ``data`` is a single array, or a list of arrays with the same dimension order - | examples: ``"xyt"``, ``"tzxy"`` - | ``dict`` mapping of ``{array_index: axis_order}`` if specific arrays have a non-default axes order. - | "array_index" is the position of the corresponding array in the data list. - | examples: ``{array_index: "tzxy", another_array_index: "xytz"}`` - - slider_dims: Optional[Union[str, int, List[Union[str, int]]]] - | The dimensions for which to create a slider - | can be a single ``str`` such as **"t"**, **"z"** or a numerical ``int`` that indexes the desired dimension - | can also be a list of ``str`` or ``int`` if multiple sliders are desired for multiple dimensions - | examples: ``"t"``, ``["t", "z"]`` - window_funcs: Dict[Union[int, str], int] - | average one or more dimensions using a given window - | if a slider exists for only one dimension this can be an ``int``. - | if multiple sliders exist, then it must be a `dict`` mapping in the form of: ``{dimension: window_size}`` - | dimension/axes can be specified using ``str`` such as "t", "z" etc. or ``int`` that indexes the dimension - | if window_size is not an odd number, adds 1 - | use ``None`` to disable averaging for a dimension, example: ``{"t": 5, "z": None}`` + | Apply function(s) with rolling windows along "t" and/or "z" dimensions of the `data` arrays. + | Pass a dict in the form: {dimension: (func, window_size)}, `func` must take a slice of the data array as the + | first argument and must take `axis` as a kwarg. + | Ex: mean along "t" dimension: {"t": (np.mean, 11)}, if `current_index` of "t" is 50, it will pass frames + | 45 to 55 to `np.mean` with `axis = 0`. + | Ex2: max along z dim: {"z": (np.max, 3)}, passes current, previous and next frame to `np.max` with `axis = 1` frame_apply: Union[callable, Dict[int, callable]] - | apply a function to slices of the array before displaying the frame - | pass a single function or a dict of functions to apply to each array individually + | Apply function(s) to `data` arrays before to generate final 2D image that is displayed. + | Ex: apply a spatial Gaussian filter + | Pass a single function or a dict of functions to apply to each array individually | examples: ``{array_index: to_grayscale}``, ``{0: to_grayscale, 2: threshold_img}`` | "array_index" is the position of the corresponding array in the data list. | if `window_funcs` is used, then this function is applied after `window_funcs` @@ -297,19 +346,27 @@ def __init__( histogram_widget: bool, default False make histogram LUT widget for each subplot + rgb: bool | list[bool], default None + Includes a True or False for each ``array`` in the ImageWidget, indicating whether images are displayed as + grayscale or RGB(A). + kwargs: Any passed to fastplotlib.graphics.Image """ - self._names = None # output context self._output = None + if _is_arraylike(data): + data = [data] + if isinstance(data, list): # verify that it's a list of np.ndarray if all([_is_arraylike(d) for d in data]): + + # Grid computations if grid_shape is None: grid_shape = calculate_gridshape(len(data)) @@ -320,17 +377,44 @@ def __init__( f"Invalid `grid_shape` passed, setting grid shape to: {grid_shape}" ) - _ndim = [d.ndim for d in data] + self._data: List[np.ndarray] = data - # verify that all image arrays have same number of dimensions - # sliders get messy otherwise - if not len(set(_ndim)) == 1: + # Establish number of image dimensions and number of scrollable dimensions for each array + if rgb is None: + rgb = [False] * len(self.data) + if rgb is bool: + rgb = [rgb] + if not isinstance(rgb, list): + raise TypeError( + f"rgb_disp parameter must be a list, a {type(rgb)} was provided" + ) + if not len(rgb) == len(self.data): raise ValueError( - f"Number of dimensions of all data arrays must match, your ndims are: {_ndim}" + f"rgb had length {len(rgb)} but there are {len(self.data)} data arrays; these must be equal" ) - self._data: List[np.ndarray] = data - self._ndim = self.data[0].ndim # all ndim must be same + self._rgb = rgb + + self._n_img_dims = [ + IMAGE_DIM_COUNTS[RGB_BOOL_MAP[self._rgb[i]]] + for i in range(len(self.data)) + ] + + self._n_scrollable_dims = [ + self._get_n_scrollable_dims(self.data[i], self._rgb[i]) + for i in range(len(self.data)) + ] + + # Define ndim of ImageWidget instance as largest number of scrollable dims + 2 (grayscale dimensions) + self._ndim = ( + max( + [ + self.n_scrollable_dims[i] + for i in range(len(self.n_scrollable_dims)) + ] + ) + + IMAGE_DIM_COUNTS[RGB_BOOL_MAP[False]] + ) if names is not None: if not all([isinstance(n, str) for n in names]): @@ -351,12 +435,6 @@ def __init__( f"You have passed the following types:\n" f"{[type(a) for a in data]}" ) - - elif _is_arraylike(data): - self._data = [data] - self._ndim = self.data[0].ndim - - grid_shape = calculate_gridshape(len(self._data)) else: raise TypeError( f"`data` must be an array-like type representing an n-dimensional image " @@ -364,149 +442,20 @@ def __init__( f"You have passed the following type {type(data)}" ) - # default dims order if not passed - # updated later if passed - self._dims_order: List[str] = [DEFAULT_DIMS_ORDER[self.ndim]] * len(self.data) - - if dims_order is not None: - if isinstance(dims_order, str): - dims_order = dims_order.lower() - if len(dims_order) != self.ndim: - raise ValueError( - f"number of dims '{len(dims_order)} passed to `dims_order` " - f"does not match ndim '{self.ndim}' of data" - ) - self._dims_order: List[str] = [dims_order] * len(self.data) - elif isinstance(dims_order, dict): - self._dims_order: List[str] = [DEFAULT_DIMS_ORDER[self.ndim]] * len( - self.data - ) - - # dict of {array_ix: dims_order_str} - for data_ix in list(dims_order.keys()): - if not isinstance(data_ix, int): - raise TypeError("`dims_order` dict keys must be ") - if len(dims_order[data_ix]) != self.ndim: - raise ValueError( - f"number of dims '{len(dims_order)} passed to `dims_order` " - f"does not match ndim '{self.ndim}' of data" - ) - _do = dims_order[data_ix].lower() - # make sure the same dims are present - if not set(_do) == set(DEFAULT_DIMS_ORDER[self.ndim]): - raise ValueError( - f"Invalid `dims_order` passed for one of your arrays, " - f"valid `dims_order` for given number of dimensions " - f"can only contain the following characters: " - f"{DEFAULT_DIMS_ORDER[self.ndim]}" - ) - try: - self.dims_order[data_ix] = _do - except Exception: - raise IndexError( - f"index {data_ix} out of bounds for `dims_order`, the bounds are 0 - {len(self.data)}" - ) - else: - raise TypeError( - f"`dims_order` must be a or , you have passed a: <{type(dims_order)}>" - ) - - if not len(self.dims_order[0]) == self.ndim: - raise ValueError( - f"Number of dims specified by `dims_order`: {len(self.dims_order[0])} does not" - f" match number of dimensions in the `data`: {self.ndim}" - ) - - ao = np.array([sorted(v) for v in self.dims_order]) - - if not np.all(ao == ao[0]): - raise ValueError( - f"`dims_order` for all arrays must contain the same combination of dimensions, your `dims_order` are: " - f"{self.dims_order}" - ) - - # if slider_dims not provided - if slider_dims is None: - # by default sliders are made for all dimensions except the last 2 - default_dim_names = {0: "t", 1: "z", 2: "c"} - slider_dims = list() - for dim in range(self.ndim - 2): - if dim in default_dim_names.keys(): - slider_dims.append(default_dim_names[dim]) - else: - slider_dims.append(f"{dim}") - - # slider for only one of the dimensions - if isinstance(slider_dims, (int, str)): - # if numerical dimension is specified - if isinstance(slider_dims, int): - ao = np.array([v for v in self.dims_order]) - if not np.all(ao == ao[0]): - raise ValueError( - f"`dims_order` for all arrays must be identical if passing in a `slider_dims` argument. " - f"Pass in a argument if the `dims_order` are different for each array." - ) - self._slider_dims: List[str] = [self.dims_order[0][slider_dims]] - - # if dimension specified by str - elif isinstance(slider_dims, str): - if slider_dims not in self.dims_order[0]: - raise ValueError( - f"if `slider_dims` is a , it must be a character found in `dims_order`. " - f"Your `dims_order` characters are: {set(self.dims_order[0])}." - ) - self._slider_dims: List[str] = [slider_dims] - - # multiple sliders, one for each dimension - elif isinstance(slider_dims, list): - self._slider_dims: List[str] = list() - - # make sure window_funcs and frame_apply are dicts if multiple sliders are desired - if (not isinstance(window_funcs, dict)) and (window_funcs is not None): - raise TypeError( - f"`window_funcs` must be a if multiple `slider_dims` are provided. You must specify the " - f"window for each dimension." - ) - if (not isinstance(frame_apply, dict)) and (frame_apply is not None): - raise TypeError( - f"`frame_apply` must be a if multiple `slider_dims` are provided. You must specify a " - f"function for each dimension." - ) - - for sdm in slider_dims: - if isinstance(sdm, int): - ao = np.array([v for v in self.dims_order]) - if not np.all(ao == ao[0]): - raise ValueError( - f"`dims_order` for all arrays must be identical if passing in a `slider_dims` argument. " - f"Pass in a argument if the `dims_order` are different for each array." - ) - # parse int to a str - self.slider_dims.append(self.dims_order[0][sdm]) - - elif isinstance(sdm, str): - if sdm not in self.dims_order[0]: - raise ValueError( - f"if `slider_dims` is a , it must be a character found in `dims_order`. " - f"Your `dims_order` characters are: {set(self.dims_order[0])}." - ) - self.slider_dims.append(sdm) - - else: - raise TypeError( - "If passing a list for `slider_dims` each element must be either an or " - ) - - else: - raise TypeError( - f"`slider_dims` must a , or , you have passed a: {type(slider_dims)}" - ) + # Sliders are made for all dimensions except the image dimensions + self._slider_dims = list() + max_scrollable = max( + [self.n_scrollable_dims[i] for i in range(len(self.n_scrollable_dims))] + ) + for dim in range(max_scrollable): + if dim in ALLOWED_SLIDER_DIMS.keys(): + self.slider_dims.append(ALLOWED_SLIDER_DIMS[dim]) self._frame_apply: Dict[int, callable] = dict() if frame_apply is not None: if callable(frame_apply): - self._frame_apply = {0: frame_apply} + self._frame_apply = frame_apply elif isinstance(frame_apply, dict): self._frame_apply: Dict[int, callable] = dict.fromkeys( @@ -537,13 +486,19 @@ def __init__( self._sliders: Dict[str, Any] = dict() - # get max bound for all data arrays for all dimensions - self._dims_max_bounds: Dict[str, int] = {k: np.inf for k in self.slider_dims} - for _dim in list(self._dims_max_bounds.keys()): - for array, order in zip(self.data, self.dims_order): - self._dims_max_bounds[_dim] = min( - self._dims_max_bounds[_dim], array.shape[order.index(_dim)] - ) + # get max bound for all data arrays for all slider dimensions and ensure compatibility across slider dims + self._dims_max_bounds: Dict[str, int] = {k: 0 for k in self.slider_dims} + for i, _dim in enumerate(list(self._dims_max_bounds.keys())): + for array, partition in zip(self.data, self.n_scrollable_dims): + if partition <= i: + continue + else: + if 0 < self._dims_max_bounds[_dim] != array.shape[i]: + raise ValueError(f"Two arrays differ along dimension {_dim}") + else: + self._dims_max_bounds[_dim] = max( + self._dims_max_bounds[_dim], array.shape[i] + ) grid_plot_kwargs_default = {"controller_ids": "sync"} if grid_plot_kwargs is None: @@ -557,6 +512,7 @@ def __init__( shape=grid_shape, **grid_plot_kwargs_default ) + self._histogram_widget = histogram_widget for data_ix, (d, subplot) in enumerate(zip(self.data, self.gridplot)): if self._names is not None: name = self._names[data_ix] @@ -570,7 +526,7 @@ def __init__( subplot.name = name subplot.set_title(name) - if histogram_widget: + if self._histogram_widget: hlut = HistogramLUT(data=d, image_graphic=ig, name="histogram_lut") subplot.docks["right"].add_graphic(hlut) @@ -607,62 +563,54 @@ def window_funcs(self) -> Dict[str, _WindowFunctions]: return self._window_funcs @window_funcs.setter - def window_funcs(self, sa: Union[int, Dict[str, int]]): - if sa is None: + def window_funcs(self, callable_dict: Dict[str, int]): + if callable_dict is None: self._window_funcs = None # force frame to update self.current_index = self.current_index return - # for a single dim - elif isinstance(sa, tuple): - if len(self.slider_dims) > 1: - raise TypeError( - "Must pass dict argument to window_funcs if using multiple sliders. See the docstring." - ) - if not callable(sa[0]) or not isinstance(sa[1], int): - raise TypeError( - "Tuple argument to `window_funcs` must be in the form of (func, window_size). See the docstring." + elif isinstance(callable_dict, dict): + if not set(callable_dict.keys()).issubset(ALLOWED_WINDOW_DIMS): + raise ValueError( + f"The only allowed keys to window funcs are {list(ALLOWED_WINDOW_DIMS)} " + f"Your window func passed in these keys: {list(callable_dict.keys())}" ) - - dim_str = self.slider_dims[0] - self._window_funcs = dict() - self._window_funcs[dim_str] = _WindowFunctions(self, *sa) - - # for multiple dims - elif isinstance(sa, dict): if not all( - [isinstance(_sa, tuple) or (_sa is None) for _sa in sa.values()] + [ + isinstance(_callable_dict, tuple) + for _callable_dict in callable_dict.values() + ] ): raise TypeError( "dict argument to `window_funcs` must be in the form of: " "`{dimension: (func, window_size)}`. " "See the docstring." ) - for v in sa.values(): - if v is not None: - if not callable(v[0]) or not ( - isinstance(v[1], int) or v[1] is None - ): - raise TypeError( - "dict argument to `window_funcs` must be in the form of: " - "`{dimension: (func, window_size)}`. " - "See the docstring." - ) + for v in callable_dict.values(): + if not callable(v[0]): + raise TypeError( + "dict argument to `window_funcs` must be in the form of: " + "`{dimension: (func, window_size)}`. " + "See the docstring." + ) + if not isinstance(v[1], int): + raise TypeError( + f"dict argument to `window_funcs` must be in the form of: " + "`{dimension: (func, window_size)}`. " + f"where window_size is integer. you passed in {v[1]} for window_size" + ) if not isinstance(self._window_funcs, dict): self._window_funcs = dict() - for k in list(sa.keys()): - if sa[k] is None: - self._window_funcs[k] = None - else: - self._window_funcs[k] = _WindowFunctions(self, *sa[k]) + for k in list(callable_dict.keys()): + self._window_funcs[k] = _WindowFunctions(self, *callable_dict[k]) else: raise TypeError( - f"`window_funcs` must be of type `int` if using a single slider or a dict if using multiple sliders. " - f"You have passed a {type(sa)}. See the docstring." + f"`window_funcs` must be either Nonetype or dict." + f"You have passed a {type(callable_dict)}. See the docstring." ) # force frame to update @@ -684,7 +632,7 @@ def _process_indices( dict in form of {dimension_index: slice_index} For example if an array has shape [1000, 30, 512, 512] corresponding to [t, z, x, y]: To get the 100th timepoint and 3rd z-plane pass: - {"t": 100, "z": 3}, or {0: 100, 1: 3} + {"t": 100, "z": 3} Returns ------- @@ -692,26 +640,32 @@ def _process_indices( array-like, 2D slice """ - indexer = [slice(None)] * self.ndim + + data_ix = None + for i in range(len(self.data)): + if self.data[i] is array: + data_ix = i + break numerical_dims = list() + + # Totally number of dimensions for this specific array + curr_ndim = self.data[data_ix].ndim + + # Initialize slices for each dimension of array + indexer = [slice(None)] * curr_ndim + + # Maps from n_scrollable_dims to one of "", "t", "tz", etc. + curr_scrollable_format = SCROLLABLE_DIMS_ORDER[self.n_scrollable_dims[data_ix]] for dim in list(slice_indices.keys()): - if isinstance(dim, str): - data_ix = None - for i in range(len(self.data)): - if self.data[i] is array: - data_ix = i - break - if data_ix is None: - raise ValueError(f"Given `array` not found in `self.data`") - # get axes order for that specific array - numerical_dim = self.dims_order[data_ix].index(dim) - else: - numerical_dim = dim + if dim not in curr_scrollable_format: + continue + # get axes order for that specific array + numerical_dim = curr_scrollable_format.index(dim) indices_dim = slice_indices[dim] - # takes care of averaging if it was specified + # takes care of index selection (window slicing) for this specific axis indices_dim = self._get_window_indices(data_ix, numerical_dim, indices_dim) # set the indices for this dimension @@ -724,9 +678,9 @@ def _process_indices( if self.window_funcs is not None: a = array for i, dim in enumerate(sorted(numerical_dims)): - dim_str = self.dims_order[data_ix][dim] + dim_str = curr_scrollable_format[dim] dim = dim - i # since we loose a dimension every iteration - _indexer = [slice(None)] * (self.ndim - i) + _indexer = [slice(None)] * (curr_ndim - i) _indexer[dim] = indexer[dim + i] # if the indexer is an int, this dim has no window func @@ -737,7 +691,6 @@ def _process_indices( func = self.window_funcs[dim_str].func window = a[tuple(_indexer)] a = func(window, axis=dim) - # a = np.mean(a[tuple(_indexer)], axis=dim) return a else: return array[tuple(indexer)] @@ -749,7 +702,7 @@ def _get_window_indices(self, data_ix, dim, indices_dim): else: ix = indices_dim - dim_str = self.dims_order[data_ix][dim] + dim_str = SCROLLABLE_DIMS_ORDER[self.n_scrollable_dims[data_ix]][dim] # if no window stuff specified for this dim if dim_str not in self.window_funcs.keys(): @@ -848,9 +801,11 @@ def set_data( self.sliders[key].value = 0 # set slider max according to new data - max_lengths = {"t": np.inf, "z": np.inf} + max_lengths = dict() + for scroll_dim in self.slider_dims: + max_lengths[scroll_dim] = np.inf - if isinstance(new_data, np.ndarray): + if _is_arraylike(new_data): new_data = [new_data] if len(self._data) != len(new_data): @@ -866,16 +821,24 @@ def set_data( f"does not equal current data ndim {current_array.ndim}" ) + # Computes the number of scrollable dims and also validates new_array + new_scrollable_dims = self._get_n_scrollable_dims(new_array, self._rgb[i]) + + if self.n_scrollable_dims[i] != new_scrollable_dims: + raise ValueError( + f"number of dimensions of data arrays must match number of dimensions of " + f"existing data arrays" + ) + # if checks pass, update with new data for i, (new_array, current_array, subplot) in enumerate( zip(new_data, self._data, self.gridplot) ): # check last two dims (x and y) to see if data shape is changing - old_data_shape = self._data[i].shape[-2:] + old_data_shape = self._data[i].shape[-self.n_img_dims[i] :] self._data[i] = new_array - if old_data_shape != new_array.shape[-2:]: - # make a new graphic with the new xy dims + if old_data_shape != new_array.shape[-self.n_img_dims[i] :]: frame = self._process_indices( new_array, slice_indices=self._current_index ) @@ -886,23 +849,31 @@ def set_data( # set hlut tool to use new graphic subplot.docks["right"]["histogram_lut"].image_graphic = new_graphic - # delete old graphic after setting hlut tool to new graphic # this ensures gc subplot.delete_graphic(graphic=subplot["image_widget_managed"]) subplot.insert_graphic(graphic=new_graphic) - if new_array.ndim > 2: - # to set max of time slider, txy or tzxy - max_lengths["t"] = min(max_lengths["t"], new_array.shape[0] - 1) - - if new_array.ndim > 3: # tzxy - max_lengths["z"] = min(max_lengths["z"], new_array.shape[1] - 1) + # Returns "", "t", or "tz" + curr_scrollable_format = SCROLLABLE_DIMS_ORDER[self.n_scrollable_dims[i]] + + for scroll_dim in self.slider_dims: + if scroll_dim in curr_scrollable_format: + new_length = new_array.shape[ + curr_scrollable_format.index(scroll_dim) + ] + if max_lengths[scroll_dim] == np.inf: + max_lengths[scroll_dim] = new_length + elif max_lengths[scroll_dim] != new_length: + raise ValueError( + f"New arrays have differing values along dim {scroll_dim}" + ) # set histogram widget - subplot.docks["right"]["histogram_lut"].set_data( - new_array, reset_vmin_vmax=reset_vmin_vmax - ) + if self._histogram_widget: + subplot.docks["right"]["histogram_lut"].set_data( + new_array, reset_vmin_vmax=reset_vmin_vmax + ) # set slider maxes # TODO: maybe make this stuff a property, like ndims, n_frames etc. and have it set the sliders