Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 94cda01

Browse files
committed
mplot3d updates:
* Fix scatter markers * Add facecolor support for plot_surface * Fix XYZ-pane order drawing * Add examples (animations, colored surface) svn path=/trunk/matplotlib/; revision=8015
1 parent 0026439 commit 94cda01

10 files changed

Lines changed: 237 additions & 53 deletions

File tree

examples/mplot3d/bars3d_demo.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
for c, z in zip(['r', 'g', 'b', 'y'], [30, 20, 10, 0]):
88
xs = np.arange(20)
99
ys = np.random.rand(20)
10-
ax.bar(xs, ys, zs=z, zdir='y', color=c, alpha=0.8)
10+
11+
# You can provide either a single color or an array. To demonstrate this,
12+
# the first bar of each set will be colored cyan.
13+
cs = [c] * len(xs)
14+
cs[0] = 'c'
15+
ax.bar(xs, ys, zs=z, zdir='y', color=cs, alpha=0.8)
1116

1217
ax.set_xlabel('X')
1318
ax.set_ylabel('Y')

examples/mplot3d/hist3d_demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
dx = 0.5 * np.ones_like(zpos)
1717
dy = dx.copy()
1818
dz = hist.flatten()
19+
1920
ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color='b')
2021

2122
plt.show()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from mpl_toolkits.mplot3d import axes3d
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
plt.ion()
6+
7+
fig = plt.figure()
8+
ax = axes3d.Axes3D(fig)
9+
X, Y, Z = axes3d.get_test_data(0.1)
10+
ax.plot_wireframe(X, Y, Z, rstride=5, cstride=5)
11+
12+
for angle in range(0, 360):
13+
ax.view_init(30, angle)
14+
plt.draw()
15+

examples/mplot3d/scatter3d_demo.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22
from mpl_toolkits.mplot3d import Axes3D
33
import matplotlib.pyplot as plt
44

5-
65
def randrange(n, vmin, vmax):
76
return (vmax-vmin)*np.random.rand(n) + vmin
87

98
fig = plt.figure()
109
ax = Axes3D(fig)
1110
n = 100
12-
for c, zl, zh in [('r', -50, -25), ('b', -30, -5)]:
11+
for c, m, zl, zh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]:
1312
xs = randrange(n, 23, 32)
1413
ys = randrange(n, 0, 100)
1514
zs = randrange(n, zl, zh)
16-
ax.scatter(xs, ys, zs, c=c)
15+
ax.scatter(xs, ys, zs, c=c, marker=m)
1716

1817
ax.set_xlabel('X Label')
1918
ax.set_ylabel('Y Label')

examples/mplot3d/surface3d_demo.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from mpl_toolkits.mplot3d import Axes3D
22
from matplotlib import cm
3+
from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter
34
import matplotlib.pyplot as plt
45
import numpy as np
56

@@ -10,7 +11,14 @@
1011
X, Y = np.meshgrid(X, Y)
1112
R = np.sqrt(X**2 + Y**2)
1213
Z = np.sin(R)
13-
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet)
14+
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet,
15+
linewidth=0, antialiased=False)
16+
ax.set_zlim3d(-1.01, 1.01)
17+
18+
ax.w_zaxis.set_major_locator(LinearLocator(10))
19+
ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f'))
20+
21+
fig.colorbar(surf, shrink=0.5, aspect=5)
1422

1523
plt.show()
1624

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from mpl_toolkits.mplot3d import Axes3D
2+
from matplotlib import cm
3+
from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
7+
fig = plt.figure()
8+
ax = Axes3D(fig)
9+
X = np.arange(-5, 5, 0.25)
10+
xlen = len(X)
11+
Y = np.arange(-5, 5, 0.25)
12+
ylen = len(Y)
13+
X, Y = np.meshgrid(X, Y)
14+
R = np.sqrt(X**2 + Y**2)
15+
Z = np.sin(R)
16+
17+
colortuple = ('y', 'b')
18+
colors = np.empty(X.shape, dtype=str)
19+
for y in range(ylen):
20+
for x in range(xlen):
21+
colors[x, y] = colortuple[(x + y) % len(colortuple)]
22+
23+
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
24+
linewidth=0, antialiased=False)
25+
26+
ax.set_zlim3d(-1.01, 1.01)
27+
ax.w_zaxis.set_major_locator(LinearLocator(10))
28+
ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f'))
29+
30+
plt.show()
31+
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from mpl_toolkits.mplot3d import axes3d
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
import time
5+
6+
def generate(X, Y, phi):
7+
R = 1 - np.sqrt(X**2 + Y**2)
8+
return np.cos(2 * np.pi * X + phi) * R
9+
10+
plt.ion()
11+
fig = plt.figure()
12+
ax = axes3d.Axes3D(fig)
13+
14+
xs = np.linspace(-1, 1, 50)
15+
ys = np.linspace(-1, 1, 50)
16+
X, Y = np.meshgrid(xs, ys)
17+
Z = generate(X, Y, 0.0)
18+
19+
wframe = None
20+
tstart = time.time()
21+
for phi in np.linspace(0, 360 / 2 / np.pi, 100):
22+
23+
oldcol = wframe
24+
25+
Z = generate(X, Y, phi)
26+
wframe = ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2)
27+
28+
# Remove old line collection before drawing
29+
if oldcol is not None:
30+
ax.collections.remove(oldcol)
31+
32+
plt.draw()
33+
34+
print 'FPS: %f' % (100 / (time.time() - tstart))

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ class Patch3DCollection(PatchCollection):
274274

275275
def __init__(self, *args, **kwargs):
276276
PatchCollection.__init__(self, *args, **kwargs)
277+
self._old_draw = lambda x: PatchCollection.draw(self, x)
277278

278279
def set_3d_properties(self, zs, zdir):
279280
xs, ys = zip(*self.get_offsets())
@@ -293,10 +294,15 @@ def do_3d_projection(self, renderer):
293294
return min(vzs)
294295

295296
def draw(self, renderer):
296-
PatchCollection.draw(self, renderer)
297+
self._old_draw(renderer)
297298

298299
def patch_collection_2d_to_3d(col, zs=0, zdir='z'):
299300
"""Convert a PatchCollection to a Patch3DCollection object."""
301+
302+
# The tricky part here is that there are several classes that are
303+
# derived from PatchCollection. We need to use the right draw method.
304+
col._old_draw = col.draw
305+
300306
col.__class__ = Patch3DCollection
301307
col.set_3d_properties(zs, zdir)
302308

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 98 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from matplotlib.transforms import Bbox
1414
from matplotlib import collections
1515
import numpy as np
16-
from matplotlib.colors import Normalize, colorConverter
16+
from matplotlib.colors import Normalize, colorConverter, LightSource
1717

1818
import art3d
1919
import proj3d
@@ -37,6 +37,21 @@ class Axes3D(Axes):
3737
"""
3838

3939
def __init__(self, fig, rect=None, *args, **kwargs):
40+
'''
41+
Build an :class:`Axes3D` instance in
42+
:class:`~matplotlib.figure.Figure` *fig* with
43+
*rect=[left, bottom, width, height]* in
44+
:class:`~matplotlib.figure.Figure` coordinates
45+
46+
Optional keyword arguments:
47+
48+
================ =========================================
49+
Keyword Description
50+
================ =========================================
51+
*azim* Azimuthal viewing angle (default -60)
52+
*elev* Elevation viewing angle (default 30)
53+
'''
54+
4055
if rect is None:
4156
rect = [0.0, 0.0, 1.0, 1.0]
4257
self.fig = fig
@@ -146,9 +161,12 @@ def draw(self, renderer):
146161
for i, (z, patch) in enumerate(zlist):
147162
patch.zorder = i
148163

149-
self.w_xaxis.draw(renderer)
150-
self.w_yaxis.draw(renderer)
151-
self.w_zaxis.draw(renderer)
164+
axes = (self.w_xaxis, self.w_yaxis, self.w_zaxis)
165+
for ax in axes:
166+
ax.draw_pane(renderer)
167+
for ax in axes:
168+
ax.draw(renderer)
169+
152170
Axes.draw(self, renderer)
153171

154172
def get_axis_position(self):
@@ -322,8 +340,9 @@ def cla(self):
322340
self.grid(rcParams['axes3d.grid'])
323341

324342
def _button_press(self, event):
325-
self.button_pressed = event.button
326-
self.sx, self.sy = event.xdata, event.ydata
343+
if event.inaxes == self:
344+
self.button_pressed = event.button
345+
self.sx, self.sy = event.xdata, event.ydata
327346

328347
def _button_release(self, event):
329348
self.button_pressed = None
@@ -565,6 +584,12 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
565584
*cstride* Array column stride (step size)
566585
*color* Color of the surface patches
567586
*cmap* A colormap for the surface patches.
587+
*facecolors* Face colors for the individual patches
588+
*norm* An instance of Normalize to map values to colors
589+
*vmin* Minimum value to map
590+
*vmax* Maximum value to map
591+
*shade* Whether to shade the facecolors, default:
592+
false when cmap specified, true otherwise
568593
========== ================================================
569594
'''
570595

@@ -575,13 +600,28 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
575600
rstride = kwargs.pop('rstride', 10)
576601
cstride = kwargs.pop('cstride', 10)
577602

578-
color = kwargs.pop('color', 'b')
579-
color = np.array(colorConverter.to_rgba(color))
603+
if 'facecolors' in kwargs:
604+
fcolors = kwargs.pop('facecolors')
605+
else:
606+
color = np.array(colorConverter.to_rgba(kwargs.pop('color', 'b')))
607+
fcolors = None
608+
580609
cmap = kwargs.get('cmap', None)
610+
norm = kwargs.pop('norm', None)
611+
vmin = kwargs.pop('vmin', None)
612+
vmax = kwargs.pop('vmax', None)
613+
linewidth = kwargs.get('linewidth', None)
614+
shade = kwargs.pop('shade', cmap is None)
615+
lightsource = kwargs.pop('lightsource', None)
616+
617+
# Shade the data
618+
if shade and cmap is not None and fcolors is not None:
619+
fcolors = self._shade_colors_lightsource(Z, cmap, lightsource)
581620

582621
polys = []
583622
normals = []
584-
avgz = []
623+
#colset contains the data for coloring: either average z or the facecolor
624+
colset = []
585625
for rs in np.arange(0, rows-1, rstride):
586626
for cs in np.arange(0, cols-1, cstride):
587627
ps = []
@@ -609,19 +649,38 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
609649
lastp = p
610650
avgzsum += p[2]
611651
polys.append(ps2)
612-
avgz.append(avgzsum / len(ps2))
613652

614-
v1 = np.array(ps2[0]) - np.array(ps2[1])
615-
v2 = np.array(ps2[2]) - np.array(ps2[0])
616-
normals.append(np.cross(v1, v2))
653+
if fcolors is not None:
654+
colset.append(fcolors[rs][cs])
655+
else:
656+
colset.append(avgzsum / len(ps2))
657+
658+
# Only need vectors to shade if no cmap
659+
if cmap is None and shade:
660+
v1 = np.array(ps2[0]) - np.array(ps2[1])
661+
v2 = np.array(ps2[2]) - np.array(ps2[0])
662+
normals.append(np.cross(v1, v2))
617663

618664
polyc = art3d.Poly3DCollection(polys, *args, **kwargs)
619-
if cmap is not None:
620-
polyc.set_array(np.array(avgz))
621-
polyc.set_linewidth(0)
665+
666+
if fcolors is not None:
667+
if shade:
668+
colset = self._shade_colors(colset, normals)
669+
polyc.set_facecolors(colset)
670+
polyc.set_edgecolors(colset)
671+
elif cmap:
672+
colset = np.array(colset)
673+
polyc.set_array(colset)
674+
if vmin is not None or vmax is not None:
675+
polyc.set_clim(vmin, vmax)
676+
if norm is not None:
677+
polyc.set_norm(norm)
622678
else:
623-
colors = self._shade_colors(color, normals)
624-
polyc.set_facecolors(colors)
679+
if shade:
680+
colset = self._shade_colors(color, normals)
681+
else:
682+
colset = color
683+
polyc.set_facecolors(colset)
625684

626685
self.add_collection(polyc)
627686
self.auto_scale_xyz(X, Y, Z, had_data)
@@ -643,24 +702,39 @@ def _generate_normals(self, polygons):
643702
return normals
644703

645704
def _shade_colors(self, color, normals):
705+
'''
706+
Shade *color* using normal vectors given by *normals*.
707+
*color* can also be an array of the same length as *normals*.
708+
'''
709+
646710
shade = []
647711
for n in normals:
648-
n = n / proj3d.mod(n) * 5
712+
n = n / proj3d.mod(n)
649713
shade.append(np.dot(n, [-1, -1, 0.5]))
650714

651715
shade = np.array(shade)
652716
mask = ~np.isnan(shade)
653717

654718
if len(shade[mask]) > 0:
655-
norm = Normalize(min(shade[mask]), max(shade[mask]))
656-
color = color.copy()
657-
color[3] = 1
658-
colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
719+
norm = Normalize(min(shade[mask]), max(shade[mask]))
720+
if art3d.iscolor(color):
721+
color = color.copy()
722+
color[3] = 1
723+
colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
724+
else:
725+
colors = [np.array(colorConverter.to_rgba(c)) * \
726+
(0.5 + norm(v) * 0.5) \
727+
for c, v in zip(color, shade)]
659728
else:
660-
colors = color.copy()
729+
colors = color.copy()
661730

662731
return colors
663732

733+
def _shade_colors_lightsource(self, data, cmap, lightsource):
734+
if lightsource is None:
735+
lightsource = LightSource(azdeg=135, altdeg=55)
736+
return lightsource.shade(data, cmap)
737+
664738
def plot_wireframe(self, X, Y, Z, *args, **kwargs):
665739
'''
666740
Plot a 3D wireframe.

0 commit comments

Comments
 (0)