1313from matplotlib .transforms import Bbox
1414from matplotlib import collections
1515import numpy as np
16- from matplotlib .colors import Normalize , colorConverter
16+ from matplotlib .colors import Normalize , colorConverter , LightSource
1717
1818import art3d
1919import proj3d
@@ -37,6 +37,21 @@ class Axes3D(Axes):
3737 """
3838
3939 def __init__ (self , fig , rect = None , * args , ** kwargs ):
40+ '''
41+ Build an :class:`Axes3D` instance in
42+ :class:`~matplotlib.figure.Figure` *fig* with
43+ *rect=[left, bottom, width, height]* in
44+ :class:`~matplotlib.figure.Figure` coordinates
45+
46+ Optional keyword arguments:
47+
48+ ================ =========================================
49+ Keyword Description
50+ ================ =========================================
51+ *azim* Azimuthal viewing angle (default -60)
52+ *elev* Elevation viewing angle (default 30)
53+ '''
54+
4055 if rect is None :
4156 rect = [0.0 , 0.0 , 1.0 , 1.0 ]
4257 self .fig = fig
@@ -146,9 +161,12 @@ def draw(self, renderer):
146161 for i , (z , patch ) in enumerate (zlist ):
147162 patch .zorder = i
148163
149- self .w_xaxis .draw (renderer )
150- self .w_yaxis .draw (renderer )
151- self .w_zaxis .draw (renderer )
164+ axes = (self .w_xaxis , self .w_yaxis , self .w_zaxis )
165+ for ax in axes :
166+ ax .draw_pane (renderer )
167+ for ax in axes :
168+ ax .draw (renderer )
169+
152170 Axes .draw (self , renderer )
153171
154172 def get_axis_position (self ):
@@ -322,8 +340,9 @@ def cla(self):
322340 self .grid (rcParams ['axes3d.grid' ])
323341
324342 def _button_press (self , event ):
325- self .button_pressed = event .button
326- self .sx , self .sy = event .xdata , event .ydata
343+ if event .inaxes == self :
344+ self .button_pressed = event .button
345+ self .sx , self .sy = event .xdata , event .ydata
327346
328347 def _button_release (self , event ):
329348 self .button_pressed = None
@@ -565,6 +584,12 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
565584 *cstride* Array column stride (step size)
566585 *color* Color of the surface patches
567586 *cmap* A colormap for the surface patches.
587+ *facecolors* Face colors for the individual patches
588+ *norm* An instance of Normalize to map values to colors
589+ *vmin* Minimum value to map
590+ *vmax* Maximum value to map
591+ *shade* Whether to shade the facecolors, default:
592+ false when cmap specified, true otherwise
568593 ========== ================================================
569594 '''
570595
@@ -575,13 +600,28 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
575600 rstride = kwargs .pop ('rstride' , 10 )
576601 cstride = kwargs .pop ('cstride' , 10 )
577602
578- color = kwargs .pop ('color' , 'b' )
579- color = np .array (colorConverter .to_rgba (color ))
603+ if 'facecolors' in kwargs :
604+ fcolors = kwargs .pop ('facecolors' )
605+ else :
606+ color = np .array (colorConverter .to_rgba (kwargs .pop ('color' , 'b' )))
607+ fcolors = None
608+
580609 cmap = kwargs .get ('cmap' , None )
610+ norm = kwargs .pop ('norm' , None )
611+ vmin = kwargs .pop ('vmin' , None )
612+ vmax = kwargs .pop ('vmax' , None )
613+ linewidth = kwargs .get ('linewidth' , None )
614+ shade = kwargs .pop ('shade' , cmap is None )
615+ lightsource = kwargs .pop ('lightsource' , None )
616+
617+ # Shade the data
618+ if shade and cmap is not None and fcolors is not None :
619+ fcolors = self ._shade_colors_lightsource (Z , cmap , lightsource )
581620
582621 polys = []
583622 normals = []
584- avgz = []
623+ #colset contains the data for coloring: either average z or the facecolor
624+ colset = []
585625 for rs in np .arange (0 , rows - 1 , rstride ):
586626 for cs in np .arange (0 , cols - 1 , cstride ):
587627 ps = []
@@ -609,19 +649,38 @@ def plot_surface(self, X, Y, Z, *args, **kwargs):
609649 lastp = p
610650 avgzsum += p [2 ]
611651 polys .append (ps2 )
612- avgz .append (avgzsum / len (ps2 ))
613652
614- v1 = np .array (ps2 [0 ]) - np .array (ps2 [1 ])
615- v2 = np .array (ps2 [2 ]) - np .array (ps2 [0 ])
616- normals .append (np .cross (v1 , v2 ))
653+ if fcolors is not None :
654+ colset .append (fcolors [rs ][cs ])
655+ else :
656+ colset .append (avgzsum / len (ps2 ))
657+
658+ # Only need vectors to shade if no cmap
659+ if cmap is None and shade :
660+ v1 = np .array (ps2 [0 ]) - np .array (ps2 [1 ])
661+ v2 = np .array (ps2 [2 ]) - np .array (ps2 [0 ])
662+ normals .append (np .cross (v1 , v2 ))
617663
618664 polyc = art3d .Poly3DCollection (polys , * args , ** kwargs )
619- if cmap is not None :
620- polyc .set_array (np .array (avgz ))
621- polyc .set_linewidth (0 )
665+
666+ if fcolors is not None :
667+ if shade :
668+ colset = self ._shade_colors (colset , normals )
669+ polyc .set_facecolors (colset )
670+ polyc .set_edgecolors (colset )
671+ elif cmap :
672+ colset = np .array (colset )
673+ polyc .set_array (colset )
674+ if vmin is not None or vmax is not None :
675+ polyc .set_clim (vmin , vmax )
676+ if norm is not None :
677+ polyc .set_norm (norm )
622678 else :
623- colors = self ._shade_colors (color , normals )
624- polyc .set_facecolors (colors )
679+ if shade :
680+ colset = self ._shade_colors (color , normals )
681+ else :
682+ colset = color
683+ polyc .set_facecolors (colset )
625684
626685 self .add_collection (polyc )
627686 self .auto_scale_xyz (X , Y , Z , had_data )
@@ -643,24 +702,39 @@ def _generate_normals(self, polygons):
643702 return normals
644703
645704 def _shade_colors (self , color , normals ):
705+ '''
706+ Shade *color* using normal vectors given by *normals*.
707+ *color* can also be an array of the same length as *normals*.
708+ '''
709+
646710 shade = []
647711 for n in normals :
648- n = n / proj3d .mod (n ) * 5
712+ n = n / proj3d .mod (n )
649713 shade .append (np .dot (n , [- 1 , - 1 , 0.5 ]))
650714
651715 shade = np .array (shade )
652716 mask = ~ np .isnan (shade )
653717
654718 if len (shade [mask ]) > 0 :
655- norm = Normalize (min (shade [mask ]), max (shade [mask ]))
656- color = color .copy ()
657- color [3 ] = 1
658- colors = [color * (0.5 + norm (v ) * 0.5 ) for v in shade ]
719+ norm = Normalize (min (shade [mask ]), max (shade [mask ]))
720+ if art3d .iscolor (color ):
721+ color = color .copy ()
722+ color [3 ] = 1
723+ colors = [color * (0.5 + norm (v ) * 0.5 ) for v in shade ]
724+ else :
725+ colors = [np .array (colorConverter .to_rgba (c )) * \
726+ (0.5 + norm (v ) * 0.5 ) \
727+ for c , v in zip (color , shade )]
659728 else :
660- colors = color .copy ()
729+ colors = color .copy ()
661730
662731 return colors
663732
733+ def _shade_colors_lightsource (self , data , cmap , lightsource ):
734+ if lightsource is None :
735+ lightsource = LightSource (azdeg = 135 , altdeg = 55 )
736+ return lightsource .shade (data , cmap )
737+
664738 def plot_wireframe (self , X , Y , Z , * args , ** kwargs ):
665739 '''
666740 Plot a 3D wireframe.
0 commit comments