1+ # -*- coding: utf-8 OA-*-za
12"""
23catch all for categorical functions
34"""
45from __future__ import (absolute_import , division , print_function ,
56 unicode_literals )
67
78import six
9+
810import numpy as np
911
1012import matplotlib .units as units
1113import matplotlib .ticker as ticker
1214
1315
16+ # pure hack for numpy 1.6 support
17+ from distutils .version import LooseVersion
18+
19+ NP_NEW = (LooseVersion (np .version .version ) >= LooseVersion ('1.7' ))
20+
21+
22+ def to_array (data , maxlen = 100 ):
23+ if NP_NEW :
24+ return np .array (data , dtype = np .unicode )
25+ try :
26+ vals = np .array (data , dtype = ('|S' , maxlen ))
27+ except UnicodeEncodeError :
28+ # pure hack
29+ vals = np .array ([convert_to_string (d ) for d in data ])
30+ return vals
31+
32+
1433class StrCategoryConverter (units .ConversionInterface ):
1534 @staticmethod
1635 def convert (value , unit , axis ):
1736 """Uses axis.unit_data map to encode
18- data as integers
37+ data as floats
1938 """
39+ vmap = dict (axis .unit_data )
2040
2141 if isinstance (value , six .string_types ):
22- return dict (axis .unit_data )[value ]
42+ return vmap [value ]
43+
44+ vals = to_array (value )
45+ for lab , loc in axis .unit_data :
46+ vals [vals == lab ] = loc
2347
24- vals = np .asarray (value , dtype = 'str' )
25- for label , loc in axis .unit_data :
26- vals [vals == label ] = loc
2748 return vals .astype ('float' )
2849
2950 @staticmethod
@@ -41,7 +62,36 @@ def default_units(data, axis):
4162 return None
4263
4364
44- def map_categories (data , old_map = [], sort = True ):
65+ class StrCategoryLocator (ticker .FixedLocator ):
66+ def __init__ (self , locs ):
67+ super (StrCategoryLocator , self ).__init__ (locs , None )
68+
69+
70+ class StrCategoryFormatter (ticker .FixedFormatter ):
71+ def __init__ (self , seq ):
72+ super (StrCategoryFormatter , self ).__init__ (seq )
73+
74+
75+ def convert_to_string (value ):
76+ """Helper function for numpy 1.6, can be replaced with
77+ np.array(...,dtype=unicode) for all later versions of numpy"""
78+
79+ if isinstance (value , six .string_types ):
80+ return value
81+ if np .isfinite (value ):
82+ value = np .asarray (value , dtype = str )[np .newaxis ][0 ]
83+ elif np .isnan (value ):
84+ value = 'nan'
85+ elif np .isposinf (value ):
86+ value = 'inf'
87+ elif np .isneginf (value ):
88+ value = '-inf'
89+ else :
90+ raise ValueError ("Unconvertable {}" .format (value ))
91+ return value
92+
93+
94+ def map_categories (data , old_map = None ):
4595 """Create mapping between unique categorical
4696 values and numerical identifier.
4797
@@ -65,53 +115,37 @@ def map_categories(data, old_map=[], sort=True):
65115 # code typical missing data in the negative range because
66116 # everything else will always have positive encoding
67117 # question able if it even makes sense
68- spdict = {'nan' : - 1 , 'inf' : - 2 , '-inf' : - 3 }
69-
70- # cast all data to str
71- strdata = [str (d ) for d in data ]
118+ spdict = {'nan' : - 1.0 , 'inf' : - 2.0 , '-inf' : - 3.0 }
72119
73- uniq = set (strdata )
120+ if isinstance (data , six .string_types ):
121+ data = [data ]
74122
75- category_map = old_map .copy ()
123+ # will update this post cbook/dict support
124+ strdata = to_array (data )
125+ uniq = np .unique (strdata )
76126
77127 if old_map :
78128 olabs , okeys = zip (* old_map )
79- olabs , okeys = set (olabs ), set (okeys )
80129 svalue = max (okeys ) + 1
81130 else :
82- olabs , okeys = set (), set ()
131+ old_map , olabs , okeys = [], [], []
83132 svalue = 0
84133
85- new_labs = (uniq - olabs )
134+ category_map = old_map [:]
135+
136+ new_labs = [u for u in uniq if u not in olabs ]
137+ missing = [nl for nl in new_labs if nl in spdict .keys ()]
86138
87- missing = (new_labs & set (spdict .keys ()))
88139 category_map .extend ([(m , spdict [m ]) for m in missing ])
89140
90- new_labs = (new_labs - missing )
91- if sort :
92- new_labs = list (new_labs )
93- new_labs .sort ()
141+ new_labs = [nl for nl in new_labs if nl not in missing ]
94142
95- new_locs = range (svalue , svalue + len (new_labs ))
143+ new_locs = np . arange (svalue , svalue + len (new_labs ), dtype = 'float' )
96144 category_map .extend (list (zip (new_labs , new_locs )))
97145 return category_map
98146
99147
100- class StrCategoryLocator (ticker .FixedLocator ):
101- def __init__ (self , locs ):
102- super (StrCategoryLocator , self ).__init__ (locs , None )
103-
104-
105- class StrCategoryFormatter (ticker .FixedFormatter ):
106- def __init__ (self , seq ):
107- super (StrCategoryFormatter , self ).__init__ (seq )
108-
109-
110148# Connects the convertor to matplotlib
111- units .registry [bytearray ] = StrCategoryConverter ()
112149units .registry [str ] = StrCategoryConverter ()
113-
114- if six .PY3 :
115- units .registry [bytes ] = StrCategoryConverter ()
116- elif six .PY2 :
117- units .registry [unicode ] = StrCategoryConverter ()
150+ units .registry [bytes ] = StrCategoryConverter ()
151+ units .registry [six .text_type ] = StrCategoryConverter ()
0 commit comments