@@ -74,15 +74,14 @@ def axisinfo(unit, axis):
74
74
return munits .AxisInfo (majloc = majloc , majfmt = majfmt )
75
75
76
76
@staticmethod
77
- def default_units (data , axis , sort = True ):
77
+ def default_units (data , axis , sort = True , normed = False ):
78
78
"""
79
79
Create mapping between string categories in *data*
80
- and integers, then store in *axis.unit_data*
80
+ and integers, and store in *axis.unit_data*
81
81
"""
82
-
83
82
if axis and axis .unit_data :
84
83
axis .unit_data .update (data , sort )
85
- return
84
+ return axis . unit_data
86
85
87
86
unit_data = UnitData (data , sort )
88
87
if axis :
@@ -122,22 +121,29 @@ def __init__(self, categories):
122
121
Out-of-range values are mapped to np.nan
123
122
"""
124
123
125
- self .unit_data = StrCategoryConverter .default_units (categories ,
126
- None , sort = False )
127
- self .categories = to_str_array (categories )
128
- self .N = len (self .categories )
129
- self .nvals = self .unit_data .locs
130
- self .vmin = min (self .nvals )
131
- self .vmax = max (self .nvals )
124
+ # facilitates cleaner DuckTyping of axis interface
125
+
126
+ class CatAxis (object ):
127
+ def __init__ (self ):
128
+ self .unit_data = None
129
+ self .units = StrCategoryConverter ()
130
+
131
+ self .axis = CatAxis ()
132
+ self .axis .units .default_units (categories , self .axis ,
133
+ sort = False )
134
+
135
+ nvals = self .axis .unit_data .locs
136
+ self .vmin = min (nvals )
137
+ self .vmax = max (nvals )
132
138
133
139
def __call__ (self , value , clip = None ):
134
140
# gonna have to go into imshow and undo casting
135
- value = np .asarray (value , dtype = int )
136
- ret = StrCategoryConverter . convert (value , None , self )
141
+ value = np .asarray (value , dtype = np . int )
142
+ ret = self . axis . units . convert (value , None , self . axis )
137
143
# knock out values not in the norm
138
- mask = np .in1d (ret , self .unit_data .locs ).reshape (ret .shape )
139
- # normalize ret
140
- ret /= self .vmax
144
+ mask = np .in1d (ret , self .axis . unit_data .locs ).reshape (ret .shape )
145
+ # normalize ret & locs
146
+ ret /= self .vmax
141
147
return np .ma .array (ret , mask = ~ mask )
142
148
143
149
@@ -184,7 +190,7 @@ def convert_to_string(value):
184
190
185
191
186
192
class UnitData (object ):
187
- # debatable makes sense to special code missing values
193
+ # debatable if it makes sense to special code missing values
188
194
spdict = {'nan' : - 1.0 , 'inf' : - 2.0 , '-inf' : - 3.0 }
189
195
190
196
def __init__ (self , data , sort = True ):
@@ -202,7 +208,7 @@ def __init__(self, data, sort=True):
202
208
self ._set_seq_locs (data , 0 , sort )
203
209
self .sort = sort
204
210
205
- def update (self , new_data , sort = None ):
211
+ def update (self , new_data , sort = True ):
206
212
if sort :
207
213
self .sort = sort
208
214
# so as not to conflict with spdict
0 commit comments