Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3754dfc

Browse files
committed
ma.allequal testing
1 parent 5ff0dbf commit 3754dfc

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

lib/matplotlib/category.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
NP_NEW = (LooseVersion(np.version.version) >= LooseVersion('1.7'))
2121

2222

23-
def to_str_array(data, maxlen=100):
23+
def to_array(data, maxlen=100):
2424
if NP_NEW:
25-
return np.array(data, dtype=np.unicode)
25+
return np.array(data)
26+
# dtype=np.unicode)
2627
if cbook.is_scalar_or_string(data):
2728
data = [data]
2829
try:
@@ -38,12 +39,11 @@ class StrCategoryConverter(munits.ConversionInterface):
3839
3940
Conversion typically happens in the following order:
4041
1. default_units:
41-
creates unit_data category-integer mapping and binds to axis
42+
create unit_data category-integer mapping and binds to axis
4243
2. axis_info:
43-
sets ticks/locator and label/formatter
44+
set ticks/locator and labels/formatter
4445
3. convert:
45-
maps input category data to integers using unit_data
46-
46+
map input category data to integers using unit_data
4747
"""
4848
@staticmethod
4949
def convert(value, unit, axis):
@@ -55,7 +55,7 @@ def convert(value, unit, axis):
5555
if isinstance(value, six.string_types):
5656
return vmap.get(value, None)
5757

58-
vals = to_str_array(value)
58+
vals = to_array(value)
5959
for lab, loc in vmap.items():
6060
vals[vals == lab] = loc
6161

@@ -125,7 +125,7 @@ def __init__(self, data):
125125
self.unit_data = None
126126
self.units.default_units(data,
127127
self, sort=False)
128-
128+
self.loc2seq = dict(zip(self.unit_data.locs, self.unit_data.seq))
129129
self.vmin = min(self.unit_data.locs)
130130
self.vmax = max(self.unit_data.locs)
131131

@@ -139,6 +139,12 @@ def __call__(self, value, clip=None):
139139
ret /= self.vmax
140140
return np.ma.array(ret, mask=~mask)
141141

142+
def inverse(self, value):
143+
if not cbook.iterable(value):
144+
value = np.asarray(value)
145+
vscaled = np.asarray(value) * self.vmax
146+
return [self.loc2seq[int(vs)] for vs in vscaled]
147+
142148

143149
def colors_from_categories(codings):
144150
"""
@@ -209,7 +215,7 @@ def update(self, new_data, sort=True):
209215

210216
def _set_seq_locs(self, data, value, sort):
211217
# magic to make it work under np1.6
212-
strdata = to_str_array(data)
218+
strdata = to_array(data)
213219

214220
# np.unique makes dateframes work
215221
if sort:

lib/matplotlib/colorbar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ def __init__(self, ax, cmap=None,
313313
if format is None:
314314
if isinstance(self.norm, colors.LogNorm):
315315
self.formatter = ticker.LogFormatterMathtext()
316+
elif isinstance(self.norm, category.CategoryNorm):
317+
self.formatter = ticker.FixedFormatter(self.norm.unit_data.seq)
316318
else:
317319
self.formatter = ticker.ScalarFormatter()
318320
elif cbook.is_string_like(format):
@@ -582,7 +584,7 @@ def _ticker(self):
582584
elif isinstance(self.norm, colors.LogNorm):
583585
locator = ticker.LogLocator()
584586
elif isinstance(self.norm, category.CategoryNorm):
585-
locator = ticker.FixedLocator(self.norm.nvals + 0.5)
587+
locator = ticker.FixedLocator(self.norm.unit_data.locs)
586588
else:
587589
if mpl.rcParams['_internal.classic_mode']:
588590
locator = ticker.MaxNLocator()

lib/matplotlib/tests/test_category.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,15 @@ class TestCategoryNorm(object):
138138
@pytest.mark.parametrize("data, nmap", testdata, ids=ids)
139139
def test_norm(self, data, nmap):
140140
norm = cat.CategoryNorm([205, 101, 302])
141-
np.testing.assert_allclose(norm(data), nmap)
141+
masked_nmap = np.ma.masked_equal(nmap, np.nan)
142+
assert np.ma.allequal(norm(data), masked_nmap)
143+
144+
def test_invert(self):
145+
data = [205, 302, 101]
146+
strdata = ['205', '302', '101']
147+
value = [0, .5, 1]
148+
norm = cat.CategoryNorm(data)
149+
assert norm.inverse(value) == strdata
142150

143151

144152
class TestColorsFromCategories(object):

0 commit comments

Comments
 (0)