Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d9afaf6

Browse files
Fix 3D axes to properly support non-linear scales (log, symlog, etc.)
This fixes a long-standing issue where 3D plots ignored scale transforms. The methods set_xscale(), set_yscale(), set_zscale() existed but the transforms were never applied to data or coordinates. Key changes: - axes3d.py: Add _get_scaled_limits() to transform axis limits through scale transforms. Modify get_proj() to use scaled limits for world transformation. Override _update_transScale() to use identity transforms since 3D projection handles scales internally. Update autoscale_view() and _set_lim3d() to apply margins in transformed space. - art3d.py: Add _apply_scale_transforms() utility function. Update all do_3d_projection() methods to apply scale transforms before projection. - axis3d.py: Update _get_coord_info() to return scaled-space bounds. Modify _draw_ticks() to transform tick locations to scaled space. Update draw() and draw_grid() for proper coordinate handling. The fix ensures that: - Data coordinates are transformed through scale transforms before projection (e.g., log10 for log scale) - World transformation matrix maps scaled coordinates to unit cube - Axis ticks, labels, and grid lines position correctly - 2D display transforms remain linear (no double-transformation) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 29b3564 commit d9afaf6

3 files changed

Lines changed: 436 additions & 52 deletions

File tree

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,68 @@ def get_dir_vector(zdir):
7272
raise ValueError("'x', 'y', 'z', None or vector of length 3 expected")
7373

7474

75+
def _apply_scale_transforms(xs, ys, zs, axes):
76+
"""
77+
Apply scale transforms to 3D coordinates.
78+
79+
This transforms data coordinates through the axis scale transforms
80+
(log, symlog, etc.) to get scaled coordinates that can be properly
81+
projected through the world transformation.
82+
83+
Parameters
84+
----------
85+
xs, ys, zs : array-like
86+
The x, y, z coordinates in data space.
87+
axes : Axes3D
88+
The axes providing the scale transforms.
89+
90+
Returns
91+
-------
92+
xs_scaled, ys_scaled, zs_scaled : np.ndarray
93+
The coordinates in scaled space.
94+
"""
95+
xs = np.asarray(xs)
96+
ys = np.asarray(ys)
97+
zs = np.asarray(zs)
98+
99+
# Get the scale transforms for each axis
100+
x_trans = axes.xaxis.get_transform()
101+
y_trans = axes.yaxis.get_transform()
102+
z_trans = axes.zaxis.get_transform()
103+
104+
# Handle masked arrays by preserving the mask
105+
x_mask = np.ma.getmask(xs) if np.ma.isMA(xs) else False
106+
y_mask = np.ma.getmask(ys) if np.ma.isMA(ys) else False
107+
z_mask = np.ma.getmask(zs) if np.ma.isMA(zs) else False
108+
109+
# Get the data (without mask) for transformation
110+
xs_data = np.ma.getdata(xs).ravel() if np.ma.isMA(xs) else xs.ravel()
111+
ys_data = np.ma.getdata(ys).ravel() if np.ma.isMA(ys) else ys.ravel()
112+
zs_data = np.ma.getdata(zs).ravel() if np.ma.isMA(zs) else zs.ravel()
113+
114+
# Transform through scale
115+
xs_scaled = x_trans.transform(xs_data).reshape(xs.shape)
116+
ys_scaled = y_trans.transform(ys_data).reshape(ys.shape)
117+
zs_scaled = z_trans.transform(zs_data).reshape(zs.shape)
118+
119+
# Reapply masks if needed
120+
if x_mask is not False or y_mask is not False or z_mask is not False:
121+
combined_mask = x_mask | y_mask | z_mask
122+
xs_scaled = np.ma.array(xs_scaled, mask=combined_mask)
123+
ys_scaled = np.ma.array(ys_scaled, mask=combined_mask)
124+
zs_scaled = np.ma.array(zs_scaled, mask=combined_mask)
125+
126+
return xs_scaled, ys_scaled, zs_scaled
127+
128+
75129
def _viewlim_mask(xs, ys, zs, axes):
76130
"""
77131
Return the mask of the points outside the axes view limits.
78132
79133
Parameters
80134
----------
81135
xs, ys, zs : array-like
82-
The points to mask.
136+
The points to mask. These should be in data coordinates.
83137
axes : Axes3D
84138
The axes to use for the view limits.
85139
@@ -198,7 +252,16 @@ def draw(self, renderer):
198252
else:
199253
pos3d = np.array([self._x, self._y, self._z], dtype=float)
200254

201-
proj = proj3d._proj_trans_points([pos3d, pos3d + self._dir_vec], self.axes.M)
255+
# Apply scale transforms before projection
256+
pos3d_scaled = np.array(_apply_scale_transforms(
257+
pos3d[0], pos3d[1], pos3d[2], self.axes))
258+
# Also scale the direction vector endpoint
259+
dir_end = pos3d + self._dir_vec
260+
dir_end_scaled = np.array(_apply_scale_transforms(
261+
dir_end[0], dir_end[1], dir_end[2], self.axes))
262+
263+
proj = proj3d._proj_trans_points(
264+
[pos3d_scaled, dir_end_scaled], self.axes.M)
202265
dx = proj[0][1] - proj[0][0]
203266
dy = proj[1][1] - proj[1][0]
204267
angle = math.degrees(math.atan2(dy, dx))
@@ -334,6 +397,8 @@ def draw(self, renderer):
334397
dtype=float, mask=mask).filled(np.nan)
335398
else:
336399
xs3d, ys3d, zs3d = self._verts3d
400+
# Apply scale transforms before projection
401+
xs3d, ys3d, zs3d = _apply_scale_transforms(xs3d, ys3d, zs3d, self.axes)
337402
xs, ys, zs, tis = proj3d._proj_transform_clip(xs3d, ys3d, zs3d,
338403
self.axes.M,
339404
self.axes._focal_length)
@@ -427,7 +492,13 @@ def do_3d_projection(self):
427492
vs_list = [np.ma.array(vs, mask=np.broadcast_to(
428493
_viewlim_mask(*vs.T, self.axes), vs.shape))
429494
for vs in vs_list]
430-
xyzs_list = [proj3d.proj_transform(*vs.T, self.axes.M) for vs in vs_list]
495+
# Apply scale transforms before projection
496+
xyzs_list = []
497+
for vs in vs_list:
498+
xs_scaled, ys_scaled, zs_scaled = _apply_scale_transforms(
499+
vs[:, 0], vs[:, 1], vs[:, 2], self.axes)
500+
xyzs_list.append(proj3d.proj_transform(
501+
xs_scaled, ys_scaled, zs_scaled, self.axes.M))
431502
self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs)
432503
for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)]
433504
zs = np.concatenate([zs for _, _, zs in xyzs_list])
@@ -511,7 +582,13 @@ def do_3d_projection(self):
511582
viewlim_mask = np.broadcast_to(viewlim_mask[..., np.newaxis],
512583
(*viewlim_mask.shape, 3))
513584
mask = mask | viewlim_mask
514-
xyzs = np.ma.array(proj3d._proj_transform_vectors(segments, self.axes.M),
585+
586+
# Apply scale transforms before projection
587+
xs_scaled, ys_scaled, zs_scaled = _apply_scale_transforms(
588+
segments[..., 0], segments[..., 1], segments[..., 2], self.axes)
589+
segments_scaled = np.stack([xs_scaled, ys_scaled, zs_scaled], axis=-1)
590+
591+
xyzs = np.ma.array(proj3d._proj_transform_vectors(segments_scaled, self.axes.M),
515592
mask=mask)
516593
segments_2d = xyzs[..., 0:2]
517594
LineCollection.set_segments(self, segments_2d)
@@ -595,6 +672,8 @@ def do_3d_projection(self):
595672
dtype=float, mask=mask).filled(np.nan)
596673
else:
597674
xs, ys, zs = zip(*s)
675+
# Apply scale transforms before projection
676+
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, self.axes)
598677
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
599678
self.axes.M,
600679
self.axes._focal_length)
@@ -657,6 +736,8 @@ def do_3d_projection(self):
657736
dtype=float, mask=mask).filled(np.nan)
658737
else:
659738
xs, ys, zs = zip(*s)
739+
# Apply scale transforms before projection
740+
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, self.axes)
660741
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
661742
self.axes.M,
662743
self.axes._focal_length)
@@ -802,6 +883,8 @@ def do_3d_projection(self):
802883
xs, ys, zs = np.ma.array(self._offsets3d, mask=mask)
803884
else:
804885
xs, ys, zs = self._offsets3d
886+
# Apply scale transforms before projection
887+
xs, ys, zs = _apply_scale_transforms(xs, ys, zs, self.axes)
805888
vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs,
806889
self.axes.M,
807890
self.axes._focal_length)
@@ -1020,7 +1103,9 @@ def do_3d_projection(self):
10201103
xyzs = np.ma.array(self._offsets3d, mask=mask)
10211104
else:
10221105
xyzs = self._offsets3d
1023-
vxs, vys, vzs, vis = proj3d._proj_transform_clip(*xyzs,
1106+
# Apply scale transforms before projection
1107+
xyzs_scaled = _apply_scale_transforms(*xyzs, self.axes)
1108+
vxs, vys, vzs, vis = proj3d._proj_transform_clip(*xyzs_scaled,
10241109
self.axes.M,
10251110
self.axes._focal_length)
10261111
self._data_scale = _get_data_scale(vxs, vys, vzs)
@@ -1353,10 +1438,16 @@ def do_3d_projection(self):
13531438
num_faces = len(self._faces)
13541439
mask = self._invalid_vertices
13551440

1441+
# Apply scale transforms to faces before projection
1442+
xs_scaled, ys_scaled, zs_scaled = _apply_scale_transforms(
1443+
self._faces[..., 0], self._faces[..., 1], self._faces[..., 2],
1444+
self.axes)
1445+
faces_scaled = np.stack([xs_scaled, ys_scaled, zs_scaled], axis=-1)
1446+
13561447
# Some faces might contain masked vertices, so we want to ignore any
13571448
# errors that those might cause
13581449
with np.errstate(invalid='ignore', divide='ignore'):
1359-
pfaces = proj3d._proj_transform_vectors(self._faces, self.axes.M)
1450+
pfaces = proj3d._proj_transform_vectors(faces_scaled, self.axes.M)
13601451

13611452
if self._axlim_clip:
13621453
viewlim_mask = _viewlim_mask(self._faces[..., 0], self._faces[..., 1],

0 commit comments

Comments
 (0)