Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 34b8eb4

Browse files
committed
category bug fix + new tests + refactor
1 parent 4c9353f commit 34b8eb4

File tree

5 files changed

+361
-292
lines changed

5 files changed

+361
-292
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5892,7 +5892,7 @@ def hist(self, x, bins=None, range=None, density=None, weights=None,
58925892
Parameters
58935893
----------
58945894
x : (n,) array or sequence of (n,) arrays
5895-
Input values, this takes either a single array or a sequency of
5895+
Input values, this takes either a single array or a sequence of
58965896
arrays which are not required to be of the same length
58975897
58985898
bins : integer or sequence or 'auto', optional
@@ -6104,30 +6104,31 @@ def hist(self, x, bins=None, range=None, density=None, weights=None,
61046104
"Please only use 'density', since 'normed'"
61056105
"will be deprecated.")
61066106

6107-
# process the unit information
6108-
self._process_unit_info(xdata=x, kwargs=kwargs)
6109-
x = self.convert_xunits(x)
6110-
if bin_range is not None:
6111-
bin_range = self.convert_xunits(bin_range)
6112-
6113-
# Check whether bins or range are given explicitly.
6114-
binsgiven = (cbook.iterable(bins) or bin_range is not None)
6115-
61166107
# basic input validation
61176108
input_empty = np.size(x) == 0
6118-
61196109
# Massage 'x' for processing.
61206110
if input_empty:
6121-
x = np.array([[]])
6111+
x = [np.array([])]
61226112
else:
61236113
x = cbook._reshape_2D(x, 'x')
61246114
nx = len(x) # number of datasets
61256115

6116+
# Process unit information
6117+
# Unit conversion is done individually on each dataset
6118+
self._process_unit_info(xdata=x[0], kwargs=kwargs)
6119+
x = [self.convert_xunits(xi) for xi in x]
6120+
6121+
if bin_range is not None:
6122+
bin_range = self.convert_xunits(bin_range)
6123+
6124+
# Check whether bins or range are given explicitly.
6125+
binsgiven = (cbook.iterable(bins) or bin_range is not None)
6126+
61266127
# We need to do to 'weights' what was done to 'x'
61276128
if weights is not None:
61286129
w = cbook._reshape_2D(weights, 'weights')
61296130
else:
6130-
w = [None]*nx
6131+
w = [None] * nx
61316132

61326133
if len(w) != nx:
61336134
raise ValueError('weights should have the same shape as x')

lib/matplotlib/axis.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,6 @@ def __init__(self, axes, pickradius=15):
668668
self.offsetText = self._get_offset_text()
669669
self.majorTicks = []
670670
self.minorTicks = []
671-
self.unit_data = None
672671
self.pickradius = pickradius
673672

674673
# Initialize here for testing; later add API
@@ -720,15 +719,14 @@ def limit_range_for_scale(self, vmin, vmax):
720719
return self._scale.limit_range_for_scale(vmin, vmax, self.get_minpos())
721720

722721
@property
722+
@cbook.deprecated("2.1.1")
723723
def unit_data(self):
724-
"""Holds data that a ConversionInterface subclass uses
725-
to convert between labels and indexes
726-
"""
727-
return self._unit_data
724+
return self._units
728725

729726
@unit_data.setter
727+
@cbook.deprecated("2.1.1")
730728
def unit_data(self, unit_data):
731-
self._unit_data = unit_data
729+
self.set_units = unit_data
732730

733731
def get_children(self):
734732
children = [self.label, self.offsetText]

lib/matplotlib/category.py

Lines changed: 117 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
# -*- coding: utf-8 OA-*-za
1+
# -*- coding: utf-8 -*-
22
"""
33
catch all for categorical functions
44
"""
55
from __future__ import (absolute_import, division, print_function,
66
unicode_literals)
7+
8+
from collections import Iterable, OrderedDict
9+
import itertools
10+
711
import six
812

913
import numpy as np
@@ -13,111 +17,149 @@
1317

1418
# np 1.6/1.7 support
1519
from distutils.version import LooseVersion
16-
import collections
1720

21+
VALID_TYPES = tuple(set(six.string_types +
22+
(bytes, six.text_type, np.str_, np.bytes_)))
1823

19-
if LooseVersion(np.__version__) >= LooseVersion('1.8.0'):
20-
def shim_array(data):
21-
return np.array(data, dtype=np.unicode)
22-
else:
23-
def shim_array(data):
24-
if (isinstance(data, six.string_types) or
25-
not isinstance(data, collections.Iterable)):
26-
data = [data]
27-
try:
28-
data = [str(d) for d in data]
29-
except UnicodeEncodeError:
30-
# this yields gibberish but unicode text doesn't
31-
# render under numpy1.6 anyway
32-
data = [d.encode('utf-8', 'ignore').decode('utf-8')
33-
for d in data]
34-
return np.array(data, dtype=np.unicode)
24+
25+
def to_str(value):
26+
"""Helper function to turn values to strings.
27+
"""
28+
# Note: This function is only used by StrCategoryFormatter
29+
if LooseVersion(np.__version__) < LooseVersion('1.7.0'):
30+
if (isinstance(value, (six.text_type, np.unicode))):
31+
value = value.encode('utf-8', 'ignore').decode('utf-8')
32+
if isinstance(value, (np.bytes_, six.binary_type)):
33+
value = value.decode(encoding='utf-8')
34+
elif not isinstance(value, (np.str_, six.string_types)):
35+
value = str(value)
36+
return value
3537

3638

3739
class StrCategoryConverter(units.ConversionInterface):
3840
@staticmethod
3941
def convert(value, unit, axis):
40-
"""Uses axis.unit_data map to encode
41-
data as floats
42+
"""Uses axis.units to encode string data as floats
43+
44+
Parameters
45+
----------
46+
value: string, iterable
47+
value or list of values to plot
48+
unit:
49+
axis:
4250
"""
43-
value = np.atleast_1d(value)
44-
# try and update from here....
45-
if hasattr(axis.unit_data, 'update'):
46-
for val in value:
47-
if isinstance(val, six.string_types):
48-
axis.unit_data.update(val)
49-
vmap = dict(zip(axis.unit_data.seq, axis.unit_data.locs))
51+
# dtype = object preserves numerical pass throughs
52+
values = np.atleast_1d(np.array(value, dtype=object))
5053

51-
if isinstance(value, six.string_types):
52-
return vmap[value]
54+
# pass through sequence of non binary numbers
55+
if all((units.ConversionInterface.is_numlike(v) and
56+
not isinstance(v, VALID_TYPES)) for v in values):
57+
return np.asarray(values, dtype=float)
5358

54-
vals = shim_array(value)
59+
# force an update so it also does type checking
60+
axis.units.update(values)
5561

56-
for lab, loc in vmap.items():
57-
vals[vals == lab] = loc
62+
str2idx = np.vectorize(axis.units._mapping.__getitem__,
63+
otypes=[float])
5864

59-
return vals.astype('float')
65+
mapped_value = str2idx(values)
66+
return mapped_value
6067

6168
@staticmethod
6269
def axisinfo(unit, axis):
63-
majloc = StrCategoryLocator(axis.unit_data.locs)
64-
majfmt = StrCategoryFormatter(axis.unit_data.seq)
70+
"""Sets the axis ticks and labels
71+
"""
72+
# locator and formatter take mapping dict because
73+
# args need to be pass by reference for updates
74+
majloc = StrCategoryLocator(axis.units)
75+
majfmt = StrCategoryFormatter(axis.units)
6576
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
6677

6778
@staticmethod
68-
def default_units(data, axis):
69-
# the conversion call stack is:
79+
def default_units(data=None, axis=None):
80+
# the conversion call stack is supposed to be
7081
# default_units->axis_info->convert
71-
if axis.unit_data is None:
72-
axis.unit_data = UnitData(data)
82+
if axis.units is None:
83+
axis.set_units(UnitData(data))
7384
else:
74-
axis.unit_data.update(data)
75-
return None
85+
axis.units.update(data)
86+
return axis.units
7687

7788

78-
class StrCategoryLocator(ticker.FixedLocator):
79-
def __init__(self, locs):
80-
self.locs = locs
81-
self.nbins = None
89+
class StrCategoryLocator(ticker.Locator):
90+
"""tick at every integer mapping of the string data"""
91+
def __init__(self, units):
92+
"""
93+
Parameters
94+
-----------
95+
units: dict
96+
(string, integer) mapping
97+
"""
98+
self._units = units
8299

100+
def __call__(self):
101+
return list(self._units._mapping.values())
83102

84-
class StrCategoryFormatter(ticker.FixedFormatter):
85-
def __init__(self, seq):
86-
self.seq = seq
87-
self.offset_string = ''
103+
def tick_values(self, vmin, vmax):
104+
return self()
88105

89106

90-
class UnitData(object):
91-
# debatable makes sense to special code missing values
92-
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
107+
class StrCategoryFormatter(ticker.Formatter):
108+
"""String representation of the data at every tick"""
109+
def __init__(self, units):
110+
"""
111+
Parameters
112+
----------
113+
units: dict
114+
(string, integer) mapping
115+
"""
116+
self._units = units
93117

94-
def __init__(self, data):
95-
"""Create mapping between unique categorical values
96-
and numerical identifier
118+
def __call__(self, x, pos=None):
119+
if pos is None:
120+
return ""
121+
r_mapping = {v: to_str(k) for k, v in self._units._mapping.items()}
122+
return r_mapping.get(int(np.round(x)), '')
97123

98-
Parameters
124+
125+
class UnitData(object):
126+
def __init__(self, data=None):
127+
"""Create mapping between unique categorical values
128+
and integer identifiers
99129
----------
100130
data: iterable
101-
sequence of values
131+
sequence of string values
132+
"""
133+
if data is None:
134+
data = ()
135+
self._mapping = OrderedDict()
136+
self._counter = itertools.count(start=0)
137+
self.update(data)
138+
139+
def update(self, data):
140+
"""Maps new values to integer identifiers.
141+
142+
Paramters
143+
---------
144+
data: iterable
145+
sequence of string values
146+
147+
Raises
148+
------
149+
TypeError
150+
If the value in data is not a string, unicode, bytes type
102151
"""
103-
self.seq, self.locs = [], []
104-
self._set_seq_locs(data, 0)
105-
106-
def update(self, new_data):
107-
# so as not to conflict with spdict
108-
value = max(max(self.locs) + 1, 0)
109-
self._set_seq_locs(new_data, value)
110-
111-
def _set_seq_locs(self, data, value):
112-
strdata = shim_array(data)
113-
new_s = [d for d in np.unique(strdata) if d not in self.seq]
114-
for ns in new_s:
115-
self.seq.append(ns)
116-
if ns in UnitData.spdict:
117-
self.locs.append(UnitData.spdict[ns])
118-
else:
119-
self.locs.append(value)
120-
value += 1
152+
153+
if (isinstance(data, VALID_TYPES) or
154+
not isinstance(data, Iterable)):
155+
data = [data]
156+
157+
unsorted_unique = OrderedDict.fromkeys(data)
158+
for val in unsorted_unique:
159+
if not isinstance(val, VALID_TYPES):
160+
raise TypeError("{val!r} is not a string".format(val=val))
161+
if val not in self._mapping:
162+
self._mapping[val] = next(self._counter)
121163

122164

123165
# Connects the convertor to matplotlib

lib/matplotlib/tests/test_axes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,14 @@ def test_hist_unequal_bins_density():
15921592
assert_allclose(mpl_heights, np_heights)
15931593

15941594

1595+
def test_hist_datetime_datasets():
1596+
data = [[datetime.datetime(2017, 1, 1), datetime.datetime(2017, 1, 1)],
1597+
[datetime.datetime(2017, 1, 1), datetime.datetime(2017, 1, 2)]]
1598+
fig, ax = plt.subplots()
1599+
ax.hist(data, stacked=True)
1600+
ax.hist(data, stacked=False)
1601+
1602+
15951603
def contour_dat():
15961604
x = np.linspace(-3, 5, 150)
15971605
y = np.linspace(-3, 5, 120)

0 commit comments

Comments
 (0)