Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Don't sort categorical keys. #9318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 42 additions & 71 deletions lib/matplotlib/category.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,40 @@
# -*- coding: utf-8 OA-*-za
"""
catch all for categorical functions
"""Helpers for categorical data.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import six

import numpy as np
from collections import OrderedDict
import itertools

import matplotlib.units as units
import matplotlib.ticker as ticker
import numpy as np

# np 1.6/1.7 support
from distutils.version import LooseVersion
import collections
from matplotlib import units, ticker


if LooseVersion(np.__version__) >= LooseVersion('1.8.0'):
def shim_array(data):
return np.array(data, dtype=np.unicode)
else:
def shim_array(data):
if (isinstance(data, six.string_types) or
not isinstance(data, collections.Iterable)):
data = [data]
try:
data = [str(d) for d in data]
except UnicodeEncodeError:
# this yields gibberish but unicode text doesn't
# render under numpy1.6 anyway
data = [d.encode('utf-8', 'ignore').decode('utf-8')
for d in data]
return np.array(data, dtype=np.unicode)
def _to_str(s):
return s.decode("ascii") if isinstance(s, bytes) else str(s)


class StrCategoryConverter(units.ConversionInterface):
@staticmethod
def convert(value, unit, axis):
"""Uses axis.unit_data map to encode
data as floats
"""
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))

if isinstance(value, six.string_types):
return vmap[value]

vals = shim_array(value)

for lab, loc in vmap.items():
vals[vals == lab] = loc

return vals.astype('float')
"""Uses axis.unit_data map to encode data as floats."""
mapping = axis.unit_data._mapping
return (mapping[_to_str(value)] if np.isscalar(value)
else np.array([mapping[_to_str(v)] for v in value], float))

@staticmethod
def axisinfo(unit, axis):
majloc = StrCategoryLocator(axis.unit_data.locs)
majfmt = StrCategoryFormatter(axis.unit_data.seq)
# Note that mapping may get mutated by later calls to plotting methods,
# so the locator and formatter must dynamically recompute locs and seq.
majloc = StrCategoryLocator(axis.unit_data._mapping)
majfmt = StrCategoryFormatter(axis.unit_data._mapping)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)

@staticmethod
def default_units(data, axis):
# the conversion call stack is:
# default_units->axis_info->convert
# the conversion call stack is default_units->axis_info->convert
if axis.unit_data is None:
axis.unit_data = UnitData(data)
else:
Expand All @@ -70,48 +43,46 @@ def default_units(data, axis):


class StrCategoryLocator(ticker.FixedLocator):
def __init__(self, locs):
self.locs = locs
def __init__(self, mapping):
self._mapping = mapping
self.nbins = None

@property
def locs(self):
return list(self._mapping.values())


class StrCategoryFormatter(ticker.FixedFormatter):
def __init__(self, seq):
self.seq = seq
self.offset_string = ''
def __init__(self, mapping):
self._mapping = mapping
self.offset_string = ""

@property
def seq(self):
return list(self._mapping)

class UnitData(object):
# debatable makes sense to special code missing values
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}

class UnitData(object):
def __init__(self, data):
"""Create mapping between unique categorical values
and numerical identifier
"""Create mapping between unique categorical values and numerical id.

Parameters
----------
data: iterable
sequence of values
"""
self.seq, self.locs = [], []
self._set_seq_locs(data, 0)

def update(self, new_data):
# so as not to conflict with spdict
value = max(max(self.locs) + 1, 0)
self._set_seq_locs(new_data, value)

def _set_seq_locs(self, data, value):
strdata = shim_array(data)
new_s = [d for d in np.unique(strdata) if d not in self.seq]
for ns in new_s:
self.seq.append(ns)
if ns in UnitData.spdict:
self.locs.append(UnitData.spdict[ns])
else:
self.locs.append(value)
value += 1
self._mapping = {}
self._counter = itertools.count()
self.update(data)

def update(self, data):
if isinstance(data, six.string_types):
data = [data]
sorted_unique = OrderedDict.fromkeys(map(_to_str, data))
for s in sorted_unique:
if s in self._mapping:
continue
self._mapping[s] = next(self._counter)


# Connects the convertor to matplotlib
Expand Down
108 changes: 27 additions & 81 deletions lib/matplotlib/tests/test_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,65 +13,47 @@


class TestUnitData(object):
testdata = [("hello world", ["hello world"], [0]),
("Здравствуйте мир", ["Здравствуйте мир"], [0]),
testdata = [("hello world", {"hello world": 0}),
("Здравствуйте мир", {"Здравствуйте мир": 0}),
(['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf],
['-inf', '3.14', 'A', 'B', 'inf', 'nan'],
[-3.0, 0, 1, 2, -2.0, -1.0])]

{'A': 0, 'nan': 1, 'B': 2, '-inf': 3, '3.14': 4, 'inf': 5})]
ids = ["single", "unicode", "mixed"]

@pytest.mark.parametrize("data, seq, locs", testdata, ids=ids)
def test_unit(self, data, seq, locs):
act = cat.UnitData(data)
assert act.seq == seq
assert act.locs == locs
@pytest.mark.parametrize("data, mapping", testdata, ids=ids)
def test_unit(self, data, mapping):
assert cat.UnitData(data)._mapping == mapping

def test_update_map(self):
data = ['a', 'd']
oseq = ['a', 'd']
olocs = [0, 1]

data_update = ['b', 'd', 'e', np.inf]
useq = ['a', 'd', 'b', 'e', 'inf']
ulocs = [0, 1, 2, 3, -2]
unitdata = cat.UnitData(['a', 'd'])
assert unitdata._mapping == {'a': 0, 'd': 1}
unitdata.update(['b', 'd', 'e'])
assert unitdata._mapping == {'a': 0, 'd': 1, 'b': 2, 'e': 3}

unitdata = cat.UnitData(data)
assert unitdata.seq == oseq
assert unitdata.locs == olocs

unitdata.update(data_update)
assert unitdata.seq == useq
assert unitdata.locs == ulocs
class MockUnitData:
def __init__(self, mapping):
self._mapping = mapping


class FakeAxis(object):
def __init__(self, unit_data):
self.unit_data = unit_data


class MockUnitData(object):
def __init__(self, data):
seq, locs = zip(*data)
self.seq = list(seq)
self.locs = list(locs)


class TestStrCategoryConverter(object):
"""Based on the pandas conversion and factorization tests:

ref: /pandas/tseries/tests/test_converter.py
/pandas/tests/test_algos.py:TestFactorize
"""
testdata = [("Здравствуйте мир", [("Здравствуйте мир", 42)], 42),
("hello world", [("hello world", 42)], 42),
testdata = [("Здравствуйте мир", {"Здравствуйте мир": 42}, 42),
("hello world", {"hello world": 42}, 42),
(['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c'],
[('a', 0), ('b', 1), ('c', 2)],
{'a': 0, 'b': 1, 'c': 2},
[0, 1, 1, 0, 0, 2, 2, 2]),
(['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf],
[('nan', -1), ('3.14', 0), ('A', 1), ('B', 2),
('-inf', 100), ('inf', 200)],
[1, 1, -1, 2, 100, 0, 200])]
(['A', 'A', 'B', 3.14],
{'3.14': 0, 'A': 1, 'B': 2},
[1, 1, 2, 0])]
ids = ["unicode", "single", "basic", "mixed"]

@pytest.fixture(autouse=True)
Expand All @@ -86,7 +68,7 @@ def test_convert(self, data, unitmap, exp):
np.testing.assert_array_equal(act, exp)

def test_axisinfo(self):
MUD = MockUnitData([(None, None)])
MUD = MockUnitData({None: None})
axis = FakeAxis(MUD)
ax = self.cc.axisinfo(None, axis)
assert isinstance(ax.majloc, cat.StrCategoryLocator)
Expand All @@ -99,8 +81,8 @@ def test_default_units(self):

class TestStrCategoryLocator(object):
def test_StrCategoryLocator(self):
locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
ticks = cat.StrCategoryLocator(locs)
locs = list(range(10))
ticks = cat.StrCategoryLocator({str(x): x for x in locs})
np.testing.assert_array_equal(ticks.tick_values(None, None), locs)


Expand Down Expand Up @@ -145,27 +127,18 @@ def data(self):
self.d = ['a', 'b', 'c', 'a']
self.dticks = [0, 1, 2]
self.dlabels = ['a', 'b', 'c']
unitmap = [('a', 0), ('b', 1), ('c', 2)]
unitmap = {'a': 0, 'b': 1, 'c': 2}
self.dunit_data = MockUnitData(unitmap)

@pytest.fixture
def missing_data(self):
self.dm = ['here', np.nan, 'here', 'there']
self.dmticks = [0, -1, 1]
self.dmlabels = ['here', 'nan', 'there']
unitmap = [('here', 0), ('nan', -1), ('there', 1)]
self.dmunit_data = MockUnitData(unitmap)

def axis_test(self, axis, ticks, labels, unit_data):
np.testing.assert_array_equal(axis.get_majorticklocs(), ticks)
assert lt(axis.get_majorticklabels()) == labels
np.testing.assert_array_equal(axis.unit_data.locs, unit_data.locs)
assert axis.unit_data.seq == unit_data.seq
assert axis.unit_data._mapping == unit_data._mapping

def test_plot_unicode(self):
words = ['Здравствуйте', 'привет']
locs = [0.0, 1.0]
unit_data = MockUnitData(zip(words, locs))
unit_data = MockUnitData(dict(zip(words, locs)))

fig, ax = plt.subplots()
ax.plot(words)
Expand All @@ -181,14 +154,6 @@ def test_plot_1d(self):

self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)

@pytest.mark.usefixtures("missing_data")
def test_plot_1d_missing(self):
fig, ax = plt.subplots()
ax.plot(self.dm)
fig.canvas.draw()

self.axis_test(ax.yaxis, self.dmticks, self.dmlabels, self.dmunit_data)

@pytest.mark.usefixtures("data")
@pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids)
def test_plot_bytes(self, bars):
Expand All @@ -208,28 +173,9 @@ def test_plot_numlike(self, bars):
ax.bar(bars, counts)
fig.canvas.draw()

unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)])
unitmap = MockUnitData({'1': 0, '11': 1, '3': 2})
self.axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap)

@pytest.mark.usefixtures("data", "missing_data")
def test_plot_2d(self):
fig, ax = plt.subplots()
ax.plot(self.dm, self.d)
fig.canvas.draw()

self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data)
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)

@pytest.mark.usefixtures("data", "missing_data")
def test_scatter_2d(self):

fig, ax = plt.subplots()
ax.scatter(self.dm, self.d)
fig.canvas.draw()

self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data)
self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data)

def test_plot_update(self):
fig, ax = plt.subplots()

Expand All @@ -240,6 +186,6 @@ def test_plot_update(self):

labels = ['a', 'b', 'd', 'c']
ticks = [0, 1, 2, 3]
unit_data = MockUnitData(list(zip(labels, ticks)))
unit_data = MockUnitData(dict(zip(labels, ticks)))

self.axis_test(ax.yaxis, ticks, labels, unit_data)