Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cd31afd

Browse files
committed
API: restore support for bytes
1 parent 98012f3 commit cd31afd

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

lib/matplotlib/category.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@ class StrCategoryConverter(units.ConversionInterface):
1717
@staticmethod
1818
def convert(value, unit, axis):
1919
"""Use axis.units mapping to map categorical data to floats."""
20-
20+
def getter(k):
21+
if not isinstance(k, six.text_type):
22+
k = k.decode('utf-8')
23+
return axis.units._mapping[k]
2124
# We also need to pass numbers through.
2225
if np.issubdtype(np.asarray(value).dtype.type, np.number):
2326
return value
2427
else:
2528
axis.units.update(value)
26-
str2idx = np.vectorize(axis.units._mapping.__getitem__,
27-
otypes=[float])
29+
str2idx = np.vectorize(getter, otypes=[float])
2830
return str2idx(value)
2931

3032
@staticmethod
@@ -61,6 +63,9 @@ def __call__(self, x, pos=None):
6163

6264

6365
class UnitData(object):
66+
valid_types = tuple(set(six.string_types +
67+
(bytes, six.text_type, np.str_, np.bytes_)))
68+
6469
def __init__(self, data=None):
6570
"""Create mapping between unique categorical values and numerical id.
6671
@@ -73,10 +78,12 @@ def __init__(self, data=None):
7378
self._vals = []
7479
if data is None:
7580
data = ()
76-
self._mapping = OrderedDict(data)
77-
for k, v in self._mapping.items():
78-
if not isinstance(k, six.text_type):
81+
self._mapping = OrderedDict()
82+
for k, v in OrderedDict(data).items():
83+
if not isinstance(k, self.valid_types):
7984
raise TypeError("{val!r} is not a string".format(val=k))
85+
if not isinstance(k, six.text_type):
86+
k = k.decode('utf-8')
8087
self._mapping[k] = int(v)
8188
if self._mapping:
8289
start = max(self._mapping.values()) + 1
@@ -85,19 +92,22 @@ def __init__(self, data=None):
8592
self._counter = itertools.count(start=start)
8693

8794
def update(self, data):
88-
if isinstance(data, six.string_types):
95+
if isinstance(data, self.valid_types):
8996
data = [data]
9097
sorted_unique = OrderedDict.fromkeys(data)
9198
for val in sorted_unique:
99+
if not isinstance(val, self.valid_types):
100+
raise TypeError("{val!r} is not a string".format(val=val))
101+
if not isinstance(val, six.text_type):
102+
val = val.decode('utf-8')
92103
if val in self._mapping:
93104
continue
94-
if not isinstance(val, six.text_type):
95-
raise TypeError("{val!r} is not a string".format(val=val))
96105
self._vals.append(val)
97106
self._mapping[val] = next(self._counter)
98107

99108

100109
# Connects the convertor to matplotlib
110+
101111
units.registry[str] = StrCategoryConverter()
102112
units.registry[bytes] = StrCategoryConverter()
103113
units.registry[np.str_] = StrCategoryConverter()

lib/matplotlib/tests/test_category.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class TestStrCategoryConverter(object):
6161
test_cases = [("unicode", {u"Здравствуйте мир": 42}),
6262
("ascii", {"hello world": 42}),
6363
("single", {'a': 0, 'b': 1, 'c': 2}),
64+
("single bytes", {b'a': 0, b'b': 1, b'c': 2}),
65+
("mixed bytes", {b'a': 0, 'b': 1, b'c': 2}),
6466
("single + values>10", {'A': 0, 'B': 1, 'C': 2,
6567
'D': 3, 'E': 4, 'F': 5,
6668
'G': 6, 'H': 7, 'I': 8,
@@ -111,11 +113,11 @@ def test_StrCategoryFormatter(self):
111113
assert labels(1, 1) == "world"
112114

113115
def test_StrCategoryFormatterUnicode(self):
114-
seq = ["Здравствуйте", "привет"]
116+
seq = [u"Здравствуйте", u"привет"]
115117
u = cat.UnitData()
116118
u.update(seq)
117119
labels = cat.StrCategoryFormatter(u)
118-
assert labels(1, 1) == "привет"
120+
assert labels(1, 1) == u"привет"
119121

120122

121123
def lt(tl):
@@ -130,6 +132,8 @@ def axis_test(axis, ticks, labels, unit_data):
130132

131133
class TestBarsBytes(object):
132134
bytes_cases = [('string list', ['a', 'b', 'c']),
135+
('bytes list', [b'a', b'b', b'c']),
136+
('mixed list', [b'a', 'b', b'c']),
133137
]
134138

135139
bytes_ids, bytes_data = zip(*bytes_cases)
@@ -232,6 +236,7 @@ def ax():
232236
[([u"Здравствуйте мир"], [0], [u"Здравствуйте мир"]),
233237
(["a", "b", "b", "a", "c", "c"], [0, 1, 1, 0, 2, 2], ["a", "b", "c"]),
234238
(["foo", "bar"], range(2), ["foo", "bar"]),
239+
([b"foo", "bar"], range(2), ["foo", "bar"]),
235240
(np.array(["1", "11", "3"]), range(3), ["1", "11", "3"])])
236241
def test_simple(ax, data, expected_indices, expected_labels):
237242
l, = ax.plot(data)

0 commit comments

Comments
 (0)