Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Rethink categoricals. #9774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,10 @@ def _xy_from_xy(self, x, y):
if self.axes.xaxis is not None and self.axes.yaxis is not None:
bx = self.axes.xaxis.update_units(x)
by = self.axes.yaxis.update_units(y)

if self.command != 'plot':
# the Line2D class can handle unitized data, with
# support for post hoc unit changes etc. Other mpl
# artists, e.g., Polygon which _process_plot_var_args
# also serves on calls to fill, cannot. So this is a
# hack to say: if you are not "plot", which is
# creating Line2D, then convert the data now to
# floats. If you are plot, pass the raw data through
# to Line2D which will handle the conversion. So
# polygons will not support post hoc conversions of
# the unit type since they are not storing the orig
# data. Hopefully we can rationalize this at a later
# date - JDH
if bx:
x = self.axes.convert_xunits(x)
if by:
y = self.axes.convert_yunits(y)
if bx:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but this breaks line.get_xdata() which will return the original x-data, not the converted...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can manually reattach the original data to the lines, but of course fell into another dragonhole on the way... (#9784).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reattached the unitful data manually.

x = self.axes.convert_xunits(x)
if by:
y = self.axes.convert_yunits(y)

# like asanyarray, but converts scalar to array, and doesn't change
# existing compatible sequences
Expand Down Expand Up @@ -376,26 +362,47 @@ def _plot_args(self, tup, kwargs):
if 'label' not in kwargs or kwargs['label'] is None:
kwargs['label'] = get_label(tup[-1], None)

if len(tup) == 2:
x = _check_1d(tup[0])
y = _check_1d(tup[-1])
if len(tup) == 1:
x, y = index_of(tup[0])
elif len(tup) == 2:
x, y = tup
else:
x, y = index_of(tup[-1])

x, y = self._xy_from_xy(x, y)

if self.command == 'plot':
func = self._makeline
assert False

deunitized_x, deunitized_y = self._xy_from_xy(x, y)
# The previous call has registered the converters, if any, on the axes.
# This check will need to be replaced by a comparison with the
# DefaultConverter when that PR goes in.
if self.axes.xaxis.converter is None or self.command is not "plot":
xt, yt = deunitized_x.T, deunitized_y.T
else:
kw['closed'] = kwargs.get('closed', True)
func = self._makefill

ncx, ncy = x.shape[1], y.shape[1]
# np.asarray would destroy unit information so we need to construct
# the 1D arrays to pass to Line2D.set_xdata manually... (but this
# is only relevant if the command is "plot").

def to_list_of_lists(data):
ndim = np.ndim(data)
if ndim == 0:
return [[data]]
elif ndim == 1:
return [data]
elif ndim == 2:
return zip(*data) # Transpose it.

xt, yt = map(to_list_of_lists, [x, y])

ncx, ncy = deunitized_x.shape[1], deunitized_y.shape[1]
if ncx > 1 and ncy > 1 and ncx != ncy:
cbook.warn_deprecated("2.2", "cycling among columns of inputs "
"with non-matching shapes is deprecated.")
for j in xrange(max(ncx, ncy)):
seg = func(x[:, j % ncx], y[:, j % ncy], kw, kwargs)
if self.command == "plot":
seg = self._makeline(xt[j % ncx], yt[j % ncy], kw, kwargs)
else:
kw['closed'] = kwargs.get('closed', True)
seg = self._makefill(deunitized_x[:, j % ncx],
deunitized_y[:, j % ncy],
kw, kwargs)
ret.append(seg)
return ret

Expand Down
12 changes: 5 additions & 7 deletions lib/matplotlib/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,6 @@ def __init__(self, axes, pickradius=15):
self.offsetText = self._get_offset_text()
self.majorTicks = []
self.minorTicks = []
self.unit_data = None
self.pickradius = pickradius

# Initialize here for testing; later add API
Expand Down Expand Up @@ -720,15 +719,14 @@ def limit_range_for_scale(self, vmin, vmax):
return self._scale.limit_range_for_scale(vmin, vmax, self.get_minpos())

@property
@cbook.deprecated("2.1.1")
def unit_data(self):
"""Holds data that a ConversionInterface subclass uses
to convert between labels and indexes
"""
return self._unit_data
return self.units

@unit_data.setter
def unit_data(self, unit_data):
self._unit_data = unit_data
@cbook.deprecated("2.1.1")
def unit_data(self, value):
self.set_units = value

def get_children(self):
children = [self.label, self.offsetText]
Expand Down
147 changes: 56 additions & 91 deletions lib/matplotlib/category.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,88 @@
# -*- coding: utf-8 OA-*-za
"""
catch all for categorical functions
"""Helpers for categorical data.
"""

from __future__ import (absolute_import, division, print_function,
unicode_literals)
import six

import numpy as np

import matplotlib.units as units
import matplotlib.ticker as ticker

# np 1.6/1.7 support
from distutils.version import LooseVersion
import collections
from collections import OrderedDict
import itertools

import numpy as np

if LooseVersion(np.__version__) >= LooseVersion('1.8.0'):
def shim_array(data):
return np.array(data, dtype=np.unicode)
else:
def shim_array(data):
if (isinstance(data, six.string_types) or
not isinstance(data, collections.Iterable)):
data = [data]
try:
data = [str(d) for d in data]
except UnicodeEncodeError:
# this yields gibberish but unicode text doesn't
# render under numpy1.6 anyway
data = [d.encode('utf-8', 'ignore').decode('utf-8')
for d in data]
return np.array(data, dtype=np.unicode)
from matplotlib import cbook, ticker, units


class StrCategoryConverter(units.ConversionInterface):
@staticmethod
def convert(value, unit, axis):
"""Uses axis.unit_data map to encode
data as floats
"""
value = np.atleast_1d(value)
# try and update from here....
if hasattr(axis.unit_data, 'update'):
for val in value:
if isinstance(val, six.string_types):
axis.unit_data.update(val)
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))

if isinstance(value, six.string_types):
return vmap[value]

vals = shim_array(value)

for lab, loc in vmap.items():
vals[vals == lab] = loc

return vals.astype('float')
"""Encode data as floats."""
# We also need to pass numbers through.
if np.issubdtype(np.asarray(value).dtype.type, np.number):
return value
else:
unit.update(value)
return np.vectorize(unit._val_to_idx.__getitem__)(value)

@staticmethod
def axisinfo(unit, axis):
majloc = StrCategoryLocator(axis.unit_data.locs)
majfmt = StrCategoryFormatter(axis.unit_data.seq)
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
# Note that mapping may get mutated by later calls to plotting methods,
# so the locator and formatter must dynamically recompute locs and seq.
return units.AxisInfo(
majloc=StrCategoryLocator(unit),
majfmt=StrCategoryFormatter(unit))

@staticmethod
def default_units(data, axis):
# the conversion call stack is:
# default_units->axis_info->convert
if axis.unit_data is None:
axis.unit_data = UnitData(data)
else:
axis.unit_data.update(data)
return None
return UnitData()


class StrCategoryLocator(ticker.FixedLocator):
def __init__(self, locs):
self.locs = locs
self.nbins = None
class StrCategoryLocator(ticker.Locator):
def __init__(self, unit_data):
self._unit_data = unit_data

def __call__(self):
return list(self._unit_data._val_to_idx.values())

class StrCategoryFormatter(ticker.FixedFormatter):
def __init__(self, seq):
self.seq = seq
self.offset_string = ''

class StrCategoryFormatter(ticker.Formatter):
def __init__(self, unit_data):
self._unit_data = unit_data

class UnitData(object):
# debatable makes sense to special code missing values
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
def __call__(self, x, pos=None):
if pos in range(len(self._unit_data._vals)):
s = self._unit_data._vals[pos]
if isinstance(s, bytes):
s = s.decode("utf-8")
return s
else:
return ""

def __init__(self, data):
"""Create mapping between unique categorical values
and numerical identifier

Parameters
----------
data: iterable
sequence of values
class UnitData(object):
def __init__(self, data=()):
"""Create mapping between unique categorical values and numerical id.
"""
self.seq, self.locs = [], []
self._set_seq_locs(data, 0)
self._vals = []
self._val_to_idx = OrderedDict()
self._counter = itertools.count()
if np.size(data):
cbook.warn_deprecated(
"2.1.1",
"Passing data to the UnitData constructor is deprecated.")
self.update(data)

def update(self, new_data):
# so as not to conflict with spdict
value = max(max(self.locs) + 1, 0)
self._set_seq_locs(new_data, value)

def _set_seq_locs(self, data, value):
strdata = shim_array(data)
new_s = [d for d in np.unique(strdata) if d not in self.seq]
for ns in new_s:
self.seq.append(ns)
if ns in UnitData.spdict:
self.locs.append(UnitData.spdict[ns])
else:
self.locs.append(value)
value += 1
if isinstance(new_data, six.string_types):
new_data = [new_data]
sorted_unique = OrderedDict.fromkeys(new_data)
for val in sorted_unique:
if val in self._val_to_idx:
continue
if not isinstance(val, (six.text_type, six.binary_type)):
raise TypeError("Not a string")
self._vals.append(val)
self._val_to_idx[val] = next(self._counter)


# Connects the convertor to matplotlib
Expand Down
4 changes: 2 additions & 2 deletions lib/matplotlib/cbook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,8 +2295,8 @@ def index_of(y):
try:
return y.index.values, y.values
except AttributeError:
y = _check_1d(y)
return np.arange(y.shape[0], dtype=float), y
# Ensure that scalar y gives x == [0].
return np.arange((np.shape(y) or (1,))[0], dtype=float), y


def safe_first_element(obj):
Expand Down
Loading