Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 647a966

Browse files
authored
Merge pull request #17851 from thangleiter/fix_add_collection3d_issues
Fix Axes3D.add_collection3d issues
2 parents 7b56226 + f24663f commit 647a966

File tree

5 files changed

+62
-2
lines changed

5 files changed

+62
-2
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,12 @@ def _path_to_3d_segment(path, zs=0, zdir='z'):
207207
def _paths_to_3d_segments(paths, zs=0, zdir='z'):
208208
"""Convert paths from a collection object to 3D segments."""
209209

210-
zs = np.broadcast_to(zs, len(paths))
210+
if not np.iterable(zs):
211+
zs = np.broadcast_to(zs, len(paths))
212+
else:
213+
if len(zs) != len(paths):
214+
raise ValueError('Number of z-coordinates does not match paths.')
215+
211216
segs = [_path_to_3d_segment(path, pathz, zdir)
212217
for path, pathz in zip(paths, zs)]
213218
return segs

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2250,7 +2250,8 @@ def add_collection3d(self, col, zs=0, zdir='z'):
22502250
art3d.patch_collection_2d_to_3d(col, zs=zs, zdir=zdir)
22512251
col.set_sort_zpos(zsortval)
22522252

2253-
super().add_collection(col)
2253+
collection = super().add_collection(col)
2254+
return collection
22542255

22552256
def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,
22562257
*args, **kwargs):

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,60 @@ def test_poly3dcollection_alpha():
543543
ax.add_collection3d(c2)
544544

545545

546+
@mpl3d_image_comparison(['add_collection3d_zs_array.png'])
547+
def test_add_collection3d_zs_array():
548+
theta = np.linspace(-4 * np.pi, 4 * np.pi, 100)
549+
z = np.linspace(-2, 2, 100)
550+
r = z**2 + 1
551+
x = r * np.sin(theta)
552+
y = r * np.cos(theta)
553+
554+
points = np.column_stack([x, y, z]).reshape(-1, 1, 3)
555+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
556+
557+
fig = plt.figure()
558+
ax = fig.gca(projection='3d')
559+
560+
norm = plt.Normalize(0, 2*np.pi)
561+
# 2D LineCollection from x & y values
562+
lc = LineCollection(segments[:, :, :2], cmap='twilight', norm=norm)
563+
lc.set_array(np.mod(theta, 2*np.pi))
564+
# Add 2D collection at z values to ax
565+
line = ax.add_collection3d(lc, zs=segments[:, :, 2])
566+
567+
assert line is not None
568+
569+
ax.set_xlim(-5, 5)
570+
ax.set_ylim(-4, 6)
571+
ax.set_zlim(-2, 2)
572+
573+
574+
@mpl3d_image_comparison(['add_collection3d_zs_scalar.png'])
575+
def test_add_collection3d_zs_scalar():
576+
theta = np.linspace(0, 2 * np.pi, 100)
577+
z = 1
578+
r = z**2 + 1
579+
x = r * np.sin(theta)
580+
y = r * np.cos(theta)
581+
582+
points = np.column_stack([x, y]).reshape(-1, 1, 2)
583+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
584+
585+
fig = plt.figure()
586+
ax = fig.gca(projection='3d')
587+
588+
norm = plt.Normalize(0, 2*np.pi)
589+
lc = LineCollection(segments, cmap='twilight', norm=norm)
590+
lc.set_array(theta)
591+
line = ax.add_collection3d(lc, zs=z)
592+
593+
assert line is not None
594+
595+
ax.set_xlim(-5, 5)
596+
ax.set_ylim(-4, 6)
597+
ax.set_zlim(0, 2)
598+
599+
546600
@mpl3d_image_comparison(['axes3d_labelpad.png'], remove_text=False)
547601
def test_axes3d_labelpad():
548602
fig = plt.figure()

0 commit comments

Comments
 (0)