@@ -336,7 +336,10 @@ def draw_pane(self, renderer):
336336 renderer .close_group ('pane3d' )
337337
338338 @artist .allow_rasterization
339- def draw (self , renderer ):
339+ def draw (self , renderer , grid = True ):
340+ if grid and self .axes ._draw_grid :
341+ self .draw_grid (renderer )
342+
340343 self .label ._transform = self .axes .transData
341344 renderer .open_group ("axis3d" , gid = self .get_gid ())
342345
@@ -462,26 +465,6 @@ def draw(self, renderer):
462465 self .offsetText .set_ha (align )
463466 self .offsetText .draw (renderer )
464467
465- if self .axes ._draw_grid and len (ticks ):
466- # Grid points where the planes meet
467- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
468- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
469-
470- # Grid lines go from the end of one plane through the plane
471- # intersection (at xyz0) to the end of the other plane. The first
472- # point (0) differs along dimension index-2 and the last (2) along
473- # dimension index-1.
474- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
475- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
476- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
477- self .gridlines .set_segments (lines )
478- gridinfo = info ['grid' ]
479- self .gridlines .set_color (gridinfo ['color' ])
480- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
481- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
482- self .gridlines .do_3d_projection ()
483- self .gridlines .draw (renderer )
484-
485468 # Draw ticks:
486469 tickdir = self ._get_tickdir ()
487470 tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -519,6 +502,45 @@ def draw(self, renderer):
519502 renderer .close_group ('axis3d' )
520503 self .stale = False
521504
505+ @artist .allow_rasterization
506+ def draw_grid (self , renderer ):
507+ self .label ._transform = self .axes .transData
508+ renderer .open_group ("grid3d" , gid = self .get_gid ())
509+
510+ ticks = self ._update_ticks ()
511+
512+ # Get general axis information:
513+ info = self ._axinfo
514+ index = info ["i" ]
515+
516+ mins , maxs , tc , highs = self ._get_coord_info ()
517+
518+ minmax = np .where (highs , maxs , mins )
519+ maxmin = np .where (~ highs , maxs , mins )
520+
521+ if self .axes ._draw_grid and len (ticks ):
522+ # Grid points where the planes meet
523+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
524+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
525+
526+ # Grid lines go from the end of one plane through the plane
527+ # intersection (at xyz0) to the end of the other plane. The first
528+ # point (0) differs along dimension index-2 and the last (2) along
529+ # dimension index-1.
530+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
531+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
532+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
533+ self .gridlines .set_segments (lines )
534+ gridinfo = info ['grid' ]
535+ self .gridlines .set_color (gridinfo ['color' ])
536+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
537+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
538+ self .gridlines .do_3d_projection ()
539+ self .gridlines .draw (renderer )
540+
541+ renderer .close_group ('grid3d' )
542+
543+
522544 # TODO: Get this to work (more) properly when mplot3d supports the
523545 # transforms framework.
524546 def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments