diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index d043c5b154a5..5d651b26c5b0 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -1,67 +1,40 @@ -# -*- coding: utf-8 OA-*-za -""" -catch all for categorical functions +"""Helpers for categorical data. """ from __future__ import (absolute_import, division, print_function, unicode_literals) import six -import numpy as np +from collections import OrderedDict +import itertools -import matplotlib.units as units -import matplotlib.ticker as ticker +import numpy as np -# np 1.6/1.7 support -from distutils.version import LooseVersion -import collections +from matplotlib import units, ticker -if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): - def shim_array(data): - return np.array(data, dtype=np.unicode) -else: - def shim_array(data): - if (isinstance(data, six.string_types) or - not isinstance(data, collections.Iterable)): - data = [data] - try: - data = [str(d) for d in data] - except UnicodeEncodeError: - # this yields gibberish but unicode text doesn't - # render under numpy1.6 anyway - data = [d.encode('utf-8', 'ignore').decode('utf-8') - for d in data] - return np.array(data, dtype=np.unicode) +def _to_str(s): + return s.decode("ascii") if isinstance(s, bytes) else str(s) class StrCategoryConverter(units.ConversionInterface): @staticmethod def convert(value, unit, axis): - """Uses axis.unit_data map to encode - data as floats - """ - vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) - - if isinstance(value, six.string_types): - return vmap[value] - - vals = shim_array(value) - - for lab, loc in vmap.items(): - vals[vals == lab] = loc - - return vals.astype('float') + """Uses axis.unit_data map to encode data as floats.""" + mapping = axis.unit_data._mapping + return (mapping[_to_str(value)] if np.isscalar(value) + else np.array([mapping[_to_str(v)] for v in value], float)) @staticmethod def axisinfo(unit, axis): - majloc = StrCategoryLocator(axis.unit_data.locs) - majfmt = StrCategoryFormatter(axis.unit_data.seq) + # Note that mapping may get mutated by later calls to plotting methods, + # so the locator and formatter must dynamically recompute locs and seq. + majloc = StrCategoryLocator(axis.unit_data._mapping) + majfmt = StrCategoryFormatter(axis.unit_data._mapping) return units.AxisInfo(majloc=majloc, majfmt=majfmt) @staticmethod def default_units(data, axis): - # the conversion call stack is: - # default_units->axis_info->convert + # the conversion call stack is default_units->axis_info->convert if axis.unit_data is None: axis.unit_data = UnitData(data) else: @@ -70,48 +43,46 @@ def default_units(data, axis): class StrCategoryLocator(ticker.FixedLocator): - def __init__(self, locs): - self.locs = locs + def __init__(self, mapping): + self._mapping = mapping self.nbins = None + @property + def locs(self): + return list(self._mapping.values()) + class StrCategoryFormatter(ticker.FixedFormatter): - def __init__(self, seq): - self.seq = seq - self.offset_string = '' + def __init__(self, mapping): + self._mapping = mapping + self.offset_string = "" + @property + def seq(self): + return list(self._mapping) -class UnitData(object): - # debatable makes sense to special code missing values - spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0} +class UnitData(object): def __init__(self, data): - """Create mapping between unique categorical values - and numerical identifier + """Create mapping between unique categorical values and numerical id. Parameters ---------- data: iterable sequence of values """ - self.seq, self.locs = [], [] - self._set_seq_locs(data, 0) - - def update(self, new_data): - # so as not to conflict with spdict - value = max(max(self.locs) + 1, 0) - self._set_seq_locs(new_data, value) - - def _set_seq_locs(self, data, value): - strdata = shim_array(data) - new_s = [d for d in np.unique(strdata) if d not in self.seq] - for ns in new_s: - self.seq.append(ns) - if ns in UnitData.spdict: - self.locs.append(UnitData.spdict[ns]) - else: - self.locs.append(value) - value += 1 + self._mapping = {} + self._counter = itertools.count() + self.update(data) + + def update(self, data): + if isinstance(data, six.string_types): + data = [data] + sorted_unique = OrderedDict.fromkeys(map(_to_str, data)) + for s in sorted_unique: + if s in self._mapping: + continue + self._mapping[s] = next(self._counter) # Connects the convertor to matplotlib diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 6e5c43d76fb9..06ef63dae215 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -13,36 +13,26 @@ class TestUnitData(object): - testdata = [("hello world", ["hello world"], [0]), - ("Здравствуйте мир", ["Здравствуйте мир"], [0]), + testdata = [("hello world", {"hello world": 0}), + ("Здравствуйте мир", {"Здравствуйте мир": 0}), (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - ['-inf', '3.14', 'A', 'B', 'inf', 'nan'], - [-3.0, 0, 1, 2, -2.0, -1.0])] - + {'A': 0, 'nan': 1, 'B': 2, '-inf': 3, '3.14': 4, 'inf': 5})] ids = ["single", "unicode", "mixed"] - @pytest.mark.parametrize("data, seq, locs", testdata, ids=ids) - def test_unit(self, data, seq, locs): - act = cat.UnitData(data) - assert act.seq == seq - assert act.locs == locs + @pytest.mark.parametrize("data, mapping", testdata, ids=ids) + def test_unit(self, data, mapping): + assert cat.UnitData(data)._mapping == mapping def test_update_map(self): - data = ['a', 'd'] - oseq = ['a', 'd'] - olocs = [0, 1] - - data_update = ['b', 'd', 'e', np.inf] - useq = ['a', 'd', 'b', 'e', 'inf'] - ulocs = [0, 1, 2, 3, -2] + unitdata = cat.UnitData(['a', 'd']) + assert unitdata._mapping == {'a': 0, 'd': 1} + unitdata.update(['b', 'd', 'e']) + assert unitdata._mapping == {'a': 0, 'd': 1, 'b': 2, 'e': 3} - unitdata = cat.UnitData(data) - assert unitdata.seq == oseq - assert unitdata.locs == olocs - unitdata.update(data_update) - assert unitdata.seq == useq - assert unitdata.locs == ulocs +class MockUnitData: + def __init__(self, mapping): + self._mapping = mapping class FakeAxis(object): @@ -50,28 +40,20 @@ def __init__(self, unit_data): self.unit_data = unit_data -class MockUnitData(object): - def __init__(self, data): - seq, locs = zip(*data) - self.seq = list(seq) - self.locs = list(locs) - - class TestStrCategoryConverter(object): """Based on the pandas conversion and factorization tests: ref: /pandas/tseries/tests/test_converter.py /pandas/tests/test_algos.py:TestFactorize """ - testdata = [("Здравствуйте мир", [("Здравствуйте мир", 42)], 42), - ("hello world", [("hello world", 42)], 42), + testdata = [("Здравствуйте мир", {"Здравствуйте мир": 42}, 42), + ("hello world", {"hello world": 42}, 42), (['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c'], - [('a', 0), ('b', 1), ('c', 2)], + {'a': 0, 'b': 1, 'c': 2}, [0, 1, 1, 0, 0, 2, 2, 2]), - (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - [('nan', -1), ('3.14', 0), ('A', 1), ('B', 2), - ('-inf', 100), ('inf', 200)], - [1, 1, -1, 2, 100, 0, 200])] + (['A', 'A', 'B', 3.14], + {'3.14': 0, 'A': 1, 'B': 2}, + [1, 1, 2, 0])] ids = ["unicode", "single", "basic", "mixed"] @pytest.fixture(autouse=True) @@ -86,7 +68,7 @@ def test_convert(self, data, unitmap, exp): np.testing.assert_array_equal(act, exp) def test_axisinfo(self): - MUD = MockUnitData([(None, None)]) + MUD = MockUnitData({None: None}) axis = FakeAxis(MUD) ax = self.cc.axisinfo(None, axis) assert isinstance(ax.majloc, cat.StrCategoryLocator) @@ -99,8 +81,8 @@ def test_default_units(self): class TestStrCategoryLocator(object): def test_StrCategoryLocator(self): - locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - ticks = cat.StrCategoryLocator(locs) + locs = list(range(10)) + ticks = cat.StrCategoryLocator({str(x): x for x in locs}) np.testing.assert_array_equal(ticks.tick_values(None, None), locs) @@ -145,27 +127,18 @@ def data(self): self.d = ['a', 'b', 'c', 'a'] self.dticks = [0, 1, 2] self.dlabels = ['a', 'b', 'c'] - unitmap = [('a', 0), ('b', 1), ('c', 2)] + unitmap = {'a': 0, 'b': 1, 'c': 2} self.dunit_data = MockUnitData(unitmap) - @pytest.fixture - def missing_data(self): - self.dm = ['here', np.nan, 'here', 'there'] - self.dmticks = [0, -1, 1] - self.dmlabels = ['here', 'nan', 'there'] - unitmap = [('here', 0), ('nan', -1), ('there', 1)] - self.dmunit_data = MockUnitData(unitmap) - def axis_test(self, axis, ticks, labels, unit_data): np.testing.assert_array_equal(axis.get_majorticklocs(), ticks) assert lt(axis.get_majorticklabels()) == labels - np.testing.assert_array_equal(axis.unit_data.locs, unit_data.locs) - assert axis.unit_data.seq == unit_data.seq + assert axis.unit_data._mapping == unit_data._mapping def test_plot_unicode(self): words = ['Здравствуйте', 'привет'] locs = [0.0, 1.0] - unit_data = MockUnitData(zip(words, locs)) + unit_data = MockUnitData(dict(zip(words, locs))) fig, ax = plt.subplots() ax.plot(words) @@ -181,14 +154,6 @@ def test_plot_1d(self): self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - @pytest.mark.usefixtures("missing_data") - def test_plot_1d_missing(self): - fig, ax = plt.subplots() - ax.plot(self.dm) - fig.canvas.draw() - - self.axis_test(ax.yaxis, self.dmticks, self.dmlabels, self.dmunit_data) - @pytest.mark.usefixtures("data") @pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids) def test_plot_bytes(self, bars): @@ -208,28 +173,9 @@ def test_plot_numlike(self, bars): ax.bar(bars, counts) fig.canvas.draw() - unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)]) + unitmap = MockUnitData({'1': 0, '11': 1, '3': 2}) self.axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap) - @pytest.mark.usefixtures("data", "missing_data") - def test_plot_2d(self): - fig, ax = plt.subplots() - ax.plot(self.dm, self.d) - fig.canvas.draw() - - self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - - @pytest.mark.usefixtures("data", "missing_data") - def test_scatter_2d(self): - - fig, ax = plt.subplots() - ax.scatter(self.dm, self.d) - fig.canvas.draw() - - self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - def test_plot_update(self): fig, ax = plt.subplots() @@ -240,6 +186,6 @@ def test_plot_update(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unit_data = MockUnitData(list(zip(labels, ticks))) + unit_data = MockUnitData(dict(zip(labels, ticks))) self.axis_test(ax.yaxis, ticks, labels, unit_data)