-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
Categorical refactor #10212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Categorical refactor #10212
Changes from all commits
4442304
98096e2
2bd832d
4195dbf
04edaa3
283ca08
6f73c24
ac355c8
6457f4f
3972c38
3c7b185
87bfe1e
d7a32f6
f68231d
78c5909
c0e4851
1c280f5
6beebcc
7896cdc
6cc1841
0ca66d2
5515d3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
|
||
Simplify String Categorical handling | ||
------------------------------------ | ||
|
||
- Do not allow missing data. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,128 +1,115 @@ | ||
# -*- coding: utf-8 OA-*-za | ||
""" | ||
catch all for categorical functions | ||
""" | ||
|
||
from __future__ import (absolute_import, division, print_function, | ||
unicode_literals) | ||
|
||
from collections import OrderedDict | ||
import itertools | ||
from matplotlib import ticker, units | ||
import six | ||
|
||
import numpy as np | ||
|
||
import matplotlib.units as units | ||
import matplotlib.ticker as ticker | ||
|
||
# np 1.6/1.7 support | ||
from distutils.version import LooseVersion | ||
import collections | ||
|
||
|
||
if LooseVersion(np.__version__) >= LooseVersion('1.8.0'): | ||
def shim_array(data): | ||
return np.array(data, dtype=np.unicode) | ||
else: | ||
def shim_array(data): | ||
if (isinstance(data, six.string_types) or | ||
not isinstance(data, collections.Iterable)): | ||
data = [data] | ||
try: | ||
data = [str(d) for d in data] | ||
except UnicodeEncodeError: | ||
# this yields gibberish but unicode text doesn't | ||
# render under numpy1.6 anyway | ||
data = [d.encode('utf-8', 'ignore').decode('utf-8') | ||
for d in data] | ||
return np.array(data, dtype=np.unicode) | ||
|
||
|
||
class StrCategoryConverter(units.ConversionInterface): | ||
@staticmethod | ||
def convert(value, unit, axis): | ||
"""Uses axis.unit_data map to encode | ||
data as floats | ||
""" | ||
value = np.atleast_1d(value) | ||
# try and update from here.... | ||
if hasattr(axis.unit_data, 'update'): | ||
for val in value: | ||
if isinstance(val, six.string_types): | ||
axis.unit_data.update(val) | ||
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs)) | ||
|
||
if isinstance(value, six.string_types): | ||
return vmap[value] | ||
|
||
vals = shim_array(value) | ||
|
||
for lab, loc in vmap.items(): | ||
vals[vals == lab] = loc | ||
|
||
return vals.astype('float') | ||
"""Use axis.units mapping to map categorical data to floats.""" | ||
def getter(k): | ||
if not isinstance(k, six.text_type): | ||
k = k.decode('utf-8') | ||
return axis.units._mapping[k] | ||
# We also need to pass numbers through. | ||
if np.issubdtype(np.asarray(value).dtype.type, np.number): | ||
return value | ||
else: | ||
axis.units.update(value) | ||
str2idx = np.vectorize(getter, otypes=[float]) | ||
return str2idx(value) | ||
|
||
@staticmethod | ||
def axisinfo(unit, axis): | ||
majloc = StrCategoryLocator(axis.unit_data.locs) | ||
majfmt = StrCategoryFormatter(axis.unit_data.seq) | ||
majloc = StrCategoryLocator(axis.units) | ||
majfmt = StrCategoryFormatter(axis.units) | ||
return units.AxisInfo(majloc=majloc, majfmt=majfmt) | ||
|
||
@staticmethod | ||
def default_units(data, axis): | ||
# the conversion call stack is: | ||
# default_units->axis_info->convert | ||
if axis.unit_data is None: | ||
axis.unit_data = UnitData(data) | ||
else: | ||
axis.unit_data.update(data) | ||
return None | ||
return UnitData() | ||
|
||
|
||
class StrCategoryLocator(ticker.Locator): | ||
def __init__(self, unit_data): | ||
self._unit_data = unit_data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for consistency, shouldn't this be units? |
||
|
||
def __call__(self): | ||
return list(self._unit_data._mapping.values()) | ||
|
||
def tick_values(self, vmin, vmax): | ||
return self() | ||
|
||
class StrCategoryLocator(ticker.FixedLocator): | ||
def __init__(self, locs): | ||
self.locs = locs | ||
self.nbins = None | ||
|
||
class StrCategoryFormatter(ticker.Formatter): | ||
def __init__(self, unit_data): | ||
self._unit_data = unit_data | ||
|
||
class StrCategoryFormatter(ticker.FixedFormatter): | ||
def __init__(self, seq): | ||
self.seq = seq | ||
self.offset_string = '' | ||
def __call__(self, x, pos=None): | ||
if pos is None: | ||
return "" | ||
r_mapping = {v: k for k, v in self._unit_data._mapping.items()} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since the mapping is now an ordered dict starting at 0, isn't it really just an array where the index is the mapping? so basically list(self._unit_data.keys())[int(x)] ? (which makes me wonder if that shouldn't just be how the mapping is stored anyway....) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe the self._unit_data. should be a proper dict/hash like object? That has the added benefit of moving the getter function into the object implementation and then if not isinstance(k, six.text_type):
k = k.decode('utf-8')
return axis.units._mapping[k] can be in one place. |
||
return r_mapping.get(int(x), '') | ||
|
||
|
||
class UnitData(object): | ||
# debatable makes sense to special code missing values | ||
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0} | ||
valid_types = tuple(set(six.string_types + | ||
(bytes, six.text_type, np.str_, np.bytes_))) | ||
|
||
def __init__(self, data): | ||
"""Create mapping between unique categorical values | ||
and numerical identifier | ||
def __init__(self, data=None): | ||
"""Create mapping between unique categorical values and numerical id. | ||
|
||
Parameters | ||
---------- | ||
data: iterable | ||
sequence of values | ||
data : Mapping[str, int] | ||
The initial categories. May be `None`. | ||
|
||
""" | ||
self.seq, self.locs = [], [] | ||
self._set_seq_locs(data, 0) | ||
|
||
def update(self, new_data): | ||
# so as not to conflict with spdict | ||
value = max(max(self.locs) + 1, 0) | ||
self._set_seq_locs(new_data, value) | ||
|
||
def _set_seq_locs(self, data, value): | ||
strdata = shim_array(data) | ||
new_s = [d for d in np.unique(strdata) if d not in self.seq] | ||
for ns in new_s: | ||
self.seq.append(ns) | ||
if ns in UnitData.spdict: | ||
self.locs.append(UnitData.spdict[ns]) | ||
else: | ||
self.locs.append(value) | ||
value += 1 | ||
self._vals = [] | ||
if data is None: | ||
data = () | ||
self._mapping = OrderedDict() | ||
for k, v in OrderedDict(data).items(): | ||
if not isinstance(k, self.valid_types): | ||
raise TypeError("{val!r} is not a string".format(val=k)) | ||
if not isinstance(k, six.text_type): | ||
k = k.decode('utf-8') | ||
self._mapping[k] = int(v) | ||
if self._mapping: | ||
start = max(self._mapping.values()) + 1 | ||
else: | ||
start = 0 | ||
self._counter = itertools.count(start=start) | ||
|
||
def update(self, data): | ||
if isinstance(data, self.valid_types): | ||
data = [data] | ||
sorted_unique = OrderedDict.fromkeys(data) | ||
for val in sorted_unique: | ||
if not isinstance(val, self.valid_types): | ||
raise TypeError("{val!r} is not a string".format(val=val)) | ||
if not isinstance(val, six.text_type): | ||
val = val.decode('utf-8') | ||
if val in self._mapping: | ||
continue | ||
self._vals.append(val) | ||
self._mapping[val] = next(self._counter) | ||
|
||
|
||
# Connects the convertor to matplotlib | ||
|
||
units.registry[str] = StrCategoryConverter() | ||
units.registry[np.str_] = StrCategoryConverter() | ||
units.registry[six.text_type] = StrCategoryConverter() | ||
units.registry[bytes] = StrCategoryConverter() | ||
units.registry[np.str_] = StrCategoryConverter() | ||
units.registry[np.bytes_] = StrCategoryConverter() | ||
units.registry[six.text_type] = StrCategoryConverter() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It bugs me that locator and formatter are getting the mapping and there's not a separation, so at the least I think there should be a comment saying that this is because the Formatter and Locator require references rather than objects/copies.