diff --git a/examples/images_contours_and_fields/bivariate_demo.py b/examples/images_contours_and_fields/bivariate_demo.py new file mode 100644 index 000000000000..fead56aaae61 --- /dev/null +++ b/examples/images_contours_and_fields/bivariate_demo.py @@ -0,0 +1,44 @@ +""" +============== +Bivariate Demo +============== + +Plotting bivariate data. + +imshow, pcolor, pcolormesh, pcolorfast allows you to plot bivariate data +using a bivaraite colormap. + +In this example we use imshow to plot air temperature with surface pressure +alongwith a color square. +""" +import matplotlib.colors as colors +from matplotlib.cbook import get_sample_data +import matplotlib.pyplot as plt +import numpy as np + + +############################################################################### +# Bivariate plotting demo +# ----------------------- + +air_temp = np.load(get_sample_data('air_temperature.npy')) +surf_pres = np.load(get_sample_data('surface_pressure.npy')) + +fig, ax = plt.subplots() + +bivariate = [air_temp, surf_pres] + +############################################################################### +# To distinguish bivariate data either BivariateNorm or BivariateColormap must +# be passed in as argument + +cax = ax.imshow(bivariate, norm=colors.BivariateNorm(), + cmap=colors.BivariateColormap()) + +############################################################################### +# If input data is bivariate then colorbar automatically draws colorsquare +# instead of colorbar + +cbar = fig.colorbar(cax, xlabel='air_temp', ylabel='surf_pres') + +plt.show() diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index e2610703489d..404275d0c5e9 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -4044,7 +4044,8 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, if colors is None: if norm is not None and not isinstance(norm, mcolors.Normalize): - msg = "'norm' must be an instance of 'mcolors.Normalize'" + msg = ("'norm' must be an instance of 'mcolors.Normalize' or " + "'mcolors.BivariateNorm'") raise ValueError(msg) collection.set_array(np.asarray(c)) collection.set_cmap(cmap) @@ -4403,7 +4404,8 @@ def hexbin(self, x, y, C=None, gridsize=100, bins=None, accum = bins.searchsorted(accum) if norm is not None and not isinstance(norm, mcolors.Normalize): - msg = "'norm' must be an instance of 'mcolors.Normalize'" + msg = ("'norm' must be an instance of 'mcolors.Normalize' or " + "'mcolors.BivariateNorm'") raise ValueError(msg) collection.set_array(accum) collection.set_cmap(cmap) @@ -5037,23 +5039,25 @@ def imshow(self, X, cmap=None, norm=None, aspect=None, Parameters ---------- - X : array_like, shape (n, m) or (n, m, 3) or (n, m, 4) + X : array_like, shape (n, m) or (n, m, 3) or (n, m, 4) or (2, n, m) Display the image in `X` to current axes. `X` may be an array or a PIL image. If `X` is an array, it can have the following shapes and types: - - MxN -- values to be mapped (float or int) + - MxN -- univariate values to be mapped (float or int) - MxNx3 -- RGB (float or uint8) - MxNx4 -- RGBA (float or uint8) + - 2xMxN -- bivariate values to be mapped (float or int) The value for each component of MxNx3 and MxNx4 float arrays - should be in the range 0.0 to 1.0. MxN arrays are mapped + should be in the range 0.0 to 1.0. MxN and 2xMxN arrays are mapped to colors based on the `norm` (mapping scalar to scalar) and the `cmap` (mapping the normed scalar to a color). - cmap : `~matplotlib.colors.Colormap`, optional, default: None + cmap : `~matplotlib.colors.Colormap`, \ + `~matplotlib.colors.BivariateColormap`, optional, default: None If None, default to rc `image.cmap` value. `cmap` is ignored - if `X` is 3-D, directly specifying RGB(A) values. + if `X` is 3-D but not bivariate, directly specifying RGB(A) values. aspect : ['auto' | 'equal' | scalar], optional, default: None If 'auto', changes the image aspect ratio to match that of the @@ -5077,7 +5081,8 @@ def imshow(self, X, cmap=None, norm=None, aspect=None, on the Agg, ps and pdf backends. Other backends will fall back to 'nearest'. - norm : `~matplotlib.colors.Normalize`, optional, default: None + norm : `~matplotlib.colors.Normalize`, \ + `matplotlib.colors.BivariateNorm`, optional, default: None A `~matplotlib.colors.Normalize` instance is used to scale a 2-D float `X` input to the (0, 1) range for input to the `cmap`. If `norm` is None, use the default func:`normalize`. @@ -5137,16 +5142,29 @@ def imshow(self, X, cmap=None, norm=None, aspect=None, of pixel (0, 0). """ - if not self._hold: self.cla() - if norm is not None and not isinstance(norm, mcolors.Normalize): - msg = "'norm' must be an instance of 'mcolors.Normalize'" + if norm is not None and not isinstance(norm, mcolors.Norms): + msg = ("'norm' must be an instance of 'mcolors.Normalize' or " + "'mcolors.BivariateNorm'") raise ValueError(msg) + + temp = np.asarray(X) + is_bivari = (isinstance(norm, mcolors.BivariateNorm) or + isinstance(cmap, mcolors.BivariateColormap)) + if is_bivari: + if temp.ndim != 3 and temp.shape[0] != 2: + raise TypeError("Expected shape like (2, n, m)") + if cmap is None: + cmap = mcolors.BivariateColormap() + if norm is None: + norm = mcolors.BivariateNorm() + if aspect is None: aspect = rcParams['image.aspect'] self.set_aspect(aspect) + im = mimage.AxesImage(self, cmap, norm, interpolation, origin, extent, filternorm=filternorm, filterrad=filterrad, resample=resample, **kwargs) @@ -5173,7 +5191,6 @@ def imshow(self, X, cmap=None, norm=None, aspect=None, @staticmethod def _pcolorargs(funcname, *args, **kw): - # This takes one kwarg, allmatch. # If allmatch is True, then the incoming X, Y, C must # have matching dimensions, taking into account that # X and Y can be 1-D rather than 2-D. This perfect @@ -5186,10 +5203,17 @@ def _pcolorargs(funcname, *args, **kw): # is False. allmatch = kw.pop("allmatch", False) + norm = kw.pop("norm", None) + cmap = kw.pop("cmap", None) if len(args) == 1: C = np.asanyarray(args[0]) - numRows, numCols = C.shape + is_bivari = (isinstance(norm, mcolors.BivariateNorm) or + isinstance(cmap, mcolors.BivariateColormap)) + if is_bivari: + numRows, numCols = C.shape[1:] + else: + numRows, numCols = C.shape if allmatch: X, Y = np.meshgrid(np.arange(numCols), np.arange(numRows)) else: @@ -5200,7 +5224,12 @@ def _pcolorargs(funcname, *args, **kw): if len(args) == 3: X, Y, C = [np.asanyarray(a) for a in args] - numRows, numCols = C.shape + is_bivari = (isinstance(norm, mcolors.BivariateNorm) or + isinstance(cmap, mcolors.BivariateColormap)) + if is_bivari: + numRows, numCols = C.shape[1:] + else: + numRows, numCols = C.shape else: raise TypeError( 'Illegal arguments to %s; see help(%s)' % (funcname, funcname)) @@ -5235,7 +5264,7 @@ def _pcolorargs(funcname, *args, **kw): @docstring.dedent_interpd def pcolor(self, *args, **kwargs): """ - Create a pseudocolor plot of a 2-D array. + Create a pseudocolor plot of a 2-D univariate or 3-D bivariate array. Call signatures:: @@ -5273,10 +5302,12 @@ def pcolor(self, *args, **kwargs): vectors, they will be expanded as needed into the appropriate 2-D arrays, making a rectangular grid. - cmap : `~matplotlib.colors.Colormap`, optional, default: None + cmap : `~matplotlib.colors.Colormap` or \ + `matplotlib.colors.BivariateColormap`, optional, default: None If `None`, default to rc settings. - norm : `matplotlib.colors.Normalize`, optional, default: None + norm : `matplotlib.colors.Normalize` or \ + `matplotlib.colors.BivariateNorm`, optional, default: None An instance is used to scale luminance data to (0, 1). If `None`, defaults to :func:`normalize`. @@ -5382,9 +5413,20 @@ def pcolor(self, *args, **kwargs): vmin = kwargs.pop('vmin', None) vmax = kwargs.pop('vmax', None) - X, Y, C = self._pcolorargs('pcolor', *args, allmatch=False) + kw = {'norm': norm, 'cmap': cmap, 'allmatch': False} + X, Y, C = self._pcolorargs('pcolor', *args, **kw) Ny, Nx = X.shape + is_bivari = (isinstance(norm, mcolors.BivariateNorm) or + isinstance(cmap, mcolors.BivariateColormap)) + if is_bivari: + if C.ndim != 3 and C.shape[0] != 2: + raise TypeError("Expected shape like (2, n, m)") + if cmap is None: + cmap = mcolors.BivariateColormap() + if norm is None: + norm = mcolors.BivariateNorm() + # unit conversion allows e.g. datetime objects as axis values self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs) X = self.convert_xunits(X) @@ -5399,7 +5441,10 @@ def pcolor(self, *args, **kwargs): xymask = (mask[0:-1, 0:-1] + mask[1:, 1:] + mask[0:-1, 1:] + mask[1:, 0:-1]) # don't plot if C or any of the surrounding vertices are masked. - mask = ma.getmaskarray(C) + xymask + if isinstance(norm, mcolors.BivariateNorm): + mask = ma.getmaskarray(C[0]) + ma.getmaskarray(C[1]) + xymask + else: + mask = ma.getmaskarray(C) + xymask newaxis = np.newaxis compress = np.compress @@ -5423,7 +5468,15 @@ def pcolor(self, *args, **kwargs): axis=1) verts = xy.reshape((npoly, 5, 2)) - C = compress(ravelmask, ma.filled(C[0:Ny - 1, 0:Nx - 1]).ravel()) + if isinstance(norm, mcolors.BivariateNorm): + C = np.array([ + compress( + ravelmask, + ma.filled(c[0:Ny - 1, 0:Nx - 1]).ravel() + ) for c in C + ]) + else: + C = compress(ravelmask, ma.filled(C[0:Ny - 1, 0:Nx - 1]).ravel()) linewidths = (0.25,) if 'linewidth' in kwargs: @@ -5450,9 +5503,12 @@ def pcolor(self, *args, **kwargs): collection.set_alpha(alpha) collection.set_array(C) - if norm is not None and not isinstance(norm, mcolors.Normalize): - msg = "'norm' must be an instance of 'mcolors.Normalize'" + + if norm is not None and not isinstance(norm, mcolors.Norms): + msg = ("'norm' must be an instance of 'mcolors.Normalize' or " + "'mcolors.BivariateNorm'") raise ValueError(msg) + collection.set_cmap(cmap) collection.set_norm(norm) collection.set_clim(vmin, vmax) @@ -5518,11 +5574,13 @@ def pcolormesh(self, *args, **kwargs): Keyword arguments: *cmap*: [ *None* | Colormap ] - A :class:`matplotlib.colors.Colormap` instance. If *None*, use - rc settings. + A :class:`matplotlib.colors.Colormap` or + :class:`matplotlib.colors.BivariateColormap` instance. If *None*, + use rc settings. *norm*: [ *None* | Normalize ] - A :class:`matplotlib.colors.Normalize` instance is used to + A :class:`matplotlib.colors.Normalize` or + :class:`matplotlib.colors.BivariateNorm` instance is used to scale luminance data to 0,1. If *None*, defaults to :func:`normalize`. @@ -5582,16 +5640,31 @@ def pcolormesh(self, *args, **kwargs): allmatch = (shading == 'gouraud') - X, Y, C = self._pcolorargs('pcolormesh', *args, allmatch=allmatch) + kw = {'norm': norm, 'cmap': cmap, 'allmatch': allmatch} + X, Y, C = self._pcolorargs('pcolormesh', *args, **kw) Ny, Nx = X.shape + is_bivari = (isinstance(norm, mcolors.BivariateNorm) or + isinstance(cmap, mcolors.BivariateColormap)) + if is_bivari: + if C.ndim != 3 and C.shape[0] != 2: + raise TypeError("Expected shape like (2, n, m)") + if cmap is None: + cmap = mcolors.BivariateColormap() + if norm is None: + norm = mcolors.BivariateNorm() + # unit conversion allows e.g. datetime objects as axis values self._process_unit_info(xdata=X, ydata=Y, kwargs=kwargs) X = self.convert_xunits(X) Y = self.convert_yunits(Y) - # convert to one dimensional arrays - C = C.ravel() + # convert to one dimensional arrays if univariate + if isinstance(norm, mcolors.BivariateNorm): + C = np.asarray([c.ravel() for c in C]) + else: + C = C.ravel() + coords = np.column_stack((X.flat, Y.flat)).astype(float, copy=False) collection = mcoll.QuadMesh(Nx - 1, Ny - 1, coords, @@ -5599,8 +5672,9 @@ def pcolormesh(self, *args, **kwargs): **kwargs) collection.set_alpha(alpha) collection.set_array(C) - if norm is not None and not isinstance(norm, mcolors.Normalize): - msg = "'norm' must be an instance of 'mcolors.Normalize'" + if norm is not None and not isinstance(norm, mcolors.Norms): + msg = ("'norm' must be an instance of 'mcolors.Normalize' or " + "'mcolors.BivariateNorm'") raise ValueError(msg) collection.set_cmap(cmap) collection.set_norm(norm) @@ -5634,7 +5708,7 @@ def pcolormesh(self, *args, **kwargs): @docstring.dedent_interpd def pcolorfast(self, *args, **kwargs): """ - pseudocolor plot of a 2-D array + pseudocolor plot of a 2-D univariate or 3-D bivariate array Experimental; this is a pcolor-type method that provides the fastest possible rendering with the Agg @@ -5693,11 +5767,13 @@ def pcolorfast(self, *args, **kwargs): Optional keyword arguments: *cmap*: [ *None* | Colormap ] - A :class:`matplotlib.colors.Colormap` instance from cm. If *None*, - use rc settings. + A :class:`matplotlib.colors.Colormap` or + :class:`matplotlib.colors.BivariateColormap` instance from cm. + If *None*, use rc settings. *norm*: [ *None* | Normalize ] - A :class:`matplotlib.colors.Normalize` instance is used to scale + A :class:`matplotlib.colors.Normalize` or + :class:`matplotlib.colors.BivariateNorm` instance is used to scale luminance data to 0,1. If *None*, defaults to normalize() *vmin*/*vmax*: [ *None* | scalar ] @@ -5723,12 +5799,26 @@ def pcolorfast(self, *args, **kwargs): cmap = kwargs.pop('cmap', None) vmin = kwargs.pop('vmin', None) vmax = kwargs.pop('vmax', None) - if norm is not None and not isinstance(norm, mcolors.Normalize): - msg = "'norm' must be an instance of 'mcolors.Normalize'" + + if norm is not None and not isinstance(norm, mcolors.Norms): + msg = ("'norm' must be an instance of 'mcolors.Normalize' or " + "'mcolors.BivariateNorm'") raise ValueError(msg) - C = args[-1] - nr, nc = C.shape + C = np.asarray(args[-1]) + + is_bivari = (isinstance(norm, mcolors.BivariateNorm) or + isinstance(cmap, mcolors.BivariateColormap)) + if is_bivari: + if C.ndim != 3 and C.shape[0] != 2: + raise TypeError("Expected shape like (2, n, m)") + if cmap is None: + cmap = mcolors.BivariateColormap() + if norm is None: + norm = mcolors.BivariateNorm() + nr, nc = C.shape[1:] + else: + nr, nc = C.shape if len(args) == 1: style = "image" x = [0, nc] diff --git a/lib/matplotlib/axes/_subplots.py b/lib/matplotlib/axes/_subplots.py index 90d55d21cc4c..3d806106d4c0 100644 --- a/lib/matplotlib/axes/_subplots.py +++ b/lib/matplotlib/axes/_subplots.py @@ -177,7 +177,6 @@ def subplot_class_factory(axes_class=None): (SubplotBase, axes_class), {'_axes_class': axes_class}) _subplot_classes[axes_class] = new_class - return new_class # This is provided for backward compatibility diff --git a/lib/matplotlib/cm.py b/lib/matplotlib/cm.py index bdf3e1575653..0ff3b56d9a45 100644 --- a/lib/matplotlib/cm.py +++ b/lib/matplotlib/cm.py @@ -238,7 +238,7 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True): """ # First check for special case, image input: try: - if x.ndim == 3: + if x.ndim == 3 and (x.shape[-1] == 3 or x.shape[-1] == 4): if x.shape[2] == 3: if alpha is None: alpha = 1 diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 3acbaeceefbe..8c0984f162d8 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -731,7 +731,8 @@ def update_scalarmappable(self): """ if self._A is None: return - if self._A.ndim > 1: + if (self._A.ndim > 1 and + not isinstance(self.norm, mcolors.BivariateNorm)): raise ValueError('Collections can only map rank 1 arrays') if not self.check_update("array"): return diff --git a/lib/matplotlib/colorbar.py b/lib/matplotlib/colorbar.py index 0fb4e3b47c17..7e4d1c9ff212 100644 --- a/lib/matplotlib/colorbar.py +++ b/lib/matplotlib/colorbar.py @@ -900,6 +900,406 @@ def remove(self): fig.delaxes(self.ax) +class ColorsquareBase(cm.ScalarMappable): + + n_rasterize = 50 # rasterize solids if number of colors >= n_rasterize + + def __init__(self, ax, cmap=None, + norm=None, + alpha=None, + xvalues=None, + yvalues=None, + xboundaries=None, + yboundaries=None, + xticks=None, + yticks=None, + xformat=None, + yformat=None, + drawedges=False, + filled=True, + xlabel='', + ylabel='', + ): + #: The axes that this colorbar lives in. + self.ax = ax + self._patch_ax() + if cmap is None: + cmap = colors.BivariateColormap() + if norm is None: + norm = colors.BivariateNorm() + self.alpha = alpha + cm.ScalarMappable.__init__(self, cmap=cmap, norm=norm) + self.xvalues = xvalues + self.yvalues = yvalues + self.xboundaries = xboundaries + self.yboundaries = yboundaries + self._inside = slice(0, None) + self.drawedges = drawedges + self.filled = filled + self.solids = None + self.lines = list() + self.dividers = None + self.set_label(xlabel, ylabel) + + if cbook.iterable(xticks): + self.xlocator = ticker.FixedLocator(xticks, nbins=len(xticks)) + else: + self.xlocator = xticks + + if cbook.iterable(yticks): + self.ylocator = ticker.FixedLocator(yticks, nbins=len(yticks)) + else: + self.ylocator = yticks + + if xformat is None: + if isinstance(self.norm.norm1, colors.LogNorm): + self.xformatter = ticker.LogFormatterSciNotation() + elif isinstance(self.norm.norm1, colors.SymLogNorm): + self.xformatter = ticker.LogFormatterSciNotation( + linthresh=self.norm.norm1.linthresh) + else: + self.xformatter = ticker.ScalarFormatter() + elif isinstance(xformat, six.string_types): + self.xformatter = ticker.FormatStrFormatter(xformat) + else: + self.xformatter = xformat # Assume it is a Formatter + + if yformat is None: + if isinstance(self.norm.norm2, colors.LogNorm): + self.yformatter = ticker.LogFormatterSciNotation() + elif isinstance(self.norm.norm2, colors.SymLogNorm): + self.yformatter = ticker.LogFormatterSciNotation( + linthresh=self.norm.norm2.linthresh) + else: + self.yformatter = ticker.ScalarFormatter() + elif isinstance(yformat, six.string_types): + self.yformatter = ticker.FormatStrFormatter(yformat) + else: + self.yformatter = yformat # Assume it is a Formatter + + # The rest is in a method so we can recalculate when clim changes. + self.config_axis() + self.draw_all() + + def _patch_ax(self): + # bind some methods to the axes to warn users + # against using those methods. + self.ax.set_xticks = _set_ticks_on_axis_warn + self.ax.set_yticks = _set_ticks_on_axis_warn + + def draw_all(self): + normx = self.norm.norm1 + normy = self.norm.norm2 + self._xvalues, self._xboundaries = self._process_values(norm=normx) + self._yvalues, self._yboundaries = self._process_values(norm=normy) + X, Y = self._mesh() + CX, CY = np.meshgrid(self._xvalues, self._yvalues) + self.update_ticks() + if self.filled: + self._add_solids(X, Y, [CX, CY]) + + def config_axis(self): + ax = self.ax + ax.set_navigate(False) + + ax.yaxis.set_label_position('right') + ax.yaxis.set_ticks_position('right') + + ax.xaxis.set_label_position('bottom') + ax.xaxis.set_ticks_position('bottom') + + self._set_label() + + def update_ticks(self): + """ + Force the update of the ticks and ticklabels. This must be + called whenever the tick locator and/or tick formatter changes. + """ + def _make_ticker(norm): + """ + Return the sequence of ticks (colorbar data locations), + ticklabels (strings), and the corresponding offset string. + """ + if norm is self.norm.norm1: + _values = self._xvalues + _boundaries = self._xboundaries + boundaries = self.xboundaries + locator = self.xlocator + formatter = self.xformatter + else: + _values = self._yvalues + _boundaries = self._yboundaries + boundaries = self.yboundaries + locator = self.ylocator + formatter = self.yformatter + + if locator is None: + if boundaries is None: + if isinstance(norm, colors.NoNorm): + nv = len(_values) + base = 1 + int(nv / 10) + locator = ticker.IndexLocator(base=base, offset=0) + elif isinstance(norm, colors.BoundaryNorm): + b = norm.boundaries + locator = ticker.FixedLocator(b, nbins=10) + elif isinstance(norm, colors.LogNorm): + locator = ticker.LogLocator(subs='all') + elif isinstance(norm, colors.SymLogNorm): + # The subs setting here should be replaced + # by logic in the locator. + locator = ticker.SymmetricalLogLocator( + subs=np.arange(1, 10), + linthresh=norm.linthresh, + base=10) + else: + # locator = ticker.AutoLocator() + locator = ticker.MaxNLocator(nbins=5) + else: + b = _boundaries[self._inside] + locator = ticker.FixedLocator(b, nbins=10) + if isinstance(norm, colors.NoNorm) and boundaries is None: + intv = _values[0], _values[-1] + else: + b = _boundaries[self._inside] + intv = b[0], b[-1] + locator.create_dummy_axis(minpos=intv[0]) + formatter.create_dummy_axis(minpos=intv[0]) + locator.set_view_interval(*intv) + locator.set_data_interval(*intv) + formatter.set_view_interval(*intv) + formatter.set_data_interval(*intv) + + b = np.array(locator()) + if isinstance(locator, ticker.LogLocator): + eps = 1e-10 + b = b[(b <= intv[1] * (1 + eps)) & (b >= intv[0] * (1 - eps))] + else: + eps = (intv[1] - intv[0]) * 1e-10 + b = b[(b <= intv[1] + eps) & (b >= intv[0] - eps)] + # self._tick_data_values = b + ticks = self._locate(b, norm) + formatter.set_locs(b) + ticklabels = [formatter(t, i) for i, t in enumerate(b)] + offset_string = formatter.get_offset() + return ticks, ticklabels, offset_string + + ax = self.ax + xticks, xticklabels, xoffset_string = _make_ticker(self.norm.norm1) + yticks, yticklabels, yoffset_string = _make_ticker(self.norm.norm2) + + ax.xaxis.set_ticks(xticks) + ax.set_xticklabels(xticklabels) + ax.xaxis.get_major_formatter().set_offset_string(xoffset_string) + + ax.yaxis.set_ticks(yticks) + ax.set_yticklabels(yticklabels) + ax.yaxis.get_major_formatter().set_offset_string(yoffset_string) + + def set_ticks(self, xticks, yticks, update_ticks=True): + if cbook.iterable(xticks): + self.xlocator = ticker.FixedLocator(xticks, nbins=len(xticks)) + else: + self.xlocator = xticks + + if cbook.iterable(yticks): + self.ylocator = ticker.FixedLocator(yticks, nbins=len(yticks)) + else: + self.ylocator = yticks + + if update_ticks: + self.update_ticks() + self.stale = True + + def set_ticklabels(self, xticklabels=None, yticklabels=None, + update_ticks=True): + """ + set tick labels. Tick labels are updated immediately unless + update_ticks is *False*. To manually update the ticks, call + *update_ticks* method explicitly. + """ + if xticklabels is not None or yticklabels is not None: + if isinstance(self.xlocator, ticker.FixedLocator): + self.xformatter = ticker.FixedFormatter(xticklabels) + + if isinstance(self.ylocator, ticker.FixedLocator): + self.yformatter = ticker.FixedFormatter(yticklabels) + + if update_ticks: + self.update_ticks() + + self.stale = True + return + + warnings.warn("set_ticks() must have been called.") + self.stale = True + + def _set_label(self): + self.ax.set_ylabel(self._ylabel, **self._labelkw) + self.ax.set_xlabel(self._xlabel, **self._labelkw) + self.stale = True + + def set_label(self, xlabel, ylabel, **kw): + """ + Label the axes of the colorbar + """ + self._xlabel = '%s' % (xlabel, ) + self._ylabel = '%s' % (ylabel, ) + self._labelkw = kw + self._set_label() + + def _edges(self, X, Y): + ''' + Return the separator line segments; helper for _add_solids. + ''' + N = X.shape[0] + return ([list(zip(X[i], Y[i])) for i in xrange(1, N - 1)] + + [list(zip(Y[i], X[i])) for i in xrange(1, N - 1)]) + + def _add_solids(self, X, Y, C): + """ + Draw the colors using :meth:`~matplotlib.axes.Axes.pcolormesh`; + optionally add separators. + """ + args = (X, Y, C) + kw = dict(cmap=self.cmap, + norm=self.norm, + alpha=self.alpha, + edgecolors='None') + # Save, set, and restore hold state to keep pcolor from + # clearing the axes. Ordinarily this will not be needed, + # since the axes object should already have hold set. + _hold = self.ax._hold + self.ax._hold = True + col = self.ax.pcolormesh(*args, **kw) + self.ax._hold = _hold + + if self.solids is not None: + self.solids.remove() + self.solids = col + if self.dividers is not None: + self.dividers.remove() + self.dividers = None + if self.drawedges: + linewidths = (0.5 * mpl.rcParams['axes.linewidth'],) + self.dividers = collections.LineCollection(self._edges(X, Y), + colors=(mpl.rcParams['axes.edgecolor'],), + linewidths=linewidths) + self.ax.add_collection(self.dividers) + elif(len(self._y) >= self.n_rasterize + or len(self._x) >= self.n_rasterize): + self.solids.set_rasterized(True) + + def _process_values(self, b=None, norm=None): + if norm is self.norm.norm1: + boundaries = self.xboundaries + values = self.xvalues + else: + boundaries = self.yboundaries + values = self.yvalues + if b is None: + b = boundaries + if b is not None: + b = np.asarray(b, dtype=float) + if values is None: + v = 0.5 * (b[:-1] + b[1:]) + if isinstance(norm, colors.NoNorm): + v = (v + 0.00001).astype(np.int16) + return v, b + v = np.array(self.values) + return v, b + if values is not None: + v = np.array(values) + if boundaries is None: + b = np.zeros(len(values) + 1, 'd') + b[1:-1] = 0.5 * (v[:-1] - v[1:]) + b[0] = 2.0 * b[1] - b[2] + b[-1] = 2.0 * b[-2] - b[-3] + return v, b + b = np.array(boundaries) + return v, b + # Neither boundaries nor values are specified; + # make reasonable ones based on cmap and norm. + N = self.cmap.N + if isinstance(norm, colors.NoNorm): + b = np.linspace(0, 1, np.sqrt(N) + 1) * np.sqrt(N) - 0.5 + v = np.zeros((len(b) - 1,), dtype=np.int16) + v[self._inside] = np.arange(np.sqrt(N), dtype=np.int16) + return v, b + elif isinstance(norm, colors.BoundaryNorm): + b = list(norm.boundaries) + b = np.array(b) + v = np.zeros((len(b) - 1,), dtype=float) + bi = norm.boundaries + v[self._inside] = 0.5 * (bi[:-1] + bi[1:]) + return v, b + else: + if not norm.scaled(): + norm.vmin = 0 + norm.norm1.vmax = 1 + + norm.vmin, norm.vmax = mtransforms.nonsingular( + norm.vmin, norm.vmax, expander=0.1) + + b = norm.inverse(np.linspace(0, 1, np.sqrt(N) + 1)) + + return self._process_values(b=b, norm=norm) + + def _mesh(self): + """ + Return X,Y, the coordinate arrays for the colorbar pcolormesh. + """ + x = np.linspace(0, 1, len(self._xboundaries)) + y = np.linspace(0, 1, len(self._yboundaries)) + self._x = x + self._y = y + X, Y = np.meshgrid(x, y) + return X, Y + + def _locate(self, x, norm): + """ + Given a set of color data values, return their + corresponding colorbar data coordinates. + """ + if norm is self.norm.norm1: + boundaries = self._xboundaries + else: + boundaries = self._yboundaries + if isinstance(norm, (colors.NoNorm, colors.BoundaryNorm)): + b = boundaries + xn = x + else: + # Do calculations using normalized coordinates so + # as to make the interpolation more accurate. + b = norm(boundaries, clip=False).filled() + xn = norm(x, clip=False).filled() + + # The rest is linear interpolation with extrapolation at ends. + ii = np.searchsorted(b, xn) + i0 = ii - 1 + itop = (ii == len(b)) + ibot = (ii == 0) + i0[itop] -= 1 + ii[itop] -= 1 + i0[ibot] += 1 + ii[ibot] += 1 + + db = np.take(b, ii) - np.take(b, i0) + y = self._y + dy = np.take(y, ii) - np.take(y, i0) + z = np.take(y, i0) + (xn - np.take(b, i0)) * dy / db + return z + + def set_alpha(self, alpha): + self.alpha = alpha + + def remove(self): + """ + Remove this colorsquare from the figure + """ + fig = self.ax.figure + fig.delaxes(self.ax) + + class Colorbar(ColorbarBase): """ This class connects a :class:`ColorbarBase` to a @@ -1055,6 +1455,81 @@ def remove(self): ax.set_subplotspec(subplotspec) +class Colorsquare(ColorsquareBase): + """ + This class connects a :class:`Colorbarsquare` to a + :class:`~matplotlib.cm.ScalarMappable` such as a + :class:`~matplotlib.image.AxesImage` generated via + :meth:`~matplotlib.axes.Axes.imshow`. + + It is not intended to be instantiated directly; instead, + use :meth:`~matplotlib.figure.Figure.colorbar` or + :func:`~matplotlib.pyplot.colorbar` to make your colorsquare. + + """ + def __init__(self, ax, mappable, **kw): + # Ensure the given mappable's norm has appropriate vmin and vmax set + # even if mappable.draw has not yet been called. + mappable.autoscale_None() + + self.mappable = mappable + kw['cmap'] = cmap = mappable.cmap + kw['norm'] = norm = mappable.norm + + if isinstance(mappable, martist.Artist): + kw['alpha'] = mappable.get_alpha() + + ColorsquareBase.__init__(self, ax, **kw) + + def on_mappable_changed(self, mappable): + """ + Updates this colorsquare to match the mappable's properties. + + Typically this is automatically registered as an event handler + by :func:`colorbar_factory` and should not be called manually. + + """ + self.set_cmap(mappable.get_cmap()) + self.set_clim(mappable.get_clim()) + self.update_normal(mappable) + + def update_normal(self, mappable): + ''' + update solid, lines, etc. Unlike update_bruteforce, it does + not clear the axes. This is meant to be called when the image + or contour plot to which this colorsquare belongs is changed. + ''' + self.draw_all() + self.stale = True + + def remove(self): + """ + Remove this colorsquare from the figure. If the colorsquare was + created with ``use_gridspec=True`` then restore the gridspec to its + previous value. + """ + Colorbarsquare.remove(self) + self.mappable.callbacksSM.disconnect(self.mappable.colorbar_cid) + self.mappable.colorbar = None + self.mappable.colorbar_cid = None + + try: + ax = self.mappable.axes + except AttributeError: + return + + try: + gs = ax.get_subplotspec().get_gridspec() + subplotspec = gs.get_topmost_subplotspec() + except AttributeError: + # use_gridspec was False + pos = ax.get_position(original=True) + ax.set_position(pos) + else: + # use_gridspec was True + ax.set_subplotspec(subplotspec) + + @docstring.Substitution(make_axes_kw_doc) def make_axes(parents, location=None, orientation=None, fraction=0.15, shrink=1.0, aspect=20, **kw): @@ -1361,6 +1836,10 @@ def colorbar_factory(cax, mappable, **kwargs): if (isinstance(mappable, contour.ContourSet) and any([hatch is not None for hatch in mappable.hatches])): cb = ColorbarPatch(cax, mappable, **kwargs) + elif (isinstance(mappable.norm, colors.BivariateNorm)): + kwargs.pop('orientation', None) + kwargs.pop('ticklocation', None) + cb = Colorsquare(cax, mappable, **kwargs) else: cb = Colorbar(cax, mappable, **kwargs) diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index 45fc27a0b353..59cc9aac240d 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -68,6 +68,7 @@ import numpy as np import matplotlib.cbook as cbook from ._color_data import BASE_COLORS, TABLEAU_COLORS, CSS4_COLORS, XKCD_COLORS +from abc import ABCMeta class _ColorMapping(dict): @@ -472,11 +473,18 @@ def __call__(self, X, alpha=None, bytes=False): xa = np.array([X]) else: vtype = 'array' + if isinstance(self, BivariateColormap): + vals = np.array([1, 0], dtype=X.dtype) + almost_one = np.nextafter(*vals) + np.copyto(X, almost_one, where=X == 1.0) + X[0] = X[0] * 256 + X[1] = X[1] * 256 + X = X.astype(int) + X = X[0] + X[1] * 256 xma = np.ma.array(X, copy=True) # Copy here to avoid side effects. mask_bad = xma.mask # Mask will be used below. xa = xma.filled() # Fill to avoid infs, etc. del xma - # Calculations with native byteorder are faster, and avoid a # bug that otherwise can occur with putmask when the last # argument is a numpy scalar. @@ -858,7 +866,42 @@ def reversed(self, name=None): return ListedColormap(colors_r, name=name, N=self.N) -class Normalize(object): +class BivariateColormap(Colormap): + def __init__(self, name='bivariate', N=256): + Colormap.__init__(self, name, N) + self.N = self.N * self.N + + def _init(self): + red = np.linspace(0, 1, np.sqrt(self.N)) + green = np.linspace(0, 1, np.sqrt(self.N)) + red_mesh, green_mesh = np.meshgrid(red, green) + blue_mesh = np.zeros_like(red_mesh) + alpha_mesh = np.ones_like(red_mesh) + bivariate_cmap = np.dstack((red_mesh, green_mesh, blue_mesh, + alpha_mesh)) + self._lut = np.vstack(bivariate_cmap) + self._isinit = True + self._set_extremes() + + def _resample(self, lutsize): + """ + Return a new color map with *lutsize x lutsize* entries. + """ + return BivariateColormap(self.name, lutsize) + + def reversed(self, name=None): + raise NotImplementedError + + +@six.add_metaclass(ABCMeta) +class Norms: + """ + Abstract Base Class to group `Normalize` and `BivariateNorm` + """ + pass + + +class Normalize(Norms): """ A class which, when called, can normalize data into the ``[0.0, 1.0]`` interval. @@ -1350,6 +1393,67 @@ def inverse(self, value): return value +class BivariateNorm(Norms): + """ + Normalize a list of two values corresponding to two 1D normalizers + """ + def __init__(self, norm1=None, norm2=None): + """ + Parameters + ---------- + norm1 : + An instance of 1D normalizers + norm2 : + An instance of 1D normalizers + """ + if norm1 is None: + self.norm1 = Normalize() + else: + self.norm1 = norm1 + if norm2 is None: + self.norm2 = Normalize() + else: + self.norm2 = norm2 + + def __call__(self, values, clip=None): + """ + Parameters + ---------- + values : array-like + A list of two values to be normalized + clip : list of bools, None, optional + A list of two bools corresponding to value in values. + If clip is None then clip is set according to corresponding + normalizers. + + Returns + ------- + A list of two normalized values according to corresponding 1D + normalizers. + """ + if clip is None: + clip = [self.norm1.clip, self.norm2.clip] + + return np.asarray([self.norm1(values[0], clip=clip[0]), + self.norm2(values[1], clip=clip[1])]) + + def autoscale(self, A): + """ + Set *vmin*, *vmax* to min, max of *A*. + """ + self.norm1.autoscale(A[0]) + self.norm2.autoscale(A[1]) + + def autoscale_None(self, A): + 'autoscale only None-valued vmin or vmax' + self.norm1.autoscale_None(A[0]) + self.norm2.autoscale_None(A[1]) + + def scaled(self): + 'return true if vmin and vmax set for both normalizers' + return self.norm1.scaled() and self.norm2.scaled() + + def rgb_to_hsv(arr): """ convert float rgb values (in the range [0, 1]), in a numpy array to hsv diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index d47fa4d2d935..8cc95dc78569 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -35,6 +35,7 @@ from matplotlib.image import FigureImage import matplotlib.colorbar as cbar +import matplotlib.colors as mcolors from matplotlib.axes import Axes, SubplotBase, subplot_class_factory from matplotlib.blocking_input import BlockingMouseInput, BlockingKeyMouseInput @@ -1834,6 +1835,9 @@ def colorbar(self, mappable, cax=None, ax=None, use_gridspec=True, **kw): # Store the value of gca so that we can set it back later on. current_ax = self.gca() + if isinstance(mappable.norm, mcolors.BivariateNorm): + kw['fraction'] = 0.30 + kw['aspect'] = 1 if cax is None: if use_gridspec and isinstance(ax, SubplotBase): cax, kw = cbar.make_axes_gridspec(ax, **kw) diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index d34f98b4e45a..db57edda9407 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -253,7 +253,11 @@ def get_size(self): if self._A is None: raise RuntimeError('You must first set the image array') - return self._A.shape[:2] + if isinstance(self.norm, mcolors.BivariateNorm): + imshape = self._A.shape[1:] + else: + imshape = self._A.shape[:2] + return imshape def set_alpha(self, alpha): """ @@ -300,6 +304,12 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0, `trans` is the affine transformation from the image to pixel space. """ + if isinstance(self.norm, mcolors.BivariateNorm): + imwidth = A.shape[1] + imheight = A.shape[2] + else: + imwidth = A.shape[0] + imheight = A.shape[1] if A is None: raise RuntimeError('You must first set the image ' 'array or the image attribute') @@ -323,15 +333,15 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0, # Flip the input image using a transform. This avoids the # problem with flipping the array, which results in a copy # when it is converted to contiguous in the C wrapper - t0 = Affine2D().translate(0, -A.shape[0]).scale(1, -1) + t0 = Affine2D().translate(0, -imwidth).scale(1, -1) else: t0 = IdentityTransform() t0 += ( Affine2D() .scale( - in_bbox.width / A.shape[1], - in_bbox.height / A.shape[0]) + in_bbox.width / imheight, + in_bbox.height / imwidth) .translate(in_bbox.x0, in_bbox.y0) + self.get_transform()) @@ -362,16 +372,18 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0, if A.ndim not in (2, 3): raise ValueError("Invalid dimensions, got {}".format(A.shape)) - if A.ndim == 2: + if A.ndim == 2 or (A.ndim == 3 and + isinstance(self.norm, mcolors.BivariateNorm)): A = self.norm(A) - if A.dtype.kind == 'f': + if (A.dtype.kind == 'f' and + not isinstance(self.norm, mcolors.BivariateNorm)): # If the image is greyscale, convert to RGBA and # use the extra channels for resizing the over, # under, and bad pixels. This is needed because # Agg's resampler is very aggressive about # clipping to [0, 1] and we use out-of-bounds # values to carry the over/under/bad information - rgba = np.empty((A.shape[0], A.shape[1], 4), dtype=A.dtype) + rgba = np.empty((imwidth, imheight, 4), dtype=A.dtype) rgba[..., 0] = A # normalized data # this is to work around spurious warnings coming # out of masked arrays. @@ -410,9 +422,10 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0, if not created_rgba_mask: # Always convert to RGBA, even if only RGB input + isBivari = (A.ndim == 3 and A.shape[0] == 2) if A.shape[2] == 3: A = _rgb_to_rgba(A) - elif A.shape[2] != 4: + elif A.shape[2] != 4 and not isBivari: raise ValueError("Invalid dimensions, got %s" % (A.shape,)) output = np.zeros((out_height, out_width, 4), dtype=A.dtype) @@ -595,8 +608,9 @@ def set_data(self, A): not np.can_cast(self._A.dtype, float, "same_kind")): raise TypeError("Image data cannot be converted to float") - if not (self._A.ndim == 2 - or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]): + isRGB = (self._A.ndim == 3 and self._A.shape[-1] in [3, 4]) + isBivari = (self._A.ndim == 3 and self._A.shape[0] == 2) + if not (self._A.ndim == 2 or isRGB or isBivari): raise TypeError("Invalid dimensions for image data") self._imcache = None diff --git a/lib/matplotlib/mpl-data/sample_data/air_temperature.npy b/lib/matplotlib/mpl-data/sample_data/air_temperature.npy new file mode 100644 index 000000000000..cfcbf84f9b6a Binary files /dev/null and b/lib/matplotlib/mpl-data/sample_data/air_temperature.npy differ diff --git a/lib/matplotlib/mpl-data/sample_data/surface_pressure.npy b/lib/matplotlib/mpl-data/sample_data/surface_pressure.npy new file mode 100644 index 000000000000..19c6cedd5b47 Binary files /dev/null and b/lib/matplotlib/mpl-data/sample_data/surface_pressure.npy differ diff --git a/lib/matplotlib/tests/baseline_images/test_axes/bivar_imshow.png b/lib/matplotlib/tests/baseline_images/test_axes/bivar_imshow.png new file mode 100644 index 000000000000..757a1929ec66 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_axes/bivar_imshow.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolor.png b/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolor.png new file mode 100644 index 000000000000..051a32c265d4 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolor.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolorfast.png b/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolorfast.png new file mode 100644 index 000000000000..b060fecd5d5f Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolorfast.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolormesh.png b/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolormesh.png new file mode 100644 index 000000000000..051a32c265d4 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_axes/bivar_pcolormesh.png differ diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index f7e5ffe87cb2..294c6e4ae0cb 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -28,6 +28,7 @@ from numpy.testing import assert_allclose, assert_array_equal from matplotlib.cbook import IgnoredKeywordWarning from matplotlib.cbook._backports import broadcast_to +from matplotlib.cbook import get_sample_data # Note: Some test cases are run twice: once normally and once with labeled data # These two must be defined in the same test function or need to have @@ -5271,3 +5272,30 @@ def test_twinx_knows_limits(): def test_zero_linewidth(): # Check that setting a zero linewidth doesn't error plt.plot([0, 1], [0, 1], ls='--', lw=0) + + +@image_comparison( + baseline_images=['bivar_imshow', 'bivar_pcolor', 'bivar_pcolormesh', + 'bivar_pcolorfast'], + extensions=['png'] +) +def test_bivariates(): + air_temp = np.load(get_sample_data('air_temperature.npy')) + surf_pres = np.load(get_sample_data('surface_pressure.npy')) + bivariate = [air_temp, surf_pres] + + fig1, ax1 = plt.subplots() + cax1 = ax1.imshow(bivariate, norm=mcolors.BivariateNorm()) + cbar = fig1.colorbar(cax1, xlabel='air_temp', ylabel='surf_pres') + + fig2, ax2 = plt.subplots() + cax2 = ax2.pcolor(bivariate, norm=mcolors.BivariateNorm()) + cbar = fig2.colorbar(cax2, xlabel='air_temp', ylabel='surf_pres') + + fig3, ax3 = plt.subplots() + cax3 = ax3.pcolormesh(bivariate, norm=mcolors.BivariateNorm()) + cbar = fig3.colorbar(cax3, xlabel='air_temp', ylabel='surf_pres') + + fig4, ax4 = plt.subplots() + cax4 = ax4.pcolorfast(bivariate, norm=mcolors.BivariateNorm()) + cbar = fig4.colorbar(cax4, xlabel='air_temp', ylabel='surf_pres') diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index 721813e62f8f..d18ab41ecf19 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -707,3 +707,15 @@ def __add__(self, other): mcolors.SymLogNorm(3, vmax=5, linscale=1), mcolors.PowerNorm(1)]: assert_array_equal(norm(data.view(MyArray)), norm(data)) + + +@pytest.mark.parametrize('norm', [ + mcolors.Normalize(), mcolors.LogNorm(), mcolors.BivariateNorm() + ] +) +def test_abstract_base_class_norms(norm): + """ + Test that all types of normalizers subclasses Abstract Base class + `colors.Norms` + """ + assert isinstance(norm, mcolors.Norms)