Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ee9500a

Browse files
authored
Merge pull request #19812 from tacaswell/fix_scatter3D
FIX: size and color rendering for Path3DCollection
2 parents 45e3d1f + f888285 commit ee9500a

File tree

3 files changed

+117
-149
lines changed

3 files changed

+117
-149
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 104 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,6 @@ def do_3d_projection(self, renderer=None):
302302
"""
303303
Project the points according to renderer matrix.
304304
"""
305-
# see _update_scalarmappable docstring for why this must be here
306-
_update_scalarmappable(self)
307305
xyslist = [proj3d.proj_trans_points(points, self.axes.M)
308306
for points in self._segments3d]
309307
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
@@ -448,16 +446,6 @@ def set_depthshade(self, depthshade):
448446
self._depthshade = depthshade
449447
self.stale = True
450448

451-
def set_facecolor(self, c):
452-
# docstring inherited
453-
super().set_facecolor(c)
454-
self._facecolor3d = self.get_facecolor()
455-
456-
def set_edgecolor(self, c):
457-
# docstring inherited
458-
super().set_edgecolor(c)
459-
self._edgecolor3d = self.get_edgecolor()
460-
461449
def set_sort_zpos(self, val):
462450
"""Set the position to use for z-sorting."""
463451
self._sort_zpos = val
@@ -474,34 +462,43 @@ def set_3d_properties(self, zs, zdir):
474462
xs = []
475463
ys = []
476464
self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir)
477-
self._facecolor3d = self.get_facecolor()
478-
self._edgecolor3d = self.get_edgecolor()
465+
self._vzs = None
479466
self.stale = True
480467

481468
@_api.delete_parameter('3.4', 'renderer')
482469
def do_3d_projection(self, renderer=None):
483-
# see _update_scalarmappable docstring for why this must be here
484-
_update_scalarmappable(self)
485470
xs, ys, zs = self._offsets3d
486471
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
487472
self.axes.M)
488-
489-
fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else
490-
self._facecolor3d)
491-
fcs = mcolors.to_rgba_array(fcs, self._alpha)
492-
super().set_facecolor(fcs)
493-
494-
ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
495-
self._edgecolor3d)
496-
ecs = mcolors.to_rgba_array(ecs, self._alpha)
497-
super().set_edgecolor(ecs)
473+
self._vzs = vzs
498474
super().set_offsets(np.column_stack([vxs, vys]))
499475

500476
if vzs.size > 0:
501477
return min(vzs)
502478
else:
503479
return np.nan
504480

481+
def _maybe_depth_shade_and_sort_colors(self, color_array):
482+
color_array = (
483+
_zalpha(color_array, self._vzs)
484+
if self._vzs is not None and self._depthshade
485+
else color_array
486+
)
487+
if len(color_array) > 1:
488+
color_array = color_array[self._z_markers_idx]
489+
return mcolors.to_rgba_array(color_array, self._alpha)
490+
491+
def get_facecolor(self):
492+
return self._maybe_depth_shade_and_sort_colors(super().get_facecolor())
493+
494+
def get_edgecolor(self):
495+
# We need this check here to make sure we do not double-apply the depth
496+
# based alpha shading when the edge color is "face" which means the
497+
# edge colour should be identical to the face colour.
498+
if cbook._str_equal(self._edgecolors, 'face'):
499+
return self.get_facecolor()
500+
return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor())
501+
505502

506503
class Path3DCollection(PathCollection):
507504
"""
@@ -525,9 +522,14 @@ def __init__(self, *args, zs=0, zdir='z', depthshade=True, **kwargs):
525522
This is typically desired in scatter plots.
526523
"""
527524
self._depthshade = depthshade
525+
self._in_draw = False
528526
super().__init__(*args, **kwargs)
529527
self.set_3d_properties(zs, zdir)
530528

529+
def draw(self, renderer):
530+
with cbook._setattr_cm(self, _in_draw=True):
531+
super().draw(renderer)
532+
531533
def set_sort_zpos(self, val):
532534
"""Set the position to use for z-sorting."""
533535
self._sort_zpos = val
@@ -544,12 +546,37 @@ def set_3d_properties(self, zs, zdir):
544546
xs = []
545547
ys = []
546548
self._offsets3d = juggle_axes(xs, ys, np.atleast_1d(zs), zdir)
547-
self._facecolor3d = self.get_facecolor()
548-
self._edgecolor3d = self.get_edgecolor()
549-
self._sizes3d = self.get_sizes()
550-
self._linewidth3d = self.get_linewidth()
549+
# In the base draw methods we access the attributes directly which
550+
# means we can not resolve the shuffling in the getter methods like
551+
# we do for the edge and face colors.
552+
#
553+
# This means we need to carry around a cache of the unsorted sizes and
554+
# widths (postfixed with 3d) and in `do_3d_projection` set the
555+
# depth-sorted version of that data into the private state used by the
556+
# base collection class in its draw method.
557+
#
558+
# grab the current sizes and linewidths to preserve them
559+
self._sizes3d = self._sizes
560+
self._linewidths3d = self._linewidths
561+
xs, ys, zs = self._offsets3d
562+
563+
# Sort the points based on z coordinates
564+
# Performance optimization: Create a sorted index array and reorder
565+
# points and point properties according to the index array
566+
self._z_markers_idx = slice(-1)
567+
self._vzs = None
551568
self.stale = True
552569

570+
def set_sizes(self, sizes, dpi=72.0):
571+
super().set_sizes(sizes, dpi)
572+
if not self._in_draw:
573+
self._sizes3d = sizes
574+
575+
def set_linewidth(self, lw):
576+
super().set_linewidth(lw)
577+
if not self._in_draw:
578+
self._linewidth3d = lw
579+
553580
def get_depthshade(self):
554581
return self._depthshade
555582

@@ -566,140 +593,57 @@ def set_depthshade(self, depthshade):
566593
self._depthshade = depthshade
567594
self.stale = True
568595

569-
def set_facecolor(self, c):
570-
# docstring inherited
571-
super().set_facecolor(c)
572-
self._facecolor3d = self.get_facecolor()
573-
574-
def set_edgecolor(self, c):
575-
# docstring inherited
576-
super().set_edgecolor(c)
577-
self._edgecolor3d = self.get_edgecolor()
578-
579-
def set_sizes(self, sizes, dpi=72.0):
580-
# docstring inherited
581-
super().set_sizes(sizes, dpi=dpi)
582-
self._sizes3d = self.get_sizes()
583-
584-
def set_linewidth(self, lw):
585-
# docstring inherited
586-
super().set_linewidth(lw)
587-
self._linewidth3d = self.get_linewidth()
588-
589596
@_api.delete_parameter('3.4', 'renderer')
590597
def do_3d_projection(self, renderer=None):
591-
# see _update_scalarmappable docstring for why this must be here
592-
_update_scalarmappable(self)
593598
xs, ys, zs = self._offsets3d
594599
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
595600
self.axes.M)
596-
597-
fcs = (_zalpha(self._facecolor3d, vzs) if self._depthshade else
598-
self._facecolor3d)
599-
ecs = (_zalpha(self._edgecolor3d, vzs) if self._depthshade else
600-
self._edgecolor3d)
601-
sizes = self._sizes3d
602-
lws = self._linewidth3d
603-
604601
# Sort the points based on z coordinates
605602
# Performance optimization: Create a sorted index array and reorder
606603
# points and point properties according to the index array
607-
z_markers_idx = np.argsort(vzs)[::-1]
604+
z_markers_idx = self._z_markers_idx = np.argsort(vzs)[::-1]
605+
self._vzs = vzs
606+
607+
# we have to special case the sizes because of code in collections.py
608+
# as the draw method does
609+
# self.set_sizes(self._sizes, self.figure.dpi)
610+
# so we can not rely on doing the sorting on the way out via get_*
611+
612+
if len(self._sizes3d) > 1:
613+
self._sizes = self._sizes3d[z_markers_idx]
614+
615+
if len(self._linewidths3d) > 1:
616+
self._linewidths = self._linewidths3d[z_markers_idx]
608617

609618
# Re-order items
610619
vzs = vzs[z_markers_idx]
611620
vxs = vxs[z_markers_idx]
612621
vys = vys[z_markers_idx]
613-
if len(fcs) > 1:
614-
fcs = fcs[z_markers_idx]
615-
if len(ecs) > 1:
616-
ecs = ecs[z_markers_idx]
617-
if len(sizes) > 1:
618-
sizes = sizes[z_markers_idx]
619-
if len(lws) > 1:
620-
lws = lws[z_markers_idx]
621-
vps = np.column_stack((vxs, vys))
622-
623-
fcs = mcolors.to_rgba_array(fcs, self._alpha)
624-
ecs = mcolors.to_rgba_array(ecs, self._alpha)
625-
626-
super().set_edgecolor(ecs)
627-
super().set_facecolor(fcs)
628-
super().set_sizes(sizes)
629-
super().set_linewidth(lws)
630-
631-
PathCollection.set_offsets(self, vps)
632622

633-
return np.min(vzs) if vzs.size else np.nan
623+
PathCollection.set_offsets(self, np.column_stack((vxs, vys)))
634624

625+
return np.min(vzs) if vzs.size else np.nan
635626

636-
def _update_scalarmappable(sm):
637-
"""
638-
Update a 3D ScalarMappable.
639-
640-
With ScalarMappable objects if the data, colormap, or norm are
641-
changed, we need to update the computed colors. This is handled
642-
by the base class method update_scalarmappable. This method works
643-
by detecting if work needs to be done, and if so stashing it on
644-
the ``self._facecolors`` attribute.
645-
646-
With 3D collections we internally sort the components so that
647-
things that should be "in front" are rendered later to simulate
648-
having a z-buffer (in addition to doing the projections). This is
649-
handled in the ``do_3d_projection`` methods which are called from the
650-
draw method of the 3D Axes. These methods:
651-
652-
- do the projection from 3D -> 2D
653-
- internally sort based on depth
654-
- stash the results of the above in the 2D analogs of state
655-
- return the z-depth of the whole artist
656-
657-
the last step is so that we can, at the Axes level, sort the children by
658-
depth.
659-
660-
The base `draw` method of the 2D artists unconditionally calls
661-
update_scalarmappable and rely on the method's internal caching logic to
662-
lazily evaluate.
663-
664-
These things together mean you can have the sequence of events:
665-
666-
- we create the artist, do the color mapping and stash the results
667-
in a 3D specific state.
668-
- change something about the ScalarMappable that marks it as in
669-
need of an update (`ScalarMappable.changed` and friends).
670-
- We call do_3d_projection and shuffle the stashed colors into the
671-
2D version of face colors
672-
- the draw method calls the update_scalarmappable method which
673-
overwrites our shuffled colors
674-
- we get a render that is wrong
675-
- if we re-render (either with a second save or implicitly via
676-
tight_layout / constrained_layout / bbox_inches='tight' (ex via
677-
inline's defaults)) we again shuffle the 3D colors
678-
- because the CM is not marked as changed update_scalarmappable is
679-
a no-op and we get a correct looking render.
680-
681-
This function is an internal helper to:
682-
683-
- sort out if we need to do the color mapping at all (has data!)
684-
- sort out if update_scalarmappable is going to be a no-op
685-
- copy the data over from the 2D -> 3D version
686-
687-
This must be called first thing in do_3d_projection to make sure that
688-
the correct colors get shuffled.
627+
def _maybe_depth_shade_and_sort_colors(self, color_array):
628+
color_array = (
629+
_zalpha(color_array, self._vzs)
630+
if self._vzs is not None and self._depthshade
631+
else color_array
632+
)
633+
if len(color_array) > 1:
634+
color_array = color_array[self._z_markers_idx]
635+
return mcolors.to_rgba_array(color_array, self._alpha)
689636

690-
Parameters
691-
----------
692-
sm : ScalarMappable
693-
The ScalarMappable to update and stash the 3D data from
637+
def get_facecolor(self):
638+
return self._maybe_depth_shade_and_sort_colors(super().get_facecolor())
694639

695-
"""
696-
if sm._A is None:
697-
return
698-
sm.update_scalarmappable()
699-
if sm._face_is_mapped:
700-
sm._facecolor3d = sm._facecolors
701-
elif sm._edge_is_mapped: # Should this be plain "if"?
702-
sm._edgecolor3d = sm._edgecolors
640+
def get_edgecolor(self):
641+
# We need this check here to make sure we do not double-apply the depth
642+
# based alpha shading when the edge color is "face" which means the
643+
# edge colour should be identical to the face colour.
644+
if cbook._str_equal(self._edgecolors, 'face'):
645+
return self.get_facecolor()
646+
return self._maybe_depth_shade_and_sort_colors(super().get_edgecolor())
703647

704648

705649
def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
@@ -725,6 +669,7 @@ def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
725669
elif isinstance(col, PatchCollection):
726670
col.__class__ = Patch3DCollection
727671
col._depthshade = depthshade
672+
col._in_draw = False
728673
col.set_3d_properties(zs, zdir)
729674

730675

@@ -839,9 +784,19 @@ def do_3d_projection(self, renderer=None):
839784
"""
840785
Perform the 3D projection for this object.
841786
"""
842-
# see _update_scalarmappable docstring for why this must be here
843-
_update_scalarmappable(self)
844-
787+
if self._A is not None:
788+
# force update of color mapping because we re-order them
789+
# below. If we do not do this here, the 2D draw will call
790+
# this, but we will never port the color mapped values back
791+
# to the 3D versions.
792+
#
793+
# We hold the 3D versions in a fixed order (the order the user
794+
# passed in) and sort the 2D version by view depth.
795+
self.update_scalarmappable()
796+
if self._face_is_mapped:
797+
self._facecolor3d = self._facecolors
798+
if self._edge_is_mapped:
799+
self._edgecolor3d = self._edgecolors
845800
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M)
846801
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
847802

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,3 +1510,16 @@ def test_computed_zorder():
15101510
zorder=4)
15111511
ax.view_init(azim=-20, elev=20)
15121512
ax.axis('off')
1513+
1514+
1515+
@image_comparison(baseline_images=['scatter_spiral.png'],
1516+
remove_text=True,
1517+
style='default')
1518+
def test_scatter_spiral():
1519+
fig = plt.figure()
1520+
ax = fig.add_subplot(projection='3d')
1521+
th = np.linspace(0, 2 * np.pi * 6, 256)
1522+
sc = ax.scatter(np.sin(th), np.cos(th), th, s=(1 + th * 5), c=th ** 2)
1523+
1524+
# force at least 1 draw!
1525+
fig.canvas.draw()

0 commit comments

Comments
 (0)