@@ -56,7 +56,7 @@ class Axes3D(Axes):
56
56
def __init__ (
57
57
self , fig , rect = None , * args ,
58
58
elev = 30 , azim = - 60 , roll = 0 , sharez = None , proj_type = 'persp' ,
59
- box_aspect = None , computed_zorder = True ,
59
+ box_aspect = None , computed_zorder = True , focal_length = None ,
60
60
** kwargs ):
61
61
"""
62
62
Parameters
@@ -103,6 +103,11 @@ def __init__(
103
103
This behavior is deprecated in 3.4, the default will
104
104
change to False in 3.5. The keyword will be undocumented
105
105
and a non-False value will be an error in 3.6.
106
+ focal_length : float, default: None
107
+ For a projection type of 'persp', the focal length of the virtual
108
+ camera. Must be > 0. If None, defaults to 1.
109
+ The focal length can be computed from a desired Field Of View via
110
+ the equation: focal_length = 1/tan(FOV/2)
106
111
107
112
**kwargs
108
113
Other optional keyword arguments:
@@ -116,7 +121,7 @@ def __init__(
116
121
self .initial_azim = azim
117
122
self .initial_elev = elev
118
123
self .initial_roll = roll
119
- self .set_proj_type (proj_type )
124
+ self .set_proj_type (proj_type , focal_length )
120
125
self .computed_zorder = computed_zorder
121
126
122
127
self .xy_viewLim = Bbox .unit ()
@@ -1027,18 +1032,36 @@ def view_init(self, elev=None, azim=None, roll=None, vertical_axis="z"):
1027
1032
dict (x = 0 , y = 1 , z = 2 ), vertical_axis = vertical_axis
1028
1033
)
1029
1034
1030
- def set_proj_type (self , proj_type ):
1035
+ def set_proj_type (self , proj_type , focal_length = None ):
1031
1036
"""
1032
1037
Set the projection type.
1033
1038
1034
1039
Parameters
1035
1040
----------
1036
1041
proj_type : {'persp', 'ortho'}
1037
- """
1038
- self ._projection = _api .check_getitem ({
1039
- 'persp' : proj3d .persp_transformation ,
1040
- 'ortho' : proj3d .ortho_transformation ,
1041
- }, proj_type = proj_type )
1042
+ The projection type.
1043
+ focal_length : float, default: None
1044
+ For a projection type of 'persp', the focal length of the virtual
1045
+ camera. Must be > 0. If None, defaults to 1.
1046
+ The focal length can be computed from a desired Field Of View via
1047
+ the equation: focal_length = 1/tan(FOV/2)
1048
+ """
1049
+ if proj_type == 'persp' :
1050
+ if focal_length is None :
1051
+ self .focal_length = 1
1052
+ else :
1053
+ if focal_length <= 0 :
1054
+ raise ValueError (f"focal_length = { focal_length } must be" +
1055
+ " greater than 0" )
1056
+ self .focal_length = focal_length
1057
+ elif proj_type == 'ortho' :
1058
+ if focal_length not in (None , np .inf ):
1059
+ raise ValueError (f"focal_length = { focal_length } must be" +
1060
+ f"None for proj_type = { proj_type } " )
1061
+ self .focal_length = np .inf
1062
+ else :
1063
+ raise ValueError (f"proj_type = { proj_type } must be in" +
1064
+ f"{ 'persp' , 'ortho' } " )
1042
1065
1043
1066
def _roll_to_vertical (self , arr ):
1044
1067
"""Roll arrays to match the different vertical axis."""
@@ -1094,8 +1117,21 @@ def get_proj(self):
1094
1117
V = np .zeros (3 )
1095
1118
V [self ._vertical_axis ] = - 1 if abs (elev_rad ) > 0.5 * np .pi else 1
1096
1119
1097
- viewM = proj3d .view_transformation (eye , R , V , roll_rad )
1098
- projM = self ._projection (- self .dist , self .dist )
1120
+ # Generate the view and projection transformation matrices
1121
+ if self .focal_length == np .inf :
1122
+ # Orthographic projection
1123
+ viewM = proj3d .view_transformation (eye , R , V , roll_rad )
1124
+ projM = proj3d .ortho_transformation (- self .dist , self .dist )
1125
+ else :
1126
+ # Perspective projection
1127
+ # Scale the eye dist to compensate for the focal length zoom effect
1128
+ eye_focal = R + self .dist * ps * self .focal_length
1129
+ viewM = proj3d .view_transformation (eye_focal , R , V , roll_rad )
1130
+ projM = proj3d .persp_transformation (- self .dist ,
1131
+ self .dist ,
1132
+ self .focal_length )
1133
+
1134
+ # Combine all the transformation matrices to get the final projection
1099
1135
M0 = np .dot (viewM , worldM )
1100
1136
M = np .dot (projM , M0 )
1101
1137
return M
@@ -1158,7 +1194,7 @@ def cla(self):
1158
1194
pass
1159
1195
1160
1196
self ._autoscaleZon = True
1161
- if self ._projection is proj3d . ortho_transformation :
1197
+ if self .focal_length == np . inf :
1162
1198
self ._zmargin = rcParams ['axes.zmargin' ]
1163
1199
else :
1164
1200
self ._zmargin = 0.
0 commit comments