Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 56f5c0e

Browse files
Change draw order so that the 3D axis spines are not blocked by gridlines
1 parent 8b58763 commit 56f5c0e

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed

lib/matplotlib/artist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def allow_rasterization(draw):
5252
"""
5353

5454
@wraps(draw)
55-
def draw_wrapper(artist, renderer):
55+
def draw_wrapper(artist, renderer, *args, **kwargs):
5656
try:
5757
if artist.get_rasterized():
5858
if renderer._raster_depth == 0 and not renderer._rasterizing:
@@ -69,7 +69,7 @@ def draw_wrapper(artist, renderer):
6969
if artist.get_agg_filter() is not None:
7070
renderer.start_filter()
7171

72-
return draw(artist, renderer)
72+
return draw(artist, renderer, *args, **kwargs)
7373
finally:
7474
if artist.get_agg_filter() is not None:
7575
renderer.stop_filter(artist.get_agg_filter())

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,12 @@ def draw(self, renderer):
489489
# Draw panes first
490490
for axis in self._axis_map.values():
491491
axis.draw_pane(renderer)
492+
# Then gridlines
493+
for axis in self._axis_map.values():
494+
axis.draw_grid(renderer)
492495
# Then axes
493496
for axis in self._axis_map.values():
494-
axis.draw(renderer)
497+
axis.draw(renderer, grid=False)
495498

496499
# Then rest
497500
super().draw(renderer)

lib/mpl_toolkits/mplot3d/axis3d.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,10 @@ def draw_pane(self, renderer):
336336
renderer.close_group('pane3d')
337337

338338
@artist.allow_rasterization
339-
def draw(self, renderer):
339+
def draw(self, renderer, grid=True):
340+
if grid and self.axes._draw_grid:
341+
self.draw_grid(renderer)
342+
340343
self.label._transform = self.axes.transData
341344
renderer.open_group("axis3d", gid=self.get_gid())
342345

@@ -462,26 +465,6 @@ def draw(self, renderer):
462465
self.offsetText.set_ha(align)
463466
self.offsetText.draw(renderer)
464467

465-
if self.axes._draw_grid and len(ticks):
466-
# Grid points where the planes meet
467-
xyz0 = np.tile(minmax, (len(ticks), 1))
468-
xyz0[:, index] = [tick.get_loc() for tick in ticks]
469-
470-
# Grid lines go from the end of one plane through the plane
471-
# intersection (at xyz0) to the end of the other plane. The first
472-
# point (0) differs along dimension index-2 and the last (2) along
473-
# dimension index-1.
474-
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
475-
lines[:, 0, index - 2] = maxmin[index - 2]
476-
lines[:, 2, index - 1] = maxmin[index - 1]
477-
self.gridlines.set_segments(lines)
478-
gridinfo = info['grid']
479-
self.gridlines.set_color(gridinfo['color'])
480-
self.gridlines.set_linewidth(gridinfo['linewidth'])
481-
self.gridlines.set_linestyle(gridinfo['linestyle'])
482-
self.gridlines.do_3d_projection()
483-
self.gridlines.draw(renderer)
484-
485468
# Draw ticks:
486469
tickdir = self._get_tickdir()
487470
tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir]
@@ -519,6 +502,45 @@ def draw(self, renderer):
519502
renderer.close_group('axis3d')
520503
self.stale = False
521504

505+
@artist.allow_rasterization
506+
def draw_grid(self, renderer):
507+
self.label._transform = self.axes.transData
508+
renderer.open_group("grid3d", gid=self.get_gid())
509+
510+
ticks = self._update_ticks()
511+
512+
# Get general axis information:
513+
info = self._axinfo
514+
index = info["i"]
515+
516+
mins, maxs, tc, highs = self._get_coord_info()
517+
518+
minmax = np.where(highs, maxs, mins)
519+
maxmin = np.where(~highs, maxs, mins)
520+
521+
if self.axes._draw_grid and len(ticks):
522+
# Grid points where the planes meet
523+
xyz0 = np.tile(minmax, (len(ticks), 1))
524+
xyz0[:, index] = [tick.get_loc() for tick in ticks]
525+
526+
# Grid lines go from the end of one plane through the plane
527+
# intersection (at xyz0) to the end of the other plane. The first
528+
# point (0) differs along dimension index-2 and the last (2) along
529+
# dimension index-1.
530+
lines = np.stack([xyz0, xyz0, xyz0], axis=1)
531+
lines[:, 0, index - 2] = maxmin[index - 2]
532+
lines[:, 2, index - 1] = maxmin[index - 1]
533+
self.gridlines.set_segments(lines)
534+
gridinfo = info['grid']
535+
self.gridlines.set_color(gridinfo['color'])
536+
self.gridlines.set_linewidth(gridinfo['linewidth'])
537+
self.gridlines.set_linestyle(gridinfo['linestyle'])
538+
self.gridlines.do_3d_projection()
539+
self.gridlines.draw(renderer)
540+
541+
renderer.close_group('grid3d')
542+
543+
522544
# TODO: Get this to work (more) properly when mplot3d supports the
523545
# transforms framework.
524546
def get_tightbbox(self, renderer=None, *, for_layout_only=False):

0 commit comments

Comments
 (0)