diff --git a/doc/users/next_whats_new/3d_plot_aspects.rst b/doc/users/next_whats_new/3d_plot_aspects.rst new file mode 100644 index 000000000000..ebd7afae92ef --- /dev/null +++ b/doc/users/next_whats_new/3d_plot_aspects.rst @@ -0,0 +1,32 @@ +Set equal aspect ratio for 3D plots +----------------------------------- + +Users can set the aspect ratio for the X, Y, Z axes of a 3D plot to be 'equal', +'equalxy', 'equalxz', or 'equalyz' rather than the default of 'auto'. + +.. plot:: + :include-source: true + + import matplotlib.pyplot as plt + import numpy as np + from itertools import combinations, product + + aspects = ('auto', 'equal', 'equalxy', 'equalyz', 'equalxz') + fig, axs = plt.subplots(1, len(aspects), subplot_kw={'projection': '3d'}) + + # Draw rectangular cuboid with side lengths [1, 1, 5] + r = [0, 1] + scale = np.array([1, 1, 5]) + pts = combinations(np.array(list(product(r, r, r))), 2) + for start, end in pts: + if np.sum(np.abs(start - end)) == r[1] - r[0]: + for ax in axs: + ax.plot3D(*zip(start*scale, end*scale), color='C0') + + # Set the aspect ratios + for i, ax in enumerate(axs): + ax.set_box_aspect((3, 4, 5)) + ax.set_aspect(aspects[i]) + ax.set_title("set_aspect('{aspects[i]}')") + + plt.show() diff --git a/examples/mplot3d/surface3d_2.py b/examples/mplot3d/surface3d_2.py index d5cfca1905a3..bca9f1ca62e8 100644 --- a/examples/mplot3d/surface3d_2.py +++ b/examples/mplot3d/surface3d_2.py @@ -23,4 +23,7 @@ # Plot the surface ax.plot_surface(x, y, z) +# Set an equal aspect ratio +ax.set_aspect('equal') + plt.show() diff --git a/examples/mplot3d/voxels_numpy_logo.py b/examples/mplot3d/voxels_numpy_logo.py index 9aa72b31e290..8b790d073988 100644 --- a/examples/mplot3d/voxels_numpy_logo.py +++ b/examples/mplot3d/voxels_numpy_logo.py @@ -42,5 +42,6 @@ def explode(data): ax = plt.figure().add_subplot(projection='3d') ax.voxels(x, y, z, filled_2, facecolors=fcolors_2, edgecolors=ecolors_2) +ax.set_aspect('equal') plt.show() diff --git a/examples/mplot3d/voxels_rgb.py b/examples/mplot3d/voxels_rgb.py index a06cc10a5cc1..31bfcbca15a9 100644 --- a/examples/mplot3d/voxels_rgb.py +++ b/examples/mplot3d/voxels_rgb.py @@ -39,5 +39,6 @@ def midpoints(x): edgecolors=np.clip(2*colors - 0.5, 0, 1), # brighter linewidth=0.5) ax.set(xlabel='r', ylabel='g', zlabel='b') +ax.set_aspect('equal') plt.show() diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 54a944bf1726..44a7f16bfa18 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -272,22 +272,19 @@ def set_aspect(self, aspect, adjustable=None, anchor=None, share=False): """ Set the aspect ratios. - Axes 3D does not current support any aspect but 'auto' which fills - the Axes with the data limits. - - To simulate having equal aspect in data space, set the ratio - of your data limits to match the value of `.get_box_aspect`. - To control box aspect ratios use `~.Axes3D.set_box_aspect`. - Parameters ---------- - aspect : {'auto'} + aspect : {'auto', 'equal', 'equalxy', 'equalxz', 'equalyz'} Possible values: ========= ================================================== value description ========= ================================================== 'auto' automatic; fill the position rectangle with data. + 'equal' adapt all the axes to have equal aspect ratios. + 'equalxy' adapt the x and y axes to have equal aspect ratios. + 'equalxz' adapt the x and z axes to have equal aspect ratios. + 'equalyz' adapt the y and z axes to have equal aspect ratios. ========= ================================================== adjustable : None @@ -321,13 +318,36 @@ def set_aspect(self, aspect, adjustable=None, anchor=None, share=False): -------- mpl_toolkits.mplot3d.axes3d.Axes3D.set_box_aspect """ - if aspect != 'auto': - raise NotImplementedError( - "Axes3D currently only supports the aspect argument " - f"'auto'. You passed in {aspect!r}." - ) + _api.check_in_list(('auto', 'equal', 'equalxy', 'equalyz', 'equalxz'), + aspect=aspect) super().set_aspect( - aspect, adjustable=adjustable, anchor=anchor, share=share) + aspect='auto', adjustable=adjustable, anchor=anchor, share=share) + self._aspect = aspect + + if aspect in ('equal', 'equalxy', 'equalxz', 'equalyz'): + if aspect == 'equal': + ax_indices = [0, 1, 2] + elif aspect == 'equalxy': + ax_indices = [0, 1] + elif aspect == 'equalxz': + ax_indices = [0, 2] + elif aspect == 'equalyz': + ax_indices = [1, 2] + + view_intervals = np.array([self.xaxis.get_view_interval(), + self.yaxis.get_view_interval(), + self.zaxis.get_view_interval()]) + mean = np.mean(view_intervals, axis=1) + ptp = np.ptp(view_intervals, axis=1) + delta = max(ptp[ax_indices]) + scale = self._box_aspect[ptp == delta][0] + deltas = delta * self._box_aspect / scale + + for i, set_lim in enumerate((self.set_xlim3d, + self.set_ylim3d, + self.set_zlim3d)): + if i in ax_indices: + set_lim(mean[i] - deltas[i]/2., mean[i] + deltas[i]/2.) def set_box_aspect(self, aspect, *, zoom=1): """ diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/aspects.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/aspects.png new file mode 100644 index 000000000000..3bb088e2d131 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/aspects.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index 53a5ff91f271..29f9e6620360 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -29,11 +29,22 @@ def test_invisible_axes(fig_test, fig_ref): ax.set_visible(False) -def test_aspect_equal_error(): - fig = plt.figure() - ax = fig.add_subplot(projection='3d') - with pytest.raises(NotImplementedError): - ax.set_aspect('equal') +@mpl3d_image_comparison(['aspects.png'], remove_text=False) +def test_aspects(): + aspects = ('auto', 'equal', 'equalxy', 'equalyz', 'equalxz') + fig, axs = plt.subplots(1, len(aspects), subplot_kw={'projection': '3d'}) + + # Draw rectangular cuboid with side lengths [1, 1, 5] + r = [0, 1] + scale = np.array([1, 1, 5]) + pts = itertools.combinations(np.array(list(itertools.product(r, r, r))), 2) + for start, end in pts: + if np.sum(np.abs(start - end)) == r[1] - r[0]: + for ax in axs: + ax.plot3D(*zip(start*scale, end*scale)) + for i, ax in enumerate(axs): + ax.set_box_aspect((3, 4, 5)) + ax.set_aspect(aspects[i]) def test_axes3d_repr():