Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit fa54783

Browse files
authored
Merge pull request #18929 from tacaswell/fix_3D_scatter
FIX: make sure scalarmappable updates are handled correctly in 3D
2 parents 02af61b + 2bbe60a commit fa54783

File tree

2 files changed

+113
-16
lines changed

2 files changed

+113
-16
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,8 @@ def do_3d_projection(self, renderer=None):
301301
"""
302302
Project the points according to renderer matrix.
303303
"""
304+
# see _update_scalarmappable docstring for why this must be here
305+
_update_scalarmappable(self)
304306
xyslist = [proj3d.proj_trans_points(points, self.axes.M)
305307
for points in self._segments3d]
306308
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
@@ -486,6 +488,8 @@ def set_3d_properties(self, zs, zdir):
486488

487489
@cbook._delete_parameter('3.4', 'renderer')
488490
def do_3d_projection(self, renderer=None):
491+
# see _update_scalarmappable docstring for why this must be here
492+
_update_scalarmappable(self)
489493
xs, ys, zs = self._offsets3d
490494
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
491495
self.axes.M)
@@ -592,6 +596,8 @@ def set_linewidth(self, lw):
592596

593597
@cbook._delete_parameter('3.4', 'renderer')
594598
def do_3d_projection(self, renderer=None):
599+
# see _update_scalarmappable docstring for why this must be here
600+
_update_scalarmappable(self)
595601
xs, ys, zs = self._offsets3d
596602
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
597603
self.axes.M)
@@ -635,6 +641,77 @@ def do_3d_projection(self, renderer=None):
635641
return np.min(vzs) if vzs.size else np.nan
636642

637643

644+
def _update_scalarmappable(sm):
645+
"""
646+
Update a 3D ScalarMappable.
647+
648+
With ScalarMappable objects if the data, colormap, or norm are
649+
changed, we need to update the computed colors. This is handled
650+
by the base class method update_scalarmappable. This method works
651+
by, detecting if work needs to be done, and if so stashing it on
652+
the ``self._facecolors`` attribute.
653+
654+
With 3D collections we internally sort the components so that
655+
things that should be "in front" are rendered later to simulate
656+
having a z-buffer (in addition to doing the projections). This is
657+
handled in the ``do_3d_projection`` methods which are called from the
658+
draw method of the 3D Axes. These methods:
659+
660+
- do the projection from 3D -> 2D
661+
- internally sort based on depth
662+
- stash the results of the above in the 2D analogs of state
663+
- return the z-depth of the whole artist
664+
665+
the last step is so that we can, at the Axes level, sort the children by
666+
depth.
667+
668+
The base `draw` method of the 2D artists unconditionally calls
669+
update_scalarmappable and rely on the method's internal caching logic to
670+
lazily evaluate.
671+
672+
These things together mean you can have the sequence of events:
673+
674+
- we create the artist, do the color mapping and stash the results
675+
in a 3D specific state.
676+
- change something about the ScalarMappable that marks it as in
677+
need of an update (`ScalarMappable.changed` and friends).
678+
- We call do_3d_projection and shuffle the stashed colors into the
679+
2D version of face colors
680+
- the draw method calls the update_scalarmappable method which
681+
overwrites our shuffled colors
682+
- we get a render that is wrong
683+
- if we re-render (either with a second save or implicitly via
684+
tight_layout / constrained_layout / bbox_inches='tight' (ex via
685+
inline's defaults)) we again shuffle the 3D colors
686+
- because the CM is not marked as changed update_scalarmappable is
687+
a no-op and we get a correct looking render.
688+
689+
This function is an internal helper to:
690+
691+
- sort out if we need to do the color mapping at all (has data!)
692+
- sort out if update_scalarmappable is going to be a no-op
693+
- copy the data over from the 2D -> 3D version
694+
695+
This must be called first thing in do_3d_projection to make sure that
696+
the correct colors get shuffled.
697+
698+
Parameters
699+
----------
700+
sm : ScalarMappable
701+
The ScalarMappable to update and stash the 3D data from
702+
703+
"""
704+
if sm._A is None:
705+
return
706+
copy_state = sm._update_dict['array']
707+
ret = sm.update_scalarmappable()
708+
if copy_state:
709+
if sm._is_filled:
710+
sm._facecolor3d = sm._facecolors
711+
elif sm._is_stroked:
712+
sm._edgecolor3d = sm._edgecolors
713+
714+
638715
def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
639716
"""
640717
Convert a :class:`~matplotlib.collections.PatchCollection` into a
@@ -757,8 +834,8 @@ def set_3d_properties(self):
757834
self.update_scalarmappable()
758835
self._sort_zpos = None
759836
self.set_zsort('average')
760-
self._facecolors3d = PolyCollection.get_facecolor(self)
761-
self._edgecolors3d = PolyCollection.get_edgecolor(self)
837+
self._facecolor3d = PolyCollection.get_facecolor(self)
838+
self._edgecolor3d = PolyCollection.get_edgecolor(self)
762839
self._alpha3d = PolyCollection.get_alpha(self)
763840
self.stale = True
764841

@@ -772,17 +849,15 @@ def do_3d_projection(self, renderer=None):
772849
"""
773850
Perform the 3D projection for this object.
774851
"""
775-
# FIXME: This may no longer be needed?
776-
if self._A is not None:
777-
self.update_scalarmappable()
778-
self._facecolors3d = self._facecolors
852+
# see _update_scalarmappable docstring for why this must be here
853+
_update_scalarmappable(self)
779854

780855
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M)
781856
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
782857

783858
# This extra fuss is to re-order face / edge colors
784-
cface = self._facecolors3d
785-
cedge = self._edgecolors3d
859+
cface = self._facecolor3d
860+
cedge = self._edgecolor3d
786861
if len(cface) != len(xyzlist):
787862
cface = cface.repeat(len(xyzlist), axis=0)
788863
if len(cedge) != len(xyzlist):
@@ -807,8 +882,8 @@ def do_3d_projection(self, renderer=None):
807882
else:
808883
PolyCollection.set_verts(self, segments_2d, self._closed)
809884

810-
if len(self._edgecolors3d) != len(cface):
811-
self._edgecolors2d = self._edgecolors3d
885+
if len(self._edgecolor3d) != len(cface):
886+
self._edgecolors2d = self._edgecolor3d
812887

813888
# Return zorder value
814889
if self._sort_zpos is not None:
@@ -826,24 +901,24 @@ def do_3d_projection(self, renderer=None):
826901
def set_facecolor(self, colors):
827902
# docstring inherited
828903
super().set_facecolor(colors)
829-
self._facecolors3d = PolyCollection.get_facecolor(self)
904+
self._facecolor3d = PolyCollection.get_facecolor(self)
830905

831906
def set_edgecolor(self, colors):
832907
# docstring inherited
833908
super().set_edgecolor(colors)
834-
self._edgecolors3d = PolyCollection.get_edgecolor(self)
909+
self._edgecolor3d = PolyCollection.get_edgecolor(self)
835910

836911
def set_alpha(self, alpha):
837912
# docstring inherited
838913
artist.Artist.set_alpha(self, alpha)
839914
try:
840-
self._facecolors3d = mcolors.to_rgba_array(
841-
self._facecolors3d, self._alpha)
915+
self._facecolor3d = mcolors.to_rgba_array(
916+
self._facecolor3d, self._alpha)
842917
except (AttributeError, TypeError, IndexError):
843918
pass
844919
try:
845920
self._edgecolors = mcolors.to_rgba_array(
846-
self._edgecolors3d, self._alpha)
921+
self._edgecolor3d, self._alpha)
847922
except (AttributeError, TypeError, IndexError):
848923
pass
849924
self.stale = True

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_bar3d_lightsource():
108108
# the top facecolors compared to the default, and that those colors are
109109
# precisely the colors from the colormap, due to the illumination parallel
110110
# to the z-axis.
111-
np.testing.assert_array_equal(color, collection._facecolors3d[1::6])
111+
np.testing.assert_array_equal(color, collection._facecolor3d[1::6])
112112

113113

114114
@mpl3d_image_comparison(['contour3d.png'])
@@ -1302,3 +1302,25 @@ def convert_lim(dmin, dmax):
13021302
assert x_center != pytest.approx(x_center0)
13031303
assert y_center != pytest.approx(y_center0)
13041304
assert z_center != pytest.approx(z_center0)
1305+
1306+
1307+
@pytest.mark.style('default')
1308+
@check_figures_equal(extensions=["png"])
1309+
def test_scalarmap_update(fig_test, fig_ref):
1310+
1311+
x, y, z = np.array((list(itertools.product(*[np.arange(0, 5, 1),
1312+
np.arange(0, 5, 1),
1313+
np.arange(0, 5, 1)])))).T
1314+
c = x + y
1315+
1316+
# test
1317+
ax_test = fig_test.add_subplot(111, projection='3d')
1318+
sc_test = ax_test.scatter(x, y, z, c=c, s=40, cmap='viridis')
1319+
# force a draw
1320+
fig_test.canvas.draw()
1321+
# mark it as "stale"
1322+
sc_test.changed()
1323+
1324+
# ref
1325+
ax_ref = fig_ref.add_subplot(111, projection='3d')
1326+
sc_ref = ax_ref.scatter(x, y, z, c=c, s=40, cmap='viridis')

0 commit comments

Comments
 (0)