From 2a54864595b3305aa3ece7c025b972190a55defd Mon Sep 17 00:00:00 2001 From: Augusto Borges Date: Sun, 8 Oct 2023 18:36:20 +0200 Subject: [PATCH] Scatter ravel is performed before _process_unit_info() is called. --- lib/matplotlib/axes/_axes.py | 4 ++-- lib/matplotlib/tests/test_category.py | 34 ++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 0fcabac8c7c0..323bded4669a 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -4657,14 +4657,14 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, # further processed by the rest of the function linewidths = kwargs.pop('linewidth', None) edgecolors = kwargs.pop('edgecolor', None) - # Process **kwargs to handle aliases, conflicts with explicit kwargs: - x, y = self._process_unit_info([("x", x), ("y", y)], kwargs) # np.ma.ravel yields an ndarray, not a masked array, # unless its argument is a masked array. x = np.ma.ravel(x) y = np.ma.ravel(y) if x.size != y.size: raise ValueError("x and y must be the same size") + # Process **kwargs to handle aliases, conflicts with explicit kwargs: + x, y = self._process_unit_info([("x", x), ("y", y)], kwargs) if s is None: s = (20 if mpl.rcParams['_internal.classic_mode'] else diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index fd4aec88b574..a25a786238ff 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -249,11 +249,13 @@ def test_update_plot(self, plotter): failing_test_cases = [("mixed", ['A', 3.14]), ("number integer", ['1', 1]), ("string integer", ['42', 42]), + ("nested categorical", [["a", "b"], ["c", "d"]]), ("missing", ['12', np.nan])] fids, fvalues = zip(*failing_test_cases) - plotters = [Axes.scatter, Axes.bar, + plotters = [pytest.param(Axes.scatter, marks=pytest.mark.xfail), + Axes.bar, pytest.param(Axes.plot, marks=pytest.mark.xfail)] @pytest.mark.parametrize("plotter", plotters) @@ -321,3 +323,33 @@ def test_set_lim(): ax.plot(["a", "b", "c", "d"], [1, 2, 3, 4]) with warnings.catch_warnings(): ax.set_xlim("b", "c") + + +categorical_examples = [("nested categorical", [["a", "b"], ["c", "d"]]), + ("nested with nan", [['0', np.nan], ["aa", "bb"]]), + ("nested mixed", [[1, 'a'], ['b', np.nan]])] +cids, cvalues = zip(*categorical_examples) + + +@pytest.mark.parametrize("xdata", cvalues, ids=cids) +@pytest.mark.parametrize("ydata", cvalues, ids=cids) +def test_nested_categorical(xdata, ydata): + ax = plt.figure().subplots() + ax.scatter(xdata, ydata) + + xtexts = [xelement._text for xelement in ax.get_xticklabels()] + + assert np.all(xtexts == np.ma.ravel(xdata)) + + +@pytest.mark.parametrize("xdata", cvalues, ids=cids) +def test_nested_categorical_and_numerical(xdata): + ydata = [[0, 1], [2, 3]] + ax = plt.figure().subplots() + splot = ax.scatter(xdata, ydata) + + xtexts = [xelement._text for xelement in ax.get_xticklabels()] + y_offset_processed = list(zip(*splot.get_offsets()))[1] + + assert np.all(xtexts == np.ma.ravel(xdata)) + assert np.all(np.ma.ravel(ydata) == y_offset_processed)