diff --git a/lib/matplotlib/cm.py b/lib/matplotlib/cm.py index 0949563ca19a..e914acf71888 100644 --- a/lib/matplotlib/cm.py +++ b/lib/matplotlib/cm.py @@ -314,9 +314,9 @@ def set_clim(self, vmin=None, vmax=None): except (TypeError, ValueError): pass if vmin is not None: - self.norm.vmin = vmin + self.norm.vmin = colors._sanitize_extrema(vmin) if vmax is not None: - self.norm.vmax = vmax + self.norm.vmax = colors._sanitize_extrema(vmax) self.changed() def set_cmap(self, cmap): diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index 681b1bc32d38..c2aa7fca82d3 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -94,6 +94,16 @@ def get_named_colors_mapping(): return _colors_full_map +def _sanitize_extrema(ex): + if ex is None: + return ex + try: + ret = np.asscalar(ex) + except AttributeError: + ret = float(ex) + return ret + + def _is_nth_color(c): """Return whether *c* can be interpreted as an item in the color cycle.""" return isinstance(c, six.string_types) and re.match(r"\AC[0-9]\Z", c) @@ -878,8 +888,8 @@ def __init__(self, vmin=None, vmax=None, clip=False): likely to lead to surprises; therefore the default is *clip* = *False*. """ - self.vmin = vmin - self.vmax = vmax + self.vmin = _sanitize_extrema(vmin) + self.vmax = _sanitize_extrema(vmax) self.clip = clip @staticmethod diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index 7de686665c86..7a19d5f135d0 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -705,11 +705,18 @@ def __add__(self, other): raise RuntimeError data = np.arange(-10, 10, 1, dtype=float) + data.shape = (10, 2) + mydata = data.view(MyArray) for norm in [mcolors.Normalize(), mcolors.LogNorm(), mcolors.SymLogNorm(3, vmax=5, linscale=1), + mcolors.Normalize(vmin=mydata.min(), vmax=mydata.max()), + mcolors.SymLogNorm(3, vmin=mydata.min(), vmax=mydata.max()), mcolors.PowerNorm(1)]: - assert_array_equal(norm(data.view(MyArray)), norm(data)) + assert_array_equal(norm(mydata), norm(data)) + fig, ax = plt.subplots() + ax.imshow(mydata, norm=norm) + fig.canvas.draw() if isinstance(norm, mcolors.PowerNorm): assert len(recwarn) == 1 warn = recwarn.pop(UserWarning)