diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 7d0a5ae009c4..01fe7af5696b 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -501,6 +501,9 @@ def draw(self, renderer): # Then gridlines for axis in self._axis_map.values(): axis.draw_grid(renderer) + # Then minor gridlines + for axis in self._axis_map.values(): + axis.draw_minor_grid(renderer) # Then axes, labels, text, and ticks for axis in self._axis_map.values(): axis.draw(renderer) @@ -1562,6 +1565,9 @@ def clear(self): self._view_margin = 1/48 # default value to match mpl3.8 self.autoscale_view() + self._draw_minor_grid = False + self._minor_grid_kwargs = {} + self.grid(mpl.rcParams['axes3d.grid']) self.grid(mpl.rcParams['axes3d.grid']) def _button_press(self, event): @@ -2055,21 +2061,31 @@ def get_zlabel(self): get_frame_on = None set_frame_on = None - def grid(self, visible=True, **kwargs): - """ - Set / unset 3D grid. - - .. note:: - - Currently, this function does not behave the same as - `.axes.Axes.grid`, but it is intended to eventually support that - behavior. - """ - # TODO: Operate on each axes separately - if len(kwargs): - visible = True - self._draw_grid = visible - self.stale = True + def grid(self, visible=True, which='major', axis='both', **kwargs): + """ + Set / unset 3D grid. + + .. note:: + Currently, this function does not behave the same as + :meth:`matplotlib.axes.Axes.grid`, but it is intended to + eventually support that behavior. + + Parameters + ---------- + visible : bool + which : {'major', 'minor', 'both'} + axis : {'both', 'x', 'y', 'z'} + **kwargs : Line properties forwarded to the grid lines. + """ + if len(kwargs): + visible = True + if which in ('major', 'both'): + self._draw_grid = bool(visible) + self._major_grid_kwargs = kwargs + if which in ('minor', 'both'): + self._draw_minor_grid = bool(visible) + self._minor_grid_kwargs = kwargs + self.stale = True def tick_params(self, axis='both', **kwargs): """ diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index 0ac2e50b1a1a..e3f5f8f52141 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -160,6 +160,8 @@ def _init3d(self): self.axes._set_artist_props(self.pane) self.gridlines = art3d.Line3DCollection([]) self.axes._set_artist_props(self.gridlines) + self.minor_gridlines = art3d.Line3DCollection([]) + self.axes._set_artist_props(self.minor_gridlines) self.axes._set_artist_props(self.label) self.axes._set_artist_props(self.offsetText) # Need to be able to place the label at the correct location @@ -678,6 +680,59 @@ def draw_grid(self, renderer): renderer.close_group('grid3d') + @artist.allow_rasterization + def draw_minor_grid(self, renderer): + if not getattr(self.axes, '_draw_minor_grid', False): + return + + renderer.open_group("minor_grid3d", gid=self.get_gid()) + + minor_locs = self.get_minorticklocs() + if len(minor_locs): + info = self._axinfo + index = info["i"] + + mins, maxs, tc, highs = self._get_coord_info() + xlim, ylim, zlim = (self.axes.get_xbound(), + self.axes.get_ybound(), + self.axes.get_zbound()) + data_mins = np.array([xlim[0], ylim[0], zlim[0]]) + data_maxs = np.array([xlim[1], ylim[1], zlim[1]]) + minmax = np.where(highs, data_maxs, data_mins) + maxmin = np.where(~highs, data_maxs, data_mins) + + # Filter to ticks within view limits + vmin, vmax = self.get_view_interval() + minor_locs = [t for t in minor_locs if vmin <= t <= vmax] + + if len(minor_locs): + xyz0 = np.tile(minmax, (len(minor_locs), 1)) + xyz0[:, index] = minor_locs + + lines = np.stack([xyz0, xyz0, xyz0], axis=1) + lines[:, 0, index - 2] = maxmin[index - 2] + lines[:, 2, index - 1] = maxmin[index - 1] + + self.minor_gridlines.set_segments(lines) + + # Default minor style: thinner and more transparent than major + gridinfo = info['grid'] + minor_kw = { + 'color': gridinfo['color'], + 'linewidth': gridinfo['linewidth'] * 0.5, + 'linestyle': gridinfo['linestyle'], + } + # Apply any user overrides from ax.grid(which='minor', ...) + minor_kw.update(getattr(self.axes, '_minor_grid_kwargs', {})) + + self.minor_gridlines.set_color(minor_kw['color']) + self.minor_gridlines.set_linewidth(minor_kw['linewidth']) + self.minor_gridlines.set_linestyle(minor_kw['linestyle']) + self.minor_gridlines.do_3d_projection() + self.minor_gridlines.draw(renderer) + + renderer.close_group('minor_grid3d') + # TODO: Get this to work (more) properly when mplot3d supports the # transforms framework. def get_tightbbox(self, renderer=None, *, for_layout_only=False): diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index ac0168ce775e..ea0098dcb5e9 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -50,6 +50,25 @@ def test_grid_off(): ax.grid(False) +def test_minor_grid_3d(): + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter([1, 2, 3], [1, 2, 3], [1, 2, 3]) + + # Minor grid off by default + assert not ax._draw_minor_grid + + # Enabling sets the flag and stores kwargs + ax.grid(which='minor', color='lightgrey', linewidth=0.3) + assert ax._draw_minor_grid + assert ax._minor_grid_kwargs['linewidth'] == 0.3 + + # Major grid unaffected + assert ax._draw_grid + + plt.close(fig) + + @mpl3d_image_comparison(['invisible_ticks_axis.png'], style='mpl20') def test_invisible_ticks_axis(): fig = plt.figure()