From 26773fb5a605f7aac86aa1799440b72f40439548 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Wed, 12 May 2021 18:37:04 +0200 Subject: [PATCH] Move AxisArtistHelpers to toplevel. The axisartist has a concept of "axis_artist_helper", which computes various computations to draw slanted/curved axises. Confusingly, `AxisArtistHelper` (and likewise `AxisArtistHelperRectlinear`) do *not* define such helper classes; they are simply namespaces that hold the `{AxisArtistHelper,AxisArtistHelperRectlinear}.{Fixed,Floating}` nested classes which *do* define helpers. More specifically, `AxisArtistHelper.{Fixed,Floating}` act as abstract base classes for `AxisArtistHelperRectlinear.{Fixed,Floating}` which are actually usable. In order to slightly disentangle this move the actual helper classes to the toplevel (as `_{Fixed,Floating}AxisArtistHelperBase` and `_{Fixed,Floating}AxisArtistHelperRectlinear`), keeping the old "purely namespace" classes around for backcompat. (But note that end users should never have to directly interact with these classes anyways -- normally, they only construct GridHelpers which take care of the interaction with AxisArtistHelpers; see e.g. the various axisartist examples.) More simply, this commit simply dedents most of the definitions of the Helper classes. --- doc/missing-references.json | 11 +- lib/mpl_toolkits/axisartist/axislines.py | 372 +++++++++--------- .../axisartist/grid_helper_curvelinear.py | 8 +- 3 files changed, 204 insertions(+), 187 deletions(-) diff --git a/doc/missing-references.json b/doc/missing-references.json index 28ed7abd4297..0e33c97dab1d 100644 --- a/doc/missing-references.json +++ b/doc/missing-references.json @@ -277,8 +277,15 @@ "mpl_toolkits.axisartist.axisline_style._FancyAxislineStyle.SimpleArrow": [ "lib/mpl_toolkits/axisartist/axisline_style.py:docstring of mpl_toolkits.axisartist.axisline_style.AxislineStyle:1" ], - "mpl_toolkits.axisartist.axislines.AxisArtistHelper._Base": [ - "lib/mpl_toolkits/axisartist/axislines.py:docstring of mpl_toolkits.axisartist.axislines.AxisArtistHelper:1" + "mpl_toolkits.axisartist.axislines._FixedAxisArtistHelperBase": [ + "lib/mpl_toolkits/axisartist/axislines.py:docstring of mpl_toolkits.axisartist.axislines.AxisArtistHelper:1", + "lib/mpl_toolkits/axisartist/axislines.py:docstring of mpl_toolkits.axisartist.axislines.FixedAxisArtistHelperRectilinear:1", + "lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py:docstring of mpl_toolkits.axisartist.grid_helper_curvelinear.FixedAxisArtistHelper:1" + ], + "mpl_toolkits.axisartist.axislines._FloatingAxisArtistHelperBase": [ + "lib/mpl_toolkits/axisartist/axislines.py:docstring of mpl_toolkits.axisartist.axislines.AxisArtistHelper:1", + "lib/mpl_toolkits/axisartist/axislines.py:docstring of mpl_toolkits.axisartist.axislines.FloatingAxisArtistHelperRectilinear:1", + "lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py:docstring of mpl_toolkits.axisartist.grid_helper_curvelinear.FloatingAxisArtistHelper:1" ], "mpl_toolkits.axisartist.floating_axes.FloatingAxesHostAxes": [ "doc/api/_as_gen/mpl_toolkits.axisartist.floating_axes.rst:32::1", diff --git a/lib/mpl_toolkits/axisartist/axislines.py b/lib/mpl_toolkits/axisartist/axislines.py index 9752a85139b3..cfb03989b0c6 100644 --- a/lib/mpl_toolkits/axisartist/axislines.py +++ b/lib/mpl_toolkits/axisartist/axislines.py @@ -50,12 +50,12 @@ from .axis_artist import AxisArtist, GridlinesCollection -class AxisArtistHelper: +class _AxisArtistHelperBase: """ - Axis helpers should define the methods listed below. The *axes* argument - will be the axes attribute of the caller artist. + Base class for axis helper. - :: + Subclasses should define the methods listed below. The *axes* + argument will be the ``.axes`` attribute of the caller artist. :: # Construct the spine. @@ -84,211 +84,219 @@ def get_tick_iterators(self, axes): return iter_major, iter_minor """ - class _Base: - """Base class for axis helper.""" - - def update_lim(self, axes): - pass - - def _to_xy(self, values, const): - """ - Create a (*values.shape, 2)-shape array representing (x, y) pairs. - - *values* go into the coordinate determined by ``self.nth_coord``. - The other coordinate is filled with the constant *const*. - - Example:: - - >>> self.nth_coord = 0 - >>> self._to_xy([1, 2, 3], const=0) - array([[1, 0], - [2, 0], - [3, 0]]) - """ - if self.nth_coord == 0: - return np.stack(np.broadcast_arrays(values, const), axis=-1) - elif self.nth_coord == 1: - return np.stack(np.broadcast_arrays(const, values), axis=-1) - else: - raise ValueError("Unexpected nth_coord") - - class Fixed(_Base): - """Helper class for a fixed (in the axes coordinate) axis.""" - - passthru_pt = _api.deprecated("3.7")(property( - lambda self: {"left": (0, 0), "right": (1, 0), - "bottom": (0, 0), "top": (0, 1)}[self._loc])) - - def __init__(self, loc, nth_coord=None): - """``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis.""" - self.nth_coord = ( - nth_coord if nth_coord is not None else - _api.check_getitem( - {"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc)) - if (nth_coord == 0 and loc not in ["left", "right"] - or nth_coord == 1 and loc not in ["bottom", "top"]): - _api.warn_deprecated( - "3.7", message=f"{loc=!r} is incompatible with " - "{nth_coord=}; support is deprecated since %(since)s") - self._loc = loc - self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc] - super().__init__() - # axis line in transAxes - self._path = Path(self._to_xy((0, 1), const=self._pos)) - - def get_nth_coord(self): - return self.nth_coord - - # LINE + def update_lim(self, axes): + pass - def get_line(self, axes): - return self._path + def _to_xy(self, values, const): + """ + Create a (*values.shape, 2)-shape array representing (x, y) pairs. - def get_line_transform(self, axes): - return axes.transAxes + The other coordinate is filled with the constant *const*. - # LABEL + Example:: - def get_axislabel_transform(self, axes): - return axes.transAxes + >>> self.nth_coord = 0 + >>> self._to_xy([1, 2, 3], const=0) + array([[1, 0], + [2, 0], + [3, 0]]) + """ + if self.nth_coord == 0: + return np.stack(np.broadcast_arrays(values, const), axis=-1) + elif self.nth_coord == 1: + return np.stack(np.broadcast_arrays(const, values), axis=-1) + else: + raise ValueError("Unexpected nth_coord") + + +class _FixedAxisArtistHelperBase(_AxisArtistHelperBase): + """Helper class for a fixed (in the axes coordinate) axis.""" + + passthru_pt = _api.deprecated("3.7")(property( + lambda self: {"left": (0, 0), "right": (1, 0), + "bottom": (0, 0), "top": (0, 1)}[self._loc])) + + def __init__(self, loc, nth_coord=None): + """``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis.""" + self.nth_coord = ( + nth_coord if nth_coord is not None else + _api.check_getitem( + {"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc)) + if (nth_coord == 0 and loc not in ["left", "right"] + or nth_coord == 1 and loc not in ["bottom", "top"]): + _api.warn_deprecated( + "3.7", message=f"{loc=!r} is incompatible with " + "{nth_coord=}; support is deprecated since %(since)s") + self._loc = loc + self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc] + super().__init__() + # axis line in transAxes + self._path = Path(self._to_xy((0, 1), const=self._pos)) - def get_axislabel_pos_angle(self, axes): - """ - Return the label reference position in transAxes. + def get_nth_coord(self): + return self.nth_coord - get_label_transform() returns a transform of (transAxes+offset) - """ - return dict(left=((0., 0.5), 90), # (position, angle_tangent) - right=((1., 0.5), 90), - bottom=((0.5, 0.), 0), - top=((0.5, 1.), 0))[self._loc] + # LINE - # TICK + def get_line(self, axes): + return self._path - def get_tick_transform(self, axes): - return [axes.get_xaxis_transform(), - axes.get_yaxis_transform()][self.nth_coord] + def get_line_transform(self, axes): + return axes.transAxes - class Floating(_Base): + # LABEL - def __init__(self, nth_coord, value): - self.nth_coord = nth_coord - self._value = value - super().__init__() + def get_axislabel_transform(self, axes): + return axes.transAxes - def get_nth_coord(self): - return self.nth_coord + def get_axislabel_pos_angle(self, axes): + """ + Return the label reference position in transAxes. - def get_line(self, axes): - raise RuntimeError( - "get_line method should be defined by the derived class") + get_label_transform() returns a transform of (transAxes+offset) + """ + return dict(left=((0., 0.5), 90), # (position, angle_tangent) + right=((1., 0.5), 90), + bottom=((0.5, 0.), 0), + top=((0.5, 1.), 0))[self._loc] + # TICK -class AxisArtistHelperRectlinear: + def get_tick_transform(self, axes): + return [axes.get_xaxis_transform(), + axes.get_yaxis_transform()][self.nth_coord] - class Fixed(AxisArtistHelper.Fixed): - def __init__(self, axes, loc, nth_coord=None): - """ - nth_coord = along which coordinate value varies - in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis - """ - super().__init__(loc, nth_coord) - self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] +class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase): - # TICK + def __init__(self, nth_coord, value): + self.nth_coord = nth_coord + self._value = value + super().__init__() - def get_tick_iterators(self, axes): - """tick_loc, tick_angle, tick_label""" - if self._loc in ["bottom", "top"]: - angle_normal, angle_tangent = 90, 0 - else: # "left", "right" - angle_normal, angle_tangent = 0, 90 - - major = self.axis.major - major_locs = major.locator() - major_labels = major.formatter.format_ticks(major_locs) - - minor = self.axis.minor - minor_locs = minor.locator() - minor_labels = minor.formatter.format_ticks(minor_locs) - - tick_to_axes = self.get_tick_transform(axes) - axes.transAxes - - def _f(locs, labels): - for loc, label in zip(locs, labels): - c = self._to_xy(loc, const=self._pos) - # check if the tick point is inside axes - c2 = tick_to_axes.transform(c) - if mpl.transforms._interval_contains_close( - (0, 1), c2[self.nth_coord]): - yield c, angle_normal, angle_tangent, label - - return _f(major_locs, major_labels), _f(minor_locs, minor_labels) - - class Floating(AxisArtistHelper.Floating): - def __init__(self, axes, nth_coord, - passingthrough_point, axis_direction="bottom"): - super().__init__(nth_coord, passingthrough_point) - self._axis_direction = axis_direction - self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] + def get_nth_coord(self): + return self.nth_coord - def get_line(self, axes): - fixed_coord = 1 - self.nth_coord - data_to_axes = axes.transData - axes.transAxes - p = data_to_axes.transform([self._value, self._value]) - return Path(self._to_xy((0, 1), const=p[fixed_coord])) + def get_line(self, axes): + raise RuntimeError( + "get_line method should be defined by the derived class") - def get_line_transform(self, axes): - return axes.transAxes - def get_axislabel_transform(self, axes): - return axes.transAxes +class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase): - def get_axislabel_pos_angle(self, axes): - """ - Return the label reference position in transAxes. - - get_label_transform() returns a transform of (transAxes+offset) - """ - angle = [0, 90][self.nth_coord] - fixed_coord = 1 - self.nth_coord - data_to_axes = axes.transData - axes.transAxes - p = data_to_axes.transform([self._value, self._value]) - verts = self._to_xy(0.5, const=p[fixed_coord]) - if 0 <= verts[fixed_coord] <= 1: - return verts, angle - else: - return None, None + def __init__(self, axes, loc, nth_coord=None): + """ + nth_coord = along which coordinate value varies + in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis + """ + super().__init__(loc, nth_coord) + self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] - def get_tick_transform(self, axes): - return axes.transData + # TICK - def get_tick_iterators(self, axes): - """tick_loc, tick_angle, tick_label""" - if self.nth_coord == 0: - angle_normal, angle_tangent = 90, 0 - else: - angle_normal, angle_tangent = 0, 90 + def get_tick_iterators(self, axes): + """tick_loc, tick_angle, tick_label""" + if self._loc in ["bottom", "top"]: + angle_normal, angle_tangent = 90, 0 + else: # "left", "right" + angle_normal, angle_tangent = 0, 90 + + major = self.axis.major + major_locs = major.locator() + major_labels = major.formatter.format_ticks(major_locs) + + minor = self.axis.minor + minor_locs = minor.locator() + minor_labels = minor.formatter.format_ticks(minor_locs) + + tick_to_axes = self.get_tick_transform(axes) - axes.transAxes + + def _f(locs, labels): + for loc, label in zip(locs, labels): + c = self._to_xy(loc, const=self._pos) + # check if the tick point is inside axes + c2 = tick_to_axes.transform(c) + if mpl.transforms._interval_contains_close( + (0, 1), c2[self.nth_coord]): + yield c, angle_normal, angle_tangent, label + + return _f(major_locs, major_labels), _f(minor_locs, minor_labels) + + +class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase): + + def __init__(self, axes, nth_coord, + passingthrough_point, axis_direction="bottom"): + super().__init__(nth_coord, passingthrough_point) + self._axis_direction = axis_direction + self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] + + def get_line(self, axes): + fixed_coord = 1 - self.nth_coord + data_to_axes = axes.transData - axes.transAxes + p = data_to_axes.transform([self._value, self._value]) + return Path(self._to_xy((0, 1), const=p[fixed_coord])) + + def get_line_transform(self, axes): + return axes.transAxes + + def get_axislabel_transform(self, axes): + return axes.transAxes + + def get_axislabel_pos_angle(self, axes): + """ + Return the label reference position in transAxes. + + get_label_transform() returns a transform of (transAxes+offset) + """ + angle = [0, 90][self.nth_coord] + fixed_coord = 1 - self.nth_coord + data_to_axes = axes.transData - axes.transAxes + p = data_to_axes.transform([self._value, self._value]) + verts = self._to_xy(0.5, const=p[fixed_coord]) + if 0 <= verts[fixed_coord] <= 1: + return verts, angle + else: + return None, None + + def get_tick_transform(self, axes): + return axes.transData + + def get_tick_iterators(self, axes): + """tick_loc, tick_angle, tick_label""" + if self.nth_coord == 0: + angle_normal, angle_tangent = 90, 0 + else: + angle_normal, angle_tangent = 0, 90 + + major = self.axis.major + major_locs = major.locator() + major_labels = major.formatter.format_ticks(major_locs) + + minor = self.axis.minor + minor_locs = minor.locator() + minor_labels = minor.formatter.format_ticks(minor_locs) + + data_to_axes = axes.transData - axes.transAxes + + def _f(locs, labels): + for loc, label in zip(locs, labels): + c = self._to_xy(loc, const=self._value) + c1, c2 = data_to_axes.transform(c) + if 0 <= c1 <= 1 and 0 <= c2 <= 1: + yield c, angle_normal, angle_tangent, label - major = self.axis.major - major_locs = major.locator() - major_labels = major.formatter.format_ticks(major_locs) + return _f(major_locs, major_labels), _f(minor_locs, minor_labels) - minor = self.axis.minor - minor_locs = minor.locator() - minor_labels = minor.formatter.format_ticks(minor_locs) - data_to_axes = axes.transData - axes.transAxes +class AxisArtistHelper: # Backcompat. + Fixed = _FixedAxisArtistHelperBase + Floating = _FloatingAxisArtistHelperBase - def _f(locs, labels): - for loc, label in zip(locs, labels): - c = self._to_xy(loc, const=self._value) - c1, c2 = data_to_axes.transform(c) - if 0 <= c1 <= 1 and 0 <= c2 <= 1: - yield c, angle_normal, angle_tangent, label - return _f(major_locs, major_labels), _f(minor_locs, minor_labels) +class AxisArtistHelperRectlinear: # Backcompat. + Fixed = FixedAxisArtistHelperRectilinear + Floating = FloatingAxisArtistHelperRectilinear class GridHelperBase: @@ -338,7 +346,7 @@ def new_fixed_axis(self, loc, if axis_direction is None: axis_direction = loc - helper = AxisArtistHelperRectlinear.Fixed(axes, loc, nth_coord) + helper = FixedAxisArtistHelperRectilinear(axes, loc, nth_coord) axisline = AxisArtist(axes, helper, offset=offset, axis_direction=axis_direction) return axisline @@ -352,7 +360,7 @@ def new_floating_axis(self, nth_coord, value, "'new_floating_axis' explicitly requires the axes keyword.") axes = self.axes - helper = AxisArtistHelperRectlinear.Floating( + helper = FloatingAxisArtistHelperRectilinear( axes, nth_coord, value, axis_direction) axisline = AxisArtist(axes, helper, axis_direction=axis_direction) axisline.line.set_clip_on(True) diff --git a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py index e5ba2c824257..71ada7bcf03b 100644 --- a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py +++ b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py @@ -10,7 +10,8 @@ import matplotlib as mpl from matplotlib.path import Path from matplotlib.transforms import Affine2D, IdentityTransform -from .axislines import AxisArtistHelper, GridHelperBase +from .axislines import ( + _FixedAxisArtistHelperBase, _FloatingAxisArtistHelperBase, GridHelperBase) from .axis_artist import AxisArtist from .grid_finder import GridFinder @@ -40,7 +41,7 @@ def _value_and_jacobian(func, xs, ys, xlims, ylims): return (val, (val_dx - val) / xeps, (val_dy - val) / yeps) -class FixedAxisArtistHelper(AxisArtistHelper.Fixed): +class FixedAxisArtistHelper(_FixedAxisArtistHelperBase): """ Helper class for a fixed axis. """ @@ -80,7 +81,8 @@ def get_tick_iterators(self, axes): return chain(ti1, ti2), iter([]) -class FloatingAxisArtistHelper(AxisArtistHelper.Floating): +class FloatingAxisArtistHelper(_FloatingAxisArtistHelperBase): + def __init__(self, grid_helper, nth_coord, value, axis_direction=None): """ nth_coord = along which coordinate value varies.