@@ -445,26 +445,6 @@ def draw(self, renderer):
445
445
self .offsetText .set_ha (align )
446
446
self .offsetText .draw (renderer )
447
447
448
- if self .axes ._draw_grid and len (ticks ):
449
- # Grid points where the planes meet
450
- xyz0 = np .tile (minmax , (len (ticks ), 1 ))
451
- xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
452
-
453
- # Grid lines go from the end of one plane through the plane
454
- # intersection (at xyz0) to the end of the other plane. The first
455
- # point (0) differs along dimension index-2 and the last (2) along
456
- # dimension index-1.
457
- lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
458
- lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
459
- lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
460
- self .gridlines .set_segments (lines )
461
- gridinfo = info ['grid' ]
462
- self .gridlines .set_color (gridinfo ['color' ])
463
- self .gridlines .set_linewidth (gridinfo ['linewidth' ])
464
- self .gridlines .set_linestyle (gridinfo ['linestyle' ])
465
- self .gridlines .do_3d_projection ()
466
- self .gridlines .draw (renderer )
467
-
468
448
# Draw ticks:
469
449
tickdir = self ._get_tickdir ()
470
450
tickdelta = deltas [tickdir ] if highs [tickdir ] else - deltas [tickdir ]
@@ -502,6 +482,45 @@ def draw(self, renderer):
502
482
renderer .close_group ('axis3d' )
503
483
self .stale = False
504
484
485
+ @artist .allow_rasterization
486
+ def draw_grid (self , renderer ):
487
+ if not self .axes ._draw_grid :
488
+ return
489
+
490
+ renderer .open_group ("grid3d" , gid = self .get_gid ())
491
+
492
+ ticks = self ._update_ticks ()
493
+ if len (ticks ):
494
+ # Get general axis information:
495
+ info = self ._axinfo
496
+ index = info ["i" ]
497
+
498
+ mins , maxs , _ , _ , _ , highs = self ._get_coord_info (renderer )
499
+
500
+ minmax = np .where (highs , maxs , mins )
501
+ maxmin = np .where (~ highs , maxs , mins )
502
+
503
+ # Grid points where the planes meet
504
+ xyz0 = np .tile (minmax , (len (ticks ), 1 ))
505
+ xyz0 [:, index ] = [tick .get_loc () for tick in ticks ]
506
+
507
+ # Grid lines go from the end of one plane through the plane
508
+ # intersection (at xyz0) to the end of the other plane. The first
509
+ # point (0) differs along dimension index-2 and the last (2) along
510
+ # dimension index-1.
511
+ lines = np .stack ([xyz0 , xyz0 , xyz0 ], axis = 1 )
512
+ lines [:, 0 , index - 2 ] = maxmin [index - 2 ]
513
+ lines [:, 2 , index - 1 ] = maxmin [index - 1 ]
514
+ self .gridlines .set_segments (lines )
515
+ gridinfo = info ['grid' ]
516
+ self .gridlines .set_color (gridinfo ['color' ])
517
+ self .gridlines .set_linewidth (gridinfo ['linewidth' ])
518
+ self .gridlines .set_linestyle (gridinfo ['linestyle' ])
519
+ self .gridlines .do_3d_projection ()
520
+ self .gridlines .draw (renderer )
521
+
522
+ renderer .close_group ('grid3d' )
523
+
505
524
# TODO: Get this to work (more) properly when mplot3d supports the
506
525
# transforms framework.
507
526
def get_tightbbox (self , renderer = None , * , for_layout_only = False ):
0 commit comments