From 976a0452c353547936d17078162ff46938085179 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sun, 5 Apr 2015 12:45:29 +0900 Subject: [PATCH] BUG: Repeated time-series plot causes memory leak --- doc/source/whatsnew/v0.17.0.txt | 2 +- pandas/tests/test_graphics.py | 30 ++ pandas/tools/plotting.py | 535 +++++++++++++------------- pandas/tseries/plotting.py | 151 ++++++-- pandas/tseries/tests/test_plotting.py | 39 ++ 5 files changed, 440 insertions(+), 317 deletions(-) diff --git a/doc/source/whatsnew/v0.17.0.txt b/doc/source/whatsnew/v0.17.0.txt index d59b6120163ff..287963ed2c825 100644 --- a/doc/source/whatsnew/v0.17.0.txt +++ b/doc/source/whatsnew/v0.17.0.txt @@ -381,7 +381,7 @@ Bug Fixes - Bug in ``Series.plot(kind='hist')`` Y Label not informative (:issue:`10485`) - Bug in ``read_csv`` when using a converter which generates a ``uint8`` type (:issue:`9266`) - +- Bug causes memory leak in time-series line and area plot (:issue:`9003`) - Bug in line and kde plot cannot accept multiple colors when ``subplots=True`` (:issue:`9894`) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 800c6f83f4902..3271493f59219 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -3281,6 +3281,36 @@ def test_sharey_and_ax(self): self.assertTrue(ax.yaxis.get_label().get_visible(), "y label is invisible but shouldn't") + def test_memory_leak(self): + """ Check that every plot type gets properly collected. """ + import weakref + import gc + + results = {} + for kind in plotting._plot_klass.keys(): + args = {} + if kind in ['hexbin', 'scatter', 'pie']: + df = self.hexbin_df + args = {'x': 'A', 'y': 'B'} + elif kind == 'area': + df = self.tdf.abs() + else: + df = self.tdf + + # Use a weakref so we can see if the object gets collected without + # also preventing it from being collected + results[kind] = weakref.proxy(df.plot(kind=kind, **args)) + + # have matplotlib delete all the figures + tm.close() + # force a garbage collection + gc.collect() + for key in results: + # check that every plot was collected + with tm.assertRaises(ReferenceError): + # need to actually access something to get an error + results[key].lines + @slow def test_df_grid_settings(self): # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792 diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 6a822a0231a2b..c16e2686c5a3a 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -774,6 +774,7 @@ class MPLPlot(object): data : """ + _kind = 'base' _layout_type = 'vertical' _default_rot = 0 orientation = None @@ -830,10 +831,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=None, self._rot_set = True else: self._rot_set = False - if isinstance(self._default_rot, dict): - self.rot = self._default_rot[self.kind] - else: - self.rot = self._default_rot + self.rot = self._default_rot if grid is None: grid = False if secondary_y else self.plt.rcParams['axes.grid'] @@ -1217,34 +1215,25 @@ def _get_xticks(self, convert_period=False): return x - def _is_datetype(self): - index = self.data.index - return (isinstance(index, (PeriodIndex, DatetimeIndex)) or - index.inferred_type in ('datetime', 'date', 'datetime64', - 'time')) + @classmethod + def _plot(cls, ax, x, y, style=None, is_errorbar=False, **kwds): + mask = com.isnull(y) + if mask.any(): + y = np.ma.array(y) + y = np.ma.masked_where(mask, y) - def _get_plot_function(self): - ''' - Returns the matplotlib plotting function (plot or errorbar) based on - the presence of errorbar keywords. - ''' - errorbar = any(e is not None for e in self.errors.values()) - def plotf(ax, x, y, style=None, **kwds): - mask = com.isnull(y) - if mask.any(): - y = np.ma.array(y) - y = np.ma.masked_where(mask, y) - - if errorbar: - return self.plt.Axes.errorbar(ax, x, y, **kwds) + if isinstance(x, Index): + x = x._mpl_repr() + + if is_errorbar: + return ax.errorbar(x, y, **kwds) + else: + # prevent style kwarg from going to errorbar, where it is unsupported + if style is not None: + args = (x, y, style) else: - # prevent style kwarg from going to errorbar, where it is unsupported - if style is not None: - args = (ax, x, y, style) - else: - args = (ax, x, y) - return self.plt.Axes.plot(*args, **kwds) - return plotf + args = (x, y) + return ax.plot(*args, **kwds) def _get_index_name(self): if isinstance(self.data.index, MultiIndex): @@ -1431,6 +1420,7 @@ def _get_axes_layout(self): return (len(y_set), len(x_set)) class ScatterPlot(MPLPlot): + _kind = 'scatter' _layout_type = 'single' def __init__(self, data, x, y, c=None, **kwargs): @@ -1509,6 +1499,7 @@ def _post_plot_logic(self): class HexBinPlot(MPLPlot): + _kind = 'hexbin' _layout_type = 'single' def __init__(self, data, x, y, C=None, **kwargs): @@ -1564,7 +1555,7 @@ def _post_plot_logic(self): class LinePlot(MPLPlot): - + _kind = 'line' _default_rot = 0 orientation = 'vertical' @@ -1576,65 +1567,30 @@ def __init__(self, data, **kwargs): if 'x_compat' in self.kwds: self.x_compat = bool(self.kwds.pop('x_compat')) - def _index_freq(self): - freq = getattr(self.data.index, 'freq', None) - if freq is None: - freq = getattr(self.data.index, 'inferred_freq', None) - if freq == 'B': - weekdays = np.unique(self.data.index.dayofweek) - if (5 in weekdays) or (6 in weekdays): - freq = None - return freq - - def _is_dynamic_freq(self, freq): - if isinstance(freq, DateOffset): - freq = freq.rule_code - else: - freq = frequencies.get_base_alias(freq) - freq = frequencies.get_period_alias(freq) - return freq is not None and self._no_base(freq) - - def _no_base(self, freq): - # hack this for 0.10.1, creating more technical debt...sigh - if isinstance(self.data.index, DatetimeIndex): - base = frequencies.get_freq(freq) - x = self.data.index - if (base <= frequencies.FreqGroup.FR_DAY): - return x[:1].is_normalized - - return Period(x[0], freq).to_timestamp(tz=x.tz) == x[0] - return True - - def _use_dynamic_x(self): - freq = self._index_freq() - - ax = self._get_ax(0) - ax_freq = getattr(ax, 'freq', None) - if freq is None: # convert irregular if axes has freq info - freq = ax_freq - else: # do not use tsplot if irregular was plotted first - if (ax_freq is None) and (len(ax.get_lines()) > 0): - return False - - return (freq is not None) and self._is_dynamic_freq(freq) - def _is_ts_plot(self): # this is slightly deceptive return not self.x_compat and self.use_index and self._use_dynamic_x() - def _make_plot(self): - self._initialize_prior(len(self.data)) + def _use_dynamic_x(self): + from pandas.tseries.plotting import _use_dynamic_x + return _use_dynamic_x(self._get_ax(0), self.data) + def _make_plot(self): if self._is_ts_plot(): - data = self._maybe_convert_index(self.data) + from pandas.tseries.plotting import _maybe_convert_index + data = _maybe_convert_index(self._get_ax(0), self.data) + x = data.index # dummy, not used - plotf = self._get_ts_plot_function() + plotf = self._ts_plot it = self._iter_data(data=data, keep_index=True) else: x = self._get_xticks(convert_period=True) - plotf = self._get_plot_function() + plotf = self._plot it = self._iter_data() + stacking_id = self._get_stacking_id() + is_errorbar = any(e is not None for e in self.errors.values()) + colors = self._get_colors() for i, (label, y) in enumerate(it): ax = self._get_ax(i) @@ -1647,84 +1603,87 @@ def _make_plot(self): label = com.pprint_thing(label) # .encode('utf-8') kwds['label'] = label - newlines = plotf(ax, x, y, style=style, column_num=i, **kwds) + newlines = plotf(ax, x, y, style=style, column_num=i, + stacking_id=stacking_id, + is_errorbar=is_errorbar, + **kwds) self._add_legend_handle(newlines[0], label, index=i) lines = _get_all_lines(ax) left, right = _get_xlim(lines) ax.set_xlim(left, right) - def _get_stacked_values(self, y, label): + @classmethod + def _plot(cls, ax, x, y, style=None, column_num=None, + stacking_id=None, **kwds): + # column_num is used to get the target column from protf in line and area plots + if column_num == 0: + cls._initialize_stacker(ax, stacking_id, len(y)) + y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label']) + lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds) + cls._update_stacker(ax, stacking_id, y) + return lines + + @classmethod + def _ts_plot(cls, ax, x, data, style=None, **kwds): + from pandas.tseries.plotting import (_maybe_resample, + _decorate_axes, + format_dateaxis) + # accept x to be consistent with normal plot func, + # x is not passed to tsplot as it uses data.index as x coordinate + # column_num must be in kwds for stacking purpose + freq, data = _maybe_resample(data, ax, kwds) + + # Set ax with freq info + _decorate_axes(ax, freq, kwds) + ax._plot_data.append((data, cls._kind, kwds)) + + lines = cls._plot(ax, data.index, data.values, style=style, **kwds) + # set date formatter, locators and rescale limits + format_dateaxis(ax, ax.freq) + return lines + + def _get_stacking_id(self): if self.stacked: - if (y >= 0).all(): - return self._pos_prior + y - elif (y <= 0).all(): - return self._neg_prior + y - else: - raise ValueError('When stacked is True, each column must be either all positive or negative.' - '{0} contains both positive and negative values'.format(label)) + return id(self.data) else: - return y - - def _get_plot_function(self): - f = MPLPlot._get_plot_function(self) - def plotf(ax, x, y, style=None, column_num=None, **kwds): - # column_num is used to get the target column from protf in line and area plots - if column_num == 0: - self._initialize_prior(len(self.data)) - y_values = self._get_stacked_values(y, kwds['label']) - lines = f(ax, x, y_values, style=style, **kwds) - self._update_prior(y) - return lines - return plotf - - def _get_ts_plot_function(self): - from pandas.tseries.plotting import tsplot - plotf = self._get_plot_function() - def _plot(ax, x, data, style=None, **kwds): - # accept x to be consistent with normal plot func, - # x is not passed to tsplot as it uses data.index as x coordinate - lines = tsplot(data, plotf, ax=ax, style=style, **kwds) - return lines - return _plot - - def _initialize_prior(self, n): - self._pos_prior = np.zeros(n) - self._neg_prior = np.zeros(n) - - def _update_prior(self, y): - if self.stacked and not self.subplots: - # tsplot resample may changedata length - if len(self._pos_prior) != len(y): - self._initialize_prior(len(y)) - if (y >= 0).all(): - self._pos_prior += y - elif (y <= 0).all(): - self._neg_prior += y - - def _maybe_convert_index(self, data): - # tsplot converts automatically, but don't want to convert index - # over and over for DataFrames - if isinstance(data.index, DatetimeIndex): - freq = getattr(data.index, 'freq', None) - - if freq is None: - freq = getattr(data.index, 'inferred_freq', None) - if isinstance(freq, DateOffset): - freq = freq.rule_code - - if freq is None: - ax = self._get_ax(0) - freq = getattr(ax, 'freq', None) - - if freq is None: - raise ValueError('Could not get frequency alias for plotting') - - freq = frequencies.get_base_alias(freq) - freq = frequencies.get_period_alias(freq) - - data.index = data.index.to_period(freq=freq) - return data + return None + + @classmethod + def _initialize_stacker(cls, ax, stacking_id, n): + if stacking_id is None: + return + if not hasattr(ax, '_stacker_pos_prior'): + ax._stacker_pos_prior = {} + if not hasattr(ax, '_stacker_neg_prior'): + ax._stacker_neg_prior = {} + ax._stacker_pos_prior[stacking_id] = np.zeros(n) + ax._stacker_neg_prior[stacking_id] = np.zeros(n) + + @classmethod + def _get_stacked_values(cls, ax, stacking_id, values, label): + if stacking_id is None: + return values + if not hasattr(ax, '_stacker_pos_prior'): + # stacker may not be initialized for subplots + cls._initialize_stacker(ax, stacking_id, len(values)) + + if (values >= 0).all(): + return ax._stacker_pos_prior[stacking_id] + values + elif (values <= 0).all(): + return ax._stacker_neg_prior[stacking_id] + values + + raise ValueError('When stacked is True, each column must be either all positive or negative.' + '{0} contains both positive and negative values'.format(label)) + + @classmethod + def _update_stacker(cls, ax, stacking_id, values): + if stacking_id is None: + return + if (values >= 0).all(): + ax._stacker_pos_prior[stacking_id] += values + elif (values <= 0).all(): + ax._stacker_neg_prior[stacking_id] += values def _post_plot_logic(self): df = self.data @@ -1749,6 +1708,7 @@ def _post_plot_logic(self): class AreaPlot(LinePlot): + _kind = 'area' def __init__(self, data, **kwargs): kwargs.setdefault('stacked', True) @@ -1759,35 +1719,36 @@ def __init__(self, data, **kwargs): # use smaller alpha to distinguish overlap self.kwds.setdefault('alpha', 0.5) - def _get_plot_function(self): if self.logy or self.loglog: raise ValueError("Log-y scales are not supported in area plot") - else: - f = MPLPlot._get_plot_function(self) - def plotf(ax, x, y, style=None, column_num=None, **kwds): - if column_num == 0: - self._initialize_prior(len(self.data)) - y_values = self._get_stacked_values(y, kwds['label']) - lines = f(ax, x, y_values, style=style, **kwds) - - # get data from the line to get coordinates for fill_between - xdata, y_values = lines[0].get_data(orig=False) - - if (y >= 0).all(): - start = self._pos_prior - elif (y <= 0).all(): - start = self._neg_prior - else: - start = np.zeros(len(y)) - if not 'color' in kwds: - kwds['color'] = lines[0].get_color() + @classmethod + def _plot(cls, ax, x, y, style=None, column_num=None, + stacking_id=None, is_errorbar=False, **kwds): + if column_num == 0: + cls._initialize_stacker(ax, stacking_id, len(y)) + y_values = cls._get_stacked_values(ax, stacking_id, y, kwds['label']) + lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds) + + # get data from the line to get coordinates for fill_between + xdata, y_values = lines[0].get_data(orig=False) + + # unable to use ``_get_stacked_values`` here to get starting point + if stacking_id is None: + start = np.zeros(len(y)) + elif (y >= 0).all(): + start = ax._stacker_pos_prior[stacking_id] + elif (y <= 0).all(): + start = ax._stacker_neg_prior[stacking_id] + else: + start = np.zeros(len(y)) - self.plt.Axes.fill_between(ax, xdata, start, y_values, **kwds) - self._update_prior(y) - return lines + if not 'color' in kwds: + kwds['color'] = lines[0].get_color() - return plotf + ax.fill_between(xdata, start, y_values, **kwds) + cls._update_stacker(ax, stacking_id, y) + return lines def _add_legend_handle(self, handle, label, index=None): from matplotlib.patches import Rectangle @@ -1810,8 +1771,9 @@ def _post_plot_logic(self): class BarPlot(MPLPlot): - - _default_rot = {'bar': 90, 'barh': 0} + _kind = 'bar' + _default_rot = 90 + orientation = 'vertical' def __init__(self, data, **kwargs): self.bar_width = kwargs.pop('width', 0.5) @@ -1848,20 +1810,13 @@ def _args_adjust(self): if com.is_list_like(self.left): self.left = np.array(self.left) - def _get_plot_function(self): - if self.kind == 'bar': - def f(ax, x, y, w, start=None, **kwds): - start = start + self.bottom - return ax.bar(x, y, w, bottom=start, log=self.log, **kwds) - elif self.kind == 'barh': - - def f(ax, x, y, w, start=None, log=self.log, **kwds): - start = start + self.left - return ax.barh(x, y, w, left=start, log=self.log, **kwds) - else: - raise ValueError("BarPlot kind must be either 'bar' or 'barh'") + @classmethod + def _plot(cls, ax, x, y, w, start=0, log=False, **kwds): + return ax.bar(x, y, w, bottom=start, log=log, **kwds) - return f + @property + def _start_base(self): + return self.bottom def _make_plot(self): import matplotlib as mpl @@ -1869,7 +1824,6 @@ def _make_plot(self): colors = self._get_colors() ncolors = len(colors) - bar_f = self._get_plot_function() pos_prior = neg_prior = np.zeros(len(self.data)) K = self.nseries @@ -1890,24 +1844,25 @@ def _make_plot(self): start = 0 if self.log and (y >= 1).all(): start = 1 + start = start + self._start_base if self.subplots: w = self.bar_width / 2 - rect = bar_f(ax, self.ax_pos + w, y, self.bar_width, - start=start, label=label, **kwds) + rect = self._plot(ax, self.ax_pos + w, y, self.bar_width, + start=start, label=label, log=self.log, **kwds) ax.set_title(label) elif self.stacked: mask = y > 0 - start = np.where(mask, pos_prior, neg_prior) + start = np.where(mask, pos_prior, neg_prior) + self._start_base w = self.bar_width / 2 - rect = bar_f(ax, self.ax_pos + w, y, self.bar_width, - start=start, label=label, **kwds) + rect = self._plot(ax, self.ax_pos + w, y, self.bar_width, + start=start, label=label, log=self.log, **kwds) pos_prior = pos_prior + np.where(mask, y, 0) neg_prior = neg_prior + np.where(mask, 0, y) else: w = self.bar_width / K - rect = bar_f(ax, self.ax_pos + (i + 0.5) * w, y, w, - start=start, label=label, **kwds) + rect = self._plot(ax, self.ax_pos + (i + 0.5) * w, y, w, + start=start, label=label, log=self.log, **kwds) self._add_legend_handle(rect, label, index=i) def _post_plot_logic(self): @@ -1922,33 +1877,40 @@ def _post_plot_logic(self): s_edge = self.ax_pos[0] - 0.25 + self.lim_offset e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset - if self.kind == 'bar': - ax.set_xlim((s_edge, e_edge)) - ax.set_xticks(self.tick_pos) - ax.set_xticklabels(str_index) - if name is not None and self.use_index: - ax.set_xlabel(name) - elif self.kind == 'barh': - # horizontal bars - ax.set_ylim((s_edge, e_edge)) - ax.set_yticks(self.tick_pos) - ax.set_yticklabels(str_index) - if name is not None and self.use_index: - ax.set_ylabel(name) - else: - raise NotImplementedError(self.kind) + self._decorate_ticks(ax, name, str_index, s_edge, e_edge) + + def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge): + ax.set_xlim((start_edge, end_edge)) + ax.set_xticks(self.tick_pos) + ax.set_xticklabels(ticklabels) + if name is not None and self.use_index: + ax.set_xlabel(name) + + +class BarhPlot(BarPlot): + _kind = 'barh' + _default_rot = 0 + orientation = 'horizontal' @property - def orientation(self): - if self.kind == 'bar': - return 'vertical' - elif self.kind == 'barh': - return 'horizontal' - else: - raise NotImplementedError(self.kind) + def _start_base(self): + return self.left + + @classmethod + def _plot(cls, ax, x, y, w, start=0, log=False, **kwds): + return ax.barh(x, y, w, left=start, log=log, **kwds) + + def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge): + # horizontal bars + ax.set_ylim((start_edge, end_edge)) + ax.set_yticks(self.tick_pos) + ax.set_yticklabels(ticklabels) + if name is not None and self.use_index: + ax.set_ylabel(name) class HistPlot(LinePlot): + _kind = 'hist' def __init__(self, data, bins=10, bottom=0, **kwargs): self.bins = bins # use mpl default @@ -1971,22 +1933,24 @@ def _args_adjust(self): if com.is_list_like(self.bottom): self.bottom = np.array(self.bottom) - def _get_plot_function(self): - def plotf(ax, y, style=None, column_num=None, **kwds): - if column_num == 0: - self._initialize_prior(len(self.bins) - 1) - y = y[~com.isnull(y)] - bottom = self._pos_prior + self.bottom - # ignore style - n, bins, patches = self.plt.Axes.hist(ax, y, bins=self.bins, - bottom=bottom, **kwds) - self._update_prior(n) - return patches - return plotf + @classmethod + def _plot(cls, ax, y, style=None, bins=None, bottom=0, column_num=0, + stacking_id=None, **kwds): + if column_num == 0: + cls._initialize_stacker(ax, stacking_id, len(bins) - 1) + y = y[~com.isnull(y)] + + base = np.zeros(len(bins) - 1) + bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds['label']) + # ignore style + n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds) + cls._update_stacker(ax, stacking_id, n) + return patches def _make_plot(self): - plotf = self._get_plot_function() colors = self._get_colors() + stacking_id = self._get_stacking_id() + for i, (label, y) in enumerate(self._iter_data()): ax = self._get_ax(i) @@ -1999,9 +1963,18 @@ def _make_plot(self): if style is not None: kwds['style'] = style - artists = plotf(ax, y, column_num=i, **kwds) + kwds = self._make_plot_keywords(kwds, y) + artists = self._plot(ax, y, column_num=i, + stacking_id=stacking_id, **kwds) self._add_legend_handle(artists[0], label, index=i) + def _make_plot_keywords(self, kwds, y): + """merge BoxPlot/KdePlot properties to passed kwds""" + # y is required for KdePlot + kwds['bottom'] = self.bottom + kwds['bins'] = self.bins + return kwds + def _post_plot_logic(self): if self.orientation == 'horizontal': for ax in self.axes: @@ -2019,6 +1992,7 @@ def orientation(self): class KdePlot(HistPlot): + _kind = 'kde' orientation = 'vertical' def __init__(self, data, bw_method=None, ind=None, **kwargs): @@ -2038,26 +2012,31 @@ def _get_ind(self, y): ind = self.ind return ind - def _get_plot_function(self): + @classmethod + def _plot(cls, ax, y, style=None, bw_method=None, ind=None, + column_num=None, stacking_id=None, **kwds): from scipy.stats import gaussian_kde from scipy import __version__ as spv - f = MPLPlot._get_plot_function(self) - def plotf(ax, y, style=None, column_num=None, **kwds): - y = remove_na(y) - if LooseVersion(spv) >= '0.11.0': - gkde = gaussian_kde(y, bw_method=self.bw_method) - else: - gkde = gaussian_kde(y) - if self.bw_method is not None: - msg = ('bw_method was added in Scipy 0.11.0.' + - ' Scipy version in use is %s.' % spv) - warnings.warn(msg) - - ind = self._get_ind(y) - y = gkde.evaluate(ind) - lines = f(ax, ind, y, style=style, **kwds) - return lines - return plotf + + y = remove_na(y) + + if LooseVersion(spv) >= '0.11.0': + gkde = gaussian_kde(y, bw_method=bw_method) + else: + gkde = gaussian_kde(y) + if bw_method is not None: + msg = ('bw_method was added in Scipy 0.11.0.' + + ' Scipy version in use is %s.' % spv) + warnings.warn(msg) + + y = gkde.evaluate(ind) + lines = MPLPlot._plot(ax, ind, y, style=style, **kwds) + return lines + + def _make_plot_keywords(self, kwds, y): + kwds['bw_method'] = self.bw_method + kwds['ind'] = self._get_ind(y) + return kwds def _post_plot_logic(self): for ax in self.axes: @@ -2065,6 +2044,7 @@ def _post_plot_logic(self): class PiePlot(MPLPlot): + _kind = 'pie' _layout_type = 'horizontal' def __init__(self, data, kind=None, **kwargs): @@ -2083,8 +2063,8 @@ def _validate_color_args(self): pass def _make_plot(self): - self.kwds.setdefault('colors', self._get_colors(num_colors=len(self.data), - color_kwds='colors')) + colors = self._get_colors(num_colors=len(self.data), color_kwds='colors') + self.kwds.setdefault('colors', colors) for i, (label, y) in enumerate(self._iter_data()): ax = self._get_ax(i) @@ -2129,6 +2109,7 @@ def blank_labeler(label, value): class BoxPlot(LinePlot): + _kind = 'box' _layout_type = 'horizontal' _valid_return_types = (None, 'axes', 'dict', 'both') @@ -2151,25 +2132,24 @@ def _args_adjust(self): else: self.sharey = False - def _get_plot_function(self): - def plotf(ax, y, column_num=None, **kwds): - if y.ndim == 2: - y = [remove_na(v) for v in y] - # Boxplot fails with empty arrays, so need to add a NaN - # if any cols are empty - # GH 8181 - y = [v if v.size > 0 else np.array([np.nan]) for v in y] - else: - y = remove_na(y) - bp = ax.boxplot(y, **kwds) + @classmethod + def _plot(cls, ax, y, column_num=None, return_type=None, **kwds): + if y.ndim == 2: + y = [remove_na(v) for v in y] + # Boxplot fails with empty arrays, so need to add a NaN + # if any cols are empty + # GH 8181 + y = [v if v.size > 0 else np.array([np.nan]) for v in y] + else: + y = remove_na(y) + bp = ax.boxplot(y, **kwds) - if self.return_type == 'dict': - return bp, bp - elif self.return_type == 'both': - return self.BP(ax=ax, lines=bp), bp - else: - return ax, bp - return plotf + if return_type == 'dict': + return bp, bp + elif return_type == 'both': + return cls.BP(ax=ax, lines=bp), bp + else: + return ax, bp def _validate_color_args(self): if 'color' in self.kwds: @@ -2223,7 +2203,6 @@ def maybe_color_bp(self, bp): setp(bp['caps'], color=caps, alpha=1) def _make_plot(self): - plotf = self._get_plot_function() if self.subplots: self._return_obj = compat.OrderedDict() @@ -2231,7 +2210,8 @@ def _make_plot(self): ax = self._get_ax(i) kwds = self.kwds.copy() - ret, bp = plotf(ax, y, column_num=i, **kwds) + ret, bp = self._plot(ax, y, column_num=i, + return_type=self.return_type, **kwds) self.maybe_color_bp(bp) self._return_obj[label] = ret @@ -2242,7 +2222,8 @@ def _make_plot(self): ax = self._get_ax(0) kwds = self.kwds.copy() - ret, bp = plotf(ax, y, column_num=0, **kwds) + ret, bp = self._plot(ax, y, column_num=0, + return_type=self.return_type, **kwds) self.maybe_color_bp(bp) self._return_obj = ret @@ -2287,10 +2268,12 @@ def result(self): _series_kinds = ['pie'] _all_kinds = _common_kinds + _dataframe_kinds + _series_kinds -_plot_klass = {'line': LinePlot, 'bar': BarPlot, 'barh': BarPlot, - 'kde': KdePlot, 'hist': HistPlot, 'box': BoxPlot, - 'scatter': ScatterPlot, 'hexbin': HexBinPlot, - 'area': AreaPlot, 'pie': PiePlot} +_klasses = [LinePlot, BarPlot, BarhPlot, KdePlot, HistPlot, BoxPlot, + ScatterPlot, HexBinPlot, AreaPlot, PiePlot] + +_plot_klass = {} +for klass in _klasses: + _plot_klass[klass._kind] = klass def _plot(data, x=None, y=None, subplots=False, diff --git a/pandas/tseries/plotting.py b/pandas/tseries/plotting.py index 9d28fa11f646f..ad27b412cddb9 100644 --- a/pandas/tseries/plotting.py +++ b/pandas/tseries/plotting.py @@ -4,12 +4,16 @@ """ #!!! TODO: Use the fact that axis can have units to simplify the process + +import numpy as np + from matplotlib import pylab from pandas.tseries.period import Period from pandas.tseries.offsets import DateOffset import pandas.tseries.frequencies as frequencies from pandas.tseries.index import DatetimeIndex import pandas.core.common as com +import pandas.compat as compat from pandas.tseries.converter import (TimeSeries_DateLocator, TimeSeries_DateFormatter) @@ -18,7 +22,7 @@ # Plotting functions and monkey patches -def tsplot(series, plotf, **kwargs): +def tsplot(series, plotf, ax=None, **kwargs): """ Plots a Series on the given Matplotlib axes or the current axes @@ -33,46 +37,33 @@ def tsplot(series, plotf, **kwargs): """ # Used inferred freq is possible, need a test case for inferred - if 'ax' in kwargs: - ax = kwargs.pop('ax') - else: + if ax is None: import matplotlib.pyplot as plt ax = plt.gca() - freq = _get_freq(ax, series) - # resample against axes freq if necessary - if freq is None: # pragma: no cover - raise ValueError('Cannot use dynamic axis without frequency info') - else: - # Convert DatetimeIndex to PeriodIndex - if isinstance(series.index, DatetimeIndex): - series = series.to_period(freq=freq) - freq, ax_freq, series = _maybe_resample(series, ax, freq, plotf, - kwargs) + freq, series = _maybe_resample(series, ax, kwargs) # Set ax with freq info _decorate_axes(ax, freq, kwargs) - - # how to make sure ax.clear() flows through? - if not hasattr(ax, '_plot_data'): - ax._plot_data = [] ax._plot_data.append((series, plotf, kwargs)) lines = plotf(ax, series.index._mpl_repr(), series.values, **kwargs) # set date formatter, locators and rescale limits format_dateaxis(ax, ax.freq) + return lines - # x and y coord info - ax.format_coord = lambda t, y: ("t = {0} " - "y = {1:8f}".format(Period(ordinal=int(t), - freq=ax.freq), - y)) - return lines +def _maybe_resample(series, ax, kwargs): + # resample against axes freq if necessary + freq, ax_freq = _get_freq(ax, series) + + if freq is None: # pragma: no cover + raise ValueError('Cannot use dynamic axis without frequency info') + # Convert DatetimeIndex to PeriodIndex + if isinstance(series.index, DatetimeIndex): + series = series.to_period(freq=freq) -def _maybe_resample(series, ax, freq, plotf, kwargs): - ax_freq = _get_ax_freq(ax) if ax_freq is not None and freq != ax_freq: if frequencies.is_superperiod(freq, ax_freq): # upsample input series = series.copy() @@ -84,21 +75,11 @@ def _maybe_resample(series, ax, freq, plotf, kwargs): series = series.resample(ax_freq, how=how).dropna() freq = ax_freq elif frequencies.is_subperiod(freq, ax_freq) or _is_sub(freq, ax_freq): - _upsample_others(ax, freq, plotf, kwargs) + _upsample_others(ax, freq, kwargs) ax_freq = freq else: # pragma: no cover raise ValueError('Incompatible frequency conversion') - return freq, ax_freq, series - - -def _get_ax_freq(ax): - ax_freq = getattr(ax, 'freq', None) - if ax_freq is None: - if hasattr(ax, 'left_ax'): - ax_freq = getattr(ax.left_ax, 'freq', None) - elif hasattr(ax, 'right_ax'): - ax_freq = getattr(ax.right_ax, 'freq', None) - return ax_freq + return freq, series def _is_sub(f1, f2): @@ -111,9 +92,10 @@ def _is_sup(f1, f2): (f2.startswith('W') and frequencies.is_superperiod(f1, 'D'))) -def _upsample_others(ax, freq, plotf, kwargs): +def _upsample_others(ax, freq, kwargs): legend = ax.get_legend() lines, labels = _replot_ax(ax, freq, kwargs) + _replot_ax(ax, freq, kwargs) other_ax = None if hasattr(ax, 'left_ax'): @@ -136,8 +118,11 @@ def _upsample_others(ax, freq, plotf, kwargs): def _replot_ax(ax, freq, kwargs): data = getattr(ax, '_plot_data', None) + + # clear current axes and data ax._plot_data = [] ax.clear() + _decorate_axes(ax, freq, kwargs) lines = [] @@ -147,7 +132,13 @@ def _replot_ax(ax, freq, kwargs): series = series.copy() idx = series.index.asfreq(freq, how='S') series.index = idx - ax._plot_data.append(series) + ax._plot_data.append((series, plotf, kwds)) + + # for tsplot + if isinstance(plotf, compat.string_types): + from pandas.tools.plotting import _plot_klass + plotf = _plot_klass[plotf]._plot + lines.append(plotf(ax, series.index._mpl_repr(), series.values, **kwds)[0]) labels.append(com.pprint_thing(series.name)) @@ -155,6 +146,10 @@ def _replot_ax(ax, freq, kwargs): def _decorate_axes(ax, freq, kwargs): + """Initialize axes for time-series plotting""" + if not hasattr(ax, '_plot_data'): + ax._plot_data = [] + ax.freq = freq xaxis = ax.get_xaxis() xaxis.freq = freq @@ -173,6 +168,11 @@ def _get_freq(ax, series): freq = getattr(series.index, 'inferred_freq', None) ax_freq = getattr(ax, 'freq', None) + if ax_freq is None: + if hasattr(ax, 'left_ax'): + ax_freq = getattr(ax.left_ax, 'freq', None) + elif hasattr(ax, 'right_ax'): + ax_freq = getattr(ax.right_ax, 'freq', None) # use axes freq if no data freq if freq is None: @@ -185,10 +185,76 @@ def _get_freq(ax, series): freq = frequencies.get_base_alias(freq) freq = frequencies.get_period_alias(freq) + return freq, ax_freq + + +def _use_dynamic_x(ax, data): + freq = _get_index_freq(data) + ax_freq = getattr(ax, 'freq', None) + + if freq is None: # convert irregular if axes has freq info + freq = ax_freq + else: # do not use tsplot if irregular was plotted first + if (ax_freq is None) and (len(ax.get_lines()) > 0): + return False + + if freq is None: + return False + + if isinstance(freq, DateOffset): + freq = freq.rule_code + else: + freq = frequencies.get_base_alias(freq) + freq = frequencies.get_period_alias(freq) + if freq is None: + return False + + # hack this for 0.10.1, creating more technical debt...sigh + if isinstance(data.index, DatetimeIndex): + base = frequencies.get_freq(freq) + x = data.index + if (base <= frequencies.FreqGroup.FR_DAY): + return x[:1].is_normalized + return Period(x[0], freq).to_timestamp(tz=x.tz) == x[0] + return True + + +def _get_index_freq(data): + freq = getattr(data.index, 'freq', None) + if freq is None: + freq = getattr(data.index, 'inferred_freq', None) + if freq == 'B': + weekdays = np.unique(data.index.dayofweek) + if (5 in weekdays) or (6 in weekdays): + freq = None return freq +def _maybe_convert_index(ax, data): + # tsplot converts automatically, but don't want to convert index + # over and over for DataFrames + if isinstance(data.index, DatetimeIndex): + freq = getattr(data.index, 'freq', None) + + if freq is None: + freq = getattr(data.index, 'inferred_freq', None) + if isinstance(freq, DateOffset): + freq = freq.rule_code + + if freq is None: + freq = getattr(ax, 'freq', None) + + if freq is None: + raise ValueError('Could not get frequency alias for plotting') + + freq = frequencies.get_base_alias(freq) + freq = frequencies.get_period_alias(freq) + + data = data.to_period(freq=freq) + return data + + # Patch methods for subplot. Only format_dateaxis is currently used. # Do we need the rest for convenience? @@ -219,4 +285,9 @@ def format_dateaxis(subplot, freq): plot_obj=subplot) subplot.xaxis.set_major_formatter(majformatter) subplot.xaxis.set_minor_formatter(minformatter) + + # x and y coord info + subplot.format_coord = lambda t, y: ("t = {0} " + "y = {1:8f}".format(Period(ordinal=int(t), freq=freq), y)) + pylab.draw_if_interactive() diff --git a/pandas/tseries/tests/test_plotting.py b/pandas/tseries/tests/test_plotting.py index 2ba65c07aa114..74f2a4550780b 100644 --- a/pandas/tseries/tests/test_plotting.py +++ b/pandas/tseries/tests/test_plotting.py @@ -105,6 +105,12 @@ def test_tsplot(self): for s in self.datetime_ser: _check_plot_works(f, s.index.freq.rule_code, ax=ax, series=s) + for s in self.period_ser: + _check_plot_works(s.plot, ax=ax) + + for s in self.datetime_ser: + _check_plot_works(s.plot, ax=ax) + ax = ts.plot(style='k') self.assertEqual((0., 0., 0.), ax.get_lines()[0].get_color()) @@ -151,6 +157,15 @@ def check_format_of_first_point(ax, expected_string): # note this is added to the annual plot already in existence, and changes its freq field daily = Series(1, index=date_range('2014-01-01', periods=3, freq='D')) check_format_of_first_point(daily.plot(), 't = 2014-01-01 y = 1.000000') + tm.close() + + # tsplot + import matplotlib.pyplot as plt + from pandas.tseries.plotting import tsplot + tsplot(annual, plt.Axes.plot) + check_format_of_first_point(plt.gca(), 't = 2014 y = 1.000000') + tsplot(daily, plt.Axes.plot) + check_format_of_first_point(plt.gca(), 't = 2014-01-01 y = 1.000000') @slow def test_line_plot_period_series(self): @@ -746,6 +761,15 @@ def test_to_weekly_resampling(self): for l in ax.get_lines(): self.assertTrue(PeriodIndex(data=l.get_xdata()).freq.startswith('W')) + # tsplot + from pandas.tseries.plotting import tsplot + import matplotlib.pyplot as plt + + tsplot(high, plt.Axes.plot) + lines = tsplot(low, plt.Axes.plot) + for l in lines: + self.assertTrue(PeriodIndex(data=l.get_xdata()).freq.startswith('W')) + @slow def test_from_weekly_resampling(self): idxh = date_range('1/1/1999', periods=52, freq='W') @@ -760,7 +784,22 @@ def test_from_weekly_resampling(self): 1553, 1558, 1562]) for l in ax.get_lines(): self.assertTrue(PeriodIndex(data=l.get_xdata()).freq.startswith('W')) + xdata = l.get_xdata(orig=False) + if len(xdata) == 12: # idxl lines + self.assert_numpy_array_equal(xdata, expected_l) + else: + self.assert_numpy_array_equal(xdata, expected_h) + tm.close() + + # tsplot + from pandas.tseries.plotting import tsplot + import matplotlib.pyplot as plt + + tsplot(low, plt.Axes.plot) + lines = tsplot(high, plt.Axes.plot) + for l in lines: + self.assertTrue(PeriodIndex(data=l.get_xdata()).freq.startswith('W')) xdata = l.get_xdata(orig=False) if len(xdata) == 12: # idxl lines self.assert_numpy_array_equal(xdata, expected_l)