Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7eb0e8c

Browse files
committed
CategoricalNorm kinda sort works, relies on StrCategorical
1 parent 624991c commit 7eb0e8c

File tree

4 files changed

+66
-79
lines changed

4 files changed

+66
-79
lines changed

build_alllocal.cmd

Lines changed: 0 additions & 36 deletions
This file was deleted.

lib/matplotlib/category.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
NP_NEW = (LooseVersion(np.version.version) >= LooseVersion('1.7'))
2121

2222

23-
def to_array(data, maxlen=100):
23+
def to_str_array(data, maxlen=100):
2424
if NP_NEW:
2525
return np.array(data, dtype=np.unicode)
2626
if cbook.is_scalar_or_string(data):
@@ -53,13 +53,13 @@ def convert(value, unit, axis):
5353
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
5454

5555
if isinstance(value, six.string_types):
56-
return vmap[value]
56+
return vmap.get(value, None)
5757

58-
vals = to_array(value)
58+
vals = to_str_array(value)
5959
for lab, loc in vmap.items():
6060
vals[vals == lab] = loc
6161

62-
return vals.astype('float')
62+
return vals.astype('float64')
6363

6464
@staticmethod
6565
def axisinfo(unit, axis):
@@ -74,16 +74,20 @@ def axisinfo(unit, axis):
7474
return munits.AxisInfo(majloc=majloc, majfmt=majfmt)
7575

7676
@staticmethod
77-
def default_units(data, axis):
77+
def default_units(data, axis, sort=True):
7878
"""
7979
Create mapping between string categories in *data*
8080
and integers, then store in *axis.unit_data*
8181
"""
82-
if axis.unit_data is None:
83-
axis.unit_data = UnitData(data)
84-
else:
85-
axis.unit_data.update(data)
86-
return None
82+
83+
if axis and axis.unit_data:
84+
axis.unit_data.update(data, sort)
85+
return
86+
87+
unit_data = UnitData(data, sort)
88+
if axis:
89+
axis.unit_data = unit_data
90+
return unit_data
8791

8892

8993
class StrCategoryLocator(mticker.FixedLocator):
@@ -115,30 +119,26 @@ def __init__(self, categories):
115119
*categories*
116120
distinct values for mapping
117121
118-
Out-of-range values are mapped to a value not in categories;
119-
these are then converted to valid indices by :meth:`Colormap.__call__`.
122+
Out-of-range values are mapped to np.nan
120123
"""
121-
self.categories = categories
124+
125+
self.unit_data = StrCategoryConverter.default_units(categories,
126+
None, sort=False)
127+
self.categories = to_str_array(categories)
122128
self.N = len(self.categories)
123-
self.vmin = 0
124-
self.vmax = self.N
125-
self._interp = False
129+
self.nvals = self.unit_data.locs
130+
self.vmin = min(self.nvals)
131+
self.vmax = max(self.nvals)
126132

127133
def __call__(self, value, clip=None):
128-
if not cbook.iterable(value):
129-
value = [value]
130-
131-
value = np.asarray(value)
132-
ret = np.ones(value.shape) * np.nan
133-
134-
for i, c in enumerate(self.categories):
135-
ret[value == c] = i / (self.N * 1.0)
136-
137-
return np.ma.array(ret, mask=np.isnan(ret))
138-
139-
def inverse(self, value):
140-
# not quite sure what invertible means in this context
141-
return ValueError("CategoryNorm is not invertible")
134+
# gonna have to go into imshow and undo casting
135+
value = np.asarray(value, dtype=int)
136+
ret = StrCategoryConverter.convert(value, None, self)
137+
# knock out values not in the norm
138+
mask = np.in1d(ret, self.unit_data.locs).reshape(ret.shape)
139+
# normalize ret
140+
ret /= self.vmax
141+
return np.ma.array(ret, mask=~mask)
142142

143143

144144
def colors_from_categories(codings):
@@ -187,27 +187,40 @@ class UnitData(object):
187187
# debatable makes sense to special code missing values
188188
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
189189

190-
def __init__(self, data):
190+
def __init__(self, data, sort=True):
191191
"""Create mapping between unique categorical values
192192
and numerical identifier
193193
Paramters
194194
---------
195195
data: iterable
196196
sequence of values
197+
sort: bool
198+
sort input data, default is True
199+
False preserves input order
197200
"""
198201
self.seq, self.locs = [], []
199-
self._set_seq_locs(data, 0)
202+
self._set_seq_locs(data, 0, sort)
203+
self.sort = sort
200204

201-
def update(self, new_data):
205+
def update(self, new_data, sort=None):
206+
if sort:
207+
self.sort = sort
202208
# so as not to conflict with spdict
203209
value = max(max(self.locs) + 1, 0)
204-
self._set_seq_locs(new_data, value)
210+
self._set_seq_locs(new_data, value, self.sort)
205211

206-
def _set_seq_locs(self, data, value):
212+
def _set_seq_locs(self, data, value, sort):
207213
# magic to make it work under np1.6
208-
strdata = to_array(data)
214+
strdata = to_str_array(data)
215+
209216
# np.unique makes dateframes work
210-
new_s = [d for d in np.unique(strdata) if d not in self.seq]
217+
if sort:
218+
unq = np.unique(strdata)
219+
else:
220+
_, idx = np.unique(strdata, return_index=~sort)
221+
unq = strdata[np.sort(idx)]
222+
223+
new_s = [d for d in unq if d not in self.seq]
211224
for ns in new_s:
212225
self.seq.append(convert_to_string(ns))
213226
if ns in UnitData.spdict.keys():

lib/matplotlib/colorbar.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import matplotlib as mpl
3232
import matplotlib.artist as martist
33+
import matplotlib.category as category
3334
import matplotlib.cbook as cbook
3435
import matplotlib.collections as collections
3536
import matplotlib.colors as colors
@@ -580,6 +581,8 @@ def _ticker(self):
580581
locator = ticker.FixedLocator(b, nbins=10)
581582
elif isinstance(self.norm, colors.LogNorm):
582583
locator = ticker.LogLocator()
584+
elif isinstance(self.norm, category.CategoryNorm):
585+
locator = ticker.FixedLocator(self.norm.nvals + 0.5)
583586
else:
584587
if mpl.rcParams['_internal.classic_mode']:
585588
locator = ticker.MaxNLocator()

lib/matplotlib/tests/test_category.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_axisinfo(self):
106106

107107
def test_default_units(self):
108108
axis = FakeAxis(None)
109-
assert self.cc.default_units(["a"], axis) is None
109+
assert isinstance(self.cc.default_units(["a"], axis), cat.UnitData)
110110

111111

112112
class TestStrCategoryLocator(object):
@@ -129,17 +129,24 @@ def test_StrCategoryFormatterUnicode(self):
129129

130130

131131
class TestCategoryNorm(object):
132-
testdata = [[[205, 302, 205, 101], [0, 2. / 3., 0, 1. / 3.]],
133-
[[205, np.nan, 101, 305], [0, 9999, 1. / 3., 2. / 3.]],
134-
[[205, 101, 504, 101], [0, 9999, 1. / 3., 1. / 3.]]]
132+
testdata = [[[205, 302, 205, 101], [0, 2, 0, 1]],
133+
[[205, np.nan, 101, 305], [0, np.nan, 1, 2]],
134+
[[205, 101, 504, 101], [0, 1, np.nan, 1]]]
135135

136136
ids = ["regular", "nan", "exclude"]
137137

138138
@pytest.mark.parametrize("data, nmap", testdata, ids=ids)
139139
def test_norm(self, data, nmap):
140140
norm = cat.CategoryNorm([205, 101, 302])
141-
test = np.ma.masked_equal(nmap, 9999)
142-
np.testing.assert_allclose(norm(data), test)
141+
np.testing.assert_array_equal(norm(data), nmap)
142+
143+
144+
def test_colors_from_categories():
145+
codings = {205: "red", 101: "blue", 302: "green"}
146+
cmap, norm = cat.colors_from_categories(codings)
147+
assert cmap.colors == ['red', 'green', 'blue']
148+
np.testing.assert_array_equal(norm.categories, ['205', '302', '101'])
149+
assert cmap.N == norm.N
143150

144151

145152
def lt(tl):

0 commit comments

Comments
 (0)