Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f1f1919

Browse files
committed
Rethink categoricals.
Don't support mixed type inputs. Don't sort keys.
1 parent c508c35 commit f1f1919

File tree

2 files changed

+72
-161
lines changed

2 files changed

+72
-161
lines changed

lib/matplotlib/category.py

Lines changed: 45 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,40 @@
1-
# -*- coding: utf-8 OA-*-za
2-
"""
3-
catch all for categorical functions
1+
"""Helpers for categorical data.
42
"""
3+
54
from __future__ import (absolute_import, division, print_function,
65
unicode_literals)
76
import six
87

9-
import numpy as np
10-
11-
import matplotlib.units as units
12-
import matplotlib.ticker as ticker
13-
14-
# np 1.6/1.7 support
15-
from distutils.version import LooseVersion
16-
import collections
8+
from collections import OrderedDict
9+
import itertools
1710

11+
import numpy as np
1812

19-
if LooseVersion(np.__version__) >= LooseVersion('1.8.0'):
20-
def shim_array(data):
21-
return np.array(data, dtype=np.unicode)
22-
else:
23-
def shim_array(data):
24-
if (isinstance(data, six.string_types) or
25-
not isinstance(data, collections.Iterable)):
26-
data = [data]
27-
try:
28-
data = [str(d) for d in data]
29-
except UnicodeEncodeError:
30-
# this yields gibberish but unicode text doesn't
31-
# render under numpy1.6 anyway
32-
data = [d.encode('utf-8', 'ignore').decode('utf-8')
33-
for d in data]
34-
return np.array(data, dtype=np.unicode)
13+
from matplotlib import cbook, ticker, units
3514

3615

3716
class StrCategoryConverter(units.ConversionInterface):
3817
@staticmethod
3918
def convert(value, unit, axis):
40-
"""Uses axis.unit_data map to encode
41-
data as floats
42-
"""
43-
value = np.atleast_1d(value)
44-
# try and update from here....
45-
if hasattr(axis.unit_data, 'update'):
46-
for val in value:
47-
if isinstance(val, six.string_types):
48-
axis.unit_data.update(val)
49-
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
50-
51-
if isinstance(value, six.string_types):
52-
return vmap[value]
53-
54-
vals = shim_array(value)
55-
56-
for lab, loc in vmap.items():
57-
vals[vals == lab] = loc
58-
59-
return vals.astype('float')
19+
"""Uses axis.unit_data map to encode data as floats."""
20+
# We also need to pass numbers through.
21+
if np.issubdtype(np.asarray(value).dtype.type, np.number):
22+
return value
23+
else:
24+
axis.unit_data.update(value)
25+
return np.vectorize(axis.unit_data._mapping.__getitem__)(value)
6026

6127
@staticmethod
6228
def axisinfo(unit, axis):
63-
majloc = StrCategoryLocator(axis.unit_data.locs)
64-
majfmt = StrCategoryFormatter(axis.unit_data.seq)
65-
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
29+
# Note that mapping may get mutated by later calls to plotting methods,
30+
# so the locator and formatter must dynamically recompute locs and seq.
31+
return units.AxisInfo(
32+
majloc=StrCategoryLocator(axis.unit_data._mapping),
33+
majfmt=StrCategoryFormatter(axis.unit_data._mapping))
6634

6735
@staticmethod
6836
def default_units(data, axis):
69-
# the conversion call stack is:
70-
# default_units->axis_info->convert
37+
# the conversion call stack is default_units->axis_info->convert
7138
if axis.unit_data is None:
7239
axis.unit_data = UnitData(data)
7340
else:
@@ -76,48 +43,46 @@ def default_units(data, axis):
7643

7744

7845
class StrCategoryLocator(ticker.FixedLocator):
79-
def __init__(self, locs):
80-
self.locs = locs
46+
def __init__(self, mapping):
47+
self._mapping = mapping
8148
self.nbins = None
8249

50+
@property
51+
def locs(self):
52+
return list(self._mapping.values())
53+
8354

8455
class StrCategoryFormatter(ticker.FixedFormatter):
85-
def __init__(self, seq):
86-
self.seq = seq
87-
self.offset_string = ''
56+
def __init__(self, mapping):
57+
self._mapping = mapping
58+
self.offset_string = ""
8859

60+
@property
61+
def seq(self):
62+
return list(self._mapping)
8963

90-
class UnitData(object):
91-
# debatable makes sense to special code missing values
92-
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
9364

65+
class UnitData(object):
9466
def __init__(self, data):
95-
"""Create mapping between unique categorical values
96-
and numerical identifier
67+
"""Create mapping between unique categorical values and numerical id.
9768
9869
Parameters
9970
----------
10071
data: iterable
10172
sequence of values
10273
"""
103-
self.seq, self.locs = [], []
104-
self._set_seq_locs(data, 0)
105-
106-
def update(self, new_data):
107-
# so as not to conflict with spdict
108-
value = max(max(self.locs) + 1, 0)
109-
self._set_seq_locs(new_data, value)
110-
111-
def _set_seq_locs(self, data, value):
112-
strdata = shim_array(data)
113-
new_s = [d for d in np.unique(strdata) if d not in self.seq]
114-
for ns in new_s:
115-
self.seq.append(ns)
116-
if ns in UnitData.spdict:
117-
self.locs.append(UnitData.spdict[ns])
118-
else:
119-
self.locs.append(value)
120-
value += 1
74+
self._mapping = {}
75+
self._counter = itertools.count()
76+
self.update(data)
77+
78+
def update(self, data):
79+
if isinstance(data, six.string_types):
80+
data = [data]
81+
sorted_unique = OrderedDict.fromkeys(data)
82+
for s in sorted_unique:
83+
if s in self._mapping:
84+
continue
85+
self._mapping[s] = next(self._counter)
12186

12287

12388
# Connects the convertor to matplotlib

lib/matplotlib/tests/test_category.py

Lines changed: 27 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -13,65 +13,47 @@
1313

1414

1515
class TestUnitData(object):
16-
testdata = [("hello world", ["hello world"], [0]),
17-
("Здравствуйте мир", ["Здравствуйте мир"], [0]),
16+
testdata = [("hello world", {"hello world": 0}),
17+
("Здравствуйте мир", {"Здравствуйте мир": 0}),
1818
(['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf],
19-
['-inf', '3.14', 'A', 'B', 'inf', 'nan'],
20-
[-3.0, 0, 1, 2, -2.0, -1.0])]
21-
19+
{'A': 0, 'nan': 1, 'B': 2, '-inf': 3, '3.14': 4, 'inf': 5})]
2220
ids = ["single", "unicode", "mixed"]
2321

24-
@pytest.mark.parametrize("data, seq, locs", testdata, ids=ids)
25-
def test_unit(self, data, seq, locs):
26-
act = cat.UnitData(data)
27-
assert act.seq == seq
28-
assert act.locs == locs
22+
@pytest.mark.parametrize("data, mapping", testdata, ids=ids)
23+
def test_unit(self, data, mapping):
24+
assert cat.UnitData(data)._mapping == mapping
2925

3026
def test_update_map(self):
31-
data = ['a', 'd']
32-
oseq = ['a', 'd']
33-
olocs = [0, 1]
34-
35-
data_update = ['b', 'd', 'e', np.inf]
36-
useq = ['a', 'd', 'b', 'e', 'inf']
37-
ulocs = [0, 1, 2, 3, -2]
27+
unitdata = cat.UnitData(['a', 'd'])
28+
assert unitdata._mapping == {'a': 0, 'd': 1}
29+
unitdata.update(['b', 'd', 'e'])
30+
assert unitdata._mapping == {'a': 0, 'd': 1, 'b': 2, 'e': 3}
3831

39-
unitdata = cat.UnitData(data)
40-
assert unitdata.seq == oseq
41-
assert unitdata.locs == olocs
4232

43-
unitdata.update(data_update)
44-
assert unitdata.seq == useq
45-
assert unitdata.locs == ulocs
33+
class MockUnitData:
34+
def __init__(self, mapping):
35+
self._mapping = mapping
4636

4737

4838
class FakeAxis(object):
4939
def __init__(self, unit_data):
5040
self.unit_data = unit_data
5141

5242

53-
class MockUnitData(object):
54-
def __init__(self, data):
55-
seq, locs = zip(*data)
56-
self.seq = list(seq)
57-
self.locs = list(locs)
58-
59-
6043
class TestStrCategoryConverter(object):
6144
"""Based on the pandas conversion and factorization tests:
6245
6346
ref: /pandas/tseries/tests/test_converter.py
6447
/pandas/tests/test_algos.py:TestFactorize
6548
"""
66-
testdata = [("Здравствуйте мир", [("Здравствуйте мир", 42)], 42),
67-
("hello world", [("hello world", 42)], 42),
49+
testdata = [("Здравствуйте мир", {"Здравствуйте мир": 42}, 42),
50+
("hello world", {"hello world": 42}, 42),
6851
(['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c'],
69-
[('a', 0), ('b', 1), ('c', 2)],
52+
{'a': 0, 'b': 1, 'c': 2},
7053
[0, 1, 1, 0, 0, 2, 2, 2]),
71-
(['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf],
72-
[('nan', -1), ('3.14', 0), ('A', 1), ('B', 2),
73-
('-inf', 100), ('inf', 200)],
74-
[1, 1, -1, 2, 100, 0, 200])]
54+
(['A', 'A', 'B', 3.14],
55+
{'3.14': 0, 'A': 1, 'B': 2},
56+
[1, 1, 2, 0])]
7557
ids = ["unicode", "single", "basic", "mixed"]
7658

7759
@pytest.fixture(autouse=True)
@@ -86,7 +68,7 @@ def test_convert(self, data, unitmap, exp):
8668
np.testing.assert_array_equal(act, exp)
8769

8870
def test_axisinfo(self):
89-
MUD = MockUnitData([(None, None)])
71+
MUD = MockUnitData({None: None})
9072
axis = FakeAxis(MUD)
9173
ax = self.cc.axisinfo(None, axis)
9274
assert isinstance(ax.majloc, cat.StrCategoryLocator)
@@ -99,8 +81,8 @@ def test_default_units(self):
9981

10082
class TestStrCategoryLocator(object):
10183
def test_StrCategoryLocator(self):
102-
locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
103-
ticks = cat.StrCategoryLocator(locs)
84+
locs = list(range(10))
85+
ticks = cat.StrCategoryLocator({str(x): x for x in locs})
10486
np.testing.assert_array_equal(ticks.tick_values(None, None), locs)
10587

10688

@@ -145,27 +127,18 @@ def data(self):
145127
self.d = ['a', 'b', 'c', 'a']
146128
self.dticks = [0, 1, 2]
147129
self.dlabels = ['a', 'b', 'c']
148-
unitmap = [('a', 0), ('b', 1), ('c', 2)]
130+
unitmap = {'a': 0, 'b': 1, 'c': 2}
149131
self.dunit_data = MockUnitData(unitmap)
150132

151-
@pytest.fixture
152-
def missing_data(self):
153-
self.dm = ['here', np.nan, 'here', 'there']
154-
self.dmticks = [0, -1, 1]
155-
self.dmlabels = ['here', 'nan', 'there']
156-
unitmap = [('here', 0), ('nan', -1), ('there', 1)]
157-
self.dmunit_data = MockUnitData(unitmap)
158-
159133
def axis_test(self, axis, ticks, labels, unit_data):
160134
np.testing.assert_array_equal(axis.get_majorticklocs(), ticks)
161135
assert lt(axis.get_majorticklabels()) == labels
162-
np.testing.assert_array_equal(axis.unit_data.locs, unit_data.locs)
163-
assert axis.unit_data.seq == unit_data.seq
136+
assert axis.unit_data._mapping == unit_data._mapping
164137

165138
def test_plot_unicode(self):
166139
words = ['Здравствуйте', 'привет']
167140
locs = [0.0, 1.0]
168-
unit_data = MockUnitData(zip(words, locs))
141+
unit_data = MockUnitData(dict(zip(words, locs)))
169142

170143
fig, ax = plt.subplots()
171144
ax.plot(words)
@@ -181,14 +154,6 @@ def test_plot_1d(self):
181154

182155
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)
183156

184-
@pytest.mark.usefixtures("missing_data")
185-
def test_plot_1d_missing(self):
186-
fig, ax = plt.subplots()
187-
ax.plot(self.dm)
188-
fig.canvas.draw()
189-
190-
self.axis_test(ax.yaxis, self.dmticks, self.dmlabels, self.dmunit_data)
191-
192157
@pytest.mark.usefixtures("data")
193158
@pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids)
194159
def test_plot_bytes(self, bars):
@@ -208,28 +173,9 @@ def test_plot_numlike(self, bars):
208173
ax.bar(bars, counts)
209174
fig.canvas.draw()
210175

211-
unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)])
176+
unitmap = MockUnitData({'1': 0, '11': 1, '3': 2})
212177
self.axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap)
213178

214-
@pytest.mark.usefixtures("data", "missing_data")
215-
def test_plot_2d(self):
216-
fig, ax = plt.subplots()
217-
ax.plot(self.dm, self.d)
218-
fig.canvas.draw()
219-
220-
self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data)
221-
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)
222-
223-
@pytest.mark.usefixtures("data", "missing_data")
224-
def test_scatter_2d(self):
225-
226-
fig, ax = plt.subplots()
227-
ax.scatter(self.dm, self.d)
228-
fig.canvas.draw()
229-
230-
self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data)
231-
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)
232-
233179
def test_plot_update(self):
234180
fig, ax = plt.subplots()
235181

@@ -240,7 +186,7 @@ def test_plot_update(self):
240186

241187
labels = ['a', 'b', 'd', 'c']
242188
ticks = [0, 1, 2, 3]
243-
unit_data = MockUnitData(list(zip(labels, ticks)))
189+
unit_data = MockUnitData(dict(zip(labels, ticks)))
244190

245191
self.axis_test(ax.yaxis, ticks, labels, unit_data)
246192

0 commit comments

Comments
 (0)