Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b4296f4

Browse files
committed
More broadcasting in mplot3d.
1 parent d5e7341 commit b4296f4

File tree

2 files changed

+26
-45
lines changed

2 files changed

+26
-45
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,9 @@ def line_2d_to_3d(line, zs=0, zdir='z'):
149149
def path_to_3d_segment(path, zs=0, zdir='z'):
150150
'''Convert a path to a 3D segment.'''
151151

152-
if not iterable(zs):
153-
zs = np.ones(len(path)) * zs
154-
155-
seg = []
152+
zs = _backports_np.broadcast_to(zs, len(path))
156153
pathsegs = path.iter_segments(simplify=False, curves=False)
157-
for (((x, y), code), z) in zip(pathsegs, zs):
158-
seg.append((x, y, z))
154+
seg = [(x, y, z) for (((x, y), code), z) in zip(pathsegs, zs)]
159155
seg3d = [juggle_axes(x, y, z, zdir) for (x, y, z) in seg]
160156
return seg3d
161157

@@ -165,21 +161,16 @@ def paths_to_3d_segments(paths, zs=0, zdir='z'):
165161
Convert paths from a collection object to 3D segments.
166162
'''
167163

168-
if not iterable(zs):
169-
zs = np.ones(len(paths)) * zs
170-
171-
segments = []
172-
for path, pathz in zip(paths, zs):
173-
segments.append(path_to_3d_segment(path, pathz, zdir))
174-
return segments
164+
zs = _backports_np.broadcast_to(zs, len(paths))
165+
segs = [path_to_3d_segment(path, pathz, zdir)
166+
for path, pathz in zip(paths, zs)]
167+
return segs
175168

176169

177170
def path_to_3d_segment_with_codes(path, zs=0, zdir='z'):
178171
'''Convert a path to a 3D segment with path codes.'''
179172

180-
if not iterable(zs):
181-
zs = np.ones(len(path)) * zs
182-
173+
zs = _backports_np.broadcast_to(zs, len(path))
183174
seg = []
184175
codes = []
185176
pathsegs = path.iter_segments(simplify=False, curves=False)
@@ -195,9 +186,7 @@ def paths_to_3d_segments_with_codes(paths, zs=0, zdir='z'):
195186
Convert paths from a collection object to 3D segments with path codes.
196187
'''
197188

198-
if not iterable(zs):
199-
zs = np.ones(len(paths)) * zs
200-
189+
zs = _backports_np.broadcast_to(zs, len(paths))
201190
segments = []
202191
codes_list = []
203192
for path, pathz in zip(paths, zs):
@@ -271,11 +260,9 @@ def __init__(self, *args, **kwargs):
271260
self.set_3d_properties(zs, zdir)
272261

273262
def set_3d_properties(self, verts, zs=0, zdir='z'):
274-
if not iterable(zs):
275-
zs = np.ones(len(verts)) * zs
276-
277-
self._segment3d = [juggle_axes(x, y, z, zdir) \
278-
for ((x, y), z) in zip(verts, zs)]
263+
zs = _backports_np.broadcast_to(zs, len(verts))
264+
self._segment3d = [juggle_axes(x, y, z, zdir)
265+
for ((x, y), z) in zip(verts, zs)]
279266
self._facecolor3d = Patch.get_facecolor(self)
280267

281268
def get_path(self):

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,23 @@
1111
"""
1212
from __future__ import (absolute_import, division, print_function,
1313
unicode_literals)
14-
import math
1514

1615
import six
1716
from six.moves import map, xrange, zip, reduce
1817

18+
import math
1919
import warnings
2020

2121
import numpy as np
22-
import matplotlib.axes as maxes
22+
23+
from matplotlib import (
24+
axes as maxes, cbook, collections as mcoll, colors as mcolors, docstring,
25+
scale as mscale, transforms as mtransforms)
26+
from matplotlib._backports import numpy as _backports_np
2327
from matplotlib.axes import Axes, rcParams
24-
from matplotlib import cbook
25-
import matplotlib.transforms as mtransforms
28+
from matplotlib.colors import Normalize, LightSource
2629
from matplotlib.transforms import Bbox
27-
import matplotlib.collections as mcoll
28-
from matplotlib import docstring
29-
import matplotlib.scale as mscale
3030
from matplotlib.tri.triangulation import Triangulation
31-
from matplotlib import colors as mcolors
32-
from matplotlib.colors import Normalize, LightSource
3331

3432
from . import art3d
3533
from . import proj3d
@@ -1536,8 +1534,7 @@ def plot(self, xs, ys, *args, **kwargs):
15361534
zdir = kwargs.pop('zdir', 'z')
15371535

15381536
# Match length
1539-
if not cbook.iterable(zs):
1540-
zs = np.ones(len(xs)) * zs
1537+
zs = _backports_np.broadcast_to(zs, len(xs))
15411538

15421539
lines = super(Axes3D, self).plot(xs, ys, *args, **kwargs)
15431540
for line in lines:
@@ -2332,18 +2329,16 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,
23322329

23332330
had_data = self.has_data()
23342331

2335-
xs, ys, zs = np.broadcast_arrays(*map(np.ma.ravel, [xs, ys, zs]))
2332+
xs, ys, zs = np.broadcast_arrays(
2333+
*[np.ravel(np.ma.filled(t, np.nan)) for t in [xs, ys, zs]])
23362334
s = np.ma.ravel(s) # This doesn't have to match x, y in size.
23372335

23382336
xs, ys, zs, s, c = cbook.delete_masked_points(xs, ys, zs, s, c)
23392337

2340-
patches = super(Axes3D, self).scatter(xs, ys, s=s, c=c, *args,
2341-
**kwargs)
2342-
if not cbook.iterable(zs):
2343-
is_2d = True
2344-
zs = np.ones(len(xs)) * zs
2345-
else:
2346-
is_2d = False
2338+
patches = super(Axes3D, self).scatter(
2339+
xs, ys, s=s, c=c, *args, **kwargs)
2340+
is_2d = not cbook.iterable(zs)
2341+
zs = _backports_np.broadcast_to(zs, len(xs))
23472342
art3d.patch_collection_2d_to_3d(patches, zs=zs, zdir=zdir,
23482343
depthshade=depthshade)
23492344

@@ -2382,8 +2377,7 @@ def bar(self, left, height, zs=0, zdir='z', *args, **kwargs):
23822377

23832378
patches = super(Axes3D, self).bar(left, height, *args, **kwargs)
23842379

2385-
if not cbook.iterable(zs):
2386-
zs = np.ones(len(left)) * zs
2380+
zs = _backports_np.broadcast_to(zs, len(left))
23872381

23882382
verts = []
23892383
verts_zs = []

0 commit comments

Comments
 (0)