diff --git a/CHANGELOG.rst b/CHANGELOG.rst index fafd2e0b..f6199857 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,12 +1,12 @@ -0.0.2 +0.4.0 ===== -New features ------------- -- `HistogramWidget` now shows individual histograms for RGB channels when - present. - - -Bug fixes ---------- -- `HistogramWidget` now works properly with 2D images. +Changes +------- +- The scatter widgets no longer use a LogNorm() for 2D histogram scaling. + This is to move the widget in line with the philosophy of using Matplotlib default + settings throughout ``napari-matplotlib``. This still leaves open the option of + adding the option to change the normalization in the future. If this is something + you would be interested in please open an issue at https://github.com/matplotlib/napari-matplotlib. +- Labels plotting with the features scatter widget no longer have underscores + replaced with spaces. diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index cb1e8498..405b7b09 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,6 +1,5 @@ from typing import Any, List, Optional, Tuple -import matplotlib.colors as mcolor import napari import numpy.typing as npt from magicgui import magicgui @@ -17,15 +16,8 @@ class ScatterBaseWidget(NapariMPLWidget): Base class for widgets that scatter two datasets against each other. """ - # opacity value for the markers - _marker_alpha = 0.5 - - # flag set to True if histogram should be used - # for plotting large points - _histogram_for_large_data = True - # if the number of points is greater than this value, - # the scatter is plotted as a 2dhist + # the scatter is plotted as a 2D histogram _threshold_to_switch_to_histogram = 500 def __init__(self, napari_viewer: napari.viewer.Viewer): @@ -44,40 +36,32 @@ def draw(self) -> None: """ Scatter the currently selected layers. """ - data, x_axis_name, y_axis_name = self._get_data() - - if len(data) == 0: - # don't plot if there isn't data - return + x, y, x_axis_name, y_axis_name = self._get_data() - if self._histogram_for_large_data and ( - data[0].size > self._threshold_to_switch_to_histogram - ): + if x.size > self._threshold_to_switch_to_histogram: self.axes.hist2d( - data[0].ravel(), - data[1].ravel(), + x.ravel(), + y.ravel(), bins=100, - norm=mcolor.LogNorm(), ) else: - self.axes.scatter(data[0], data[1], alpha=self._marker_alpha) + self.axes.scatter(x, y, alpha=0.5) self.axes.set_xlabel(x_axis_name) self.axes.set_ylabel(y_axis_name) - def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: - """Get the plot data. + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data. This must be implemented on the subclass. Returns ------- - data : np.ndarray - The list containing the scatter plot data. - x_axis_name : str - The label to display on the x axis - y_axis_name: str - The label to display on the y axis + x, y : np.ndarray + x and y values of plot data. + x_axis_name, y_axis_name : str + Label to display on the x/y axis """ raise NotImplementedError @@ -93,7 +77,7 @@ class ScatterWidget(ScatterBaseWidget): n_layers_input = Interval(2, 2) input_layer_types = (napari.layers.Image,) - def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: """ Get the plot data. @@ -106,11 +90,12 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: y_axis_name: str The title to display on the y axis """ - data = [layer.data[self.current_z] for layer in self.layers] + x = self.layers[0].data[self.current_z] + y = self.layers[1].data[self.current_z] x_axis_name = self.layers[0].name y_axis_name = self.layers[1].name - return data, x_axis_name, y_axis_name + return x, y, x_axis_name, y_axis_name class FeaturesScatterWidget(ScatterBaseWidget): @@ -191,9 +176,33 @@ def _get_valid_axis_keys( else: return self.layers[0].features.keys() - def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: + def _ready_to_scatter(self) -> bool: """ - Get the plot data. + Return True if selected layer has a feature table we can scatter with, + and the two columns to be scatterd have been selected. + """ + if not hasattr(self.layers[0], "features"): + return False + + feature_table = self.layers[0].features + return ( + feature_table is not None + and len(feature_table) > 0 + and self.x_axis_key is not None + and self.y_axis_key is not None + ) + + def draw(self) -> None: + """ + Scatter two features from the currently selected layer. + """ + if self._ready_to_scatter(): + super().draw() + + def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]: + """ + Get the plot data from the ``features`` attribute of the first + selected layer. Returns ------- @@ -207,28 +216,15 @@ def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: The title to display on the y axis. Returns an empty string if nothing to plot. """ - if not hasattr(self.layers[0], "features"): - # if the selected layer doesn't have a featuretable, - # skip draw - return [], "", "" - feature_table = self.layers[0].features - if ( - (len(feature_table) == 0) - or (self.x_axis_key is None) - or (self.y_axis_key is None) - ): - return [], "", "" - - data_x = feature_table[self.x_axis_key] - data_y = feature_table[self.y_axis_key] - data = [data_x, data_y] + x = feature_table[self.x_axis_key] + y = feature_table[self.y_axis_key] - x_axis_name = self.x_axis_key.replace("_", " ") - y_axis_name = self.y_axis_key.replace("_", " ") + x_axis_name = str(self.x_axis_key) + y_axis_name = str(self.y_axis_key) - return data, x_axis_name, y_axis_name + return x, y, x_axis_name, y_axis_name def _on_update_layers(self) -> None: """ diff --git a/src/napari_matplotlib/tests/test_scatter.py b/src/napari_matplotlib/tests/test_scatter.py index fe07655d..88e0584c 100644 --- a/src/napari_matplotlib/tests/test_scatter.py +++ b/src/napari_matplotlib/tests/test_scatter.py @@ -39,7 +39,9 @@ def make_labels_layer_with_features() -> ( def test_features_scatter_get_data(make_napari_viewer): - """Test the get data method""" + """ + Test the get data method. + """ # make the label image label_image, feature_table = make_labels_layer_with_features() @@ -55,17 +57,16 @@ def test_features_scatter_get_data(make_napari_viewer): y_column = "feature_2" scatter_widget.y_axis_key = y_column - data, x_axis_name, y_axis_name = scatter_widget._get_data() - np.testing.assert_allclose( - data, np.stack((feature_table[x_column], feature_table[y_column])) - ) - assert x_axis_name == x_column.replace("_", " ") - assert y_axis_name == y_column.replace("_", " ") + x, y, x_axis_name, y_axis_name = scatter_widget._get_data() + np.testing.assert_allclose(x, feature_table[x_column]) + np.testing.assert_allclose(y, np.stack(feature_table[y_column])) + assert x_axis_name == x_column + assert y_axis_name == y_column def test_get_valid_axis_keys(make_napari_viewer): - """Test the values returned from - FeaturesScatterWidget._get_valid_keys() when there + """ + Test the values returned from _get_valid_keys() when there are valid keys. """ # make the label image