diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index 85c4f45cc4d8..33703bb06a2f 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -20,6 +20,12 @@ from . import proj3d +class Object3D(object): + """A base class for any 3D object""" + def __init__(self): + self._sort_zpos = -np.inf # defaults to background (plotted first) + + def norm_angle(a): """Return the given angle normalized to -180 < *a* <= 180 degrees.""" a = (a + 360) % 360 @@ -71,7 +77,7 @@ def get_dir_vector(zdir): raise ValueError("'x', 'y', 'z', None or vector of length 3 expected") -class Text3D(mtext.Text): +class Text3D(mtext.Text, Object3D): """ Text object with 3D position and direction. @@ -93,6 +99,7 @@ class Text3D(mtext.Text): def __init__(self, x=0, y=0, z=0, text='', zdir='z', **kwargs): mtext.Text.__init__(self, x, y, text, **kwargs) + Object3D.__init__(self) self.set_3d_properties(z, zdir) def set_3d_properties(self, z=0, zdir='z'): @@ -120,7 +127,7 @@ def text_2d_to_3d(obj, z=0, zdir='z'): obj.set_3d_properties(z, zdir) -class Line3D(lines.Line2D): +class Line3D(lines.Line2D, Object3D): """ 3D line object. """ @@ -130,8 +137,14 @@ def __init__(self, xs, ys, zs, *args, **kwargs): Keyword arguments are passed onto :func:`~matplotlib.lines.Line2D`. """ lines.Line2D.__init__(self, [], [], *args, **kwargs) + Object3D.__init__(self) self._verts3d = xs, ys, zs + def set_sort_zpos(self, val): + """Set the position to use for z-sorting.""" + self._sort_zpos = val + self.stale = True + def set_3d_properties(self, zs=0, zdir='z'): xs = self.get_xdata() ys = self.get_ydata() @@ -209,11 +222,15 @@ def paths_to_3d_segments_with_codes(paths, zs=0, zdir='z'): return segments, codes_list -class Line3DCollection(LineCollection): +class Line3DCollection(LineCollection, Object3D): """ A collection of 3D lines. """ + def __init__(self, *args, **kwargs): + LineCollection.__init__(self, *args, **kwargs) + Object3D.__init__(self) + def set_sort_zpos(self, val): """Set the position to use for z-sorting.""" self._sort_zpos = val @@ -256,13 +273,14 @@ def line_collection_2d_to_3d(col, zs=0, zdir='z'): col.set_segments(segments3d) -class Patch3D(Patch): +class Patch3D(Patch, Object3D): """ 3D patch object. """ def __init__(self, *args, zs=(), zdir='z', **kwargs): Patch.__init__(self, *args, **kwargs) + Object3D.__init__(self) self.set_3d_properties(zs, zdir) def set_3d_properties(self, verts, zs=0, zdir='z'): @@ -287,13 +305,14 @@ def do_3d_projection(self, renderer): return min(vzs) -class PathPatch3D(Patch3D): +class PathPatch3D(Patch3D, Object3D): """ 3D PathPatch object. """ def __init__(self, path, *, zs=(), zdir='z', **kwargs): Patch.__init__(self, **kwargs) + Object3D.__init__(self) self.set_3d_properties(path, zs, zdir) def set_3d_properties(self, path, zs=0, zdir='z'): @@ -338,7 +357,7 @@ def pathpatch_2d_to_3d(pathpatch, z=0, zdir='z'): pathpatch.set_3d_properties(mpath, z, zdir) -class Patch3DCollection(PatchCollection): +class Patch3DCollection(PatchCollection, Object3D): """ A collection of 3D patches. """ @@ -360,7 +379,8 @@ def __init__(self, *args, zs=0, zdir='z', depthshade=True, **kwargs): This is typically desired in scatter plots. """ self._depthshade = depthshade - super().__init__(*args, **kwargs) + PatchCollection.__init__(self, *args, **kwargs) + Object3D.__init__(self) self.set_3d_properties(zs, zdir) def set_sort_zpos(self, val): @@ -404,7 +424,7 @@ def do_3d_projection(self, renderer): return np.nan -class Path3DCollection(PathCollection): +class Path3DCollection(PathCollection, Object3D): """ A collection of 3D paths. """ @@ -426,7 +446,8 @@ def __init__(self, *args, zs=0, zdir='z', depthshade=True, **kwargs): This is typically desired in scatter plots. """ self._depthshade = depthshade - super().__init__(*args, **kwargs) + PathCollection.__init__(self, *args, **kwargs) + Object3D.__init__(self) self.set_3d_properties(zs, zdir) def set_sort_zpos(self, val): @@ -493,7 +514,7 @@ def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True): col.set_3d_properties(zs, zdir) -class Poly3DCollection(PolyCollection): +class Poly3DCollection(PolyCollection, Object3D): """ A collection of 3D polygons. """ @@ -510,7 +531,8 @@ def __init__(self, verts, *args, zsort='average', **kwargs): Note that this class does a bit of magic with the _facecolors and _edgecolors properties. """ - super().__init__(verts, *args, **kwargs) + PolyCollection.__init__(self, verts, *args, **kwargs) + Object3D.__init__(self) self.set_zsort(zsort) self._codes3d = None diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 3aa24d5e6e02..5276364562aa 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -281,16 +281,23 @@ def draw(self, renderer): # Make sure they are drawn above the grids. zorder_offset = max(axis.get_zorder() for axis in self._get_axis_list()) + 1 - for i, col in enumerate( - sorted(self.collections, - key=lambda col: col.do_3d_projection(renderer), - reverse=True)): - col.zorder = zorder_offset + i - for i, patch in enumerate( - sorted(self.patches, - key=lambda patch: patch.do_3d_projection(renderer), - reverse=True)): - patch.zorder = zorder_offset + i + + # Group of objects to be dynamically reordered + select_children = list() + select_children.extend(self.collections) + select_children.extend(self.patches) + select_children.extend(self.lines) + select_children.extend(self.texts) + for i, h in enumerate( + sorted( + select_children, + key=lambda h: ( + h._sort_zpos if hasattr(h, '_sort_zpos') else np.inf, + -h.do_3d_projection(renderer) if hasattr(h, 'do_3d_projection') else -np.inf + ) + ) + ): + h.zorder = zorder_offset + i if self._axis3don: # Draw panes first