diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index 65e268461d98..24c595c04ca9 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -244,11 +244,23 @@ def _get_coord_info(self, renderer): bounds_proj = self.axes.tunit_cube(bounds, self.axes.M) # Determine which one of the parallel planes are higher up: - highs = np.zeros(3, dtype=bool) + means_z0 = np.zeros(3) + means_z1 = np.zeros(3) for i in range(3): - mean_z0 = np.mean(bounds_proj[self._PLANES[2 * i], 2]) - mean_z1 = np.mean(bounds_proj[self._PLANES[2 * i + 1], 2]) - highs[i] = mean_z0 < mean_z1 + means_z0[i] = np.mean(bounds_proj[self._PLANES[2 * i], 2]) + means_z1[i] = np.mean(bounds_proj[self._PLANES[2 * i + 1], 2]) + highs = means_z0 < means_z1 + + # Special handling for edge-on views + equals = np.abs(means_z0 - means_z1) <= np.finfo(float).eps + if np.sum(equals) == 2: + vertical = np.where(~equals)[0][0] + if vertical == 2: # looking at XY plane + highs = np.array([True, True, highs[2]]) + elif vertical == 1: # looking at XZ plane + highs = np.array([True, highs[1], False]) + elif vertical == 0: # looking at YZ plane + highs = np.array([highs[0], False, False]) return mins, maxs, centers, deltas, bounds_proj, highs diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_primary_views.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_primary_views.png new file mode 100644 index 000000000000..025156f34d39 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_primary_views.png differ diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_rotated.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_rotated.png index 0c79fd32e42c..9e7193d6b326 100644 Binary files a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_rotated.png and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_rotated.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index ea5a000f0d70..eb9578c57b2a 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -60,6 +60,27 @@ def test_axes3d_repr(): "title={'center': 'title'}, xlabel='x', ylabel='y', zlabel='z'>") +@mpl3d_image_comparison(['axes3d_primary_views.png']) +def test_axes3d_primary_views(): + # (elev, azim, roll) + views = [(90, -90, 0), # XY + (0, -90, 0), # XZ + (0, 0, 0), # YZ + (-90, 90, 0), # -XY + (0, 90, 0), # -XZ + (0, 180, 0)] # -YZ + # When viewing primary planes, draw the two visible axes so they intersect + # at their low values + fig, axs = plt.subplots(2, 3, subplot_kw={'projection': '3d'}) + for i, ax in enumerate(axs.flat): + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_zlabel('z') + ax.set_proj_type('ortho') + ax.view_init(elev=views[i][0], azim=views[i][1], roll=views[i][2]) + plt.tight_layout() + + @mpl3d_image_comparison(['bar3d.png']) def test_bar3d(): fig = plt.figure() @@ -1839,9 +1860,9 @@ def test_scatter_spiral(): [0.0, 0.0, -1.142857, 10.571429], ], [ - ([0.06329114, -0.06329114], [-0.04746835, -0.04746835]), - ([-0.06329114, -0.06329114], [0.04746835, -0.04746835]), - ([0.05617978, 0.06329114], [-0.04213483, -0.04746835]), + ([-0.06329114, 0.06329114], [0.04746835, 0.04746835]), + ([0.06329114, 0.06329114], [-0.04746835, 0.04746835]), + ([-0.05617978, -0.06329114], [0.04213483, 0.04746835]), ], [2, 2, 0], ), @@ -1854,9 +1875,9 @@ def test_scatter_spiral(): [0.0, -1.142857, 0.0, 10.571429], ], [ - ([-0.06329114, -0.06329114], [-0.04746835, 0.04746835]), - ([0.06329114, 0.05617978], [-0.04746835, -0.04213483]), - ([0.06329114, -0.06329114], [-0.04746835, -0.04746835]), + ([-0.06329114, -0.06329114], [0.04746835, -0.04746835]), + ([0.06329114, 0.05617978], [0.04746835, 0.04213483]), + ([0.06329114, -0.06329114], [0.04746835, 0.04746835]), ], [1, 2, 1], ),