diff --git a/examples/color/colormap_normalizations_funcnorm.py b/examples/color/colormap_normalizations_funcnorm.py new file mode 100644 index 000000000000..dacddbb97c2b --- /dev/null +++ b/examples/color/colormap_normalizations_funcnorm.py @@ -0,0 +1,73 @@ +""" +===================================================================== +Examples of normalization using :class:`~matplotlib.colors.FuncNorm` +===================================================================== + +This is an example on how to perform a normalization using an arbitrary +function with :class:`~matplotlib.colors.FuncNorm`. A logarithm normalization +and a square root normalization will be use as examples. + +""" + +import matplotlib.cm as cm +import matplotlib.colors as colors +import matplotlib.pyplot as plt + +import numpy as np + +norm_log = colors.FuncNorm(f='log10', vmin=0.01) +# The same can be achieved with +# norm_log = colors.FuncNorm(f=np.log10, +# finv=lambda x: 10.**(x), vmin=0.01) + +norm_sqrt = colors.FuncNorm(f='sqrt', vmin=0.0) +# The same can be achieved with +# norm_sqrt = colors.FuncNorm(f='root{2}', vmin=0.) +# or with +# norm_sqrt = colors.FuncNorm(f=lambda x: x**0.5, +# finv=lambda x: x**2, vmin=0.0) + +normalizations = [(None, 'Regular linear scale'), + (norm_log, 'Log normalization'), + (norm_sqrt, 'Root normalization')] + +# Fabricating some data +x = np.linspace(0, 1, 300) +y = np.linspace(-1, 1, 90) +X, Y = np.meshgrid(x, y) + +data = np.zeros(X.shape) + + +def gauss2d(x, y, a0, x0, y0, wx, wy): + return a0 * np.exp(-(x - x0)**2 / wx**2 - (y - y0)**2 / wy**2) + +for x in np.linspace(0., 1, 15): + data += gauss2d(X, Y, x, x, 0, 0.25 / 15, 0.25) + +data -= data.min() +data /= data.max() + +# Using the custom normalizations to plot the data +fig, axes = plt.subplots(3, 2, sharex='col', + gridspec_kw={'width_ratios': [1, 3.5]}, + figsize=plt.figaspect(0.6)) + +for (ax_left, ax_right), (norm, title) in zip(axes, normalizations): + + # Showing the normalization effect on an image + cax = ax_right.imshow(data, cmap=cm.afmhot, norm=norm, aspect='auto') + fig.colorbar(cax, format='%.3g', ax=ax_right) + ax_right.set_title(title) + ax_right.xaxis.set_ticks([]) + ax_right.yaxis.set_ticks([]) + + # Plotting the behaviour of the normalization + d_values = np.linspace(cax.norm.vmin, cax.norm.vmax, 100) + cm_values = cax.norm(d_values) + ax_left.plot(d_values, cm_values) + ax_left.set_ylabel('Colormap values') + +ax_left.set_xlabel('Data values') + +plt.show() diff --git a/lib/matplotlib/colorbar.py b/lib/matplotlib/colorbar.py index df488bba7811..072243846670 100644 --- a/lib/matplotlib/colorbar.py +++ b/lib/matplotlib/colorbar.py @@ -583,6 +583,9 @@ def _ticker(self): locator = ticker.FixedLocator(b, nbins=10) elif isinstance(self.norm, colors.LogNorm): locator = ticker.LogLocator(subs='all') + elif isinstance(self.norm, colors.FuncNorm): + locator = ticker.FuncLocator(self.norm.__call__, + self.norm.inverse) elif isinstance(self.norm, colors.SymLogNorm): # The subs setting here should be replaced # by logic in the locator. diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index 4a58161cdc61..925434c847e7 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -960,6 +960,164 @@ def scaled(self): return (self.vmin is not None and self.vmax is not None) +class FuncNorm(Normalize): + """ + A norm based on a monotonic custom function. + + The norm will use a provided custom function to map the data + values into colormap values in the [0,1] range. It will be calculated + as (f(x)-f(vmin))/(f(vmax)-f(vmin)). + + Parameters + ---------- + f : callable or string + Function to be used for the normalization receiving a single + parameter, compatible with scalar values and arrays. + Alternatively some predefined functions may be specified + as a string (See Notes). The chosen function must be strictly + increasing and bounded in the [`vmin`, `vmax`] interval. + finv : callable, optional + Inverse of `f` satisfying finv(f(x)) == x. Optional and ignored + when `f` is a string; otherwise, required. + vmin, vmax : None or float, optional + Data values to be mapped to 0 and 1. If either is None, it is + assigned the minimum or maximum value of the data supplied to + the first call of the norm. Default None. + clip : bool, optional + If True, clip data values to [`vmin`, `vmax`]. This effectively + defeats the purpose of setting the over and under values of the + color map. If False, values below `vmin` and above `vmax` will + be mapped to -0.1 and 1.1 respectively. Default False. + + Notes + ----- + Valid predefined functions are ['linear', 'quadratic', + 'cubic', 'x**{p}', 'sqrt', 'cbrt', 'root{p}(x)', 'log', 'log10', + 'log2', 'log{p}(x)', 'log(x+{p}) 'log10(x+{p})', 'log{p}(x+{p})] + where 'p' must be replaced by the corresponding value of the + parameter when present. + + Examples + -------- + Creating a logarithmic normalization using the predefined strings: + + >>> import matplotlib.colors as colors + >>> norm = colors.FuncNorm(f='log10', vmin=0.01, vmax=2) + + Or manually: + + >>> import matplotlib.colors as colors + >>> norm = colors.FuncNorm(f=lambda x: np.log10(x), + ... finv=lambda x: 10.**(x), + ... vmin=0.01, vmax=2) + + """ + + def __init__(self, f, finv=None, vmin=None, vmax=None, clip=False): + super(FuncNorm, self).__init__(vmin=vmin, vmax=vmax, clip=clip) + + if isinstance(f, six.string_types): + func_parser = cbook._StringFuncParser(f) + f = func_parser.function + finv = func_parser.inverse + if not callable(f): + raise ValueError("`f` must be a callable or a string.") + if finv is None: + raise ValueError("Inverse function `finv` not provided.") + if not callable(finv): + raise ValueError("`finv` must be a callable.") + + self._f = f + self._finv = finv + + def _update_f(self, vmin, vmax): + # This method is to be used by derived classes in cases where + # the limits vmin and vmax may require changing/updating the + # function depending on vmin/vmax, for example rescaling it + # to accommodate to the new interval. + pass + + def __call__(self, value, clip=None): + """ + Normalizes `value` data in the ``[vmin, vmax]`` interval into + the ``[0.0, 1.0]`` interval and returns it. + + Parameters + ---------- + value : scalar or array-like + Data to be normalized. + clip : boolean, optional + Whether to clip the data outside the ``[`vmin`, `vmax`]`` limits. + Default `self.clip` from `Normalize` (which defaults to `False`). + + Returns + ------- + result : masked array of floats + Normalized data to the ``[0.0, 1.0]`` interval. If `clip` == False, + values smaller than `vmin` or greater than `vmax` will be clipped + to -0.1 and 1.1 respectively. + + """ + if clip is None: + clip = self.clip + + result, is_scalar = self.process_value(value) + self.autoscale_None(result) + + vmin, vmax = self._check_vmin_vmax() + + self._update_f(vmin, vmax) + + if clip: + result = np.clip(result, vmin, vmax) + resultnorm = ((self._f(result) - self._f(vmin)) / + (self._f(vmax) - self._f(vmin))) + else: + resultnorm = result.copy() + mask_over = result > vmax + mask_under = result < vmin + mask = ~(mask_over | mask_under) + # Since the non linear function is arbitrary and may not be + # defined outside the boundaries, we just set obvious under + # and over values + resultnorm[mask_over] = 1.1 + resultnorm[mask_under] = -0.1 + resultnorm[mask] = ((self._f(result[mask]) - self._f(vmin)) / + (self._f(vmax) - self._f(vmin))) + + if is_scalar: + return resultnorm[0] + else: + return resultnorm + + def inverse(self, value): + """ + Performs the inverse normalization from the ``[0.0, 1.0]`` into the + ``[`vmin`, `vmax`]`` interval and returns it. + + Parameters + ---------- + value : float or ndarray of floats + Data in the ``[0.0, 1.0]`` interval. + + Returns + ------- + result : float or ndarray of floats + Data before normalization. + + """ + vmin, vmax = self._check_vmin_vmax() + self._update_f(vmin, vmax) + value = self._finv( + value * (self._f(vmax) - self._f(vmin)) + self._f(vmin)) + return value + + def _check_vmin_vmax(self): + if self.vmin >= self.vmax: + raise ValueError("vmin must be smaller than vmax") + return float(self.vmin), float(self.vmax) + + class LogNorm(Normalize): """ Normalize a given value to the 0-1 range on a log scale diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index ebfc41aa53e3..78497ec62751 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -146,6 +146,63 @@ def test_BoundaryNorm(): assert_true(np.all(bn(vals).mask)) +class TestFuncNorm(object): + def test_limits_with_string(self): + norm = mcolors.FuncNorm(f='log10', vmin=0.01, vmax=2.) + assert_array_equal(norm([0.01, 2]), [0, 1.0]) + + def test_limits_with_lambda(self): + norm = mcolors.FuncNorm(f=lambda x: np.log10(x), + finv=lambda x: 10.**(x), + vmin=0.01, vmax=2.) + assert_array_equal(norm([0.01, 2]), [0, 1.0]) + + def test_limits_without_vmin_vmax(self): + norm = mcolors.FuncNorm(f='log10') + assert_array_equal(norm([0.01, 2]), [0, 1.0]) + + def test_limits_without_vmin(self): + norm = mcolors.FuncNorm(f='log10', vmax=2.) + assert_array_equal(norm([0.01, 2]), [0, 1.0]) + + def test_limits_without_vmax(self): + norm = mcolors.FuncNorm(f='log10', vmin=0.01) + assert_array_equal(norm([0.01, 2]), [0, 1.0]) + + def test_clip_true(self): + norm = mcolors.FuncNorm(f='log10', vmin=0.01, vmax=2., + clip=True) + assert_array_equal(norm([0.0, 2.5]), [0.0, 1.0]) + + def test_clip_false(self): + norm = mcolors.FuncNorm(f='log10', vmin=0.01, vmax=2., + clip=False) + assert_array_equal(norm([0.0, 2.5]), [-0.1, 1.1]) + + def test_clip_default_false(self): + norm = mcolors.FuncNorm(f='log10', vmin=0.01, vmax=2.) + assert_array_equal(norm([0.0, 2.5]), [-0.1, 1.1]) + + def test_intermediate_values(self): + norm = mcolors.FuncNorm(f='log10') + assert_array_almost_equal(norm([0.01, 0.5, 2]), + [0, 0.73835195870437, 1.0]) + + def test_inverse(self): + norm = mcolors.FuncNorm(f='log10', vmin=0.01, vmax=2.) + x = np.linspace(0.01, 2, 10) + assert_array_almost_equal(x, norm.inverse(norm(x))) + + def test_scalar(self): + norm = mcolors.FuncNorm(f='linear', vmin=1., vmax=2., + clip=True) + assert_equal(norm(1.5), 0.5) + assert_equal(norm(1.), 0.) + assert_equal(norm(0.5), 0.) + assert_equal(norm(2.), 1.) + assert_equal(norm(2.5), 1.) + + def test_LogNorm(): """ LogNorm ignored clip, now it has the same diff --git a/lib/matplotlib/tests/test_ticker.py b/lib/matplotlib/tests/test_ticker.py index e390237d543c..790160091eee 100644 --- a/lib/matplotlib/tests/test_ticker.py +++ b/lib/matplotlib/tests/test_ticker.py @@ -1,7 +1,8 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) -from numpy.testing import assert_almost_equal +from numpy.testing import (assert_almost_equal, + assert_array_almost_equal) import numpy as np import pytest @@ -75,6 +76,31 @@ def test_LogLocator(): assert_almost_equal(loc.tick_values(1, 100), test_value) +class TestFuncLocator(object): + def test_call(self): + loc = mticker.FuncLocator(np.sqrt, lambda x: x**2) + expected = [0., 0.01, 0.04, 0.09, 0.16, 0.25, 0.4, + 0.49, 0.6, 0.8, 1.] + assert_array_almost_equal(loc(), expected) + + def test_tick_values(self): + loc = mticker.FuncLocator(np.sqrt, lambda x: x**2) + expected = [0., 0.01, 0.04, 0.09, 0.16, 0.25, 0.4, + 0.49, 0.6, 0.8, 1.] + assert_array_almost_equal(loc.tick_values(), expected) + + def test_set_params(self): + loc = mticker.FuncLocator(lambda x: x, lambda x: x, 6) + expected = [0., 0.2, 0.4, 0.6, 0.8, 1.] + assert_array_almost_equal(loc.tick_values(), expected) + loc.set_params(function=np.sqrt, + inverse=lambda x: x**2, + numticks=11) + expected = [0., 0.01, 0.04, 0.09, 0.16, 0.25, 0.4, + 0.49, 0.6, 0.8, 1.] + assert_array_almost_equal(loc.tick_values(), expected) + + def test_LinearLocator_set_params(): """ Create linear locator with presets={}, numticks=2 and change it to diff --git a/lib/matplotlib/ticker.py b/lib/matplotlib/ticker.py index afea9f748e05..ee6e5e55dbcf 100644 --- a/lib/matplotlib/ticker.py +++ b/lib/matplotlib/ticker.py @@ -1973,6 +1973,93 @@ def is_close_to_int(x): return abs(x - nearest_long(x)) < 1e-10 +class FuncLocator(Locator): + """ + Determines the tick locations for using user provided functions. + + It attempts to provide a fixed number `numticks` of tick locations + relatively uniformly spread across the axis, while rounding the + values as much as possible. + + Parameters + ---------- + function : callable + Transformation of the axis using the ticks. + inverse : callable + Inverse transformation of `function`. + numticks : integer, optional + Number of ticks to include. Default 11. + + """ + def __init__(self, function, inverse, numticks=None): + self._numticks = numticks + self._function = function + self._inverse = inverse + + def tick_values(self, vmin=None, vmax=None): + """ + Returns the tick locations + + Parameters + ---------- + vmin, vmax : integer, optional + Maximum and minimum values. Not used. + + Returns + ------- + ticks : ndarray + 1d array of length `numticks` with the proposed tick locations. + + """ + + if self._numticks is None: + self._set_numticks() + + ticks = self._inverse(np.linspace(0, 1, self._numticks)) + finalticks = np.zeros(ticks.shape, dtype=np.bool) + finalticks[0] = True + finalticks[-1] = True + ticks = FuncLocator._round_ticks(ticks, finalticks) + return ticks + + def _set_numticks(self): + self._numticks = 11 + + def set_params(self, function=None, inverse=None, numticks=None): + """Set parameters within this locator.""" + if inverse is not None: + self._inverse = inverse + if function is not None: + self._function = function + if numticks is not None: + self._numticks = numticks + + def __call__(self): + """ + Returns the tick locations + + Returns + ------- + ticks : ndarray + 1d array of length `numticks` with the proposed tick locations. + + """ + return self.tick_values() + + @staticmethod + def _round_ticks(ticks, permanent_tick): + ticks = ticks.copy() + for i in range(len(ticks)): + if i == 0 or i == len(ticks) - 1 or permanent_tick[i]: + continue + d1 = ticks[i] - ticks[i - 1] + d2 = ticks[i + 1] - ticks[i] + d = min([d1, d2]) + order = -np.floor(np.log10(d)) + ticks[i] = float(np.round(ticks[i] * 10**order)) / 10**order + return ticks + + class LogLocator(Locator): """ Determine the tick locations for log axes