Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8a96281

Browse files
committed
categorical axis support with np1.6 hack
1 parent b184842 commit 8a96281

4 files changed

Lines changed: 178 additions & 106 deletions

File tree

lib/matplotlib/axes/_axes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import matplotlib.collections as mcoll
2222
import matplotlib.colors as mcolors
2323
import matplotlib.contour as mcontour
24-
import matplotlib.category as _ # <-registers a date unit converter
24+
import matplotlib.category as _ # <-registers a category unit converter
2525
import matplotlib.dates as _ # <-registers a date unit converter
2626
from matplotlib import docstring
2727
import matplotlib.image as mimage

lib/matplotlib/category.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,50 @@
1+
# -*- coding: utf-8 OA-*-za
12
"""
23
catch all for categorical functions
34
"""
45
from __future__ import (absolute_import, division, print_function,
56
unicode_literals)
67

78
import six
9+
810
import numpy as np
911

1012
import matplotlib.units as units
1113
import matplotlib.ticker as ticker
1214

1315

16+
# pure hack for numpy 1.6 support
17+
from distutils.version import LooseVersion
18+
19+
NP_NEW = (LooseVersion(np.version.version) >= LooseVersion('1.7'))
20+
21+
22+
def to_array(data, maxlen=100):
23+
if NP_NEW:
24+
return np.array(data, dtype=np.unicode)
25+
try:
26+
vals = np.array(data, dtype=('|S', maxlen))
27+
except UnicodeEncodeError:
28+
# pure hack
29+
vals = np.array([convert_to_string(d) for d in data])
30+
return vals
31+
32+
1433
class StrCategoryConverter(units.ConversionInterface):
1534
@staticmethod
1635
def convert(value, unit, axis):
1736
"""Uses axis.unit_data map to encode
18-
data as integers
37+
data as floats
1938
"""
39+
vmap = dict(axis.unit_data)
2040

2141
if isinstance(value, six.string_types):
22-
return dict(axis.unit_data)[value]
42+
return vmap[value]
43+
44+
vals = to_array(value)
45+
for lab, loc in axis.unit_data:
46+
vals[vals == lab] = loc
2347

24-
vals = np.asarray(value, dtype='str')
25-
for label, loc in axis.unit_data:
26-
vals[vals == label] = loc
2748
return vals.astype('float')
2849

2950
@staticmethod
@@ -41,7 +62,36 @@ def default_units(data, axis):
4162
return None
4263

4364

44-
def map_categories(data, old_map=[], sort=True):
65+
class StrCategoryLocator(ticker.FixedLocator):
66+
def __init__(self, locs):
67+
super(StrCategoryLocator, self).__init__(locs, None)
68+
69+
70+
class StrCategoryFormatter(ticker.FixedFormatter):
71+
def __init__(self, seq):
72+
super(StrCategoryFormatter, self).__init__(seq)
73+
74+
75+
def convert_to_string(value):
76+
"""Helper function for numpy 1.6, can be replaced with
77+
np.array(...,dtype=unicode) for all later versions of numpy"""
78+
79+
if isinstance(value, six.string_types):
80+
return value
81+
if np.isfinite(value):
82+
value = np.asarray(value, dtype=str)[np.newaxis][0]
83+
elif np.isnan(value):
84+
value = 'nan'
85+
elif np.isposinf(value):
86+
value = 'inf'
87+
elif np.isneginf(value):
88+
value = '-inf'
89+
else:
90+
raise ValueError("Unconvertable {}".format(value))
91+
return value
92+
93+
94+
def map_categories(data, old_map=None):
4595
"""Create mapping between unique categorical
4696
values and numerical identifier.
4797
@@ -65,53 +115,37 @@ def map_categories(data, old_map=[], sort=True):
65115
# code typical missing data in the negative range because
66116
# everything else will always have positive encoding
67117
# question able if it even makes sense
68-
spdict = {'nan': -1, 'inf': -2, '-inf': -3}
69-
70-
# cast all data to str
71-
strdata = [str(d) for d in data]
118+
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
72119

73-
uniq = set(strdata)
120+
if isinstance(data, six.string_types):
121+
data = [data]
74122

75-
category_map = old_map.copy()
123+
# will update this post cbook/dict support
124+
strdata = to_array(data)
125+
uniq = np.unique(strdata)
76126

77127
if old_map:
78128
olabs, okeys = zip(*old_map)
79-
olabs, okeys = set(olabs), set(okeys)
80129
svalue = max(okeys) + 1
81130
else:
82-
olabs, okeys = set(), set()
131+
old_map, olabs, okeys = [], [], []
83132
svalue = 0
84133

85-
new_labs = (uniq - olabs)
134+
category_map = old_map[:]
135+
136+
new_labs = [u for u in uniq if u not in olabs]
137+
missing = [nl for nl in new_labs if nl in spdict.keys()]
86138

87-
missing = (new_labs & set(spdict.keys()))
88139
category_map.extend([(m, spdict[m]) for m in missing])
89140

90-
new_labs = (new_labs - missing)
91-
if sort:
92-
new_labs = list(new_labs)
93-
new_labs.sort()
141+
new_labs = [nl for nl in new_labs if nl not in missing]
94142

95-
new_locs = range(svalue, svalue + len(new_labs))
143+
new_locs = np.arange(svalue, svalue + len(new_labs), dtype='float')
96144
category_map.extend(list(zip(new_labs, new_locs)))
97145
return category_map
98146

99147

100-
class StrCategoryLocator(ticker.FixedLocator):
101-
def __init__(self, locs):
102-
super(StrCategoryLocator, self).__init__(locs, None)
103-
104-
105-
class StrCategoryFormatter(ticker.FixedFormatter):
106-
def __init__(self, seq):
107-
super(StrCategoryFormatter, self).__init__(seq)
108-
109-
110148
# Connects the convertor to matplotlib
111-
units.registry[bytearray] = StrCategoryConverter()
112149
units.registry[str] = StrCategoryConverter()
113-
114-
if six.PY3:
115-
units.registry[bytes] = StrCategoryConverter()
116-
elif six.PY2:
117-
units.registry[unicode] = StrCategoryConverter()
150+
units.registry[bytes] = StrCategoryConverter()
151+
units.registry[six.text_type] = StrCategoryConverter()

0 commit comments

Comments
 (0)