|
1 |
| -# -*- coding: utf-8 OA-*-za |
| 1 | +# -*- coding: utf-8 -*- |
2 | 2 | """
|
3 |
| -catch all for categorical functions |
| 3 | +Module that allows plotting of string "category" data. i.e. |
| 4 | +``plot(['d', 'f', 'a'],[1, 2, 3])`` will plot three points with x-axis |
| 5 | +values of 'd', 'f', 'a'. |
| 6 | +
|
| 7 | +See :doc:`/gallery/lines_bars_and_markers/categorical_variables` for an |
| 8 | +example. |
| 9 | +
|
| 10 | +The module uses Matplotlib's `matplotlib.units` mechanism to convert from |
| 11 | +strings to integers, provides a tick locator and formatter, and the |
| 12 | +class:`.UnitData` that creates and stores the string-to-integer mapping. |
4 | 13 | """
|
5 | 14 | from __future__ import (absolute_import, division, print_function,
|
6 | 15 | unicode_literals)
|
| 16 | + |
| 17 | +from collections import OrderedDict |
| 18 | +import itertools |
| 19 | + |
7 | 20 | import six
|
8 | 21 |
|
| 22 | + |
9 | 23 | import numpy as np
|
10 | 24 |
|
11 | 25 | import matplotlib.units as units
|
12 | 26 | import matplotlib.ticker as ticker
|
13 | 27 |
|
14 | 28 | # np 1.6/1.7 support
|
15 | 29 | from distutils.version import LooseVersion
|
16 |
| -import collections |
17 |
| - |
18 |
| - |
19 |
| -if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): |
20 |
| - def shim_array(data): |
21 |
| - return np.array(data, dtype=np.unicode) |
22 |
| -else: |
23 |
| - def shim_array(data): |
24 |
| - if (isinstance(data, six.string_types) or |
25 |
| - not isinstance(data, collections.Iterable)): |
26 |
| - data = [data] |
27 |
| - try: |
28 |
| - data = [str(d) for d in data] |
29 |
| - except UnicodeEncodeError: |
30 |
| - # this yields gibberish but unicode text doesn't |
31 |
| - # render under numpy1.6 anyway |
32 |
| - data = [d.encode('utf-8', 'ignore').decode('utf-8') |
33 |
| - for d in data] |
34 |
| - return np.array(data, dtype=np.unicode) |
| 30 | + |
| 31 | +VALID_TYPES = tuple(set(six.string_types + |
| 32 | + (bytes, six.text_type, np.str_, np.bytes_))) |
35 | 33 |
|
36 | 34 |
|
37 | 35 | class StrCategoryConverter(units.ConversionInterface):
|
38 | 36 | @staticmethod
|
39 | 37 | def convert(value, unit, axis):
|
40 |
| - """Uses axis.unit_data map to encode |
41 |
| - data as floats |
| 38 | + """Converts strings in value to floats using |
| 39 | + mapping information store in the unit object |
| 40 | +
|
| 41 | + Parameters |
| 42 | + ---------- |
| 43 | + value : string or iterable |
| 44 | + value or list of values to be converted |
| 45 | + unit : :class:`.UnitData` |
| 46 | + object string unit information for value |
| 47 | + axis : :class:`~matplotlib.Axis.axis` |
| 48 | + axis on which the converted value is plotted |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + mapped_ value : float or ndarray[float] |
| 53 | +
|
| 54 | + .. note:: axis is not used in this function |
42 | 55 | """
|
43 |
| - value = np.atleast_1d(value) |
44 |
| - # try and update from here.... |
45 |
| - if hasattr(axis.unit_data, 'update'): |
46 |
| - for val in value: |
47 |
| - if isinstance(val, six.string_types): |
48 |
| - axis.unit_data.update(val) |
49 |
| - vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) |
| 56 | + # dtype = object preserves numerical pass throughs |
| 57 | + values = np.atleast_1d(np.array(value, dtype=object)) |
50 | 58 |
|
51 |
| - if isinstance(value, six.string_types): |
52 |
| - return vmap[value] |
| 59 | + # pass through sequence of non binary numbers |
| 60 | + if all((units.ConversionInterface.is_numlike(v) and |
| 61 | + not isinstance(v, VALID_TYPES)) for v in values): |
| 62 | + return np.asarray(values, dtype=float) |
53 | 63 |
|
54 |
| - vals = shim_array(value) |
| 64 | + # force an update so it also does type checking |
| 65 | + unit.update(values) |
55 | 66 |
|
56 |
| - for lab, loc in vmap.items(): |
57 |
| - vals[vals == lab] = loc |
| 67 | + str2idx = np.vectorize(unit._mapping.__getitem__, |
| 68 | + otypes=[float]) |
58 | 69 |
|
59 |
| - return vals.astype('float') |
| 70 | + mapped_value = str2idx(values) |
| 71 | + return mapped_value |
60 | 72 |
|
61 | 73 | @staticmethod
|
62 | 74 | def axisinfo(unit, axis):
|
63 |
| - majloc = StrCategoryLocator(axis.unit_data.locs) |
64 |
| - majfmt = StrCategoryFormatter(axis.unit_data.seq) |
| 75 | + """Sets the default axis ticks and labels |
| 76 | +
|
| 77 | + Parameters |
| 78 | + --------- |
| 79 | + unit : :class:`.UnitData` |
| 80 | + object string unit information for value |
| 81 | + axis : :class:`~matplotlib.Axis.axis` |
| 82 | + axis for which information is being set |
| 83 | +
|
| 84 | + Returns |
| 85 | + ------- |
| 86 | + :class:~matplotlib.units.AxisInfo~ |
| 87 | + Information to support default tick labeling |
| 88 | +
|
| 89 | + .. note: axis is not used |
| 90 | + """ |
| 91 | + # locator and formatter take mapping dict because |
| 92 | + # args need to be pass by reference for updates |
| 93 | + majloc = StrCategoryLocator(unit._mapping) |
| 94 | + majfmt = StrCategoryFormatter(unit._mapping) |
65 | 95 | return units.AxisInfo(majloc=majloc, majfmt=majfmt)
|
66 | 96 |
|
67 | 97 | @staticmethod
|
68 | 98 | def default_units(data, axis):
|
69 |
| - # the conversion call stack is: |
| 99 | + """ Sets and updates the :class:`~matplotlib.Axis.axis~ units |
| 100 | +
|
| 101 | + Parameters |
| 102 | + ---------- |
| 103 | + data : string or iterable of strings |
| 104 | + axis : :class:`~matplotlib.Axis.axis` |
| 105 | + axis on which the data is plotted |
| 106 | +
|
| 107 | + Returns |
| 108 | + ------- |
| 109 | + class:~.UnitData~ |
| 110 | + object storing string to integer mapping |
| 111 | + """ |
| 112 | + # the conversion call stack is supposed to be |
70 | 113 | # default_units->axis_info->convert
|
71 |
| - if axis.unit_data is None: |
72 |
| - axis.unit_data = UnitData(data) |
| 114 | + if axis.units is None: |
| 115 | + axis.set_units(UnitData(data)) |
73 | 116 | else:
|
74 |
| - axis.unit_data.update(data) |
75 |
| - return None |
| 117 | + axis.units.update(data) |
| 118 | + return axis.units |
| 119 | + |
| 120 | + |
| 121 | +class StrCategoryLocator(ticker.Locator): |
| 122 | + """tick at every integer mapping of the string data""" |
| 123 | + def __init__(self, units_mapping): |
| 124 | + """ |
| 125 | + Parameters |
| 126 | + ----------- |
| 127 | + units: dict |
| 128 | + string:integer mapping |
| 129 | + """ |
| 130 | + self._units = units_mapping |
76 | 131 |
|
| 132 | + def __call__(self): |
| 133 | + return list(self._units.values()) |
77 | 134 |
|
78 |
| -class StrCategoryLocator(ticker.FixedLocator): |
79 |
| - def __init__(self, locs): |
80 |
| - self.locs = locs |
81 |
| - self.nbins = None |
| 135 | + def tick_values(self, vmin, vmax): |
| 136 | + return self() |
82 | 137 |
|
83 | 138 |
|
84 |
| -class StrCategoryFormatter(ticker.FixedFormatter): |
85 |
| - def __init__(self, seq): |
86 |
| - self.seq = seq |
87 |
| - self.offset_string = '' |
| 139 | +class StrCategoryFormatter(ticker.Formatter): |
| 140 | + """String representation of the data at every tick""" |
| 141 | + def __init__(self, units_mapping): |
| 142 | + """ |
| 143 | + Parameters |
| 144 | + ---------- |
| 145 | + units: dict |
| 146 | + string:integer mapping |
| 147 | + """ |
| 148 | + self._units = units_mapping |
88 | 149 |
|
| 150 | + def __call__(self, x, pos=None): |
| 151 | + if pos is None: |
| 152 | + return "" |
| 153 | + r_mapping = {v: StrCategoryFormatter._text(k) |
| 154 | + for k, v in self._units.items()} |
| 155 | + return r_mapping.get(int(np.round(x)), '') |
89 | 156 |
|
90 |
| -class UnitData(object): |
91 |
| - # debatable makes sense to special code missing values |
92 |
| - spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0} |
| 157 | + @staticmethod |
| 158 | + def _text(value): |
| 159 | + """Converts text values into `utf-8` or `ascii` strings |
| 160 | + """ |
| 161 | + if LooseVersion(np.__version__) < LooseVersion('1.7.0'): |
| 162 | + if (isinstance(value, (six.text_type, np.unicode))): |
| 163 | + value = value.encode('utf-8', 'ignore').decode('utf-8') |
| 164 | + if isinstance(value, (np.bytes_, six.binary_type)): |
| 165 | + value = value.decode(encoding='utf-8') |
| 166 | + elif not isinstance(value, (np.str_, six.string_types)): |
| 167 | + value = str(value) |
| 168 | + return value |
93 | 169 |
|
94 |
| - def __init__(self, data): |
95 |
| - """Create mapping between unique categorical values |
96 |
| - and numerical identifier |
97 | 170 |
|
98 |
| - Parameters |
| 171 | +class UnitData(object): |
| 172 | + def __init__(self, data=None): |
| 173 | + """Create mapping between unique categorical values |
| 174 | + and integer identifiers |
99 | 175 | ----------
|
100 | 176 | data: iterable
|
101 |
| - sequence of values |
| 177 | + sequence of string values |
102 | 178 | """
|
103 |
| - self.seq, self.locs = [], [] |
104 |
| - self._set_seq_locs(data, 0) |
105 |
| - |
106 |
| - def update(self, new_data): |
107 |
| - # so as not to conflict with spdict |
108 |
| - value = max(max(self.locs) + 1, 0) |
109 |
| - self._set_seq_locs(new_data, value) |
110 |
| - |
111 |
| - def _set_seq_locs(self, data, value): |
112 |
| - strdata = shim_array(data) |
113 |
| - new_s = [d for d in np.unique(strdata) if d not in self.seq] |
114 |
| - for ns in new_s: |
115 |
| - self.seq.append(ns) |
116 |
| - if ns in UnitData.spdict: |
117 |
| - self.locs.append(UnitData.spdict[ns]) |
118 |
| - else: |
119 |
| - self.locs.append(value) |
120 |
| - value += 1 |
| 179 | + self._mapping = OrderedDict() |
| 180 | + self._counter = itertools.count(start=0) |
| 181 | + if data is not None: |
| 182 | + self.update(data) |
| 183 | + |
| 184 | + def update(self, data): |
| 185 | + """Maps new values to integer identifiers. |
| 186 | +
|
| 187 | + Paramters |
| 188 | + --------- |
| 189 | + data: iterable |
| 190 | + sequence of string values |
| 191 | +
|
| 192 | + Raises |
| 193 | + ------ |
| 194 | + TypeError |
| 195 | + If the value in data is not a string, unicode, bytes type |
| 196 | + """ |
| 197 | + data = np.atleast_1d(np.array(data, dtype=object)) |
| 198 | + |
| 199 | + for val in OrderedDict.fromkeys(data): |
| 200 | + if not isinstance(val, VALID_TYPES): |
| 201 | + raise TypeError("{val!r} is not a string".format(val=val)) |
| 202 | + if val not in self._mapping: |
| 203 | + self._mapping[val] = next(self._counter) |
121 | 204 |
|
122 | 205 |
|
123 | 206 | # Connects the convertor to matplotlib
|
|
0 commit comments