@@ -336,7 +336,10 @@ def draw_pane(self, renderer):
336
336
renderer .close_group ('pane3d' )
337
337
338
338
@artist .allow_rasterization
339
- def draw (self , renderer ):
339
+ def draw (self , renderer , grid = True ):
340
+ if grid and self .axes ._draw_grid :
341
+ self .draw_grid (renderer )
342
+
340
343
self .label ._transform = self .axes .transData
341
344
renderer .open_group ("axis3d" , gid = self .get_gid ())
342
345
@@ -462,26 +465,6 @@ def draw(self, renderer):
462
465
self .offsetText .set_ha (align )
463
466
self .offsetText .draw (renderer )
464
467
465
- if self .axes ._draw_grid and len (ticks ):
466
- # Grid points where the planes meet
467
- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
468
- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
469
-
470
- # Grid lines go from the end of one plane through the plane
471
- # intersection (at xyz0) to the end of the other plane. The first
472
- # point (0) differs along dimension index-2 and the last (2) along
473
- # dimension index-1.
474
- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
475
- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
476
- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
477
- self .gridlines .set_segments (lines )
478
- gridinfo = info ['grid' ]
479
- self .gridlines .set_color (gridinfo ['color' ])
480
- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
481
- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
482
- self .gridlines .do_3d_projection ()
483
- self .gridlines .draw (renderer )
484
-
485
468
# Draw ticks:
486
469
tickdir = self ._get_tickdir ()
487
470
tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -519,6 +502,45 @@ def draw(self, renderer):
519
502
renderer .close_group ('axis3d' )
520
503
self .stale = False
521
504
505
+ @artist .allow_rasterization
506
+ def draw_grid (self , renderer ):
507
+ self .label ._transform = self .axes .transData
508
+ renderer .open_group ("grid3d" , gid = self .get_gid ())
509
+
510
+ ticks = self ._update_ticks ()
511
+
512
+ # Get general axis information:
513
+ info = self ._axinfo
514
+ index = info ["i" ]
515
+
516
+ mins , maxs , tc , highs = self ._get_coord_info ()
517
+
518
+ minmax = np .where (highs , maxs , mins )
519
+ maxmin = np .where (~ highs , maxs , mins )
520
+
521
+ if self .axes ._draw_grid and len (ticks ):
522
+ # Grid points where the planes meet
523
+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
524
+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
525
+
526
+ # Grid lines go from the end of one plane through the plane
527
+ # intersection (at xyz0) to the end of the other plane. The first
528
+ # point (0) differs along dimension index-2 and the last (2) along
529
+ # dimension index-1.
530
+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
531
+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
532
+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
533
+ self .gridlines .set_segments (lines )
534
+ gridinfo = info ['grid' ]
535
+ self .gridlines .set_color (gridinfo ['color' ])
536
+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
537
+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
538
+ self .gridlines .do_3d_projection ()
539
+ self .gridlines .draw (renderer )
540
+
541
+ renderer .close_group ('grid3d' )
542
+
543
+
522
544
# TODO: Get this to work (more) properly when mplot3d supports the
523
545
# transforms framework.
524
546
def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments