20
20
NP_NEW = (LooseVersion (np .version .version ) >= LooseVersion ('1.7' ))
21
21
22
22
23
- def to_array (data , maxlen = 100 ):
23
+ def to_str_array (data , maxlen = 100 ):
24
24
if NP_NEW :
25
25
return np .array (data , dtype = np .unicode )
26
26
if cbook .is_scalar_or_string (data ):
@@ -53,13 +53,13 @@ def convert(value, unit, axis):
53
53
vmap = dict (zip (axis .unit_data .seq , axis .unit_data .locs ))
54
54
55
55
if isinstance (value , six .string_types ):
56
- return vmap [ value ]
56
+ return vmap . get ( value , None )
57
57
58
- vals = to_array (value )
58
+ vals = to_str_array (value )
59
59
for lab , loc in vmap .items ():
60
60
vals [vals == lab ] = loc
61
61
62
- return vals .astype ('float ' )
62
+ return vals .astype ('float64 ' )
63
63
64
64
@staticmethod
65
65
def axisinfo (unit , axis ):
@@ -74,16 +74,20 @@ def axisinfo(unit, axis):
74
74
return munits .AxisInfo (majloc = majloc , majfmt = majfmt )
75
75
76
76
@staticmethod
77
- def default_units (data , axis ):
77
+ def default_units (data , axis , sort = True ):
78
78
"""
79
79
Create mapping between string categories in *data*
80
80
and integers, then store in *axis.unit_data*
81
81
"""
82
- if axis .unit_data is None :
83
- axis .unit_data = UnitData (data )
84
- else :
85
- axis .unit_data .update (data )
86
- return None
82
+
83
+ if axis and axis .unit_data :
84
+ axis .unit_data .update (data , sort )
85
+ return
86
+
87
+ unit_data = UnitData (data , sort )
88
+ if axis :
89
+ axis .unit_data = unit_data
90
+ return unit_data
87
91
88
92
89
93
class StrCategoryLocator (mticker .FixedLocator ):
@@ -115,30 +119,26 @@ def __init__(self, categories):
115
119
*categories*
116
120
distinct values for mapping
117
121
118
- Out-of-range values are mapped to a value not in categories;
119
- these are then converted to valid indices by :meth:`Colormap.__call__`.
122
+ Out-of-range values are mapped to np.nan
120
123
"""
121
- self .categories = categories
124
+
125
+ self .unit_data = StrCategoryConverter .default_units (categories ,
126
+ None , sort = False )
127
+ self .categories = to_str_array (categories )
122
128
self .N = len (self .categories )
123
- self .vmin = 0
124
- self .vmax = self .N
125
- self ._interp = False
129
+ self .nvals = self . unit_data . locs
130
+ self .vmin = min ( self .nvals )
131
+ self .vmax = max ( self . nvals )
126
132
127
133
def __call__ (self , value , clip = None ):
128
- if not cbook .iterable (value ):
129
- value = [value ]
130
-
131
- value = np .asarray (value )
132
- ret = np .ones (value .shape ) * np .nan
133
-
134
- for i , c in enumerate (self .categories ):
135
- ret [value == c ] = i / (self .N * 1.0 )
136
-
137
- return np .ma .array (ret , mask = np .isnan (ret ))
138
-
139
- def inverse (self , value ):
140
- # not quite sure what invertible means in this context
141
- return ValueError ("CategoryNorm is not invertible" )
134
+ # gonna have to go into imshow and undo casting
135
+ value = np .asarray (value , dtype = int )
136
+ ret = StrCategoryConverter .convert (value , None , self )
137
+ # knock out values not in the norm
138
+ mask = np .in1d (ret , self .unit_data .locs ).reshape (ret .shape )
139
+ # normalize ret
140
+ ret /= self .vmax
141
+ return np .ma .array (ret , mask = ~ mask )
142
142
143
143
144
144
def colors_from_categories (codings ):
@@ -187,27 +187,40 @@ class UnitData(object):
187
187
# debatable makes sense to special code missing values
188
188
spdict = {'nan' : - 1.0 , 'inf' : - 2.0 , '-inf' : - 3.0 }
189
189
190
- def __init__ (self , data ):
190
+ def __init__ (self , data , sort = True ):
191
191
"""Create mapping between unique categorical values
192
192
and numerical identifier
193
193
Paramters
194
194
---------
195
195
data: iterable
196
196
sequence of values
197
+ sort: bool
198
+ sort input data, default is True
199
+ False preserves input order
197
200
"""
198
201
self .seq , self .locs = [], []
199
- self ._set_seq_locs (data , 0 )
202
+ self ._set_seq_locs (data , 0 , sort )
203
+ self .sort = sort
200
204
201
- def update (self , new_data ):
205
+ def update (self , new_data , sort = None ):
206
+ if sort :
207
+ self .sort = sort
202
208
# so as not to conflict with spdict
203
209
value = max (max (self .locs ) + 1 , 0 )
204
- self ._set_seq_locs (new_data , value )
210
+ self ._set_seq_locs (new_data , value , self . sort )
205
211
206
- def _set_seq_locs (self , data , value ):
212
+ def _set_seq_locs (self , data , value , sort ):
207
213
# magic to make it work under np1.6
208
- strdata = to_array (data )
214
+ strdata = to_str_array (data )
215
+
209
216
# np.unique makes dateframes work
210
- new_s = [d for d in np .unique (strdata ) if d not in self .seq ]
217
+ if sort :
218
+ unq = np .unique (strdata )
219
+ else :
220
+ _ , idx = np .unique (strdata , return_index = ~ sort )
221
+ unq = strdata [np .sort (idx )]
222
+
223
+ new_s = [d for d in unq if d not in self .seq ]
211
224
for ns in new_s :
212
225
self .seq .append (convert_to_string (ns ))
213
226
if ns in UnitData .spdict .keys ():
0 commit comments