From 9f9ea545301abe1a26aa03d57be266b6ee20abc2 Mon Sep 17 00:00:00 2001 From: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Date: Fri, 18 Sep 2020 01:38:19 +0200 Subject: [PATCH 1/2] Move cbook._check_shape() to _api.check_shape() --- lib/matplotlib/_api.py | 40 ++++++++++++++++++++++++++++++ lib/matplotlib/cbook/__init__.py | 39 ----------------------------- lib/matplotlib/path.py | 4 +-- lib/matplotlib/tests/test_api.py | 21 ++++++++++++++++ lib/matplotlib/tests/test_cbook.py | 17 ------------- 5 files changed, 63 insertions(+), 58 deletions(-) create mode 100644 lib/matplotlib/tests/test_api.py diff --git a/lib/matplotlib/_api.py b/lib/matplotlib/_api.py index bd9a3d7da2ab..4f5b3dc21ae4 100644 --- a/lib/matplotlib/_api.py +++ b/lib/matplotlib/_api.py @@ -1,3 +1,4 @@ +import itertools def check_in_list(_values, *, _print_supported_values=True, **kwargs): @@ -31,3 +32,42 @@ def check_in_list(_values, *, _print_supported_values=True, **kwargs): f"supported values are {', '.join(map(repr, values))}") else: raise ValueError(f"{val!r} is not a valid value for {key}") + + +def check_shape(_shape, **kwargs): + """ + For each *key, value* pair in *kwargs*, check that *value* has the shape + *_shape*, if not, raise an appropriate ValueError. + + *None* in the shape is treated as a "free" size that can have any length. + e.g. (None, 2) -> (N, 2) + + The values checked must be numpy arrays. + + Examples + -------- + To check for (N, 2) shaped arrays + + >>> _api.check_shape((None, 2), arg=arg, other_arg=other_arg) + """ + target_shape = _shape + for k, v in kwargs.items(): + data_shape = v.shape + + if len(target_shape) != len(data_shape) or any( + t not in [s, None] + for t, s in zip(target_shape, data_shape) + ): + dim_labels = iter(itertools.chain( + 'MNLIJKLH', + (f"D{i}" for i in itertools.count()))) + text_shape = ", ".join((str(n) + if n is not None + else next(dim_labels) + for n in target_shape)) + + raise ValueError( + f"{k!r} must be {len(target_shape)}D " + f"with shape ({text_shape}). " + f"Your input has shape {v.shape}." + ) diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index ebec27cdf615..68e01cb6360a 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -2281,45 +2281,6 @@ def type_name(tp): type_name(type(v)))) -def _check_shape(_shape, **kwargs): - """ - For each *key, value* pair in *kwargs*, check that *value* has the shape - *_shape*, if not, raise an appropriate ValueError. - - *None* in the shape is treated as a "free" size that can have any length. - e.g. (None, 2) -> (N, 2) - - The values checked must be numpy arrays. - - Examples - -------- - To check for (N, 2) shaped arrays - - >>> _api.check_in_list((None, 2), arg=arg, other_arg=other_arg) - """ - target_shape = _shape - for k, v in kwargs.items(): - data_shape = v.shape - - if len(target_shape) != len(data_shape) or any( - t not in [s, None] - for t, s in zip(target_shape, data_shape) - ): - dim_labels = iter(itertools.chain( - 'MNLIJKLH', - (f"D{i}" for i in itertools.count()))) - text_shape = ", ".join((str(n) - if n is not None - else next(dim_labels) - for n in target_shape)) - - raise ValueError( - f"{k!r} must be {len(target_shape)}D " - f"with shape ({text_shape}). " - f"Your input has shape {v.shape}." - ) - - def _check_getitem(_mapping, **kwargs): """ *kwargs* must consist of a single *key, value* pair. If *key* is in diff --git a/lib/matplotlib/path.py b/lib/matplotlib/path.py index e1daf35eed45..bc65a5d6fe27 100644 --- a/lib/matplotlib/path.py +++ b/lib/matplotlib/path.py @@ -15,7 +15,7 @@ import numpy as np import matplotlib as mpl -from . import _path, cbook +from . import _api, _path, cbook from .cbook import _to_unmasked_float_array, simple_linear_interpolation from .bezier import BezierSegment @@ -129,7 +129,7 @@ def __init__(self, vertices, codes=None, _interpolation_steps=1, and codes as read-only arrays. """ vertices = _to_unmasked_float_array(vertices) - cbook._check_shape((None, 2), vertices=vertices) + _api.check_shape((None, 2), vertices=vertices) if codes is not None: codes = np.asarray(codes, self.code_type) diff --git a/lib/matplotlib/tests/test_api.py b/lib/matplotlib/tests/test_api.py new file mode 100644 index 000000000000..be2d80bb4244 --- /dev/null +++ b/lib/matplotlib/tests/test_api.py @@ -0,0 +1,21 @@ +import re + +import numpy as np +import pytest + +from matplotlib import _api + + +@pytest.mark.parametrize('target,test_shape', + [((None, ), (1, 3)), + ((None, 3), (1,)), + ((None, 3), (1, 2)), + ((1, 5), (1, 9)), + ((None, 2, None), (1, 3, 1)) + ]) +def test_check_shape(target, test_shape): + error_pattern = (f"^'aardvark' must be {len(target)}D.*" + + re.escape(f'has shape {test_shape}')) + data = np.zeros(test_shape) + with pytest.raises(ValueError, match=error_pattern): + _api.check_shape(target, aardvark=data) diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index de99701d7757..f4671c75b4aa 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -1,6 +1,5 @@ import itertools import pickle -import re from weakref import ref from unittest.mock import patch, Mock @@ -675,22 +674,6 @@ def divisors(n): check(x, rstride=rstride, cstride=cstride) -@pytest.mark.parametrize('target,test_shape', - [((None, ), (1, 3)), - ((None, 3), (1,)), - ((None, 3), (1, 2)), - ((1, 5), (1, 9)), - ((None, 2, None), (1, 3, 1)) - ]) -def test_check_shape(target, test_shape): - error_pattern = (f"^'aardvark' must be {len(target)}D.*" + - re.escape(f'has shape {test_shape}')) - data = np.zeros(test_shape) - with pytest.raises(ValueError, - match=error_pattern): - cbook._check_shape(target, aardvark=data) - - def test_setattr_cm(): class A: From c725179565c42f2c0a6e1eb29a3813367707fe24 Mon Sep 17 00:00:00 2001 From: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Date: Fri, 18 Sep 2020 01:45:31 +0200 Subject: [PATCH 2/2] Move cbook._check_getitem() to _api.check_getitem() --- lib/matplotlib/_api.py | 22 +++++++++++++++++++++ lib/matplotlib/axes/_axes.py | 6 +++--- lib/matplotlib/axes/_base.py | 4 ++-- lib/matplotlib/axis.py | 6 +++--- lib/matplotlib/backends/backend_cairo.py | 6 +++--- lib/matplotlib/backends/backend_ps.py | 2 +- lib/matplotlib/cbook/__init__.py | 22 --------------------- lib/matplotlib/collections.py | 2 +- lib/matplotlib/colorbar.py | 6 +++--- lib/matplotlib/mathtext.py | 4 ++-- lib/matplotlib/offsetbox.py | 2 +- lib/matplotlib/quiver.py | 2 +- lib/matplotlib/scale.py | 2 +- lib/matplotlib/testing/jpl_units/UnitDbl.py | 6 +++--- lib/mpl_toolkits/axisartist/axis_artist.py | 10 +++++----- lib/mpl_toolkits/mplot3d/axes3d.py | 2 +- 16 files changed, 52 insertions(+), 52 deletions(-) diff --git a/lib/matplotlib/_api.py b/lib/matplotlib/_api.py index 4f5b3dc21ae4..856a7e0a063a 100644 --- a/lib/matplotlib/_api.py +++ b/lib/matplotlib/_api.py @@ -71,3 +71,25 @@ def check_shape(_shape, **kwargs): f"with shape ({text_shape}). " f"Your input has shape {v.shape}." ) + + +def check_getitem(_mapping, **kwargs): + """ + *kwargs* must consist of a single *key, value* pair. If *key* is in + *_mapping*, return ``_mapping[value]``; else, raise an appropriate + ValueError. + + Examples + -------- + >>> _api.check_getitem({"foo": "bar"}, arg=arg) + """ + mapping = _mapping + if len(kwargs) != 1: + raise ValueError("check_getitem takes a single keyword argument") + (k, v), = kwargs.items() + try: + return mapping[v] + except KeyError: + raise ValueError( + "{!r} is not a valid value for {}; supported values are {}" + .format(v, k, ', '.join(map(repr, mapping)))) from None diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 04a63d1a7339..30b2b6a4dfb4 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -86,7 +86,7 @@ def get_title(self, loc="center"): titles = {'left': self._left_title, 'center': self.title, 'right': self._right_title} - title = cbook._check_getitem(titles, loc=loc.lower()) + title = _api.check_getitem(titles, loc=loc.lower()) return title.get_text() def set_title(self, label, fontdict=None, loc=None, pad=None, *, y=None, @@ -149,7 +149,7 @@ def set_title(self, label, fontdict=None, loc=None, pad=None, *, y=None, titles = {'left': self._left_title, 'center': self.title, 'right': self._right_title} - title = cbook._check_getitem(titles, loc=loc.lower()) + title = _api.check_getitem(titles, loc=loc.lower()) default = { 'fontsize': rcParams['axes.titlesize'], 'fontweight': rcParams['axes.titleweight'], @@ -7195,7 +7195,7 @@ def magnitude_spectrum(self, x, Fs=None, Fc=None, window=None, pad_to=pad_to, sides=sides) freqs += Fc - yunits = cbook._check_getitem( + yunits = _api.check_getitem( {None: 'energy', 'default': 'energy', 'linear': 'energy', 'dB': 'dB'}, scale=scale) diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index e0d14b009254..40a4af4630b0 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -3002,10 +3002,10 @@ def ticklabel_format(self, *, axis='both', style='', scilimits=None, raise ValueError("scilimits must be a sequence of 2 integers" ) from err STYLES = {'sci': True, 'scientific': True, 'plain': False, '': None} - is_sci_style = cbook._check_getitem(STYLES, style=style) + is_sci_style = _api.check_getitem(STYLES, style=style) axis_map = {**{k: [v] for k, v in self._get_axis_map().items()}, 'both': self._get_axis_list()} - axises = cbook._check_getitem(axis_map, axis=axis) + axises = _api.check_getitem(axis_map, axis=axis) try: for axis in axises: if is_sci_style is not None: diff --git a/lib/matplotlib/axis.py b/lib/matplotlib/axis.py index 2e4671e54876..264d8b5b99da 100644 --- a/lib/matplotlib/axis.py +++ b/lib/matplotlib/axis.py @@ -2048,7 +2048,7 @@ def set_label_position(self, position): ---------- position : {'top', 'bottom'} """ - self.label.set_verticalalignment(cbook._check_getitem({ + self.label.set_verticalalignment(_api.check_getitem({ 'top': 'baseline', 'bottom': 'top', }, position=position)) self.label_position = position @@ -2340,7 +2340,7 @@ def set_label_position(self, position): """ self.label.set_rotation_mode('anchor') self.label.set_horizontalalignment('center') - self.label.set_verticalalignment(cbook._check_getitem({ + self.label.set_verticalalignment(_api.check_getitem({ 'left': 'bottom', 'right': 'top', }, position=position)) self.label_position = position @@ -2425,7 +2425,7 @@ def set_offset_position(self, position): position : {'left', 'right'} """ x, y = self.offsetText.get_position() - x = cbook._check_getitem({'left': 0, 'right': 1}, position=position) + x = _api.check_getitem({'left': 0, 'right': 1}, position=position) self.offsetText.set_ha(position) self.offsetText.set_position((x, y)) diff --git a/lib/matplotlib/backends/backend_cairo.py b/lib/matplotlib/backends/backend_cairo.py index 5d78b3e6af32..756d88009121 100644 --- a/lib/matplotlib/backends/backend_cairo.py +++ b/lib/matplotlib/backends/backend_cairo.py @@ -24,7 +24,7 @@ "cairo backend requires that pycairo>=1.11.0 or cairocffi " "is installed") from err -from .. import cbook, font_manager +from .. import _api, cbook, font_manager from matplotlib.backend_bases import ( _Backend, _check_savefig_extra_args, FigureCanvasBase, FigureManagerBase, GraphicsContextBase, RendererBase) @@ -358,7 +358,7 @@ def set_alpha(self, alpha): # one for False. def set_capstyle(self, cs): - self.ctx.set_line_cap(cbook._check_getitem(self._capd, capstyle=cs)) + self.ctx.set_line_cap(_api.check_getitem(self._capd, capstyle=cs)) self._capstyle = cs def set_clip_rectangle(self, rectangle): @@ -401,7 +401,7 @@ def get_rgb(self): return self.ctx.get_source().get_rgba()[:3] def set_joinstyle(self, js): - self.ctx.set_line_join(cbook._check_getitem(self._joind, joinstyle=js)) + self.ctx.set_line_join(_api.check_getitem(self._joind, joinstyle=js)) self._joinstyle = js def set_linewidth(self, w): diff --git a/lib/matplotlib/backends/backend_ps.py b/lib/matplotlib/backends/backend_ps.py index 350785164a10..07ba63c093ef 100644 --- a/lib/matplotlib/backends/backend_ps.py +++ b/lib/matplotlib/backends/backend_ps.py @@ -817,7 +817,7 @@ def _print_ps( papertype = papertype.lower() _api.check_in_list(['auto', *papersize], papertype=papertype) - orientation = cbook._check_getitem( + orientation = _api.check_getitem( _Orientation, orientation=orientation.lower()) printer = (self._print_figure_tex diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index 68e01cb6360a..2360f90f5aed 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -2281,28 +2281,6 @@ def type_name(tp): type_name(type(v)))) -def _check_getitem(_mapping, **kwargs): - """ - *kwargs* must consist of a single *key, value* pair. If *key* is in - *_mapping*, return ``_mapping[value]``; else, raise an appropriate - ValueError. - - Examples - -------- - >>> cbook._check_getitem({"foo": "bar"}, arg=arg) - """ - mapping = _mapping - if len(kwargs) != 1: - raise ValueError("_check_getitem takes a single keyword argument") - (k, v), = kwargs.items() - try: - return mapping[v] - except KeyError: - raise ValueError( - "{!r} is not a valid value for {}; supported values are {}" - .format(v, k, ', '.join(map(repr, mapping)))) from None - - class _classproperty: """ Like `property`, but also triggers on access via the class, and it is the diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 81a08ae57f9e..2392d2489b51 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -1619,7 +1619,7 @@ def set_orientation(self, orientation=None): orientation : {'horizontal', 'vertical'} """ try: - is_horizontal = cbook._check_getitem( + is_horizontal = _api.check_getitem( {"horizontal": True, "vertical": False}, orientation=orientation) except ValueError: diff --git a/lib/matplotlib/colorbar.py b/lib/matplotlib/colorbar.py index e1231c079b2c..6eb5bd7fa633 100644 --- a/lib/matplotlib/colorbar.py +++ b/lib/matplotlib/colorbar.py @@ -461,7 +461,7 @@ def __init__(self, ax, cmap=None, self.values = values self.boundaries = boundaries self.extend = extend - self._inside = cbook._check_getitem( + self._inside = _api.check_getitem( {'neither': slice(0, None), 'both': slice(1, -1), 'min': slice(1, None), 'max': slice(0, -1)}, extend=extend) @@ -1372,10 +1372,10 @@ def remove(self): def _normalize_location_orientation(location, orientation): if location is None: - location = cbook._check_getitem( + location = _api.check_getitem( {None: "right", "vertical": "right", "horizontal": "bottom"}, orientation=orientation) - loc_settings = cbook._check_getitem({ + loc_settings = _api.check_getitem({ "left": {"location": "left", "orientation": "vertical", "anchor": (1.0, 0.5), "panchor": (0.0, 0.5), "pad": 0.10}, "right": {"location": "right", "orientation": "vertical", diff --git a/lib/matplotlib/mathtext.py b/lib/matplotlib/mathtext.py index 81b9742bf85a..bc77d76e4669 100644 --- a/lib/matplotlib/mathtext.py +++ b/lib/matplotlib/mathtext.py @@ -24,7 +24,7 @@ import numpy as np from PIL import Image -from matplotlib import cbook, colors as mcolors, rcParams, _mathtext +from matplotlib import _api, cbook, colors as mcolors, rcParams, _mathtext from matplotlib.ft2font import FT2Image, LOAD_NO_HINTING from matplotlib.font_manager import FontProperties # Backcompat imports, all are deprecated as of 3.4. @@ -444,7 +444,7 @@ def _parse_cached(self, s, dpi, prop, force_standard_ps_fonts): fontset_class = ( _mathtext.StandardPsFonts if force_standard_ps_fonts - else cbook._check_getitem( + else _api.check_getitem( self._font_type_mapping, fontset=prop.get_math_fontfamily())) backend = self._backend_mapping[self._output]() font_output = fontset_class(prop, backend) diff --git a/lib/matplotlib/offsetbox.py b/lib/matplotlib/offsetbox.py index b222bec548a5..692c61d6f891 100644 --- a/lib/matplotlib/offsetbox.py +++ b/lib/matplotlib/offsetbox.py @@ -1095,7 +1095,7 @@ def __init__(self, loc, self.set_child(child) if isinstance(loc, str): - loc = cbook._check_getitem(self.codes, loc=loc) + loc = _api.check_getitem(self.codes, loc=loc) self.loc = loc self.borderpad = borderpad diff --git a/lib/matplotlib/quiver.py b/lib/matplotlib/quiver.py index 164597742c4a..844a23cc5c4b 100644 --- a/lib/matplotlib/quiver.py +++ b/lib/matplotlib/quiver.py @@ -352,7 +352,7 @@ def draw(self, renderer): self.stale = False def _set_transform(self): - self.set_transform(cbook._check_getitem({ + self.set_transform(_api.check_getitem({ "data": self.Q.axes.transData, "axes": self.Q.axes.transAxes, "figure": self.Q.axes.figure.transFigure, diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index 5aad48f50e06..e5e9d012f3b0 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -200,7 +200,7 @@ def __init__(self, base, nonpositive='clip'): if base <= 0 or base == 1: raise ValueError('The log base cannot be <= 0 or == 1') self.base = base - self._clip = cbook._check_getitem( + self._clip = _api.check_getitem( {"clip": True, "mask": False}, nonpositive=nonpositive) def __str__(self): diff --git a/lib/matplotlib/testing/jpl_units/UnitDbl.py b/lib/matplotlib/testing/jpl_units/UnitDbl.py index 14d77cd8faf7..68481a80fd57 100644 --- a/lib/matplotlib/testing/jpl_units/UnitDbl.py +++ b/lib/matplotlib/testing/jpl_units/UnitDbl.py @@ -2,7 +2,7 @@ import operator -from matplotlib import cbook +from matplotlib import _api class UnitDbl: @@ -48,7 +48,7 @@ def __init__(self, value, units): - value The numeric value of the UnitDbl. - units The string name of the units the value is in. """ - data = cbook._check_getitem(self.allowed, units=units) + data = _api.check_getitem(self.allowed, units=units) self._value = float(value * data[0]) self._units = data[1] @@ -68,7 +68,7 @@ def convert(self, units): """ if self._units == units: return self._value - data = cbook._check_getitem(self.allowed, units=units) + data = _api.check_getitem(self.allowed, units=units) if self._units != data[1]: raise ValueError(f"Error trying to convert to different units.\n" f" Invalid conversion requested.\n" diff --git a/lib/mpl_toolkits/axisartist/axis_artist.py b/lib/mpl_toolkits/axisartist/axis_artist.py index 23e508c75e0b..d25200d2c899 100644 --- a/lib/mpl_toolkits/axisartist/axis_artist.py +++ b/lib/mpl_toolkits/axisartist/axis_artist.py @@ -90,7 +90,7 @@ import numpy as np -from matplotlib import cbook, rcParams +from matplotlib import _api, cbook, rcParams import matplotlib.artist as martist import matplotlib.text as mtext @@ -416,7 +416,7 @@ def get_text(self): top=("bottom", "center")) def set_default_alignment(self, d): - va, ha = cbook._check_getitem(self._default_alignments, d=d) + va, ha = _api.check_getitem(self._default_alignments, d=d) self.set_va(va) self.set_ha(ha) @@ -426,7 +426,7 @@ def set_default_alignment(self, d): top=180) def set_default_angle(self, d): - self.set_rotation(cbook._check_getitem(self._default_angles, d=d)) + self.set_rotation(_api.check_getitem(self._default_angles, d=d)) def set_axis_direction(self, d): """ @@ -807,7 +807,7 @@ def set_ticklabel_direction(self, tick_direction): ---------- tick_direction : {"+", "-"} """ - self._ticklabel_add_angle = cbook._check_getitem( + self._ticklabel_add_angle = _api.check_getitem( {"+": 0, "-": 180}, tick_direction=tick_direction) def invert_ticklabel_direction(self): @@ -826,7 +826,7 @@ def set_axislabel_direction(self, label_direction): ---------- tick_direction : {"+", "-"} """ - self._axislabel_add_angle = cbook._check_getitem( + self._axislabel_add_angle = _api.check_getitem( {"+": 0, "-": 180}, label_direction=label_direction) def get_transform(self): diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 80294e043c77..5ebec4ba61b8 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -1015,7 +1015,7 @@ def set_proj_type(self, proj_type): ---------- proj_type : {'persp', 'ortho'} """ - self._projection = cbook._check_getitem({ + self._projection = _api.check_getitem({ 'persp': proj3d.persp_transformation, 'ortho': proj3d.ortho_transformation, }, proj_type=proj_type)