diff --git a/doc/users/next_whats_new/3d_plot_focal_length.rst b/doc/users/next_whats_new/3d_plot_focal_length.rst new file mode 100644 index 000000000000..9422faa71546 --- /dev/null +++ b/doc/users/next_whats_new/3d_plot_focal_length.rst @@ -0,0 +1,31 @@ +Give the 3D camera a custom focal length +---------------------------------------- + +Users can now better mimic real-world cameras by specifying the focal length of +the virtual camera in 3D plots. The default focal length of 1 corresponds to a +Field of View (FOV) of 90 deg, and is backwards-compatible with existing 3D +plots. An increased focal length between 1 and infinity "flattens" the image, +while a decreased focal length between 1 and 0 exaggerates the perspective and +gives the image more apparent depth. + +The focal length can be calculated from a desired FOV via the equation: + +.. mathmpl:: + + focal\_length = 1/\tan(FOV/2) + +.. plot:: + :include-source: true + + from mpl_toolkits.mplot3d import axes3d + import matplotlib.pyplot as plt + fig, axs = plt.subplots(1, 3, subplot_kw={'projection': '3d'}, + constrained_layout=True) + X, Y, Z = axes3d.get_test_data(0.05) + focal_lengths = [0.25, 1, 4] + for ax, fl in zip(axs, focal_lengths): + ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10) + ax.set_proj_type('persp', focal_length=fl) + ax.set_title(f"focal_length = {fl}") + plt.tight_layout() + plt.show() diff --git a/examples/mplot3d/projections.py b/examples/mplot3d/projections.py new file mode 100644 index 000000000000..0368ef68de9e --- /dev/null +++ b/examples/mplot3d/projections.py @@ -0,0 +1,55 @@ +""" +======================== +3D plot projection types +======================== + +Demonstrates the different camera projections for 3D plots, and the effects of +changing the focal length for a perspective projection. Note that Matplotlib +corrects for the 'zoom' effect of changing the focal length. + +The default focal length of 1 corresponds to a Field of View (FOV) of 90 deg. +An increased focal length between 1 and infinity "flattens" the image, while a +decreased focal length between 1 and 0 exaggerates the perspective and gives +the image more apparent depth. In the limiting case, a focal length of +infinity corresponds to an orthographic projection after correction of the +zoom effect. + +You can calculate focal length from a FOV via the equation: + +.. mathmpl:: + + 1 / \tan (FOV / 2) + +Or vice versa: + +.. mathmpl:: + + FOV = 2 * \atan (1 / focal length) + +""" + +from mpl_toolkits.mplot3d import axes3d +import matplotlib.pyplot as plt + + +fig, axs = plt.subplots(1, 3, subplot_kw={'projection': '3d'}) + +# Get the test data +X, Y, Z = axes3d.get_test_data(0.05) + +# Plot the data +for ax in axs: + ax.plot_wireframe(X, Y, Z, rstride=10, cstride=10) + +# Set the orthographic projection. +axs[0].set_proj_type('ortho') # FOV = 0 deg +axs[0].set_title("'ortho'\nfocal_length = ∞", fontsize=10) + +# Set the perspective projections +axs[1].set_proj_type('persp') # FOV = 90 deg +axs[1].set_title("'persp'\nfocal_length = 1 (default)", fontsize=10) + +axs[2].set_proj_type('persp', focal_length=0.2) # FOV = 157.4 deg +axs[2].set_title("'persp'\nfocal_length = 0.2", fontsize=10) + +plt.show() diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 1fcbdedea174..35f81a4628d9 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -57,7 +57,7 @@ class Axes3D(Axes): def __init__( self, fig, rect=None, *args, elev=30, azim=-60, roll=0, sharez=None, proj_type='persp', - box_aspect=None, computed_zorder=True, + box_aspect=None, computed_zorder=True, focal_length=None, **kwargs): """ Parameters @@ -104,6 +104,13 @@ def __init__( This behavior is deprecated in 3.4, the default will change to False in 3.5. The keyword will be undocumented and a non-False value will be an error in 3.6. + focal_length : float, default: None + For a projection type of 'persp', the focal length of the virtual + camera. Must be > 0. If None, defaults to 1. + For a projection type of 'ortho', must be set to either None + or infinity (numpy.inf). If None, defaults to infinity. + The focal length can be computed from a desired Field Of View via + the equation: focal_length = 1/tan(FOV/2) **kwargs Other optional keyword arguments: @@ -117,7 +124,7 @@ def __init__( self.initial_azim = azim self.initial_elev = elev self.initial_roll = roll - self.set_proj_type(proj_type) + self.set_proj_type(proj_type, focal_length) self.computed_zorder = computed_zorder self.xy_viewLim = Bbox.unit() @@ -989,18 +996,33 @@ def view_init(self, elev=None, azim=None, roll=None, vertical_axis="z"): dict(x=0, y=1, z=2), vertical_axis=vertical_axis ) - def set_proj_type(self, proj_type): + def set_proj_type(self, proj_type, focal_length=None): """ Set the projection type. Parameters ---------- proj_type : {'persp', 'ortho'} - """ - self._projection = _api.check_getitem({ - 'persp': proj3d.persp_transformation, - 'ortho': proj3d.ortho_transformation, - }, proj_type=proj_type) + The projection type. + focal_length : float, default: None + For a projection type of 'persp', the focal length of the virtual + camera. Must be > 0. If None, defaults to 1. + The focal length can be computed from a desired Field Of View via + the equation: focal_length = 1/tan(FOV/2) + """ + _api.check_in_list(['persp', 'ortho'], proj_type=proj_type) + if proj_type == 'persp': + if focal_length is None: + focal_length = 1 + elif focal_length <= 0: + raise ValueError(f"focal_length = {focal_length} must be " + "greater than 0") + self._focal_length = focal_length + elif proj_type == 'ortho': + if focal_length not in (None, np.inf): + raise ValueError(f"focal_length = {focal_length} must be " + f"None for proj_type = {proj_type}") + self._focal_length = np.inf def _roll_to_vertical(self, arr): """Roll arrays to match the different vertical axis.""" @@ -1056,8 +1078,21 @@ def get_proj(self): V = np.zeros(3) V[self._vertical_axis] = -1 if abs(elev_rad) > 0.5 * np.pi else 1 - viewM = proj3d.view_transformation(eye, R, V, roll_rad) - projM = self._projection(-self._dist, self._dist) + # Generate the view and projection transformation matrices + if self._focal_length == np.inf: + # Orthographic projection + viewM = proj3d.view_transformation(eye, R, V, roll_rad) + projM = proj3d.ortho_transformation(-self._dist, self._dist) + else: + # Perspective projection + # Scale the eye dist to compensate for the focal length zoom effect + eye_focal = R + self._dist * ps * self._focal_length + viewM = proj3d.view_transformation(eye_focal, R, V, roll_rad) + projM = proj3d.persp_transformation(-self._dist, + self._dist, + self._focal_length) + + # Combine all the transformation matrices to get the final projection M0 = np.dot(viewM, worldM) M = np.dot(projM, M0) return M @@ -1120,7 +1155,7 @@ def cla(self): pass self._autoscaleZon = True - if self._projection is proj3d.ortho_transformation: + if self._focal_length == np.inf: self._zmargin = rcParams['axes.zmargin'] else: self._zmargin = 0. diff --git a/lib/mpl_toolkits/mplot3d/proj3d.py b/lib/mpl_toolkits/mplot3d/proj3d.py index c7c2f93230be..7aa088d57e35 100644 --- a/lib/mpl_toolkits/mplot3d/proj3d.py +++ b/lib/mpl_toolkits/mplot3d/proj3d.py @@ -90,23 +90,27 @@ def view_transformation(E, R, V, roll): return np.dot(Mr, Mt) -def persp_transformation(zfront, zback): - a = (zfront+zback)/(zfront-zback) - b = -2*(zfront*zback)/(zfront-zback) - return np.array([[1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, a, b], - [0, 0, -1, 0]]) +def persp_transformation(zfront, zback, focal_length): + e = focal_length + a = 1 # aspect ratio + b = (zfront+zback)/(zfront-zback) + c = -2*(zfront*zback)/(zfront-zback) + proj_matrix = np.array([[e, 0, 0, 0], + [0, e/a, 0, 0], + [0, 0, b, c], + [0, 0, -1, 0]]) + return proj_matrix def ortho_transformation(zfront, zback): # note: w component in the resulting vector will be (zback-zfront), not 1 a = -(zfront + zback) b = -(zfront - zback) - return np.array([[2, 0, 0, 0], - [0, 2, 0, 0], - [0, 0, -2, 0], - [0, 0, a, b]]) + proj_matrix = np.array([[2, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, -2, 0], + [0, 0, a, b]]) + return proj_matrix def _proj_transform_vec(vec, M): diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_focal_length.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_focal_length.png new file mode 100644 index 000000000000..1d61e0a0c0f6 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/axes3d_focal_length.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index 273c9882921f..0c1682132380 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -871,7 +871,7 @@ def _test_proj_make_M(): V = np.array([0, 0, 1]) roll = 0 viewM = proj3d.view_transformation(E, R, V, roll) - perspM = proj3d.persp_transformation(100, -100) + perspM = proj3d.persp_transformation(100, -100, 1) M = np.dot(perspM, viewM) return M @@ -1036,6 +1036,22 @@ def test_unautoscale(axis, auto): np.testing.assert_array_equal(get_lim(), (-0.5, 0.5)) +def test_axes3d_focal_length_checks(): + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + with pytest.raises(ValueError): + ax.set_proj_type('persp', focal_length=0) + with pytest.raises(ValueError): + ax.set_proj_type('ortho', focal_length=1) + + +@mpl3d_image_comparison(['axes3d_focal_length.png'], remove_text=False) +def test_axes3d_focal_length(): + fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'}) + axs[0].set_proj_type('persp', focal_length=np.inf) + axs[1].set_proj_type('persp', focal_length=0.15) + + @mpl3d_image_comparison(['axes3d_ortho.png'], remove_text=False) def test_axes3d_ortho(): fig = plt.figure()