Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f679026

Browse files
committed
Backport PR #18929: FIX: make sure scalarmappable updates are handled correctly in 3D
1 parent fe81185 commit f679026

File tree

2 files changed

+113
-16
lines changed

2 files changed

+113
-16
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def do_3d_projection(self, renderer):
264264
"""
265265
Project the points according to renderer matrix.
266266
"""
267+
# see _update_scalarmappable docstring for why this must be here
268+
_update_scalarmappable(self)
267269
xyslist = [
268270
proj3d.proj_trans_points(points, renderer.M) for points in
269271
self._segments3d]
@@ -418,6 +420,8 @@ def set_3d_properties(self, zs, zdir):
418420
self.stale = True
419421

420422
def do_3d_projection(self, renderer):
423+
# see _update_scalarmappable docstring for why this must be here
424+
_update_scalarmappable(self)
421425
xs, ys, zs = self._offsets3d
422426
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs, renderer.M)
423427

@@ -486,6 +490,8 @@ def set_3d_properties(self, zs, zdir):
486490
self.stale = True
487491

488492
def do_3d_projection(self, renderer):
493+
# see _update_scalarmappable docstring for why this must be here
494+
_update_scalarmappable(self)
489495
xs, ys, zs = self._offsets3d
490496
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs, renderer.M)
491497

@@ -528,6 +534,77 @@ def do_3d_projection(self, renderer):
528534
return np.min(vzs) if vzs.size else np.nan
529535

530536

537+
def _update_scalarmappable(sm):
538+
"""
539+
Update a 3D ScalarMappable.
540+
541+
With ScalarMappable objects if the data, colormap, or norm are
542+
changed, we need to update the computed colors. This is handled
543+
by the base class method update_scalarmappable. This method works
544+
by, detecting if work needs to be done, and if so stashing it on
545+
the ``self._facecolors`` attribute.
546+
547+
With 3D collections we internally sort the components so that
548+
things that should be "in front" are rendered later to simulate
549+
having a z-buffer (in addition to doing the projections). This is
550+
handled in the ``do_3d_projection`` methods which are called from the
551+
draw method of the 3D Axes. These methods:
552+
553+
- do the projection from 3D -> 2D
554+
- internally sort based on depth
555+
- stash the results of the above in the 2D analogs of state
556+
- return the z-depth of the whole artist
557+
558+
the last step is so that we can, at the Axes level, sort the children by
559+
depth.
560+
561+
The base `draw` method of the 2D artists unconditionally calls
562+
update_scalarmappable and rely on the method's internal caching logic to
563+
lazily evaluate.
564+
565+
These things together mean you can have the sequence of events:
566+
567+
- we create the artist, do the color mapping and stash the results
568+
in a 3D specific state.
569+
- change something about the ScalarMappable that marks it as in
570+
need of an update (`ScalarMappable.changed` and friends).
571+
- We call do_3d_projection and shuffle the stashed colors into the
572+
2D version of face colors
573+
- the draw method calls the update_scalarmappable method which
574+
overwrites our shuffled colors
575+
- we get a render that is wrong
576+
- if we re-render (either with a second save or implicitly via
577+
tight_layout / constrained_layout / bbox_inches='tight' (ex via
578+
inline's defaults)) we again shuffle the 3D colors
579+
- because the CM is not marked as changed update_scalarmappable is
580+
a no-op and we get a correct looking render.
581+
582+
This function is an internal helper to:
583+
584+
- sort out if we need to do the color mapping at all (has data!)
585+
- sort out if update_scalarmappable is going to be a no-op
586+
- copy the data over from the 2D -> 3D version
587+
588+
This must be called first thing in do_3d_projection to make sure that
589+
the correct colors get shuffled.
590+
591+
Parameters
592+
----------
593+
sm : ScalarMappable
594+
The ScalarMappable to update and stash the 3D data from
595+
596+
"""
597+
if sm._A is None:
598+
return
599+
copy_state = sm._update_dict['array']
600+
ret = sm.update_scalarmappable()
601+
if copy_state:
602+
if sm._is_filled:
603+
sm._facecolor3d = sm._facecolors
604+
elif sm._is_stroked:
605+
sm._edgecolor3d = sm._edgecolors
606+
607+
531608
def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
532609
"""
533610
Convert a :class:`~matplotlib.collections.PatchCollection` into a
@@ -650,8 +727,8 @@ def set_3d_properties(self):
650727
self.update_scalarmappable()
651728
self._sort_zpos = None
652729
self.set_zsort('average')
653-
self._facecolors3d = PolyCollection.get_facecolor(self)
654-
self._edgecolors3d = PolyCollection.get_edgecolor(self)
730+
self._facecolor3d = PolyCollection.get_facecolor(self)
731+
self._edgecolor3d = PolyCollection.get_edgecolor(self)
655732
self._alpha3d = PolyCollection.get_alpha(self)
656733
self.stale = True
657734

@@ -664,17 +741,15 @@ def do_3d_projection(self, renderer):
664741
"""
665742
Perform the 3D projection for this object.
666743
"""
667-
# FIXME: This may no longer be needed?
668-
if self._A is not None:
669-
self.update_scalarmappable()
670-
self._facecolors3d = self._facecolors
744+
# see _update_scalarmappable docstring for why this must be here
745+
_update_scalarmappable(self)
671746

672747
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, renderer.M)
673748
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
674749

675750
# This extra fuss is to re-order face / edge colors
676-
cface = self._facecolors3d
677-
cedge = self._edgecolors3d
751+
cface = self._facecolor3d
752+
cedge = self._edgecolor3d
678753
if len(cface) != len(xyzlist):
679754
cface = cface.repeat(len(xyzlist), axis=0)
680755
if len(cedge) != len(xyzlist):
@@ -699,8 +774,8 @@ def do_3d_projection(self, renderer):
699774
else:
700775
PolyCollection.set_verts(self, segments_2d, self._closed)
701776

702-
if len(self._edgecolors3d) != len(cface):
703-
self._edgecolors2d = self._edgecolors3d
777+
if len(self._edgecolor3d) != len(cface):
778+
self._edgecolors2d = self._edgecolor3d
704779

705780
# Return zorder value
706781
if self._sort_zpos is not None:
@@ -717,23 +792,23 @@ def do_3d_projection(self, renderer):
717792

718793
def set_facecolor(self, colors):
719794
PolyCollection.set_facecolor(self, colors)
720-
self._facecolors3d = PolyCollection.get_facecolor(self)
795+
self._facecolor3d = PolyCollection.get_facecolor(self)
721796

722797
def set_edgecolor(self, colors):
723798
PolyCollection.set_edgecolor(self, colors)
724-
self._edgecolors3d = PolyCollection.get_edgecolor(self)
799+
self._edgecolor3d = PolyCollection.get_edgecolor(self)
725800

726801
def set_alpha(self, alpha):
727802
# docstring inherited
728803
artist.Artist.set_alpha(self, alpha)
729804
try:
730-
self._facecolors3d = mcolors.to_rgba_array(
731-
self._facecolors3d, self._alpha)
805+
self._facecolor3d = mcolors.to_rgba_array(
806+
self._facecolor3d, self._alpha)
732807
except (AttributeError, TypeError, IndexError):
733808
pass
734809
try:
735810
self._edgecolors = mcolors.to_rgba_array(
736-
self._edgecolors3d, self._alpha)
811+
self._edgecolor3d, self._alpha)
737812
except (AttributeError, TypeError, IndexError):
738813
pass
739814
self.stale = True

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_bar3d_lightsource():
106106
# the top facecolors compared to the default, and that those colors are
107107
# precisely the colors from the colormap, due to the illumination parallel
108108
# to the z-axis.
109-
np.testing.assert_array_equal(color, collection._facecolors3d[1::6])
109+
np.testing.assert_array_equal(color, collection._facecolor3d[1::6])
110110

111111

112112
@mpl3d_image_comparison(['contour3d.png'])
@@ -1087,3 +1087,25 @@ def test_colorbar_pos():
10871087
fig.canvas.draw()
10881088
# check that actually on the bottom
10891089
assert cbar.ax.get_position().extents[1] < 0.2
1090+
1091+
1092+
@pytest.mark.style('default')
1093+
@check_figures_equal(extensions=["png"])
1094+
def test_scalarmap_update(fig_test, fig_ref):
1095+
1096+
x, y, z = np.array((list(itertools.product(*[np.arange(0, 5, 1),
1097+
np.arange(0, 5, 1),
1098+
np.arange(0, 5, 1)])))).T
1099+
c = x + y
1100+
1101+
# test
1102+
ax_test = fig_test.add_subplot(111, projection='3d')
1103+
sc_test = ax_test.scatter(x, y, z, c=c, s=40, cmap='viridis')
1104+
# force a draw
1105+
fig_test.canvas.draw()
1106+
# mark it as "stale"
1107+
sc_test.changed()
1108+
1109+
# ref
1110+
ax_ref = fig_ref.add_subplot(111, projection='3d')
1111+
sc_ref = ax_ref.scatter(x, y, z, c=c, s=40, cmap='viridis')

0 commit comments

Comments
 (0)