From b11fb2f2618bf7e628ee964dc6834ddb548ba2ce Mon Sep 17 00:00:00 2001 From: David Stansby Date: Tue, 7 Nov 2017 10:37:35 +0000 Subject: [PATCH] Backport PR #9705: Fix scatterplot categorical support --- lib/matplotlib/category.py | 6 ++++++ lib/matplotlib/tests/test_category.py | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index d043c5b154a5..d2754d32fd3a 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -40,6 +40,12 @@ def convert(value, unit, axis): """Uses axis.unit_data map to encode data as floats """ + value = np.atleast_1d(value) + # try and update from here.... + if hasattr(axis.unit_data, 'update'): + for val in value: + if isinstance(val, six.string_types): + axis.unit_data.update(val) vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) if isinstance(value, six.string_types): diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 6e5c43d76fb9..7156dc59933c 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -243,3 +243,16 @@ def test_plot_update(self): unit_data = MockUnitData(list(zip(labels, ticks))) self.axis_test(ax.yaxis, ticks, labels, unit_data) + + def test_scatter_update(self): + fig, ax = plt.subplots() + + ax.scatter(['a', 'b'], [0., 3.]) + ax.scatter(['a', 'b', 'd'], [1., 2., 3.]) + ax.scatter(['b', 'c', 'd'], [4., 1., 2.]) + fig.canvas.draw() + + labels = ['a', 'b', 'd', 'c'] + ticks = [0, 1, 2, 3] + unit_data = MockUnitData(list(zip(labels, ticks))) + self.axis_test(ax.xaxis, ticks, labels, unit_data)