Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b068739

Browse files
committed
Rewrite category.py.
1 parent 70df945 commit b068739

File tree

2 files changed

+65
-125
lines changed

2 files changed

+65
-125
lines changed

lib/matplotlib/category.py

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,45 @@
44
unicode_literals)
55
import six
66

7-
import collections
87
from collections import OrderedDict
9-
from distutils.version import LooseVersion
108
import itertools
9+
from numbers import Number
1110

1211
import numpy as np
1312

14-
import matplotlib.units as units
15-
import matplotlib.ticker as ticker
13+
from matplotlib import units, ticker
1614

1715

18-
if LooseVersion(np.__version__) >= LooseVersion('1.8.0'):
19-
def shim_array(data):
20-
return np.array(data, dtype=np.unicode)
21-
else:
22-
def shim_array(data):
23-
if (isinstance(data, six.string_types) or
24-
not isinstance(data, collections.Iterable)):
25-
data = [data]
26-
try:
27-
data = [str(d) for d in data]
28-
except UnicodeEncodeError:
29-
# this yields gibberish but unicode text doesn't
30-
# render under numpy1.6 anyway
31-
data = [d.encode('utf-8', 'ignore').decode('utf-8')
32-
for d in data]
33-
return np.array(data, dtype=np.unicode)
16+
def _to_str(s):
17+
return s.decode("ascii") if isinstance(s, bytes) else str(s)
3418

3519

3620
class StrCategoryConverter(units.ConversionInterface):
3721
@staticmethod
3822
def convert(value, unit, axis):
3923
"""Uses axis.unit_data map to encode data as floats."""
40-
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
41-
42-
if isinstance(value, six.string_types):
43-
return vmap[value]
44-
45-
vals = shim_array(value)
46-
47-
for lab, loc in vmap.items():
48-
vals[vals == lab] = loc
49-
50-
return vals.astype('float')
24+
mapping = axis.unit_data._mapping
25+
if isinstance(value, (Number, np.number)):
26+
return value
27+
elif isinstance(value, (str, bytes)):
28+
return mapping[_to_str(value)]
29+
else:
30+
return np.array([v if isinstance(v, (Number, np.number))
31+
else mapping[_to_str(v)]
32+
for v in value],
33+
float)
5134

5235
@staticmethod
5336
def axisinfo(unit, axis):
54-
majloc = StrCategoryLocator(axis.unit_data.locs)
55-
majfmt = StrCategoryFormatter(axis.unit_data.seq)
37+
# Note that mapping may get mutated by later calls to plotting methods,
38+
# so the locator and formatter must dynamically recompute locs and seq.
39+
majloc = StrCategoryLocator(axis.unit_data._mapping)
40+
majfmt = StrCategoryFormatter(axis.unit_data._mapping)
5641
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
5742

5843
@staticmethod
5944
def default_units(data, axis):
60-
# the conversion call stack is:
61-
# default_units->axis_info->convert
45+
# the conversion call stack is default_units->axis_info->convert
6246
if axis.unit_data is None:
6347
axis.unit_data = UnitData(data)
6448
else:
@@ -67,21 +51,26 @@ def default_units(data, axis):
6751

6852

6953
class StrCategoryLocator(ticker.FixedLocator):
70-
def __init__(self, locs):
71-
self.locs = locs
54+
def __init__(self, mapping):
55+
self._mapping = mapping
7256
self.nbins = None
7357

58+
@property
59+
def locs(self):
60+
return list(self._mapping.values())
61+
7462

7563
class StrCategoryFormatter(ticker.FixedFormatter):
76-
def __init__(self, seq):
77-
self.seq = seq
78-
self.offset_string = ''
64+
def __init__(self, mapping):
65+
self._mapping = mapping
66+
self.offset_string = ""
7967

68+
@property
69+
def seq(self):
70+
return list(self._mapping)
8071

81-
class UnitData(object):
82-
# debatable makes sense to special code missing values
83-
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
8472

73+
class UnitData(object):
8574
def __init__(self, data):
8675
"""Create mapping between unique categorical values and numerical id.
8776
@@ -90,21 +79,18 @@ def __init__(self, data):
9079
data: iterable
9180
sequence of values
9281
"""
93-
self.seq, self.locs = [], []
82+
self._mapping = {}
9483
self._counter = itertools.count()
9584
self.update(data)
9685

9786
def update(self, data):
98-
data = np.atleast_1d(shim_array(data))
99-
sorted_unique = list(OrderedDict(zip(data, itertools.repeat(None))))
87+
if isinstance(data, six.string_types):
88+
data = [data]
89+
sorted_unique = OrderedDict((_to_str(d), None) for d in data)
10090
for s in sorted_unique:
101-
if s in self.seq:
91+
if s in self._mapping:
10292
continue
103-
self.seq.append(s)
104-
if s in UnitData.spdict:
105-
self.locs.append(UnitData.spdict[s])
106-
else:
107-
self.locs.append(next(self._counter))
93+
self._mapping[s] = next(self._counter)
10894

10995

11096
# Connects the convertor to matplotlib

lib/matplotlib/tests/test_category.py

Lines changed: 27 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,57 +13,47 @@
1313

1414

1515
class TestUnitData(object):
16-
testdata = [("hello world", ["hello world"], [0]),
17-
("Здравствуйте мир", ["Здравствуйте мир"], [0]),
16+
testdata = [("hello world", {"hello world": 0}),
17+
("Здравствуйте мир", {"Здравствуйте мир": 0}),
1818
(['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf],
19-
['A', 'nan', 'B', '-inf', '3.14', 'inf'],
20-
[0, -1, 1, -3, 2, -2])]
21-
19+
{'A': 0, 'nan': 1, 'B': 2, '-inf': 3, '3.14': 4, 'inf': 5})]
2220
ids = ["single", "unicode", "mixed"]
2321

24-
@pytest.mark.parametrize("data, seq, locs", testdata, ids=ids)
25-
def test_unit(self, data, seq, locs):
26-
act = cat.UnitData(data)
27-
assert act.seq == seq
28-
assert act.locs == locs
22+
@pytest.mark.parametrize("data, mapping", testdata, ids=ids)
23+
def test_unit(self, data, mapping):
24+
assert cat.UnitData(data)._mapping == mapping
2925

3026
def test_update_map(self):
3127
unitdata = cat.UnitData(['a', 'd'])
32-
assert unitdata.seq == ['a', 'd']
33-
assert unitdata.locs == [0, 1]
28+
assert unitdata._mapping == {'a': 0, 'd': 1}
29+
unitdata.update(['b', 'd', 'e'])
30+
assert unitdata._mapping == {'a': 0, 'd': 1, 'b': 2, 'e': 3}
31+
3432

35-
unitdata.update(['b', 'd', 'e', np.inf])
36-
assert unitdata.seq == ['a', 'd', 'b', 'e', 'inf']
37-
assert unitdata.locs == [0, 1, 2, 3, -2]
33+
class MockUnitData:
34+
def __init__(self, mapping):
35+
self._mapping = mapping
3836

3937

4038
class FakeAxis(object):
4139
def __init__(self, unit_data):
4240
self.unit_data = unit_data
4341

4442

45-
class MockUnitData(object):
46-
def __init__(self, data):
47-
seq, locs = zip(*data)
48-
self.seq = list(seq)
49-
self.locs = list(locs)
50-
51-
5243
class TestStrCategoryConverter(object):
5344
"""Based on the pandas conversion and factorization tests:
5445
5546
ref: /pandas/tseries/tests/test_converter.py
5647
/pandas/tests/test_algos.py:TestFactorize
5748
"""
58-
testdata = [("Здравствуйте мир", [("Здравствуйте мир", 42)], 42),
59-
("hello world", [("hello world", 42)], 42),
49+
testdata = [("Здравствуйте мир", {"Здравствуйте мир": 42}, 42),
50+
("hello world", {"hello world": 42}, 42),
6051
(['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c'],
61-
[('a', 0), ('b', 1), ('c', 2)],
52+
{'a': 0, 'b': 1, 'c': 2},
6253
[0, 1, 1, 0, 0, 2, 2, 2]),
63-
(['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf],
64-
[('nan', -1), ('3.14', 0), ('A', 1), ('B', 2),
65-
('-inf', 100), ('inf', 200)],
66-
[1, 1, -1, 2, 100, 0, 200])]
54+
(['A', 'A', 'B', 3.14],
55+
{'A': 1, 'B': 2},
56+
[1, 1, 2, 3.14])]
6757
ids = ["unicode", "single", "basic", "mixed"]
6858

6959
@pytest.fixture(autouse=True)
@@ -78,7 +68,7 @@ def test_convert(self, data, unitmap, exp):
7868
np.testing.assert_array_equal(act, exp)
7969

8070
def test_axisinfo(self):
81-
MUD = MockUnitData([(None, None)])
71+
MUD = MockUnitData({None: None})
8272
axis = FakeAxis(MUD)
8373
ax = self.cc.axisinfo(None, axis)
8474
assert isinstance(ax.majloc, cat.StrCategoryLocator)
@@ -91,8 +81,8 @@ def test_default_units(self):
9181

9282
class TestStrCategoryLocator(object):
9383
def test_StrCategoryLocator(self):
94-
locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
95-
ticks = cat.StrCategoryLocator(locs)
84+
locs = list(range(10))
85+
ticks = cat.StrCategoryLocator({str(x): x for x in locs})
9686
np.testing.assert_array_equal(ticks.tick_values(None, None), locs)
9787

9888

@@ -137,27 +127,18 @@ def data(self):
137127
self.d = ['a', 'b', 'c', 'a']
138128
self.dticks = [0, 1, 2]
139129
self.dlabels = ['a', 'b', 'c']
140-
unitmap = [('a', 0), ('b', 1), ('c', 2)]
130+
unitmap = {'a': 0, 'b': 1, 'c': 2}
141131
self.dunit_data = MockUnitData(unitmap)
142132

143-
@pytest.fixture
144-
def missing_data(self):
145-
self.dm = ['here', np.nan, 'here', 'there']
146-
self.dmticks = [0, -1, 1]
147-
self.dmlabels = ['here', 'nan', 'there']
148-
unitmap = [('here', 0), ('nan', -1), ('there', 1)]
149-
self.dmunit_data = MockUnitData(unitmap)
150-
151133
def axis_test(self, axis, ticks, labels, unit_data):
152134
np.testing.assert_array_equal(axis.get_majorticklocs(), ticks)
153135
assert lt(axis.get_majorticklabels()) == labels
154-
np.testing.assert_array_equal(axis.unit_data.locs, unit_data.locs)
155-
assert axis.unit_data.seq == unit_data.seq
136+
assert axis.unit_data._mapping == unit_data._mapping
156137

157138
def test_plot_unicode(self):
158139
words = ['Здравствуйте', 'привет']
159140
locs = [0.0, 1.0]
160-
unit_data = MockUnitData(zip(words, locs))
141+
unit_data = MockUnitData(dict(zip(words, locs)))
161142

162143
fig, ax = plt.subplots()
163144
ax.plot(words)
@@ -173,14 +154,6 @@ def test_plot_1d(self):
173154

174155
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)
175156

176-
@pytest.mark.usefixtures("missing_data")
177-
def test_plot_1d_missing(self):
178-
fig, ax = plt.subplots()
179-
ax.plot(self.dm)
180-
fig.canvas.draw()
181-
182-
self.axis_test(ax.yaxis, self.dmticks, self.dmlabels, self.dmunit_data)
183-
184157
@pytest.mark.usefixtures("data")
185158
@pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids)
186159
def test_plot_bytes(self, bars):
@@ -200,28 +173,9 @@ def test_plot_numlike(self, bars):
200173
ax.bar(bars, counts)
201174
fig.canvas.draw()
202175

203-
unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)])
176+
unitmap = MockUnitData({'1': 0, '11': 1, '3': 2})
204177
self.axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap)
205178

206-
@pytest.mark.usefixtures("data", "missing_data")
207-
def test_plot_2d(self):
208-
fig, ax = plt.subplots()
209-
ax.plot(self.dm, self.d)
210-
fig.canvas.draw()
211-
212-
self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data)
213-
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)
214-
215-
@pytest.mark.usefixtures("data", "missing_data")
216-
def test_scatter_2d(self):
217-
218-
fig, ax = plt.subplots()
219-
ax.scatter(self.dm, self.d)
220-
fig.canvas.draw()
221-
222-
self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data)
223-
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)
224-
225179
def test_plot_update(self):
226180
fig, ax = plt.subplots()
227181

@@ -232,6 +186,6 @@ def test_plot_update(self):
232186

233187
labels = ['a', 'b', 'd', 'c']
234188
ticks = [0, 1, 2, 3]
235-
unit_data = MockUnitData(list(zip(labels, ticks)))
189+
unit_data = MockUnitData(dict(zip(labels, ticks)))
236190

237191
self.axis_test(ax.yaxis, ticks, labels, unit_data)

0 commit comments

Comments
 (0)