diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index 822ff79fcc24..a521263a62f2 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -302,8 +302,6 @@ def do_3d_projection(self, renderer=None): """ Project the points according to renderer matrix. """ - # see _update_scalarmappable docstring for why this must be here - _update_scalarmappable(self) xyslist = [proj3d.proj_trans_points(points, self.axes.M) for points in self._segments3d] segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist] @@ -448,16 +446,6 @@ def set_depthshade(self, depthshade): self._depthshade = depthshade self.stale = True - def set_facecolor(self, c): - # docstring inherited - super().set_facecolor(c) - self._facecolor3d = self.get_facecolor() - - def set_edgecolor(self, c): - # docstring inherited - super().set_edgecolor(c) - self._edgecolor3d = self.get_edgecolor() - def set_sort_zpos(self, val): """Set the position to use for z-sorting.""" self._sort_zpos = val @@ -474,27 +462,15 @@ def set_3d_properties(self, zs, zdir): xs = [] ys = [] self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir) - self._facecolor3d = self.get_facecolor() - self._edgecolor3d = self.get_edgecolor() + self._vzs = None self.stale = True @_api.delete_parameter('3.4', 'renderer') def do_3d_projection(self, renderer=None): - # see _update_scalarmappable docstring for why this must be here - _update_scalarmappable(self) xs, ys, zs = self._offsets3d vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs, self.axes.M) - - fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else - self._facecolor3d) - fcs = mcolors.to_rgba_array(fcs, self._alpha) - super().set_facecolor(fcs) - - ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else - self._edgecolor3d) - ecs = mcolors.to_rgba_array(ecs, self._alpha) - super().set_edgecolor(ecs) + self._vzs = vzs super().set_offsets(np.column_stack([vxs, vys])) if vzs.size > 0: @@ -502,6 +478,27 @@ def do_3d_projection(self, renderer=None): else: return np.nan + def _maybe_depth_shade_and_sort_colors(self, color_array): + color_array = ( + _zalpha(color_array, self._vzs) + if self._vzs is not None and self._depthshade + else color_array + ) + if len(color_array) > 1: + color_array = color_array[self._z_markers_idx] + return mcolors.to_rgba_array(color_array, self._alpha) + + def get_facecolor(self): + return self._maybe_depth_shade_and_sort_colors(super().get_facecolor()) + + def get_edgecolor(self): + # We need this check here to make sure we do not double-apply the depth + # based alpha shading when the edge color is "face" which means the + # edge colour should be identical to the face colour. + if cbook._str_equal(self._edgecolors, 'face'): + return self.get_facecolor() + return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor()) + class Path3DCollection(PathCollection): """ @@ -525,9 +522,14 @@ def __init__(self, *args, zs=0, zdir='z', depthshade=True, **kwargs): This is typically desired in scatter plots. """ self._depthshade = depthshade + self._in_draw = False super().__init__(*args, **kwargs) self.set_3d_properties(zs, zdir) + def draw(self, renderer): + with cbook._setattr_cm(self, _in_draw=True): + super().draw(renderer) + def set_sort_zpos(self, val): """Set the position to use for z-sorting.""" self._sort_zpos = val @@ -544,12 +546,37 @@ def set_3d_properties(self, zs, zdir): xs = [] ys = [] self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir) - self._facecolor3d = self.get_facecolor() - self._edgecolor3d = self.get_edgecolor() - self._sizes3d = self.get_sizes() - self._linewidth3d = self.get_linewidth() + # In the base draw methods we access the attributes directly which + # means we can not resolve the shuffling in the getter methods like + # we do for the edge and face colors. + # + # This means we need to carry around a cache of the unsorted sizes and + # widths (postfixed with 3d) and in `do_3d_projection` set the + # depth-sorted version of that data into the private state used by the + # base collection class in its draw method. + # + # Grab the current sizes and linewidths to preserve them. + self._sizes3d = self._sizes + self._linewidths3d = self._linewidths + xs, ys, zs = self._offsets3d + + # Sort the points based on z coordinates + # Performance optimization: Create a sorted index array and reorder + # points and point properties according to the index array + self._z_markers_idx = slice(-1) + self._vzs = None self.stale = True + def set_sizes(self, sizes, dpi=72.0): + super().set_sizes(sizes, dpi) + if not self._in_draw: + self._sizes3d = sizes + + def set_linewidth(self, lw): + super().set_linewidth(lw) + if not self._in_draw: + self._linewidth3d = lw + def get_depthshade(self): return self._depthshade @@ -566,142 +593,57 @@ def set_depthshade(self, depthshade): self._depthshade = depthshade self.stale = True - def set_facecolor(self, c): - # docstring inherited - super().set_facecolor(c) - self._facecolor3d = self.get_facecolor() - - def set_edgecolor(self, c): - # docstring inherited - super().set_edgecolor(c) - self._edgecolor3d = self.get_edgecolor() - - def set_sizes(self, sizes, dpi=72.0): - # docstring inherited - super().set_sizes(sizes, dpi=dpi) - self._sizes3d = self.get_sizes() - - def set_linewidth(self, lw): - # docstring inherited - super().set_linewidth(lw) - self._linewidth3d = self.get_linewidth() - @_api.delete_parameter('3.4', 'renderer') def do_3d_projection(self, renderer=None): - # see _update_scalarmappable docstring for why this must be here - _update_scalarmappable(self) xs, ys, zs = self._offsets3d vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs, self.axes.M) - - fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else - self._facecolor3d) - ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else - self._edgecolor3d) - sizes = self._sizes3d - lws = self._linewidth3d - # Sort the points based on z coordinates # Performance optimization: Create a sorted index array and reorder # points and point properties according to the index array - z_markers_idx = np.argsort(vzs)[::-1] + z_markers_idx = self._z_markers_idx = np.argsort(vzs)[::-1] + self._vzs = vzs + + # we have to special case the sizes because of code in collections.py + # as the draw method does + # self.set_sizes(self._sizes, self.figure.dpi) + # so we can not rely on doing the sorting on the way out via get_* + + if len(self._sizes3d) > 1: + self._sizes = self._sizes3d[z_markers_idx] + + if len(self._linewidths3d) > 1: + self._linewidths = self._linewidths3d[z_markers_idx] # Re-order items vzs = vzs[z_markers_idx] vxs = vxs[z_markers_idx] vys = vys[z_markers_idx] - if len(fcs) > 1: - fcs = fcs[z_markers_idx] - if len(ecs) > 1: - ecs = ecs[z_markers_idx] - if len(sizes) > 1: - sizes = sizes[z_markers_idx] - if len(lws) > 1: - lws = lws[z_markers_idx] - vps = np.column_stack((vxs, vys)) - - fcs = mcolors.to_rgba_array(fcs, self._alpha) - ecs = mcolors.to_rgba_array(ecs, self._alpha) - - super().set_edgecolor(ecs) - super().set_facecolor(fcs) - super().set_sizes(sizes) - super().set_linewidth(lws) - - PathCollection.set_offsets(self, vps) - return np.min(vzs) if vzs.size else np.nan + PathCollection.set_offsets(self, np.column_stack((vxs, vys))) + return np.min(vzs) if vzs.size else np.nan -def _update_scalarmappable(sm): - """ - Update a 3D ScalarMappable. - - With ScalarMappable objects if the data, colormap, or norm are - changed, we need to update the computed colors. This is handled - by the base class method update_scalarmappable. This method works - by detecting if work needs to be done, and if so stashing it on - the ``self._facecolors`` attribute. - - With 3D collections we internally sort the components so that - things that should be "in front" are rendered later to simulate - having a z-buffer (in addition to doing the projections). This is - handled in the ``do_3d_projection`` methods which are called from the - draw method of the 3D Axes. These methods: - - - do the projection from 3D -> 2D - - internally sort based on depth - - stash the results of the above in the 2D analogs of state - - return the z-depth of the whole artist - - the last step is so that we can, at the Axes level, sort the children by - depth. - - The base `draw` method of the 2D artists unconditionally calls - update_scalarmappable and rely on the method's internal caching logic to - lazily evaluate. - - These things together mean you can have the sequence of events: - - - we create the artist, do the color mapping and stash the results - in a 3D specific state. - - change something about the ScalarMappable that marks it as in - need of an update (`ScalarMappable.changed` and friends). - - We call do_3d_projection and shuffle the stashed colors into the - 2D version of face colors - - the draw method calls the update_scalarmappable method which - overwrites our shuffled colors - - we get a render that is wrong - - if we re-render (either with a second save or implicitly via - tight_layout / constrained_layout / bbox_inches='tight' (ex via - inline's defaults)) we again shuffle the 3D colors - - because the CM is not marked as changed update_scalarmappable is - a no-op and we get a correct looking render. - - This function is an internal helper to: - - - sort out if we need to do the color mapping at all (has data!) - - sort out if update_scalarmappable is going to be a no-op - - copy the data over from the 2D -> 3D version - - This must be called first thing in do_3d_projection to make sure that - the correct colors get shuffled. + def _maybe_depth_shade_and_sort_colors(self, color_array): + color_array = ( + _zalpha(color_array, self._vzs) + if self._vzs is not None and self._depthshade + else color_array + ) + if len(color_array) > 1: + color_array = color_array[self._z_markers_idx] + return mcolors.to_rgba_array(color_array, self._alpha) - Parameters - ---------- - sm : ScalarMappable - The ScalarMappable to update and stash the 3D data from + def get_facecolor(self): + return self._maybe_depth_shade_and_sort_colors(super().get_facecolor()) - """ - if sm._A is None: - return - copy_state = sm._update_dict['array'] - ret = sm.update_scalarmappable() - if copy_state: - if sm._face_is_mapped: - sm._facecolor3d = sm._facecolors - elif sm._edge_is_mapped: # Should this be plain "if"? - sm._edgecolor3d = sm._edgecolors + def get_edgecolor(self): + # We need this check here to make sure we do not double-apply the depth + # based alpha shading when the edge color is "face" which means the + # edge colour should be identical to the face colour. + if cbook._str_equal(self._edgecolors, 'face'): + return self.get_facecolor() + return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor()) def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True): @@ -727,6 +669,7 @@ def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True): elif isinstance(col, PatchCollection): col.__class__ = Patch3DCollection col._depthshade = depthshade + col._in_draw = False col.set_3d_properties(zs, zdir) @@ -841,9 +784,21 @@ def do_3d_projection(self, renderer=None): """ Perform the 3D projection for this object. """ - # see _update_scalarmappable docstring for why this must be here - _update_scalarmappable(self) - + if self._A is not None: + # force update of color mapping because we re-order them + # below. If we do not do this here, the 2D draw will call + # this, but we will never port the color mapped values back + # to the 3D versions. + # + # We hold the 3D versions in a fixed order (the order the user + # passed in) and sort the 2D version by view depth. + copy_state = self._update_dict['array'] + self.update_scalarmappable() + if copy_state: + if self._face_is_mapped: + self._facecolor3d = self._facecolors + if self._edge_is_mapped: + self._edgecolor3d = self._edgecolors txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M) xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices] diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/scatter_spiral.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/scatter_spiral.png new file mode 100644 index 000000000000..134e75e170cc Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/scatter_spiral.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index f55edaa4aa1f..65049f5f9835 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -1427,3 +1427,16 @@ def test_subfigure_simple(): sf = fig.subfigures(1, 2) ax = sf[0].add_subplot(1, 1, 1, projection='3d') ax = sf[1].add_subplot(1, 1, 1, projection='3d', label='other') + + +@image_comparison(baseline_images=['scatter_spiral.png'], + remove_text=True, + style='default') +def test_scatter_spiral(): + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + th = np.linspace(0, 2 * np.pi * 6, 256) + sc = ax.scatter(np.sin(th), np.cos(th), th, s=(1 + th * 5), c=th ** 2) + + # force at least 1 draw! + fig.canvas.draw()