|
1 |
| -# -*- coding: utf-8 OA-*-za |
| 1 | +# -*- coding: utf-8 -*- |
2 | 2 | """
|
3 | 3 | catch all for categorical functions
|
4 | 4 | """
|
5 | 5 | from __future__ import (absolute_import, division, print_function,
|
6 | 6 | unicode_literals)
|
| 7 | + |
| 8 | +from collections import Iterable, OrderedDict |
| 9 | +import itertools |
| 10 | + |
7 | 11 | import six
|
8 | 12 |
|
9 | 13 | import numpy as np
|
|
13 | 17 |
|
14 | 18 | # np 1.6/1.7 support
|
15 | 19 | from distutils.version import LooseVersion
|
16 |
| -import collections |
17 | 20 |
|
| 21 | +VALID_TYPES = tuple(set(six.string_types + |
| 22 | + (bytes, six.text_type, np.str_, np.bytes_))) |
18 | 23 |
|
19 |
| -if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): |
20 |
| - def shim_array(data): |
21 |
| - return np.array(data, dtype=np.unicode) |
22 |
| -else: |
23 |
| - def shim_array(data): |
24 |
| - if (isinstance(data, six.string_types) or |
25 |
| - not isinstance(data, collections.Iterable)): |
26 |
| - data = [data] |
27 |
| - try: |
28 |
| - data = [str(d) for d in data] |
29 |
| - except UnicodeEncodeError: |
30 |
| - # this yields gibberish but unicode text doesn't |
31 |
| - # render under numpy1.6 anyway |
32 |
| - data = [d.encode('utf-8', 'ignore').decode('utf-8') |
33 |
| - for d in data] |
34 |
| - return np.array(data, dtype=np.unicode) |
| 24 | + |
| 25 | +def to_str(value): |
| 26 | + """Helper function to turn values to strings. |
| 27 | + """ |
| 28 | + # Note: This function is only used by StrCategoryFormatter |
| 29 | + if LooseVersion(np.__version__) < LooseVersion('1.7.0'): |
| 30 | + if (isinstance(value, (six.text_type, np.unicode))): |
| 31 | + value = value.encode('utf-8', 'ignore').decode('utf-8') |
| 32 | + if isinstance(value, (np.bytes_, six.binary_type)): |
| 33 | + value = value.decode(encoding='utf-8') |
| 34 | + elif not isinstance(value, (np.str_, six.string_types)): |
| 35 | + value = str(value) |
| 36 | + return value |
35 | 37 |
|
36 | 38 |
|
37 | 39 | class StrCategoryConverter(units.ConversionInterface):
|
38 | 40 | @staticmethod
|
39 | 41 | def convert(value, unit, axis):
|
40 |
| - """Uses axis.unit_data map to encode |
41 |
| - data as floats |
| 42 | + """Uses axis.units to encode string data as floats |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + value: string, iterable |
| 47 | + value or list of values to plot |
| 48 | + unit: |
| 49 | + axis: |
42 | 50 | """
|
43 |
| - value = np.atleast_1d(value) |
44 |
| - # try and update from here.... |
45 |
| - if hasattr(axis.unit_data, 'update'): |
46 |
| - for val in value: |
47 |
| - if isinstance(val, six.string_types): |
48 |
| - axis.unit_data.update(val) |
49 |
| - vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) |
| 51 | + # dtype = object preserves numerical pass throughs |
| 52 | + values = np.atleast_1d(np.array(value, dtype=object)) |
50 | 53 |
|
51 |
| - if isinstance(value, six.string_types): |
52 |
| - return vmap[value] |
| 54 | + # pass through sequence of non binary numbers |
| 55 | + if all((units.ConversionInterface.is_numlike(v) and |
| 56 | + not isinstance(v, VALID_TYPES)) for v in values): |
| 57 | + return np.asarray(values, dtype=float) |
53 | 58 |
|
54 |
| - vals = shim_array(value) |
| 59 | + # force an update so it also does type checking |
| 60 | + axis.units.update(values) |
55 | 61 |
|
56 |
| - for lab, loc in vmap.items(): |
57 |
| - vals[vals == lab] = loc |
| 62 | + str2idx = np.vectorize(axis.units._mapping.__getitem__, |
| 63 | + otypes=[float]) |
58 | 64 |
|
59 |
| - return vals.astype('float') |
| 65 | + mapped_value = str2idx(values) |
| 66 | + return mapped_value |
60 | 67 |
|
61 | 68 | @staticmethod
|
62 | 69 | def axisinfo(unit, axis):
|
63 |
| - majloc = StrCategoryLocator(axis.unit_data.locs) |
64 |
| - majfmt = StrCategoryFormatter(axis.unit_data.seq) |
| 70 | + """Sets the axis ticks and labels |
| 71 | + """ |
| 72 | + # locator and formatter take mapping dict because |
| 73 | + # args need to be pass by reference for updates |
| 74 | + majloc = StrCategoryLocator(axis.units) |
| 75 | + majfmt = StrCategoryFormatter(axis.units) |
65 | 76 | return units.AxisInfo(majloc=majloc, majfmt=majfmt)
|
66 | 77 |
|
67 | 78 | @staticmethod
|
68 |
| - def default_units(data, axis): |
69 |
| - # the conversion call stack is: |
| 79 | + def default_units(data=None, axis=None): |
| 80 | + # the conversion call stack is supposed to be |
70 | 81 | # default_units->axis_info->convert
|
71 |
| - if axis.unit_data is None: |
72 |
| - axis.unit_data = UnitData(data) |
| 82 | + if axis.units is None: |
| 83 | + axis.set_units(UnitData(data)) |
73 | 84 | else:
|
74 |
| - axis.unit_data.update(data) |
75 |
| - return None |
| 85 | + axis.units.update(data) |
| 86 | + return axis.units |
76 | 87 |
|
77 | 88 |
|
78 |
| -class StrCategoryLocator(ticker.FixedLocator): |
79 |
| - def __init__(self, locs): |
80 |
| - self.locs = locs |
81 |
| - self.nbins = None |
| 89 | +class StrCategoryLocator(ticker.Locator): |
| 90 | + """tick at every integer mapping of the string data""" |
| 91 | + def __init__(self, units): |
| 92 | + """ |
| 93 | + Parameters |
| 94 | + ----------- |
| 95 | + units: dict |
| 96 | + (string, integer) mapping |
| 97 | + """ |
| 98 | + self._units = units |
82 | 99 |
|
| 100 | + def __call__(self): |
| 101 | + return list(self._units._mapping.values()) |
83 | 102 |
|
84 |
| -class StrCategoryFormatter(ticker.FixedFormatter): |
85 |
| - def __init__(self, seq): |
86 |
| - self.seq = seq |
87 |
| - self.offset_string = '' |
| 103 | + def tick_values(self, vmin, vmax): |
| 104 | + return self() |
88 | 105 |
|
89 | 106 |
|
90 |
| -class UnitData(object): |
91 |
| - # debatable makes sense to special code missing values |
92 |
| - spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0} |
| 107 | +class StrCategoryFormatter(ticker.Formatter): |
| 108 | + """String representation of the data at every tick""" |
| 109 | + def __init__(self, units): |
| 110 | + """ |
| 111 | + Parameters |
| 112 | + ---------- |
| 113 | + units: dict |
| 114 | + (string, integer) mapping |
| 115 | + """ |
| 116 | + self._units = units |
93 | 117 |
|
94 |
| - def __init__(self, data): |
95 |
| - """Create mapping between unique categorical values |
96 |
| - and numerical identifier |
| 118 | + def __call__(self, x, pos=None): |
| 119 | + if pos is None: |
| 120 | + return "" |
| 121 | + r_mapping = {v: to_str(k) for k, v in self._units._mapping.items()} |
| 122 | + return r_mapping.get(int(np.round(x)), '') |
97 | 123 |
|
98 |
| - Parameters |
| 124 | + |
| 125 | +class UnitData(object): |
| 126 | + def __init__(self, data=None): |
| 127 | + """Create mapping between unique categorical values |
| 128 | + and integer identifiers |
99 | 129 | ----------
|
100 | 130 | data: iterable
|
101 |
| - sequence of values |
| 131 | + sequence of string values |
| 132 | + """ |
| 133 | + if data is None: |
| 134 | + data = () |
| 135 | + self._mapping = OrderedDict() |
| 136 | + self._counter = itertools.count(start=0) |
| 137 | + self.update(data) |
| 138 | + |
| 139 | + def update(self, data): |
| 140 | + """Maps new values to integer identifiers. |
| 141 | +
|
| 142 | + Paramters |
| 143 | + --------- |
| 144 | + data: iterable |
| 145 | + sequence of string values |
| 146 | +
|
| 147 | + Raises |
| 148 | + ------ |
| 149 | + TypeError |
| 150 | + If the value in data is not a string, unicode, bytes type |
102 | 151 | """
|
103 |
| - self.seq, self.locs = [], [] |
104 |
| - self._set_seq_locs(data, 0) |
105 |
| - |
106 |
| - def update(self, new_data): |
107 |
| - # so as not to conflict with spdict |
108 |
| - value = max(max(self.locs) + 1, 0) |
109 |
| - self._set_seq_locs(new_data, value) |
110 |
| - |
111 |
| - def _set_seq_locs(self, data, value): |
112 |
| - strdata = shim_array(data) |
113 |
| - new_s = [d for d in np.unique(strdata) if d not in self.seq] |
114 |
| - for ns in new_s: |
115 |
| - self.seq.append(ns) |
116 |
| - if ns in UnitData.spdict: |
117 |
| - self.locs.append(UnitData.spdict[ns]) |
118 |
| - else: |
119 |
| - self.locs.append(value) |
120 |
| - value += 1 |
| 152 | + |
| 153 | + if (isinstance(data, VALID_TYPES) or |
| 154 | + not isinstance(data, Iterable)): |
| 155 | + data = [data] |
| 156 | + |
| 157 | + unsorted_unique = OrderedDict.fromkeys(data) |
| 158 | + for val in unsorted_unique: |
| 159 | + if not isinstance(val, VALID_TYPES): |
| 160 | + raise TypeError("{val!r} is not a string".format(val=val)) |
| 161 | + if val not in self._mapping: |
| 162 | + self._mapping[val] = next(self._counter) |
121 | 163 |
|
122 | 164 |
|
123 | 165 | # Connects the convertor to matplotlib
|
|
0 commit comments