diff --git a/doc/api/api_changes.rst b/doc/api/api_changes.rst index 46a65f203e65..bdf3e1d6e76e 100644 --- a/doc/api/api_changes.rst +++ b/doc/api/api_changes.rst @@ -47,6 +47,7 @@ to ``'mask'``. As a side effect of this change, error bars which go negative now work as expected on log scales. + API Changes in 2.1.0 ==================== diff --git a/doc/api/api_changes/2018-01_TAC.rst b/doc/api/api_changes/2018-01_TAC.rst new file mode 100644 index 000000000000..ea1b2d948ae8 --- /dev/null +++ b/doc/api/api_changes/2018-01_TAC.rst @@ -0,0 +1,5 @@ + +Simplify String Categorical handling +------------------------------------ + + - Do not allow missing data. diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 59405a10d93c..733d610f2791 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -719,7 +719,9 @@ def __init__(self, axes, pickradius=15): self.label = self._get_label() self.labelpad = rcParams['axes.labelpad'] self.offsetText = self._get_offset_text() - self.unit_data = None + + self.majorTicks = [] + self.minorTicks = [] self.pickradius = pickradius # Initialize here for testing; later add API @@ -777,15 +779,14 @@ def limit_range_for_scale(self, vmin, vmax): return self._scale.limit_range_for_scale(vmin, vmax, self.get_minpos()) @property + @cbook.deprecated("2.2") def unit_data(self): - """Holds data that a ConversionInterface subclass uses - to convert between labels and indexes - """ - return self._unit_data + return self.units @unit_data.setter + @cbook.deprecated("2.2") def unit_data(self, unit_data): - self._unit_data = unit_data + self.set_units(unit_data) def get_children(self): children = [self.label, self.offsetText] diff --git a/lib/matplotlib/category.py b/lib/matplotlib/category.py index d2754d32fd3a..783306234071 100644 --- a/lib/matplotlib/category.py +++ b/lib/matplotlib/category.py @@ -1,128 +1,115 @@ -# -*- coding: utf-8 OA-*-za """ catch all for categorical functions """ + from __future__ import (absolute_import, division, print_function, unicode_literals) + +from collections import OrderedDict +import itertools +from matplotlib import ticker, units import six import numpy as np -import matplotlib.units as units -import matplotlib.ticker as ticker - -# np 1.6/1.7 support -from distutils.version import LooseVersion -import collections - - -if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): - def shim_array(data): - return np.array(data, dtype=np.unicode) -else: - def shim_array(data): - if (isinstance(data, six.string_types) or - not isinstance(data, collections.Iterable)): - data = [data] - try: - data = [str(d) for d in data] - except UnicodeEncodeError: - # this yields gibberish but unicode text doesn't - # render under numpy1.6 anyway - data = [d.encode('utf-8', 'ignore').decode('utf-8') - for d in data] - return np.array(data, dtype=np.unicode) - class StrCategoryConverter(units.ConversionInterface): @staticmethod def convert(value, unit, axis): - """Uses axis.unit_data map to encode - data as floats - """ - value = np.atleast_1d(value) - # try and update from here.... - if hasattr(axis.unit_data, 'update'): - for val in value: - if isinstance(val, six.string_types): - axis.unit_data.update(val) - vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) - - if isinstance(value, six.string_types): - return vmap[value] - - vals = shim_array(value) - - for lab, loc in vmap.items(): - vals[vals == lab] = loc - - return vals.astype('float') + """Use axis.units mapping to map categorical data to floats.""" + def getter(k): + if not isinstance(k, six.text_type): + k = k.decode('utf-8') + return axis.units._mapping[k] + # We also need to pass numbers through. + if np.issubdtype(np.asarray(value).dtype.type, np.number): + return value + else: + axis.units.update(value) + str2idx = np.vectorize(getter, otypes=[float]) + return str2idx(value) @staticmethod def axisinfo(unit, axis): - majloc = StrCategoryLocator(axis.unit_data.locs) - majfmt = StrCategoryFormatter(axis.unit_data.seq) + majloc = StrCategoryLocator(axis.units) + majfmt = StrCategoryFormatter(axis.units) return units.AxisInfo(majloc=majloc, majfmt=majfmt) @staticmethod def default_units(data, axis): - # the conversion call stack is: - # default_units->axis_info->convert - if axis.unit_data is None: - axis.unit_data = UnitData(data) - else: - axis.unit_data.update(data) - return None + return UnitData() + + +class StrCategoryLocator(ticker.Locator): + def __init__(self, unit_data): + self._unit_data = unit_data + + def __call__(self): + return list(self._unit_data._mapping.values()) + def tick_values(self, vmin, vmax): + return self() -class StrCategoryLocator(ticker.FixedLocator): - def __init__(self, locs): - self.locs = locs - self.nbins = None +class StrCategoryFormatter(ticker.Formatter): + def __init__(self, unit_data): + self._unit_data = unit_data -class StrCategoryFormatter(ticker.FixedFormatter): - def __init__(self, seq): - self.seq = seq - self.offset_string = '' + def __call__(self, x, pos=None): + if pos is None: + return "" + r_mapping = {v: k for k, v in self._unit_data._mapping.items()} + return r_mapping.get(int(x), '') class UnitData(object): - # debatable makes sense to special code missing values - spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0} + valid_types = tuple(set(six.string_types + + (bytes, six.text_type, np.str_, np.bytes_))) - def __init__(self, data): - """Create mapping between unique categorical values - and numerical identifier + def __init__(self, data=None): + """Create mapping between unique categorical values and numerical id. Parameters ---------- - data: iterable - sequence of values + data : Mapping[str, int] + The initial categories. May be `None`. + """ - self.seq, self.locs = [], [] - self._set_seq_locs(data, 0) - - def update(self, new_data): - # so as not to conflict with spdict - value = max(max(self.locs) + 1, 0) - self._set_seq_locs(new_data, value) - - def _set_seq_locs(self, data, value): - strdata = shim_array(data) - new_s = [d for d in np.unique(strdata) if d not in self.seq] - for ns in new_s: - self.seq.append(ns) - if ns in UnitData.spdict: - self.locs.append(UnitData.spdict[ns]) - else: - self.locs.append(value) - value += 1 + self._vals = [] + if data is None: + data = () + self._mapping = OrderedDict() + for k, v in OrderedDict(data).items(): + if not isinstance(k, self.valid_types): + raise TypeError("{val!r} is not a string".format(val=k)) + if not isinstance(k, six.text_type): + k = k.decode('utf-8') + self._mapping[k] = int(v) + if self._mapping: + start = max(self._mapping.values()) + 1 + else: + start = 0 + self._counter = itertools.count(start=start) + + def update(self, data): + if isinstance(data, self.valid_types): + data = [data] + sorted_unique = OrderedDict.fromkeys(data) + for val in sorted_unique: + if not isinstance(val, self.valid_types): + raise TypeError("{val!r} is not a string".format(val=val)) + if not isinstance(val, six.text_type): + val = val.decode('utf-8') + if val in self._mapping: + continue + self._vals.append(val) + self._mapping[val] = next(self._counter) # Connects the convertor to matplotlib + units.registry[str] = StrCategoryConverter() -units.registry[np.str_] = StrCategoryConverter() -units.registry[six.text_type] = StrCategoryConverter() units.registry[bytes] = StrCategoryConverter() +units.registry[np.str_] = StrCategoryConverter() units.registry[np.bytes_] = StrCategoryConverter() +units.registry[six.text_type] = StrCategoryConverter() diff --git a/lib/matplotlib/lines.py b/lib/matplotlib/lines.py index a9999b419f15..9fd097f648ee 100644 --- a/lib/matplotlib/lines.py +++ b/lib/matplotlib/lines.py @@ -676,8 +676,8 @@ def recache(self, always=False): if nanmask.any(): self._x_filled = self._x.copy() indices = np.arange(len(x)) - self._x_filled[nanmask] = np.interp(indices[nanmask], - indices[~nanmask], self._x[~nanmask]) + self._x_filled[nanmask] = np.interp( + indices[nanmask], indices[~nanmask], self._x[~nanmask]) else: self._x_filled = self._x diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 3e86a6296fa5..8bf81f042361 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -5158,6 +5158,7 @@ def test_pandas_pcolormesh(pd): fig, ax = plt.subplots() ax.pcolormesh(time, depth, data) + fig.canvas.draw() def test_pandas_indexing_dates(pd): @@ -5169,6 +5170,7 @@ def test_pandas_indexing_dates(pd): without_zero_index = df[np.array(df.index) % 2 == 1].copy() ax.plot('dates', 'values', data=without_zero_index) + ax.figure.canvas.draw() def test_pandas_errorbar_indexing(pd): @@ -5177,6 +5179,7 @@ def test_pandas_errorbar_indexing(pd): index=[1, 2, 3, 4, 5]) fig, ax = plt.subplots() ax.errorbar('x', 'y', xerr='xe', yerr='ye', data=df) + fig.canvas.draw() def test_pandas_indexing_hist(pd): @@ -5184,6 +5187,7 @@ def test_pandas_indexing_hist(pd): ser_2 = ser_1.iloc[1:] fig, axes = plt.subplots() axes.hist(ser_2) + fig.canvas.draw() def test_pandas_bar_align_center(pd): diff --git a/lib/matplotlib/tests/test_category.py b/lib/matplotlib/tests/test_category.py index 17a90d47f379..c744bf57bf69 100644 --- a/lib/matplotlib/tests/test_category.py +++ b/lib/matplotlib/tests/test_category.py @@ -2,58 +2,53 @@ """Catch all for categorical functions""" from __future__ import absolute_import, division, print_function +import six import pytest import numpy as np +from numpy.testing import assert_array_equal import matplotlib.pyplot as plt import matplotlib.category as cat - -import unittest +from matplotlib.axes import Axes class TestUnitData(object): - testdata = [("hello world", ["hello world"], [0]), - (u"Здравствуйте мир", [u"Здравствуйте мир"], [0]), - (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - ['-inf', '3.14', 'A', 'B', 'inf', 'nan'], - [-3.0, 0, 1, 2, -2.0, -1.0])] + test_cases = [('single', ("hello world", ["hello world"], [0])), + ('unicode', (u"Здравствуйте мир", [u"Здравствуйте мир"], [0])), + ('mixed', (['A', 'A', 'B'], + ['A', 'B', ], + [0, 1]))] - ids = ["single", "unicode", "mixed"] + ids, data = zip(*test_cases) - @pytest.mark.parametrize("data, seq, locs", testdata, ids=ids) + @pytest.mark.parametrize("data, seq, locs", data, ids=ids) def test_unit(self, data, seq, locs): - act = cat.UnitData(data) - assert act.seq == seq - assert act.locs == locs + act = cat.UnitData() + for v in data: + act.update(data) + assert list(act._mapping.keys()) == seq + assert list(act._mapping.values()) == locs def test_update_map(self): - data = ['a', 'd'] oseq = ['a', 'd'] olocs = [0, 1] - data_update = ['b', 'd', 'e', np.inf] - useq = ['a', 'd', 'b', 'e', 'inf'] - ulocs = [0, 1, 2, 3, -2] + data_update = ['b', 'd', 'e'] + useq = ['a', 'd', 'b', 'e'] + ulocs = [0, 1, 2, 3] - unitdata = cat.UnitData(data) - assert unitdata.seq == oseq - assert unitdata.locs == olocs + unitdata = cat.UnitData(zip(oseq, olocs)) + assert list(unitdata._mapping.keys()) == oseq + assert list(unitdata._mapping.values()) == olocs unitdata.update(data_update) - assert unitdata.seq == useq - assert unitdata.locs == ulocs + assert list(unitdata._mapping.keys()) == useq + assert list(unitdata._mapping.values()) == ulocs class FakeAxis(object): def __init__(self, unit_data): - self.unit_data = unit_data - - -class MockUnitData(object): - def __init__(self, data): - seq, locs = zip(*data) - self.seq = list(seq) - self.locs = list(locs) + self.units = unit_data class TestStrCategoryConverter(object): @@ -62,30 +57,33 @@ class TestStrCategoryConverter(object): ref: /pandas/tseries/tests/test_converter.py /pandas/tests/test_algos.py:TestFactorize """ - testdata = [(u"Здравствуйте мир", [(u"Здравствуйте мир", 42)], 42), - ("hello world", [("hello world", 42)], 42), - (['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c'], - [('a', 0), ('b', 1), ('c', 2)], - [0, 1, 1, 0, 0, 2, 2, 2]), - (['A', 'A', np.nan, 'B', -np.inf, 3.14, np.inf], - [('nan', -1), ('3.14', 0), ('A', 1), ('B', 2), - ('-inf', 100), ('inf', 200)], - [1, 1, -1, 2, 100, 0, 200])] - ids = ["unicode", "single", "basic", "mixed"] + + test_cases = [("unicode", {u"Здравствуйте мир": 42}), + ("ascii", {"hello world": 42}), + ("single", {'a': 0, 'b': 1, 'c': 2}), + ("single bytes", {b'a': 0, b'b': 1, b'c': 2}), + ("mixed bytes", {b'a': 0, 'b': 1, b'c': 2}), + ("single + values>10", {'A': 0, 'B': 1, 'C': 2, + 'D': 3, 'E': 4, 'F': 5, + 'G': 6, 'H': 7, 'I': 8, + 'J': 9, 'K': 10})] + + ids, unitmaps = zip(*test_cases) @pytest.fixture(autouse=True) def mock_axis(self, request): self.cc = cat.StrCategoryConverter() - @pytest.mark.parametrize("data, unitmap, exp", testdata, ids=ids) - def test_convert(self, data, unitmap, exp): - MUD = MockUnitData(unitmap) + @pytest.mark.parametrize("unitmap", unitmaps, ids=ids) + def test_convert(self, unitmap): + data, exp = zip(*six.iteritems(unitmap)) + MUD = cat.UnitData(unitmap) axis = FakeAxis(MUD) act = self.cc.convert(data, None, axis) - np.testing.assert_array_equal(act, exp) + np.testing.assert_allclose(act, exp) def test_axisinfo(self): - MUD = MockUnitData([(None, None)]) + MUD = cat.UnitData() axis = FakeAxis(MUD) ax = self.cc.axisinfo(None, axis) assert isinstance(ax.majloc, cat.StrCategoryLocator) @@ -93,145 +91,116 @@ def test_axisinfo(self): def test_default_units(self): axis = FakeAxis(None) - assert self.cc.default_units(["a"], axis) is None + assert isinstance(self.cc.default_units(["a"], axis), cat.UnitData) class TestStrCategoryLocator(object): def test_StrCategoryLocator(self): locs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - ticks = cat.StrCategoryLocator(locs) + u = cat.UnitData() + for j in range(11): + u.update(str(j)) + ticks = cat.StrCategoryLocator(u) np.testing.assert_array_equal(ticks.tick_values(None, None), locs) -class TestStrCategoryFormatter(unittest.TestCase): +class TestStrCategoryFormatter(object): def test_StrCategoryFormatter(self): seq = ["hello", "world", "hi"] - labels = cat.StrCategoryFormatter(seq) - assert labels('a', 1) == "world" + u = cat.UnitData() + u.update(seq) + labels = cat.StrCategoryFormatter(u) + assert labels(1, 1) == "world" def test_StrCategoryFormatterUnicode(self): - seq = ["Здравствуйте", "привет"] - labels = cat.StrCategoryFormatter(seq) - assert labels('a', 1) == "привет" + seq = [u"Здравствуйте", u"привет"] + u = cat.UnitData() + u.update(seq) + labels = cat.StrCategoryFormatter(u) + assert labels(1, 1) == u"привет" def lt(tl): return [l.get_text() for l in tl] -class TestPlot(object): - bytes_data = [ - ['a', 'b', 'c'], - [b'a', b'b', b'c'], - np.array([b'a', b'b', b'c']) - ] - - bytes_ids = ['string list', 'bytes list', 'bytes ndarray'] +def axis_test(axis, ticks, labels, unit_data): + assert axis.get_majorticklocs() == ticks + assert lt(axis.get_majorticklabels()) == labels + assert axis.units._mapping == unit_data._mapping - numlike_data = [ - ['1', '11', '3'], - np.array(['1', '11', '3']), - [b'1', b'11', b'3'], - np.array([b'1', b'11', b'3']), - ] - numlike_ids = [ - 'string list', 'string ndarray', 'bytes list', 'bytes ndarray' - ] +class TestBarsBytes(object): + bytes_cases = [('string list', ['a', 'b', 'c']), + ('bytes list', [b'a', b'b', b'c']), + ('mixed list', [b'a', 'b', b'c']), + ] - @pytest.fixture - def data(self): - self.d = ['a', 'b', 'c', 'a'] - self.dticks = [0, 1, 2] - self.dlabels = ['a', 'b', 'c'] - unitmap = [('a', 0), ('b', 1), ('c', 2)] - self.dunit_data = MockUnitData(unitmap) + bytes_ids, bytes_data = zip(*bytes_cases) - @pytest.fixture - def missing_data(self): - self.dm = ['here', np.nan, 'here', 'there'] - self.dmticks = [0, -1, 1] - self.dmlabels = ['here', 'nan', 'there'] - unitmap = [('here', 0), ('nan', -1), ('there', 1)] - self.dmunit_data = MockUnitData(unitmap) - - def axis_test(self, axis, ticks, labels, unit_data): - np.testing.assert_array_equal(axis.get_majorticklocs(), ticks) - assert lt(axis.get_majorticklabels()) == labels - np.testing.assert_array_equal(axis.unit_data.locs, unit_data.locs) - assert axis.unit_data.seq == unit_data.seq - - def test_plot_unicode(self): - words = [u'Здравствуйте', u'привет'] - locs = [0.0, 1.0] - unit_data = MockUnitData(zip(words, locs)) - - fig, ax = plt.subplots() - ax.plot(words) - fig.canvas.draw() + @pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids) + def test_plot_bytes(self, bars): - self.axis_test(ax.yaxis, locs, words, unit_data) + unitmap = cat.UnitData([('a', 0), ('b', 1), ('c', 2)]) - @pytest.mark.usefixtures("data") - def test_plot_1d(self): + counts = np.array([4, 6, 5]) fig, ax = plt.subplots() - ax.plot(self.d) + ax.bar(bars, counts) fig.canvas.draw() + axis_test(ax.xaxis, [0, 1, 2], ['a', 'b', 'c'], unitmap) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - @pytest.mark.usefixtures("missing_data") - def test_plot_1d_missing(self): - fig, ax = plt.subplots() - ax.plot(self.dm) - fig.canvas.draw() +class TestBarsNumlike(object): + numlike_cases = [('string list', ['1', '11', '3']), + ('string ndarray', np.array(['1', '11', '3']))] - self.axis_test(ax.yaxis, self.dmticks, self.dmlabels, self.dmunit_data) + numlike_ids, numlike_data = zip(*numlike_cases) - @pytest.mark.usefixtures("data") - @pytest.mark.parametrize("bars", bytes_data, ids=bytes_ids) - def test_plot_bytes(self, bars): + @pytest.mark.parametrize("bars", numlike_data, ids=numlike_ids) + def test_plot_numlike(self, bars): counts = np.array([4, 6, 5]) fig, ax = plt.subplots() ax.bar(bars, counts) fig.canvas.draw() - self.axis_test(ax.xaxis, self.dticks, self.dlabels, self.dunit_data) + unitmap = cat.UnitData([('1', 0), ('11', 1), ('3', 2)]) + axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap) - @pytest.mark.parametrize("bars", numlike_data, ids=numlike_ids) - def test_plot_numlike(self, bars): - counts = np.array([4, 6, 5]) - fig, ax = plt.subplots() - ax.bar(bars, counts) - fig.canvas.draw() +class TestPlotTypes(object): + @pytest.fixture + def complete_data(self): + self.complete = ['a', 'b', 'c', 'a'] + self.complete_ticks = [0, 1, 2] + self.complete_labels = ['a', 'b', 'c'] + unitmap = [('a', 0), ('b', 1), ('c', 2)] + self.complete_unit_data = cat.UnitData(unitmap) - unitmap = MockUnitData([('1', 0), ('11', 1), ('3', 2)]) - self.axis_test(ax.xaxis, [0, 1, 2], ['1', '11', '3'], unitmap) + def test_plot_unicode(self): + words = [u'Здравствуйте', u'привет'] + locs = [0.0, 1.0] + unit_data = cat.UnitData(zip(words, locs)) - @pytest.mark.usefixtures("data", "missing_data") - def test_plot_2d(self): fig, ax = plt.subplots() - ax.plot(self.dm, self.d) + ax.plot(words) fig.canvas.draw() - self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - - @pytest.mark.usefixtures("data", "missing_data") - def test_scatter_2d(self): + axis_test(ax.yaxis, locs, words, unit_data) + @pytest.mark.usefixtures("complete_data") + def test_plot_yaxis(self): fig, ax = plt.subplots() - ax.scatter(self.dm, self.d) + ax.plot(self.complete) fig.canvas.draw() + axis_test(ax.yaxis, self.complete_ticks, self.complete_labels, + self.complete_unit_data) - self.axis_test(ax.xaxis, self.dmticks, self.dmlabels, self.dmunit_data) - self.axis_test(ax.yaxis, self.dticks, self.dlabels, self.dunit_data) - def test_plot_update(self): - fig, ax = plt.subplots() +class TestUpdatePlot(object): + def test_update_plot(self): + fig, ax = plt.subplots() ax.plot(['a', 'b']) ax.plot(['a', 'b', 'd']) ax.plot(['b', 'c', 'd']) @@ -239,13 +208,12 @@ def test_plot_update(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unit_data = MockUnitData(zip(labels, ticks)) + unitmap = cat.UnitData(list(zip(labels, ticks))) - self.axis_test(ax.yaxis, ticks, labels, unit_data) + axis_test(ax.yaxis, ticks, labels, unitmap) - def test_scatter_update(self): + def test_update_scatter(self): fig, ax = plt.subplots() - ax.scatter(['a', 'b'], [0., 3.]) ax.scatter(['a', 'b', 'd'], [1., 2., 3.]) ax.scatter(['b', 'c', 'd'], [4., 1., 2.]) @@ -253,5 +221,65 @@ def test_scatter_update(self): labels = ['a', 'b', 'd', 'c'] ticks = [0, 1, 2, 3] - unit_data = MockUnitData(list(zip(labels, ticks))) - self.axis_test(ax.xaxis, ticks, labels, unit_data) + unitmap = cat.UnitData(list(zip(labels, ticks))) + + axis_test(ax.xaxis, ticks, labels, unitmap) + + +@pytest.fixture +def ax(): + return plt.figure().subplots() + + +@pytest.mark.parametrize( + "data, expected_indices, expected_labels", + [([u"Здравствуйте мир"], [0], [u"Здравствуйте мир"]), + (["a", "b", "b", "a", "c", "c"], [0, 1, 1, 0, 2, 2], ["a", "b", "c"]), + (["foo", "bar"], range(2), ["foo", "bar"]), + ([b"foo", "bar"], range(2), ["foo", "bar"]), + (np.array(["1", "11", "3"]), range(3), ["1", "11", "3"])]) +def test_simple(ax, data, expected_indices, expected_labels): + l, = ax.plot(data) + assert_array_equal(l.get_ydata(orig=False), expected_indices) + assert isinstance(ax.yaxis.major.locator, cat.StrCategoryLocator) + assert isinstance(ax.yaxis.major.formatter, cat.StrCategoryFormatter) + ax.figure.canvas.draw() + labels = [label.get_text() for label in ax.yaxis.get_majorticklabels()] + assert labels == expected_labels + + +def test_default_units(ax): + ax.plot(["a"]) + du = ax.yaxis.converter.default_units(["a"], ax.yaxis) + assert isinstance(du, cat.UnitData) + + +def test_update(ax): + l1, = ax.plot(["a", "d"]) + l2, = ax.plot(["b", "d", "e"]) + assert_array_equal(l1.get_ydata(orig=False), [0, 1]) + assert_array_equal(l2.get_ydata(orig=False), [2, 1, 3]) + assert ax.yaxis.units._vals == ["a", "d", "b", "e"] + assert ax.yaxis.units._mapping == {"a": 0, "d": 1, "b": 2, "e": 3} + + +@pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter, Axes.bar]) +def test_StrCategoryLocator(ax, plotter): + ax.plot(["a", "b", "c"]) + assert_array_equal(ax.yaxis.major.locator(), range(3)) + + +@pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter, Axes.bar]) +def test_StrCategoryFormatter(ax, plotter): + plotter(ax, range(2), ["hello", u"мир"]) + assert ax.yaxis.major.formatter(0, 0) == "hello" + assert ax.yaxis.major.formatter(1, 1) == u"мир" + assert ax.yaxis.major.formatter(2, 2) == "" + assert ax.yaxis.major.formatter(0, None) == "" + + +@pytest.mark.parametrize("plotter", [Axes.plot, Axes.scatter]) +@pytest.mark.parametrize("xdata", [[1, 'a'], ['a', 1]]) +def test_mixed_failures(ax, plotter, xdata): + with pytest.raises((ValueError, TypeError)): + plotter(ax, xdata, [1, 2])