Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a24160f

Browse files
Factor repeated code into proj3d functions
1 parent 24ab58b commit a24160f

2 files changed

Lines changed: 93 additions & 82 deletions

File tree

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 19 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -72,33 +72,6 @@ def get_dir_vector(zdir):
7272
raise ValueError("'x', 'y', 'z', None or vector of length 3 expected")
7373

7474

75-
def _apply_scale_transforms(xs, ys, zs, axes):
76-
"""
77-
Apply axis scale transforms to 3D coordinates.
78-
79-
Transforms data coordinates to transformed coordinates (applying log,
80-
symlog, etc.) for 3D projection. Preserves masked arrays.
81-
"""
82-
def transform_coord(coord, axis):
83-
coord = np.asanyarray(coord)
84-
data = np.ma.getdata(coord).ravel()
85-
return axis.get_transform().transform(data).reshape(coord.shape)
86-
87-
xs_scaled = transform_coord(xs, axes.xaxis)
88-
ys_scaled = transform_coord(ys, axes.yaxis)
89-
zs_scaled = transform_coord(zs, axes.zaxis)
90-
91-
# Preserve combined mask from any masked input
92-
masks = [np.ma.getmask(a) for a in [xs, ys, zs]]
93-
if any(m is not np.ma.nomask for m in masks):
94-
combined = np.ma.mask_or(np.ma.mask_or(masks[0], masks[1]), masks[2])
95-
xs_scaled = np.ma.array(xs_scaled, mask=combined)
96-
ys_scaled = np.ma.array(ys_scaled, mask=combined)
97-
zs_scaled = np.ma.array(zs_scaled, mask=combined)
98-
99-
return xs_scaled, ys_scaled, zs_scaled
100-
101-
10275
def _viewlim_mask(xs, ys, zs, axes):
10376
"""
10477
Return the mask of the points outside the axes view limits.
@@ -225,16 +198,9 @@ def draw(self, renderer):
225198
else:
226199
pos3d = np.array([self._x, self._y, self._z], dtype=float)
227200

228-
# Apply scale transforms before projection
229-
pos3d_scaled = np.array(_apply_scale_transforms(
230-
pos3d[0], pos3d[1], pos3d[2], self.axes))
231-
# Also scale the direction vector endpoint
232201
dir_end = pos3d + self._dir_vec
233-
dir_end_scaled = np.array(_apply_scale_transforms(
234-
dir_end[0], dir_end[1], dir_end[2], self.axes))
235-
236-
proj = proj3d._proj_trans_points(
237-
[pos3d_scaled, dir_end_scaled], self.axes.M)
202+
proj = proj3d._proj_trans_points_scaled(
203+
[pos3d, dir_end], self.axes)
238204
dx = proj[0][1] - proj[0][0]
239205
dy = proj[1][1] - proj[1][0]
240206
angle = math.degrees(math.atan2(dy, dx))
@@ -370,11 +336,8 @@ def draw(self, renderer):
370336
dtype=float, mask=mask).filled(np.nan)
371337
else:
372338
xs3d, ys3d, zs3d = self._verts3d
373-
# Apply scale transforms before projection
374-
xs3d, ys3d, zs3d = _apply_scale_transforms(xs3d, ys3d, zs3d, self.axes)
375-
xs, ys, zs, tis = proj3d._proj_transform_clip(xs3d, ys3d, zs3d,
376-
self.axes.M,
377-
self.axes._focal_length)
339+
xs, ys, zs, tis = proj3d._proj_transform_clip_scaled(
340+
xs3d, ys3d, zs3d, self.axes)
378341
self.set_data(xs, ys)
379342
super().draw(renderer)
380343
self.stale = False
@@ -465,13 +428,8 @@ def do_3d_projection(self):
465428
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
466429
_viewlim_mask(*vs.T, self.axes), vs.shape))
467430
for vs in vs_list]
468-
# Apply scale transforms before projection
469-
xyzs_list = []
470-
for vs in vs_list:
471-
xs_scaled, ys_scaled, zs_scaled = _apply_scale_transforms(
472-
vs[:, 0], vs[:, 1], vs[:, 2], self.axes)
473-
xyzs_list.append(proj3d.proj_transform(
474-
xs_scaled, ys_scaled, zs_scaled, self.axes.M))
431+
xyzs_list = [proj3d._proj_transform_scaled(
432+
vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list]
475433
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
476434
for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)]
477435
zs = np.concatenate([zs for _, _, zs in xyzs_list])
@@ -561,13 +519,9 @@ def do_3d_projection(self):
561519
(*viewlim_mask.shape, 3))
562520
mask = mask | viewlim_mask
563521

564-
# Apply scale transforms before projection
565-
xs_scaled, ys_scaled, zs_scaled = _apply_scale_transforms(
566-
segments[..., 0], segments[..., 1], segments[..., 2], self.axes)
567-
segments_scaled = np.stack([xs_scaled, ys_scaled, zs_scaled], axis=-1)
568-
569-
xyzs = np.ma.array(proj3d._proj_transform_vectors(segments_scaled, self.axes.M),
570-
mask=mask)
522+
xyzs = np.ma.array(
523+
proj3d._proj_transform_vectors_scaled(segments, self.axes),
524+
mask=mask)
571525
segments_2d = xyzs[..., 0:2]
572526
LineCollection.set_segments(self, segments_2d)
573527

@@ -650,11 +604,8 @@ def do_3d_projection(self):
650604
dtype=float, mask=mask).filled(np.nan)
651605
else:
652606
xs, ys, zs = zip(*s)
653-
# Apply scale transforms before projection
654-
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, self.axes)
655-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
656-
self.axes.M,
657-
self.axes._focal_length)
607+
vxs, vys, vzs, vis = proj3d._proj_transform_clip_scaled(
608+
xs, ys, zs, self.axes)
658609
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]))
659610
return min(vzs)
660611

@@ -714,11 +665,8 @@ def do_3d_projection(self):
714665
dtype=float, mask=mask).filled(np.nan)
715666
else:
716667
xs, ys, zs = zip(*s)
717-
# Apply scale transforms before projection
718-
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, self.axes)
719-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
720-
self.axes.M,
721-
self.axes._focal_length)
668+
vxs, vys, vzs, vis = proj3d._proj_transform_clip_scaled(
669+
xs, ys, zs, self.axes)
722670
self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d)
723671
return min(vzs)
724672

@@ -861,11 +809,8 @@ def do_3d_projection(self):
861809
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
862810
else:
863811
xs, ys, zs = self._offsets3d
864-
# Apply scale transforms before projection
865-
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, self.axes)
866-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
867-
self.axes.M,
868-
self.axes._focal_length)
812+
vxs, vys, vzs, vis = proj3d._proj_transform_clip_scaled(
813+
xs, ys, zs, self.axes)
869814
self._vzs = vzs
870815
if np.ma.isMA(vxs):
871816
super().set_offsets(np.ma.column_stack([vxs, vys]))
@@ -1081,11 +1026,8 @@ def do_3d_projection(self):
10811026
xyzs = np.ma.array(self._offsets3d, mask=mask)
10821027
else:
10831028
xyzs = self._offsets3d
1084-
# Apply scale transforms before projection
1085-
xyzs_scaled = _apply_scale_transforms(*xyzs, self.axes)
1086-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(*xyzs_scaled,
1087-
self.axes.M,
1088-
self.axes._focal_length)
1029+
vxs, vys, vzs, vis = proj3d._proj_transform_clip_scaled(
1030+
*xyzs, self.axes)
10891031
self._data_scale = _get_data_scale(vxs, vys, vzs)
10901032
# Sort the points based on z coordinates
10911033
# Performance optimization: Create a sorted index array and reorder
@@ -1416,16 +1358,11 @@ def do_3d_projection(self):
14161358
num_faces = len(self._faces)
14171359
mask = self._invalid_vertices
14181360

1419-
# Apply scale transforms to faces before projection
1420-
xs_scaled, ys_scaled, zs_scaled = _apply_scale_transforms(
1421-
self._faces[..., 0], self._faces[..., 1], self._faces[..., 2],
1422-
self.axes)
1423-
faces_scaled = np.stack([xs_scaled, ys_scaled, zs_scaled], axis=-1)
1424-
14251361
# Some faces might contain masked vertices, so we want to ignore any
14261362
# errors that those might cause
14271363
with np.errstate(invalid='ignore', divide='ignore'):
1428-
pfaces = proj3d._proj_transform_vectors(faces_scaled, self.axes.M)
1364+
pfaces = proj3d._proj_transform_vectors_scaled(
1365+
self._faces, self.axes)
14291366

14301367
if self._axlim_clip:
14311368
viewlim_mask = _viewlim_mask(self._faces[..., 0], self._faces[..., 1],

lib/mpl_toolkits/mplot3d/proj3d.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,33 @@ def _ortho_transformation(zfront, zback):
131131
return proj_matrix
132132

133133

134+
def _apply_scale_transforms(xs, ys, zs, axes):
135+
"""
136+
Apply axis scale transforms to 3D coordinates.
137+
138+
Transforms data coordinates to transformed coordinates (applying log,
139+
symlog, etc.) for 3D projection. Preserves masked arrays.
140+
"""
141+
def transform_coord(coord, axis):
142+
coord = np.asanyarray(coord)
143+
data = np.ma.getdata(coord).ravel()
144+
return axis.get_transform().transform(data).reshape(coord.shape)
145+
146+
xs_scaled = transform_coord(xs, axes.xaxis)
147+
ys_scaled = transform_coord(ys, axes.yaxis)
148+
zs_scaled = transform_coord(zs, axes.zaxis)
149+
150+
# Preserve combined mask from any masked input
151+
masks = [np.ma.getmask(a) for a in [xs, ys, zs]]
152+
if any(m is not np.ma.nomask for m in masks):
153+
combined = np.ma.mask_or(np.ma.mask_or(masks[0], masks[1]), masks[2])
154+
xs_scaled = np.ma.array(xs_scaled, mask=combined)
155+
ys_scaled = np.ma.array(ys_scaled, mask=combined)
156+
zs_scaled = np.ma.array(zs_scaled, mask=combined)
157+
158+
return xs_scaled, ys_scaled, zs_scaled
159+
160+
134161
def _proj_transform_vec(vec, M):
135162
vecw = np.dot(M, vec.data)
136163
ts = vecw[0:3]/vecw[3]
@@ -234,3 +261,50 @@ def _proj_trans_points(points, M):
234261
points = np.asanyarray(points)
235262
xs, ys, zs = points[:, 0], points[:, 1], points[:, 2]
236263
return proj_transform(xs, ys, zs, M)
264+
265+
266+
def _proj_transform_clip_scaled(xs, ys, zs, axes):
267+
"""
268+
Apply scale transforms and project with clipping.
269+
270+
Combines `_apply_scale_transforms` and `_proj_transform_clip` into a
271+
single call. Returns txs, tys, tzs, tis.
272+
"""
273+
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, axes)
274+
return _proj_transform_clip(xs, ys, zs, axes.M, axes._focal_length)
275+
276+
277+
def _proj_transform_vectors_scaled(vecs, axes):
278+
"""
279+
Apply scale transforms and project vectors.
280+
281+
Combines `_apply_scale_transforms` and `_proj_transform_vectors` into a
282+
single call. *vecs* has shape ``(..., 3)``.
283+
"""
284+
xs, ys, zs = _apply_scale_transforms(
285+
vecs[..., 0], vecs[..., 1], vecs[..., 2], axes)
286+
vecs = np.stack([xs, ys, zs], axis=-1)
287+
return _proj_transform_vectors(vecs, axes.M)
288+
289+
290+
def _proj_transform_scaled(xs, ys, zs, axes):
291+
"""
292+
Apply scale transforms and project.
293+
294+
Combines `_apply_scale_transforms` and `proj_transform` into a single
295+
call. Returns txs, tys, tzs.
296+
"""
297+
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, axes)
298+
return proj_transform(xs, ys, zs, axes.M)
299+
300+
301+
def _proj_trans_points_scaled(points, axes):
302+
"""
303+
Apply scale transforms and project points.
304+
305+
Combines `_apply_scale_transforms` and `_proj_trans_points` into a single
306+
call.
307+
"""
308+
points = np.asanyarray(points)
309+
xs, ys, zs = points[:, 0], points[:, 1], points[:, 2]
310+
return _proj_transform_scaled(xs, ys, zs, axes)

0 commit comments

Comments
 (0)