From 44423046608768462b2e523b437bf292b1f644d8 Mon Sep 17 00:00:00 2001 From: hannah Date: Mon, 13 Nov 2017 17:36:17 -0500 Subject: [PATCH 01/22] MNT: start to re-factor the categorical implementation --- lib/matplotlib/category.py | 111 ++++++------ lib/matplotlib/tests/test_category.py | 244 ++++++++++++++------------ 2 files changed, 185 insertions(+), 170 deletions(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index d2754d32fd3a..9f6d187021ed 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -4,6 +4,11 @@ """ from __future__ import (absolute_import, division, print_function, unicode_literals) + +from collections import Iterable, Sequence, OrderedDict +import itertools +import numbers + import six import numpy as np @@ -13,25 +18,19 @@ # np 1.6/1.7 support from distutils.version import LooseVersion -import collections -if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): - def shim_array(data): - return np.array(data, dtype=np.unicode) -else: - def shim_array(data): - if (isinstance(data, six.string_types) or - not isinstance(data, collections.Iterable)): - data = [data] - try: - data = [str(d) for d in data] - except UnicodeEncodeError: - # this yields gibberish but unicode text doesn't - # render under numpy1.6 anyway - data = [d.encode('utf-8', 'ignore').decode('utf-8') - for d in data] - return np.array(data, dtype=np.unicode) +def to_str(value): + if LooseVersion(np.__version__) < LooseVersion('1.7.0'): + if (isinstance(value, (six.text_type, np.unicode))): + value = value.encode('utf-8', 'ignore').decode('utf-8') + if isinstance(value, (bytes, np.bytes_, six.binary_type)): + value = value.decode(encoding='utf-8') + elif isinstance(value, (bytes, np.bytes_, six.binary_type)): + return value.decode(encoding='utf-8') + elif not isinstance(value, (str, np.str_, six.text_type)): + value = str(value) + return value class StrCategoryConverter(units.ConversionInterface): @@ -40,28 +39,28 @@ def convert(value, unit, axis): """Uses axis.unit_data map to encode data as floats """ - value = np.atleast_1d(value) - # try and update from here.... - if hasattr(axis.unit_data, 'update'): - for val in value: - if isinstance(val, six.string_types): - axis.unit_data.update(val) - vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) - if isinstance(value, six.string_types): - return vmap[value] + return axis.unit_data._mapping[value] - vals = shim_array(value) + # dtype=object preserves 42, '42' distinction on scatter + values = np.atleast_1d(np.array(value, dtype=object)) + if units.ConversionInterface.is_numlike(value): + return np.array([axis.unit_data._mapping.get(v, v) + for v in values]) - for lab, loc in vmap.items(): - vals[vals == lab] = loc + if hasattr(axis.unit_data, 'update'): + axis.unit_data.update(values) - return vals.astype('float') + str2idx = np.vectorize(axis.unit_data._mapping.__getitem__, + otypes=[float]) + + mapped_value = str2idx(values) + return mapped_value @staticmethod def axisinfo(unit, axis): - majloc = StrCategoryLocator(axis.unit_data.locs) - majfmt = StrCategoryFormatter(axis.unit_data.seq) + majloc = StrCategoryLocator(axis.unit_data._locs) + majfmt = StrCategoryFormatter(axis.unit_data._seq) return units.AxisInfo(majloc=majloc, majfmt=majfmt) @staticmethod @@ -88,9 +87,6 @@ def __init__(self, seq): class UnitData(object): - # debatable makes sense to special code missing values - spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0} - def __init__(self, data): """Create mapping between unique categorical values and numerical identifier @@ -98,27 +94,34 @@ def __init__(self, data): Parameters ---------- data: iterable - sequence of values + sequence of values """ - self.seq, self.locs = [], [] - self._set_seq_locs(data, 0) - - def update(self, new_data): - # so as not to conflict with spdict - value = max(max(self.locs) + 1, 0) - self._set_seq_locs(new_data, value) - - def _set_seq_locs(self, data, value): - strdata = shim_array(data) - new_s = [d for d in np.unique(strdata) if d not in self.seq] - for ns in new_s: - self.seq.append(ns) - if ns in UnitData.spdict: - self.locs.append(UnitData.spdict[ns]) - else: - self.locs.append(value) - value += 1 + # seq, loc need to be pass by reference or there needs to be + # a callback from Locator/Formatter on update + self._seq, self._locs = [], [] + self._mapping = OrderedDict() + self._counter = itertools.count() + self.update(data) + + def _update_mapping(self, value): + if value in self._mapping: + return + if isinstance(value, (float, complex)) and np.isnan(value): + self._mapping[value] = np.nan + else: + self._mapping[value] = next(self._counter) + self._seq.append(to_str(value)) + self._locs.append(self._mapping[value]) + return + def update(self, data): + if (isinstance(data, six.string_types) or + not isinstance(data, Iterable)): + self._update_mapping(data) + else: + unsorted_unique = OrderedDict.fromkeys(data) + for ns in unsorted_unique: + self._update_mapping(ns) # Connects the convertor to matplotlib units.registry[str] = StrCategoryConverter() diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 17a90d47f379..32c17ac6dc17 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -2,29 +2,30 @@ """Catch all for categorical functions""" from __future__ import absolute_import, division, print_function +from collections import OrderedDict +import six import pytest import numpy as np + import matplotlib.pyplot as plt import matplotlib.category as cat -import unittest - class TestUnitData(object): - testdata = [("hello world", ["hello world"], [0]), - (u"Здравствуйте мир", [u"Здравствуйте мир"], [0]), - (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - ['-inf', '3.14', 'A', 'B', 'inf', 'nan'], - [-3.0, 0, 1, 2, -2.0, -1.0])] + test_cases = [('single', ("hello world", ["hello world"], [0])), + ('unicode', (u"Здравствуйте мир", [u"Здравствуйте мир"], [0])), + ('mixed', (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], + ['A', 'nan', 'B', '-inf', '3.14', 'inf'], + [0, np.nan, 1, 2, 3, 4]))] - ids = ["single", "unicode", "mixed"] + ids, data = zip(*test_cases) - @pytest.mark.parametrize("data, seq, locs", testdata, ids=ids) + @pytest.mark.parametrize("data, seq, locs", data, ids=ids) def test_unit(self, data, seq, locs): act = cat.UnitData(data) - assert act.seq == seq - assert act.locs == locs + assert act._seq == seq + assert act._locs == locs def test_update_map(self): data = ['a', 'd'] @@ -33,15 +34,15 @@ def test_update_map(self): data_update = ['b', 'd', 'e', np.inf] useq = ['a', 'd', 'b', 'e', 'inf'] - ulocs = [0, 1, 2, 3, -2] + ulocs = [0, 1, 2, 3, 4] unitdata = cat.UnitData(data) - assert unitdata.seq == oseq - assert unitdata.locs == olocs + assert unitdata._seq == oseq + assert unitdata._locs == olocs unitdata.update(data_update) - assert unitdata.seq == useq - assert unitdata.locs == ulocs + assert unitdata._seq == useq + assert unitdata._locs == ulocs class FakeAxis(object): @@ -50,10 +51,13 @@ def __init__(self, unit_data): class MockUnitData(object): - def __init__(self, data): - seq, locs = zip(*data) - self.seq = list(seq) - self.locs = list(locs) + def __init__(self, data, labels=None): + self._mapping = OrderedDict(data) + if labels: + self._seq = labels + else: + self._seq = list(self._mapping.keys()) + self._locs = list(self._mapping.values()) class TestStrCategoryConverter(object): @@ -62,27 +66,29 @@ class TestStrCategoryConverter(object): ref: /pandas/tseries/tests/test_converter.py /pandas/tests/test_algos.py:TestFactorize """ - testdata = [(u"Здравствуйте мир", [(u"Здравствуйте мир", 42)], 42), - ("hello world", [("hello world", 42)], 42), - (['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c'], - [('a', 0), ('b', 1), ('c', 2)], - [0, 1, 1, 0, 0, 2, 2, 2]), - (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - [('nan', -1), ('3.14', 0), ('A', 1), ('B', 2), - ('-inf', 100), ('inf', 200)], - [1, 1, -1, 2, 100, 0, 200])] - ids = ["unicode", "single", "basic", "mixed"] + + test_cases = [("unicode", {u"Здравствуйте мир": 42}), + ("ascii", {"hello world": 42}), + ("single", {'a': 0, 'b': 1, 'c': 2}), + ("mixed", {3.14: 0, 'A': 1, 'B': 2, + -np.inf: 3, np.inf: 4, np.nan: 5}), + ("integer string", {"!": 0, "0": 1, 0: 1}), + ("number", {0.0: 0.0}), + ("number string", {'42': 0, 42: 1})] + + ids, unitmaps = zip(*test_cases) @pytest.fixture(autouse=True) def mock_axis(self, request): self.cc = cat.StrCategoryConverter() - @pytest.mark.parametrize("data, unitmap, exp", testdata, ids=ids) - def test_convert(self, data, unitmap, exp): + @pytest.mark.parametrize("unitmap", unitmaps, ids=ids) + def test_convert(self, unitmap): + data, exp = zip(*six.iteritems(unitmap)) MUD = MockUnitData(unitmap) axis = FakeAxis(MUD) act = self.cc.convert(data, None, axis) - np.testing.assert_array_equal(act, exp) + np.testing.assert_allclose(act, exp) def test_axisinfo(self): MUD = MockUnitData([(None, None)]) @@ -103,7 +109,7 @@ def test_StrCategoryLocator(self): np.testing.assert_array_equal(ticks.tick_values(None, None), locs) -class TestStrCategoryFormatter(unittest.TestCase): +class TestStrCategoryFormatter(object): def test_StrCategoryFormatter(self): seq = ["hello", "world", "hi"] labels = cat.StrCategoryFormatter(seq) @@ -119,47 +125,69 @@ def lt(tl): return [l.get_text() for l in tl] -class TestPlot(object): - bytes_data = [ - ['a', 'b', 'c'], - [b'a', b'b', b'c'], - np.array([b'a', b'b', b'c']) - ] +def axis_test(axis, ticks, labels, unit_data): + np.testing.assert_array_equal(axis.get_majorticklocs(), ticks) + assert lt(axis.get_majorticklabels()) == labels + np.testing.assert_array_equal(axis.unit_data._locs, unit_data._locs) + assert axis.unit_data._seq == unit_data._seq + + +class TestBarsBytes(object): + bytes_cases = [('string list', ['a', 'b', 'c']), + ('bytes list', [b'a', b'b', b'c']), + ('bytes ndarray', np.array([b'a', b'b', b'c']))] + + bytes_ids, bytes_data = zip(*bytes_cases) + + @pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids) + def test_plot_bytes(self, bars): + + unitmap = MockUnitData([('a', 0), ('b', 1), ('c', 2)]) + + counts = np.array([4, 6, 5]) + fig, ax = plt.subplots() + ax.bar(bars, counts) + fig.canvas.draw() + axis_test(ax.xaxis, [0, 1, 2], ['a', 'b', 'c'], unitmap) + - bytes_ids = ['string list', 'bytes list', 'bytes ndarray'] +class TestBarsNumlike(object): + numlike_cases = [('string list', ['1', '11', '3']), + ('string ndarray', np.array(['1', '11', '3'])), + ('bytes list', [b'1', b'11', b'3']), + ('bytes ndarray', np.array([b'1', b'11', b'3']))] - numlike_data = [ - ['1', '11', '3'], - np.array(['1', '11', '3']), - [b'1', b'11', b'3'], - np.array([b'1', b'11', b'3']), - ] + numlike_ids, numlike_data = zip(*numlike_cases) - numlike_ids = [ - 'string list', 'string ndarray', 'bytes list', 'bytes ndarray' - ] + @pytest.mark.parametrize("bars", numlike_data, ids=numlike_ids) + def test_plot_numlike(self, bars): + counts = np.array([4, 6, 5]) + + fig, ax = plt.subplots() + ax.bar(bars, counts) + fig.canvas.draw() + + unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)]) + axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap) + +class TestPlotTypes(object): @pytest.fixture - def data(self): - self.d = ['a', 'b', 'c', 'a'] - self.dticks = [0, 1, 2] - self.dlabels = ['a', 'b', 'c'] + def complete_data(self): + self.complete = ['a', 'b', 'c', 'a'] + self.complete_ticks = [0, 1, 2] + self.complete_labels = ['a', 'b', 'c'] unitmap = [('a', 0), ('b', 1), ('c', 2)] - self.dunit_data = MockUnitData(unitmap) + self.complete_unit_data = MockUnitData(unitmap) @pytest.fixture def missing_data(self): - self.dm = ['here', np.nan, 'here', 'there'] - self.dmticks = [0, -1, 1] - self.dmlabels = ['here', 'nan', 'there'] - unitmap = [('here', 0), ('nan', -1), ('there', 1)] - self.dmunit_data = MockUnitData(unitmap) - - def axis_test(self, axis, ticks, labels, unit_data): - np.testing.assert_array_equal(axis.get_majorticklocs(), ticks) - assert lt(axis.get_majorticklabels()) == labels - np.testing.assert_array_equal(axis.unit_data.locs, unit_data.locs) - assert axis.unit_data.seq == unit_data.seq + self.missing = ['here', np.nan, 'here', 'there'] + self.missing_ticks = [0, np.nan, 1] + self.missing_labels = ['here', 'nan', 'there'] + unitmap = [('here', 0), (np.nan, np.nan), ('there', 1)] + self.missing_unit_data = MockUnitData(unitmap, + labels=self.missing_labels) def test_plot_unicode(self): words = [u'Здравствуйте', u'привет'] @@ -170,68 +198,52 @@ def test_plot_unicode(self): ax.plot(words) fig.canvas.draw() - self.axis_test(ax.yaxis, locs, words, unit_data) + axis_test(ax.yaxis, locs, words, unit_data) - @pytest.mark.usefixtures("data") - def test_plot_1d(self): + @pytest.mark.usefixtures("complete_data") + def test_plot_yaxis(self): fig, ax = plt.subplots() - ax.plot(self.d) + ax.plot(self.complete) fig.canvas.draw() + axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, + self.complete_unit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - + @pytest.mark.xfail(reason="scatter/plot inconsistencies") @pytest.mark.usefixtures("missing_data") - def test_plot_1d_missing(self): + def test_plot_yaxis_missing_data(self): fig, ax = plt.subplots() - ax.plot(self.dm) + ax.plot(self.missing) fig.canvas.draw() + axis_test(ax.yaxis, self.missing_ticks, self.missing_labels, + self.missing_unit_data) - self.axis_test(ax.yaxis, self.dmticks, self.dmlabels, self.dmunit_data) - - @pytest.mark.usefixtures("data") - @pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids) - def test_plot_bytes(self, bars): - counts = np.array([4, 6, 5]) - + @pytest.mark.xfail(reason="scatter/plot inconsistencies") + @pytest.mark.usefixtures("complete_data", "missing_data") + def test_plot_missing_xaxis_yaxis(self): fig, ax = plt.subplots() - ax.bar(bars, counts) + ax.plot(self.missing, self.complete) fig.canvas.draw() - self.axis_test(ax.xaxis, self.dticks, self.dlabels, self.dunit_data) - - @pytest.mark.parametrize("bars", numlike_data, ids=numlike_ids) - def test_plot_numlike(self, bars): - counts = np.array([4, 6, 5]) + axis_test(ax.xaxis, self.missing_ticks, self.missing_labels, + self.missing_unit_data) + axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, + self.complete_unit_data) + @pytest.mark.usefixtures("complete_data", "missing_data") + def test_scatter_missing_xaxis_yaxis(self): fig, ax = plt.subplots() - ax.bar(bars, counts) + ax.scatter(self.missing, self.complete) fig.canvas.draw() + axis_test(ax.xaxis, self.missing_ticks, self.missing_labels, + self.missing_unit_data) + axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, + self.complete_unit_data) - unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)]) - self.axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap) - - @pytest.mark.usefixtures("data", "missing_data") - def test_plot_2d(self): - fig, ax = plt.subplots() - ax.plot(self.dm, self.d) - fig.canvas.draw() - - self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - - @pytest.mark.usefixtures("data", "missing_data") - def test_scatter_2d(self): - fig, ax = plt.subplots() - ax.scatter(self.dm, self.d) - fig.canvas.draw() - - self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) +class TestUpdatePlot(object): - def test_plot_update(self): + def test_update_plot(self): fig, ax = plt.subplots() - ax.plot(['a', 'b']) ax.plot(['a', 'b', 'd']) ax.plot(['b', 'c', 'd']) @@ -239,13 +251,12 @@ def test_plot_update(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unit_data = MockUnitData(zip(labels, ticks)) + unitmap = MockUnitData(list(zip(labels, ticks))) - self.axis_test(ax.yaxis, ticks, labels, unit_data) + axis_test(ax.yaxis, ticks, labels, unitmap) - def test_scatter_update(self): + def test_update_scatter(self): fig, ax = plt.subplots() - ax.scatter(['a', 'b'], [0., 3.]) ax.scatter(['a', 'b', 'd'], [1., 2., 3.]) ax.scatter(['b', 'c', 'd'], [4., 1., 2.]) @@ -253,5 +264,6 @@ def test_scatter_update(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unit_data = MockUnitData(list(zip(labels, ticks))) - self.axis_test(ax.xaxis, ticks, labels, unit_data) + unitmap = MockUnitData(list(zip(labels, ticks))) + + axis_test(ax.xaxis, ticks, labels, unitmap) From 98096e267332c9a6dbf32c2a81955e7345240bcd Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Mon, 20 Nov 2017 21:02:04 -0500 Subject: [PATCH 02/22] MNT: remove typos and re-organize imports --- lib/matplotlib/category.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index 9f6d187021ed..a967aba17aba 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -1,24 +1,22 @@ -# -*- coding: utf-8 OA-*-za +# -*- coding: utf-8 OA -*- + """ catch all for categorical functions """ + from __future__ import (absolute_import, division, print_function, unicode_literals) from collections import Iterable, Sequence, OrderedDict import itertools import numbers - +from matplotlib import cbook, ticker, units import six -import numpy as np - -import matplotlib.units as units -import matplotlib.ticker as ticker - -# np 1.6/1.7 support -from distutils.version import LooseVersion +from collections import OrderedDict +import itertools +import numpy as np def to_str(value): if LooseVersion(np.__version__) < LooseVersion('1.7.0'): From 2bd832d261021cdf7f4773bd366a179766faf487 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 21 Nov 2017 22:21:09 -0500 Subject: [PATCH 03/22] API: reduce scope of StringCategorical & simplify internal structure - Expect that `axis.units` has a `_mapping` attribute. - Allow only strings. This prevents issues with mixed input where plotting different sub-sets of the data will or will not trigger the correct converter / unit behavior. - Do not allow missing data. Do not special case nan or inf for the same reasons as above. - deprecate `Axis.unit_data`. Axis already has a `units` attribute which hold a UnitData instance. --- lib/matplotlib/axis.py | 13 +++-- lib/matplotlib/category.py | 109 ++++++++++++++----------------------- 2 files changed, 48 insertions(+), 74 deletions(-) diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 59405a10d93c..733d610f2791 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -719,7 +719,9 @@ def __init__(self, axes, pickradius=15): self.label = self._get_label() self.labelpad = rcParams['axes.labelpad'] self.offsetText = self._get_offset_text() - self.unit_data = None + + self.majorTicks = [] + self.minorTicks = [] self.pickradius = pickradius # Initialize here for testing; later add API @@ -777,15 +779,14 @@ def limit_range_for_scale(self, vmin, vmax): return self._scale.limit_range_for_scale(vmin, vmax, self.get_minpos()) @property + @cbook.deprecated("2.2") def unit_data(self): - """Holds data that a ConversionInterface subclass uses - to convert between labels and indexes - """ - return self._unit_data + return self.units @unit_data.setter + @cbook.deprecated("2.2") def unit_data(self, unit_data): - self._unit_data = unit_data + self.set_units(unit_data) def get_children(self): children = [self.label, self.offsetText] diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index a967aba17aba..9d3293790839 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -18,42 +18,20 @@ import numpy as np -def to_str(value): - if LooseVersion(np.__version__) < LooseVersion('1.7.0'): - if (isinstance(value, (six.text_type, np.unicode))): - value = value.encode('utf-8', 'ignore').decode('utf-8') - if isinstance(value, (bytes, np.bytes_, six.binary_type)): - value = value.decode(encoding='utf-8') - elif isinstance(value, (bytes, np.bytes_, six.binary_type)): - return value.decode(encoding='utf-8') - elif not isinstance(value, (str, np.str_, six.text_type)): - value = str(value) - return value - class StrCategoryConverter(units.ConversionInterface): @staticmethod def convert(value, unit, axis): - """Uses axis.unit_data map to encode - data as floats - """ - if isinstance(value, six.string_types): - return axis.unit_data._mapping[value] - - # dtype=object preserves 42, '42' distinction on scatter - values = np.atleast_1d(np.array(value, dtype=object)) - if units.ConversionInterface.is_numlike(value): - return np.array([axis.unit_data._mapping.get(v, v) - for v in values]) - - if hasattr(axis.unit_data, 'update'): - axis.unit_data.update(values) + """Use axis.units mapping tncode data as floats.""" - str2idx = np.vectorize(axis.unit_data._mapping.__getitem__, - otypes=[float]) - - mapped_value = str2idx(values) - return mapped_value + # We also need to pass numbers through. + if np.issubdtype(np.asarray(value).dtype.type, np.number): + return value + else: + axis.units.update(value) + str2idx = np.vectorize(axis.units._mapping.__getitem__, + otypes=[float]) + return str2idx(value) @staticmethod def axisinfo(unit, axis): @@ -63,13 +41,8 @@ def axisinfo(unit, axis): @staticmethod def default_units(data, axis): - # the conversion call stack is: - # default_units->axis_info->convert - if axis.unit_data is None: - axis.unit_data = UnitData(data) - else: - axis.unit_data.update(data) - return None + return UnitData() + class StrCategoryLocator(ticker.FixedLocator): @@ -85,45 +58,45 @@ def __init__(self, seq): class UnitData(object): - def __init__(self, data): - """Create mapping between unique categorical values - and numerical identifier + def __init__(self, data=None): + """Create mapping between unique categorical values and numerical id. Parameters ---------- - data: iterable - sequence of values + data : Mapping[str, int] + The initial categories. May be `None`. + """ - # seq, loc need to be pass by reference or there needs to be - # a callback from Locator/Formatter on update - self._seq, self._locs = [], [] - self._mapping = OrderedDict() - self._counter = itertools.count() - self.update(data) - - def _update_mapping(self, value): - if value in self._mapping: - return - if isinstance(value, (float, complex)) and np.isnan(value): - self._mapping[value] = np.nan + self._vals = [] + if data is None: + data = () + self._mapping = OrderedDict(data) + for k, v in self._mapping.items(): + if not isinstance(k, six.text_type): + raise TypeError("{val!r} is not a string".format(val=k)) + self._mapping[k] = int(v) + if self._mapping: + start = max(self._mapping.values()) + 1 else: - self._mapping[value] = next(self._counter) - self._seq.append(to_str(value)) - self._locs.append(self._mapping[value]) - return + start = 0 + self._counter = itertools.count(start=start) def update(self, data): - if (isinstance(data, six.string_types) or - not isinstance(data, Iterable)): - self._update_mapping(data) - else: - unsorted_unique = OrderedDict.fromkeys(data) - for ns in unsorted_unique: - self._update_mapping(ns) + if isinstance(data, six.string_types): + data = [data] + sorted_unique = OrderedDict.fromkeys(data) + for val in sorted_unique: + if val in self._mapping: + continue + if not isinstance(val, six.text_type): + raise TypeError("{val!r} is not a string".format(val)) + self._vals.append(val) + self._mapping[val] = next(self._counter) + # Connects the convertor to matplotlib units.registry[str] = StrCategoryConverter() -units.registry[np.str_] = StrCategoryConverter() -units.registry[six.text_type] = StrCategoryConverter() units.registry[bytes] = StrCategoryConverter() +units.registry[np.str_] = StrCategoryConverter() units.registry[np.bytes_] = StrCategoryConverter() +units.registry[six.text_type] = StrCategoryConverter() From 4195dbf25ce150db4adaf8fab3b5a036b65c0eec Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 21 Nov 2017 22:23:25 -0500 Subject: [PATCH 04/22] API: change baseclass of StrCategoryLocator and StrCategoryFormatter Change to using direct sub-classes of Locator and Formatter which hold a reference to the UnitData object. This makes the classes a bit less brittle and avoids having extra methods on the classes which will not work as intended. --- lib/matplotlib/category.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index 9d3293790839..0e5218d45870 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -35,8 +35,8 @@ def convert(value, unit, axis): @staticmethod def axisinfo(unit, axis): - majloc = StrCategoryLocator(axis.unit_data._locs) - majfmt = StrCategoryFormatter(axis.unit_data._seq) + majloc = StrCategoryLocator(axis.units) + majfmt = StrCategoryFormatter(axis.units) return units.AxisInfo(majloc=majloc, majfmt=majfmt) @staticmethod @@ -44,17 +44,26 @@ def default_units(data, axis): return UnitData() +class StrCategoryLocator(ticker.Locator): + def __init__(self, unit_data): + self._unit_data = unit_data -class StrCategoryLocator(ticker.FixedLocator): - def __init__(self, locs): - self.locs = locs - self.nbins = None + def __call__(self): + return list(self._unit_data._mapping.values()) + def tick_values(self, vmin, vmax): + return self() -class StrCategoryFormatter(ticker.FixedFormatter): - def __init__(self, seq): - self.seq = seq - self.offset_string = '' + +class StrCategoryFormatter(ticker.Formatter): + def __init__(self, unit_data): + self._unit_data = unit_data + + def __call__(self, x, pos=None): + if pos is None: + return "" + r_mapping = {v: k for k, v in self._unit_data._mapping.items()} + return r_mapping.get(int(x), '') class UnitData(object): From 04edaa3bb1aae9e16f5f1eab7dbc7d8d46960e76 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Mon, 20 Nov 2017 20:50:19 -0500 Subject: [PATCH 05/22] MNT: re-work how units are handled in axes._base Reattach the unitful data to Line2D. --- lib/matplotlib/axes/_base.py | 72 ++++++++++++++++++-------------- lib/matplotlib/cbook/__init__.py | 4 +- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index bb70972b6c51..34129b7ab8b4 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -214,24 +214,10 @@ def _xy_from_xy(self, x, y): if self.axes.xaxis is not None and self.axes.yaxis is not None: bx = self.axes.xaxis.update_units(x) by = self.axes.yaxis.update_units(y) - - if self.command != 'plot': - # the Line2D class can handle unitized data, with - # support for post hoc unit changes etc. Other mpl - # artists, e.g., Polygon which _process_plot_var_args - # also serves on calls to fill, cannot. So this is a - # hack to say: if you are not "plot", which is - # creating Line2D, then convert the data now to - # floats. If you are plot, pass the raw data through - # to Line2D which will handle the conversion. So - # polygons will not support post hoc conversions of - # the unit type since they are not storing the orig - # data. Hopefully we can rationalize this at a later - # date - JDH - if bx: - x = self.axes.convert_xunits(x) - if by: - y = self.axes.convert_yunits(y) + if bx: + x = self.axes.convert_xunits(x) + if by: + y = self.axes.convert_yunits(y) # like asanyarray, but converts scalar to array, and doesn't change # existing compatible sequences @@ -374,26 +360,48 @@ def _plot_args(self, tup, kwargs): if 'label' not in kwargs or kwargs['label'] is None: kwargs['label'] = get_label(tup[-1], None) - if len(tup) == 2: - x = _check_1d(tup[0]) - y = _check_1d(tup[-1]) + if len(tup) == 1: + x, y = index_of(tup[0]) + elif len(tup) == 2: + x, y = tup else: - x, y = index_of(tup[-1]) - - x, y = self._xy_from_xy(x, y) - - if self.command == 'plot': - func = self._makeline + assert False + + deunitized_x, deunitized_y = self._xy_from_xy(x, y) + # The previous call has registered the converters, if any, on the axes. + # This check will need to be replaced by a comparison with the + # DefaultConverter when that PR goes in. + if self.axes.xaxis.converter is None or self.command is not "plot": + xt, yt = deunitized_x.T, deunitized_y.T else: - kw['closed'] = kwargs.get('closed', True) - func = self._makefill - - ncx, ncy = x.shape[1], y.shape[1] + # np.asarray would destroy unit information so we need to construct + # the 1D arrays to pass to Line2D.set_xdata manually... (but this + # is only relevant if the command is "plot"). + + def to_list_of_lists(data): + ndim = np.ndim(data) + if ndim == 0: + return [[data]] + elif ndim == 1: + return [data] + elif ndim == 2: + return zip(*data) # Transpose it. + + xt, yt = map(to_list_of_lists, [x, y]) + + ncx, ncy = deunitized_x.shape[1], deunitized_y.shape[1] if ncx > 1 and ncy > 1 and ncx != ncy: cbook.warn_deprecated("2.2", "cycling among columns of inputs " "with non-matching shapes is deprecated.") + for j in xrange(max(ncx, ncy)): - seg = func(x[:, j % ncx], y[:, j % ncy], kw, kwargs) + if self.command == "plot": + seg = self._makeline(xt[j % ncx], yt[j % ncy], kw, kwargs) + else: + kw['closed'] = kwargs.get('closed', True) + seg = self._makefill(deunitized_x[:, j % ncx], + deunitized_y[:, j % ncy], + kw, kwargs) ret.append(seg) return ret diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index 96c33fb3adb2..4861a17d8690 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -2312,8 +2312,8 @@ def index_of(y): try: return y.index.values, y.values except AttributeError: - y = _check_1d(y) - return np.arange(y.shape[0], dtype=float), y + # Ensure that scalar y gives x == [0]. + return np.arange((np.shape(y) or (1,))[0], dtype=float), y def safe_first_element(obj): From 283ca08392e99b527a0a3282514278f8581a2872 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Wed, 22 Nov 2017 23:04:26 -0500 Subject: [PATCH 06/22] TST: fix up tests --- lib/matplotlib/tests/test_category.py | 141 ++++++++------------------ 1 file changed, 45 insertions(+), 96 deletions(-) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 32c17ac6dc17..4c8c184d8268 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -2,62 +2,53 @@ """Catch all for categorical functions""" from __future__ import absolute_import, division, print_function -from collections import OrderedDict import six import pytest import numpy as np - +from numpy.testing import assert_array_equal import matplotlib.pyplot as plt import matplotlib.category as cat +from matplotlib.axes import Axes class TestUnitData(object): test_cases = [('single', ("hello world", ["hello world"], [0])), ('unicode', (u"Здравствуйте мир", [u"Здравствуйте мир"], [0])), - ('mixed', (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - ['A', 'nan', 'B', '-inf', '3.14', 'inf'], - [0, np.nan, 1, 2, 3, 4]))] + ('mixed', (['A', 'A', 'B'], + ['A', 'B', ], + [0, 1]))] ids, data = zip(*test_cases) @pytest.mark.parametrize("data, seq, locs", data, ids=ids) def test_unit(self, data, seq, locs): - act = cat.UnitData(data) - assert act._seq == seq - assert act._locs == locs + act = cat.UnitData() + for v in data: + act.update(data) + assert list(act._mapping.keys()) == seq + assert list(act._mapping.values()) == locs def test_update_map(self): - data = ['a', 'd'] oseq = ['a', 'd'] olocs = [0, 1] - data_update = ['b', 'd', 'e', np.inf] - useq = ['a', 'd', 'b', 'e', 'inf'] - ulocs = [0, 1, 2, 3, 4] + data_update = ['b', 'd', 'e'] + useq = ['a', 'd', 'b', 'e'] + ulocs = [0, 1, 2, 3] - unitdata = cat.UnitData(data) - assert unitdata._seq == oseq - assert unitdata._locs == olocs + unitdata = cat.UnitData(zip(oseq, olocs)) + assert list(unitdata._mapping.keys()) == oseq + assert list(unitdata._mapping.values()) == olocs unitdata.update(data_update) - assert unitdata._seq == useq - assert unitdata._locs == ulocs + assert list(unitdata._mapping.keys()) == useq + assert list(unitdata._mapping.values()) == ulocs class FakeAxis(object): def __init__(self, unit_data): - self.unit_data = unit_data - - -class MockUnitData(object): - def __init__(self, data, labels=None): - self._mapping = OrderedDict(data) - if labels: - self._seq = labels - else: - self._seq = list(self._mapping.keys()) - self._locs = list(self._mapping.values()) + self.units = unit_data class TestStrCategoryConverter(object): @@ -69,12 +60,7 @@ class TestStrCategoryConverter(object): test_cases = [("unicode", {u"Здравствуйте мир": 42}), ("ascii", {"hello world": 42}), - ("single", {'a': 0, 'b': 1, 'c': 2}), - ("mixed", {3.14: 0, 'A': 1, 'B': 2, - -np.inf: 3, np.inf: 4, np.nan: 5}), - ("integer string", {"!": 0, "0": 1, 0: 1}), - ("number", {0.0: 0.0}), - ("number string", {'42': 0, 42: 1})] + ("single", {'a': 0, 'b': 1, 'c': 2})] ids, unitmaps = zip(*test_cases) @@ -85,13 +71,13 @@ def mock_axis(self, request): @pytest.mark.parametrize("unitmap", unitmaps, ids=ids) def test_convert(self, unitmap): data, exp = zip(*six.iteritems(unitmap)) - MUD = MockUnitData(unitmap) + MUD = cat.UnitData(unitmap) axis = FakeAxis(MUD) act = self.cc.convert(data, None, axis) np.testing.assert_allclose(act, exp) def test_axisinfo(self): - MUD = MockUnitData([(None, None)]) + MUD = cat.UnitData() axis = FakeAxis(MUD) ax = self.cc.axisinfo(None, axis) assert isinstance(ax.majloc, cat.StrCategoryLocator) @@ -99,26 +85,33 @@ def test_axisinfo(self): def test_default_units(self): axis = FakeAxis(None) - assert self.cc.default_units(["a"], axis) is None + assert isinstance(self.cc.default_units(["a"], axis), cat.UnitData) class TestStrCategoryLocator(object): def test_StrCategoryLocator(self): locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - ticks = cat.StrCategoryLocator(locs) + u = cat.UnitData() + for j in range(11): + u.update(str(j)) + ticks = cat.StrCategoryLocator(u) np.testing.assert_array_equal(ticks.tick_values(None, None), locs) class TestStrCategoryFormatter(object): def test_StrCategoryFormatter(self): seq = ["hello", "world", "hi"] - labels = cat.StrCategoryFormatter(seq) - assert labels('a', 1) == "world" + u = cat.UnitData() + u.update(seq) + labels = cat.StrCategoryFormatter(u) + assert labels(1, 1) == "world" def test_StrCategoryFormatterUnicode(self): seq = ["Здравствуйте", "привет"] - labels = cat.StrCategoryFormatter(seq) - assert labels('a', 1) == "привет" + u = cat.UnitData() + u.update(seq) + labels = cat.StrCategoryFormatter(u) + assert labels(1, 1) == "привет" def lt(tl): @@ -126,23 +119,21 @@ def lt(tl): def axis_test(axis, ticks, labels, unit_data): - np.testing.assert_array_equal(axis.get_majorticklocs(), ticks) + assert axis.get_majorticklocs() == ticks assert lt(axis.get_majorticklabels()) == labels - np.testing.assert_array_equal(axis.unit_data._locs, unit_data._locs) - assert axis.unit_data._seq == unit_data._seq + assert axis.units._mapping == unit_data._mapping class TestBarsBytes(object): bytes_cases = [('string list', ['a', 'b', 'c']), - ('bytes list', [b'a', b'b', b'c']), - ('bytes ndarray', np.array([b'a', b'b', b'c']))] + ] bytes_ids, bytes_data = zip(*bytes_cases) @pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids) def test_plot_bytes(self, bars): - unitmap = MockUnitData([('a', 0), ('b', 1), ('c', 2)]) + unitmap = cat.UnitData([('a', 0), ('b', 1), ('c', 2)]) counts = np.array([4, 6, 5]) fig, ax = plt.subplots() @@ -153,9 +144,7 @@ def test_plot_bytes(self, bars): class TestBarsNumlike(object): numlike_cases = [('string list', ['1', '11', '3']), - ('string ndarray', np.array(['1', '11', '3'])), - ('bytes list', [b'1', b'11', b'3']), - ('bytes ndarray', np.array([b'1', b'11', b'3']))] + ('string ndarray', np.array(['1', '11', '3']))] numlike_ids, numlike_data = zip(*numlike_cases) @@ -167,7 +156,7 @@ def test_plot_numlike(self, bars): ax.bar(bars, counts) fig.canvas.draw() - unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)]) + unitmap = cat.UnitData([('1', 0), ('11', 1), ('3', 2)]) axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap) @@ -178,21 +167,12 @@ def complete_data(self): self.complete_ticks = [0, 1, 2] self.complete_labels = ['a', 'b', 'c'] unitmap = [('a', 0), ('b', 1), ('c', 2)] - self.complete_unit_data = MockUnitData(unitmap) - - @pytest.fixture - def missing_data(self): - self.missing = ['here', np.nan, 'here', 'there'] - self.missing_ticks = [0, np.nan, 1] - self.missing_labels = ['here', 'nan', 'there'] - unitmap = [('here', 0), (np.nan, np.nan), ('there', 1)] - self.missing_unit_data = MockUnitData(unitmap, - labels=self.missing_labels) + self.complete_unit_data = cat.UnitData(unitmap) def test_plot_unicode(self): words = [u'Здравствуйте', u'привет'] locs = [0.0, 1.0] - unit_data = MockUnitData(zip(words, locs)) + unit_data = cat.UnitData(zip(words, locs)) fig, ax = plt.subplots() ax.plot(words) @@ -208,37 +188,6 @@ def test_plot_yaxis(self): axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, self.complete_unit_data) - @pytest.mark.xfail(reason="scatter/plot inconsistencies") - @pytest.mark.usefixtures("missing_data") - def test_plot_yaxis_missing_data(self): - fig, ax = plt.subplots() - ax.plot(self.missing) - fig.canvas.draw() - axis_test(ax.yaxis, self.missing_ticks, self.missing_labels, - self.missing_unit_data) - - @pytest.mark.xfail(reason="scatter/plot inconsistencies") - @pytest.mark.usefixtures("complete_data", "missing_data") - def test_plot_missing_xaxis_yaxis(self): - fig, ax = plt.subplots() - ax.plot(self.missing, self.complete) - fig.canvas.draw() - - axis_test(ax.xaxis, self.missing_ticks, self.missing_labels, - self.missing_unit_data) - axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, - self.complete_unit_data) - - @pytest.mark.usefixtures("complete_data", "missing_data") - def test_scatter_missing_xaxis_yaxis(self): - fig, ax = plt.subplots() - ax.scatter(self.missing, self.complete) - fig.canvas.draw() - axis_test(ax.xaxis, self.missing_ticks, self.missing_labels, - self.missing_unit_data) - axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, - self.complete_unit_data) - class TestUpdatePlot(object): @@ -251,7 +200,7 @@ def test_update_plot(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unitmap = MockUnitData(list(zip(labels, ticks))) + unitmap = cat.UnitData(list(zip(labels, ticks))) axis_test(ax.yaxis, ticks, labels, unitmap) @@ -264,6 +213,6 @@ def test_update_scatter(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unitmap = MockUnitData(list(zip(labels, ticks))) + unitmap = cat.UnitData(list(zip(labels, ticks))) axis_test(ax.xaxis, ticks, labels, unitmap) From 6f73c245b125856cad0a72f885b6c14891bb3eb9 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Wed, 22 Nov 2017 23:05:47 -0500 Subject: [PATCH 07/22] TST: add additional tests to categorical classes --- lib/matplotlib/tests/test_category.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 4c8c184d8268..eb8f95b8016b 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -216,3 +216,54 @@ def test_update_scatter(self): unitmap = cat.UnitData(list(zip(labels, ticks))) axis_test(ax.xaxis, ticks, labels, unitmap) + + +@pytest.fixture +def ax(): + return plt.figure().subplots() + + +@pytest.mark.parametrize( + "data, expected_indices, expected_labels", + [(["Здравствуйте мир"], [0], ["Здравствуйте мир"]), + (["a", "b", "b", "a", "c", "c"], [0, 1, 1, 0, 2, 2], ["a", "b", "c"]), + (["foo", "bar"], range(2), ["foo", "bar"]), + (np.array(["1", "11", "3"]), range(3), ["1", "11", "3"])]) +def test_simple(ax, data, expected_indices, expected_labels): + l, = ax.plot(data) + assert_array_equal(l.get_ydata(orig=False), expected_indices) + assert isinstance(ax.yaxis.major.locator, cat.StrCategoryLocator) + assert isinstance(ax.yaxis.major.formatter, cat.StrCategoryFormatter) + ax.figure.canvas.draw() + labels = [label.get_text() for label in ax.yaxis.get_majorticklabels()] + assert labels == expected_labels + + +def test_default_units(ax): + ax.plot(["a"]) + du = ax.yaxis.converter.default_units(["a"], ax.yaxis) + assert isinstance(du, cat.UnitData) + + +def test_update(ax): + l1, = ax.plot(["a", "d"]) + l2, = ax.plot(["b", "d", "e"]) + assert_array_equal(l1.get_ydata(orig=False), [0, 1]) + assert_array_equal(l2.get_ydata(orig=False), [2, 1, 3]) + assert ax.yaxis.units._vals == ["a", "d", "b", "e"] + assert ax.yaxis.units._mapping == {"a": 0, "d": 1, "b": 2, "e": 3} + + +@pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter, Axes.bar]) +def test_StrCategoryLocator(ax, plotter): + ax.plot(["a", "b", "c"]) + assert_array_equal(ax.yaxis.major.locator(), range(3)) + + +@pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter, Axes.bar]) +def test_StrCategoryFormatter(ax, plotter): + plotter(ax, range(2), ["hello", "мир"]) + assert ax.yaxis.major.formatter(0, 0) == "hello" + assert ax.yaxis.major.formatter(1, 1) == "мир" + assert ax.yaxis.major.formatter(2, 2) == "" + assert ax.yaxis.major.formatter(0, None) == "" From ac355c8dcbd6fa694b64a3854c5c63925bbd70c9 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 26 Nov 2017 15:41:18 -0500 Subject: [PATCH 08/22] WIP: add explicit test of conversion to float --- lib/matplotlib/axes/_base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 34129b7ab8b4..1aaeab3fadbe 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -211,6 +211,11 @@ def set_patchprops(self, fill_poly, **kwargs): fill_poly.set(**kwargs) def _xy_from_xy(self, x, y): + def _check_numeric(inp): + if inp.dtype.kind not in set('biufc'): + raise TypeError('input must be numeric after unit conversion') + return inp + if self.axes.xaxis is not None and self.axes.yaxis is not None: bx = self.axes.xaxis.update_units(x) by = self.axes.yaxis.update_units(y) @@ -221,8 +226,9 @@ def _xy_from_xy(self, x, y): # like asanyarray, but converts scalar to array, and doesn't change # existing compatible sequences - x = _check_1d(x) - y = _check_1d(y) + x = _check_numeric(_check_1d(x)) + y = _check_numeric(_check_1d(y)) + if x.shape[0] != y.shape[0]: raise ValueError("x and y must have same first dimension, but " "have shapes {} and {}".format(x.shape, y.shape)) From 6457f4fcb5814bc0fdd357d69a0009a745ef027f Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 26 Nov 2017 15:41:47 -0500 Subject: [PATCH 09/22] WIP: fix broken error message --- lib/matplotlib/category.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index 0e5218d45870..c5f82c76e767 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -98,7 +98,7 @@ def update(self, data): if val in self._mapping: continue if not isinstance(val, six.text_type): - raise TypeError("{val!r} is not a string".format(val)) + raise TypeError("{val!r} is not a string".format(val=val)) self._vals.append(val) self._mapping[val] = next(self._counter) From 3972c3852175247a98414d06a5b8fefaab0658dc Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 26 Nov 2017 15:42:06 -0500 Subject: [PATCH 10/22] STY: whitespace clean up --- lib/matplotlib/lines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/lines.py b/lib/matplotlib/lines.py index a9999b419f15..9fd097f648ee 100644 --- a/lib/matplotlib/lines.py +++ b/lib/matplotlib/lines.py @@ -676,8 +676,8 @@ def recache(self, always=False): if nanmask.any(): self._x_filled = self._x.copy() indices = np.arange(len(x)) - self._x_filled[nanmask] = np.interp(indices[nanmask], - indices[~nanmask], self._x[~nanmask]) + self._x_filled[nanmask] = np.interp( + indices[nanmask], indices[~nanmask], self._x[~nanmask]) else: self._x_filled = self._x From 3c7b1852c2f18293e8b51c3c4720ad88fd2c70c1 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 26 Nov 2017 15:42:21 -0500 Subject: [PATCH 11/22] DOC: document API changes --- doc/api/api_changes.rst | 1 + doc/api/api_changes/2018-01_TAC.rst | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 doc/api/api_changes/2018-01_TAC.rst diff --git a/doc/api/api_changes.rst b/doc/api/api_changes.rst index 46a65f203e65..bdf3e1d6e76e 100644 --- a/doc/api/api_changes.rst +++ b/doc/api/api_changes.rst @@ -47,6 +47,7 @@ to ``'mask'``. As a side effect of this change, error bars which go negative now work as expected on log scales. + API Changes in 2.1.0 ==================== diff --git a/doc/api/api_changes/2018-01_TAC.rst b/doc/api/api_changes/2018-01_TAC.rst new file mode 100644 index 000000000000..f3455e103f0b --- /dev/null +++ b/doc/api/api_changes/2018-01_TAC.rst @@ -0,0 +1,8 @@ + +Simplify String Categorical handling +------------------------------------ + + - Only handling `str` (not `bytes` in pyhon3). This disallows using integers or + floats as category labels. If you wish to have integer category labels, convert + to string before plotting. + - Do not allow missing data. This is a consequence of the first change. From 87bfe1e9b41946ccff1ed4b9038a8cc234233bd7 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 26 Nov 2017 16:13:11 -0500 Subject: [PATCH 12/22] TST: add mixed-type failure tests --- lib/matplotlib/tests/test_category.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index eb8f95b8016b..3748ba4f5bd8 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -267,3 +267,10 @@ def test_StrCategoryFormatter(ax, plotter): assert ax.yaxis.major.formatter(1, 1) == "мир" assert ax.yaxis.major.formatter(2, 2) == "" assert ax.yaxis.major.formatter(0, None) == "" + + +@pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter]) +@pytest.mark.parametrize("xdata", [[1, 'a'], ['a', 1]]) +def test_mixed_failures(ax, plotter, xdata): + with pytest.raises((ValueError, TypeError)): + plotter(ax, xdata, [1, 2]) From d7a32f66cc0804786640bb2a6f721762d251f747 Mon Sep 17 00:00:00 2001 From: hannah Date: Wed, 29 Nov 2017 14:28:12 -0500 Subject: [PATCH 13/22] added test for #9843 --- lib/matplotlib/tests/test_category.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 3748ba4f5bd8..8732c80b9048 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -60,7 +60,11 @@ class TestStrCategoryConverter(object): test_cases = [("unicode", {u"Здравствуйте мир": 42}), ("ascii", {"hello world": 42}), - ("single", {'a': 0, 'b': 1, 'c': 2})] + ("single", {'a': 0, 'b': 1, 'c': 2}), + ("single + values>10", {'A': 0, 'B': 1, 'C': 2, + 'D': 3, 'E': 4, 'F': 5, + 'G': 6, 'H': 7, 'I': 8, + 'J': 9, 'K': 10})] ids, unitmaps = zip(*test_cases) From f68231d3f38449daa8036fac6b22091c9608d297 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 9 Jan 2018 22:23:00 -0500 Subject: [PATCH 14/22] STY: remove utf marker --- lib/matplotlib/category.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index c5f82c76e767..42699449c243 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 OA -*- - """ catch all for categorical functions """ From 78c5909a903098bac13b728ee69a68b1eb7ea91f Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 9 Jan 2018 22:26:23 -0500 Subject: [PATCH 15/22] DOC: improve docstring --- lib/matplotlib/category.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index 42699449c243..0e8bffee785e 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -20,7 +20,7 @@ class StrCategoryConverter(units.ConversionInterface): @staticmethod def convert(value, unit, axis): - """Use axis.units mapping tncode data as floats.""" + """Use axis.units mapping to map categorical data to floats.""" # We also need to pass numbers through. if np.issubdtype(np.asarray(value).dtype.type, np.number): From c0e485104417a7cf451dd3219d6071be9320faa1 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 9 Jan 2018 22:26:32 -0500 Subject: [PATCH 16/22] MNT: remove unused imports in category.py --- lib/matplotlib/category.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index 0e8bffee785e..d1d468133331 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -5,14 +5,10 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -from collections import Iterable, Sequence, OrderedDict -import itertools -import numbers -from matplotlib import cbook, ticker, units -import six - from collections import OrderedDict import itertools +from matplotlib import ticker, units +import six import numpy as np From 1c280f515b28c8dd0182bf0a069f4f9a9254dc3e Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 9 Jan 2018 23:06:23 -0500 Subject: [PATCH 17/22] TST: fix tests on py2 --- lib/matplotlib/tests/test_category.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 8732c80b9048..24aa369f4a3f 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -229,7 +229,7 @@ def ax(): @pytest.mark.parametrize( "data, expected_indices, expected_labels", - [(["Здравствуйте мир"], [0], ["Здравствуйте мир"]), + [([u"Здравствуйте мир"], [0], [u"Здравствуйте мир"]), (["a", "b", "b", "a", "c", "c"], [0, 1, 1, 0, 2, 2], ["a", "b", "c"]), (["foo", "bar"], range(2), ["foo", "bar"]), (np.array(["1", "11", "3"]), range(3), ["1", "11", "3"])]) @@ -266,9 +266,9 @@ def test_StrCategoryLocator(ax, plotter): @pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter, Axes.bar]) def test_StrCategoryFormatter(ax, plotter): - plotter(ax, range(2), ["hello", "мир"]) + plotter(ax, range(2), ["hello", u"мир"]) assert ax.yaxis.major.formatter(0, 0) == "hello" - assert ax.yaxis.major.formatter(1, 1) == "мир" + assert ax.yaxis.major.formatter(1, 1) == u"мир" assert ax.yaxis.major.formatter(2, 2) == "" assert ax.yaxis.major.formatter(0, None) == "" From 6beebcc0b3556282133d0f921cb90b19a6d574bb Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Tue, 9 Jan 2018 23:06:47 -0500 Subject: [PATCH 18/22] API: restore support for bytes --- lib/matplotlib/category.py | 28 ++++++++++++++++++--------- lib/matplotlib/tests/test_category.py | 9 +++++++-- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index d1d468133331..783306234071 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -17,14 +17,16 @@ class StrCategoryConverter(units.ConversionInterface): @staticmethod def convert(value, unit, axis): """Use axis.units mapping to map categorical data to floats.""" - + def getter(k): + if not isinstance(k, six.text_type): + k = k.decode('utf-8') + return axis.units._mapping[k] # We also need to pass numbers through. if np.issubdtype(np.asarray(value).dtype.type, np.number): return value else: axis.units.update(value) - str2idx = np.vectorize(axis.units._mapping.__getitem__, - otypes=[float]) + str2idx = np.vectorize(getter, otypes=[float]) return str2idx(value) @staticmethod @@ -61,6 +63,9 @@ def __call__(self, x, pos=None): class UnitData(object): + valid_types = tuple(set(six.string_types + + (bytes, six.text_type, np.str_, np.bytes_))) + def __init__(self, data=None): """Create mapping between unique categorical values and numerical id. @@ -73,10 +78,12 @@ def __init__(self, data=None): self._vals = [] if data is None: data = () - self._mapping = OrderedDict(data) - for k, v in self._mapping.items(): - if not isinstance(k, six.text_type): + self._mapping = OrderedDict() + for k, v in OrderedDict(data).items(): + if not isinstance(k, self.valid_types): raise TypeError("{val!r} is not a string".format(val=k)) + if not isinstance(k, six.text_type): + k = k.decode('utf-8') self._mapping[k] = int(v) if self._mapping: start = max(self._mapping.values()) + 1 @@ -85,19 +92,22 @@ def __init__(self, data=None): self._counter = itertools.count(start=start) def update(self, data): - if isinstance(data, six.string_types): + if isinstance(data, self.valid_types): data = [data] sorted_unique = OrderedDict.fromkeys(data) for val in sorted_unique: + if not isinstance(val, self.valid_types): + raise TypeError("{val!r} is not a string".format(val=val)) + if not isinstance(val, six.text_type): + val = val.decode('utf-8') if val in self._mapping: continue - if not isinstance(val, six.text_type): - raise TypeError("{val!r} is not a string".format(val=val)) self._vals.append(val) self._mapping[val] = next(self._counter) # Connects the convertor to matplotlib + units.registry[str] = StrCategoryConverter() units.registry[bytes] = StrCategoryConverter() units.registry[np.str_] = StrCategoryConverter() diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 24aa369f4a3f..c744bf57bf69 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -61,6 +61,8 @@ class TestStrCategoryConverter(object): test_cases = [("unicode", {u"Здравствуйте мир": 42}), ("ascii", {"hello world": 42}), ("single", {'a': 0, 'b': 1, 'c': 2}), + ("single bytes", {b'a': 0, b'b': 1, b'c': 2}), + ("mixed bytes", {b'a': 0, 'b': 1, b'c': 2}), ("single + values>10", {'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, @@ -111,11 +113,11 @@ def test_StrCategoryFormatter(self): assert labels(1, 1) == "world" def test_StrCategoryFormatterUnicode(self): - seq = ["Здравствуйте", "привет"] + seq = [u"Здравствуйте", u"привет"] u = cat.UnitData() u.update(seq) labels = cat.StrCategoryFormatter(u) - assert labels(1, 1) == "привет" + assert labels(1, 1) == u"привет" def lt(tl): @@ -130,6 +132,8 @@ def axis_test(axis, ticks, labels, unit_data): class TestBarsBytes(object): bytes_cases = [('string list', ['a', 'b', 'c']), + ('bytes list', [b'a', b'b', b'c']), + ('mixed list', [b'a', 'b', b'c']), ] bytes_ids, bytes_data = zip(*bytes_cases) @@ -232,6 +236,7 @@ def ax(): [([u"Здравствуйте мир"], [0], [u"Здравствуйте мир"]), (["a", "b", "b", "a", "c", "c"], [0, 1, 1, 0, 2, 2], ["a", "b", "c"]), (["foo", "bar"], range(2), ["foo", "bar"]), + ([b"foo", "bar"], range(2), ["foo", "bar"]), (np.array(["1", "11", "3"]), range(3), ["1", "11", "3"])]) def test_simple(ax, data, expected_indices, expected_labels): l, = ax.plot(data) From 7896cdc7db94e462f7ca506ccccadca8b6768f2c Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 4 Feb 2018 21:31:57 -0500 Subject: [PATCH 19/22] Revert "WIP: add explicit test of conversion to float" This reverts commit aaefde4fcb27932f98851b02b6f4e020111c5371. --- lib/matplotlib/axes/_base.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 1aaeab3fadbe..34129b7ab8b4 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -211,11 +211,6 @@ def set_patchprops(self, fill_poly, **kwargs): fill_poly.set(**kwargs) def _xy_from_xy(self, x, y): - def _check_numeric(inp): - if inp.dtype.kind not in set('biufc'): - raise TypeError('input must be numeric after unit conversion') - return inp - if self.axes.xaxis is not None and self.axes.yaxis is not None: bx = self.axes.xaxis.update_units(x) by = self.axes.yaxis.update_units(y) @@ -226,9 +221,8 @@ def _check_numeric(inp): # like asanyarray, but converts scalar to array, and doesn't change # existing compatible sequences - x = _check_numeric(_check_1d(x)) - y = _check_numeric(_check_1d(y)) - + x = _check_1d(x) + y = _check_1d(y) if x.shape[0] != y.shape[0]: raise ValueError("x and y must have same first dimension, but " "have shapes {} and {}".format(x.shape, y.shape)) From 6cc1841ec23cd2f7b9978fcaffabab5b286393d9 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 4 Feb 2018 21:44:04 -0500 Subject: [PATCH 20/22] Revert "MNT: re-work how units are handled in axes._base" This reverts commit ec692d87f4e28a4e37fa83c6865deb8eb1d83fab. --- lib/matplotlib/axes/_base.py | 72 ++++++++++++++------------------ lib/matplotlib/cbook/__init__.py | 4 +- 2 files changed, 34 insertions(+), 42 deletions(-) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 34129b7ab8b4..bb70972b6c51 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -214,10 +214,24 @@ def _xy_from_xy(self, x, y): if self.axes.xaxis is not None and self.axes.yaxis is not None: bx = self.axes.xaxis.update_units(x) by = self.axes.yaxis.update_units(y) - if bx: - x = self.axes.convert_xunits(x) - if by: - y = self.axes.convert_yunits(y) + + if self.command != 'plot': + # the Line2D class can handle unitized data, with + # support for post hoc unit changes etc. Other mpl + # artists, e.g., Polygon which _process_plot_var_args + # also serves on calls to fill, cannot. So this is a + # hack to say: if you are not "plot", which is + # creating Line2D, then convert the data now to + # floats. If you are plot, pass the raw data through + # to Line2D which will handle the conversion. So + # polygons will not support post hoc conversions of + # the unit type since they are not storing the orig + # data. Hopefully we can rationalize this at a later + # date - JDH + if bx: + x = self.axes.convert_xunits(x) + if by: + y = self.axes.convert_yunits(y) # like asanyarray, but converts scalar to array, and doesn't change # existing compatible sequences @@ -360,48 +374,26 @@ def _plot_args(self, tup, kwargs): if 'label' not in kwargs or kwargs['label'] is None: kwargs['label'] = get_label(tup[-1], None) - if len(tup) == 1: - x, y = index_of(tup[0]) - elif len(tup) == 2: - x, y = tup + if len(tup) == 2: + x = _check_1d(tup[0]) + y = _check_1d(tup[-1]) else: - assert False - - deunitized_x, deunitized_y = self._xy_from_xy(x, y) - # The previous call has registered the converters, if any, on the axes. - # This check will need to be replaced by a comparison with the - # DefaultConverter when that PR goes in. - if self.axes.xaxis.converter is None or self.command is not "plot": - xt, yt = deunitized_x.T, deunitized_y.T + x, y = index_of(tup[-1]) + + x, y = self._xy_from_xy(x, y) + + if self.command == 'plot': + func = self._makeline else: - # np.asarray would destroy unit information so we need to construct - # the 1D arrays to pass to Line2D.set_xdata manually... (but this - # is only relevant if the command is "plot"). - - def to_list_of_lists(data): - ndim = np.ndim(data) - if ndim == 0: - return [[data]] - elif ndim == 1: - return [data] - elif ndim == 2: - return zip(*data) # Transpose it. - - xt, yt = map(to_list_of_lists, [x, y]) - - ncx, ncy = deunitized_x.shape[1], deunitized_y.shape[1] + kw['closed'] = kwargs.get('closed', True) + func = self._makefill + + ncx, ncy = x.shape[1], y.shape[1] if ncx > 1 and ncy > 1 and ncx != ncy: cbook.warn_deprecated("2.2", "cycling among columns of inputs " "with non-matching shapes is deprecated.") - for j in xrange(max(ncx, ncy)): - if self.command == "plot": - seg = self._makeline(xt[j % ncx], yt[j % ncy], kw, kwargs) - else: - kw['closed'] = kwargs.get('closed', True) - seg = self._makefill(deunitized_x[:, j % ncx], - deunitized_y[:, j % ncy], - kw, kwargs) + seg = func(x[:, j % ncx], y[:, j % ncy], kw, kwargs) ret.append(seg) return ret diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index 4861a17d8690..96c33fb3adb2 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -2312,8 +2312,8 @@ def index_of(y): try: return y.index.values, y.values except AttributeError: - # Ensure that scalar y gives x == [0]. - return np.arange((np.shape(y) or (1,))[0], dtype=float), y + y = _check_1d(y) + return np.arange(y.shape[0], dtype=float), y def safe_first_element(obj): From 0ca66d223081e40563927c7a4184fc178d7a5502 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 4 Feb 2018 21:46:49 -0500 Subject: [PATCH 21/22] TST: make sure pandas tests actually draw --- lib/matplotlib/tests/test_axes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 3e86a6296fa5..8bf81f042361 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -5158,6 +5158,7 @@ def test_pandas_pcolormesh(pd): fig, ax = plt.subplots() ax.pcolormesh(time, depth, data) + fig.canvas.draw() def test_pandas_indexing_dates(pd): @@ -5169,6 +5170,7 @@ def test_pandas_indexing_dates(pd): without_zero_index = df[np.array(df.index) % 2 == 1].copy() ax.plot('dates', 'values', data=without_zero_index) + ax.figure.canvas.draw() def test_pandas_errorbar_indexing(pd): @@ -5177,6 +5179,7 @@ def test_pandas_errorbar_indexing(pd): index=[1, 2, 3, 4, 5]) fig, ax = plt.subplots() ax.errorbar('x', 'y', xerr='xe', yerr='ye', data=df) + fig.canvas.draw() def test_pandas_indexing_hist(pd): @@ -5184,6 +5187,7 @@ def test_pandas_indexing_hist(pd): ser_2 = ser_1.iloc[1:] fig, axes = plt.subplots() axes.hist(ser_2) + fig.canvas.draw() def test_pandas_bar_align_center(pd): From 5515d3a41b8e87b63b0037639e2b3e396e2c0e81 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 4 Feb 2018 21:48:52 -0500 Subject: [PATCH 22/22] DOC: fix API docs --- doc/api/api_changes/2018-01_TAC.rst | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/doc/api/api_changes/2018-01_TAC.rst b/doc/api/api_changes/2018-01_TAC.rst index f3455e103f0b..ea1b2d948ae8 100644 --- a/doc/api/api_changes/2018-01_TAC.rst +++ b/doc/api/api_changes/2018-01_TAC.rst @@ -2,7 +2,4 @@ Simplify String Categorical handling ------------------------------------ - - Only handling `str` (not `bytes` in pyhon3). This disallows using integers or - floats as category labels. If you wish to have integer category labels, convert - to string before plotting. - - Do not allow missing data. This is a consequence of the first change. + - Do not allow missing data.