Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2e19d87

Browse files
Cleanup and consolidation
1 parent 280acb8 commit 2e19d87

3 files changed

Lines changed: 45 additions & 176 deletions

File tree

lib/mpl_toolkits/mplot3d/art3d.py

Lines changed: 19 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -74,55 +74,27 @@ def get_dir_vector(zdir):
7474

7575
def _apply_scale_transforms(xs, ys, zs, axes):
7676
"""
77-
Apply scale transforms to 3D coordinates.
77+
Apply axis scale transforms to 3D coordinates.
7878
79-
This transforms data coordinates through the axis scale transforms
80-
(log, symlog, etc.) to get scaled coordinates that can be properly
81-
projected through the world transformation.
82-
83-
Parameters
84-
----------
85-
xs, ys, zs : array-like
86-
The x, y, z coordinates in data space.
87-
axes : Axes3D
88-
The axes providing the scale transforms.
89-
90-
Returns
91-
-------
92-
xs_scaled, ys_scaled, zs_scaled : np.ndarray
93-
The coordinates in scaled space.
79+
Transforms data coordinates through scale transforms (log, symlog, etc.)
80+
to scaled coordinates for proper 3D projection. Preserves masked arrays.
9481
"""
95-
# Use asanyarray to preserve masked arrays
96-
xs = np.asanyarray(xs)
97-
ys = np.asanyarray(ys)
98-
zs = np.asanyarray(zs)
99-
100-
# Get the scale transforms for each axis
101-
x_trans = axes.xaxis.get_transform()
102-
y_trans = axes.yaxis.get_transform()
103-
z_trans = axes.zaxis.get_transform()
104-
105-
# Handle masked arrays by preserving the mask
106-
x_mask = np.ma.getmask(xs) if np.ma.isMA(xs) else False
107-
y_mask = np.ma.getmask(ys) if np.ma.isMA(ys) else False
108-
z_mask = np.ma.getmask(zs) if np.ma.isMA(zs) else False
109-
110-
# Get the data (without mask) for transformation
111-
xs_data = np.ma.getdata(xs).ravel() if np.ma.isMA(xs) else xs.ravel()
112-
ys_data = np.ma.getdata(ys).ravel() if np.ma.isMA(ys) else ys.ravel()
113-
zs_data = np.ma.getdata(zs).ravel() if np.ma.isMA(zs) else zs.ravel()
114-
115-
# Transform through scale
116-
xs_scaled = x_trans.transform(xs_data).reshape(xs.shape)
117-
ys_scaled = y_trans.transform(ys_data).reshape(ys.shape)
118-
zs_scaled = z_trans.transform(zs_data).reshape(zs.shape)
119-
120-
# Reapply masks if needed
121-
if x_mask is not False or y_mask is not False or z_mask is not False:
122-
combined_mask = x_mask | y_mask | z_mask
123-
xs_scaled = np.ma.array(xs_scaled, mask=combined_mask)
124-
ys_scaled = np.ma.array(ys_scaled, mask=combined_mask)
125-
zs_scaled = np.ma.array(zs_scaled, mask=combined_mask)
82+
def transform_coord(coord, axis):
83+
coord = np.asanyarray(coord)
84+
data = np.ma.getdata(coord).ravel()
85+
return axis.get_transform().transform(data).reshape(coord.shape)
86+
87+
xs_scaled = transform_coord(xs, axes.xaxis)
88+
ys_scaled = transform_coord(ys, axes.yaxis)
89+
zs_scaled = transform_coord(zs, axes.zaxis)
90+
91+
# Preserve combined mask from any masked input
92+
masks = [np.ma.getmask(a) for a in [xs, ys, zs]]
93+
if any(m is not np.ma.nomask for m in masks):
94+
combined = np.ma.mask_or(np.ma.mask_or(masks[0], masks[1]), masks[2])
95+
xs_scaled = np.ma.array(xs_scaled, mask=combined)
96+
ys_scaled = np.ma.array(ys_scaled, mask=combined)
97+
zs_scaled = np.ma.array(zs_scaled, mask=combined)
12698

12799
return xs_scaled, ys_scaled, zs_scaled
128100

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 8 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,35 +1329,6 @@ def _roll_to_vertical(
13291329
else:
13301330
return np.roll(arr, (self._vertical_axis - 2))
13311331

1332-
def _get_scale_transform(self, axis):
1333-
"""
1334-
Return the scale transform for the given axis.
1335-
1336-
For non-linear scales (log, symlog, etc.), this returns the
1337-
transform that maps data coordinates to scaled coordinates.
1338-
For linear scales, returns an IdentityTransform.
1339-
"""
1340-
return axis.get_transform()
1341-
1342-
def _transform_limit_to_scale(self, limit, transform):
1343-
"""
1344-
Transform a limit value through the scale transform.
1345-
1346-
Parameters
1347-
----------
1348-
limit : float
1349-
The limit value in data coordinates.
1350-
transform : Transform
1351-
The scale transform to apply.
1352-
1353-
Returns
1354-
-------
1355-
float
1356-
The limit value in scaled coordinates.
1357-
"""
1358-
# Transform the limit through the scale
1359-
return transform.transform([limit])[0]
1360-
13611332
def _get_scaled_limits(self):
13621333
"""
13631334
Get axis limits transformed through their respective scale transforms.
@@ -1368,59 +1339,20 @@ def _get_scaled_limits(self):
13681339
(xmin_scaled, xmax_scaled, ymin_scaled, ymax_scaled,
13691340
zmin_scaled, zmax_scaled)
13701341
"""
1371-
x_trans = self._get_scale_transform(self.xaxis)
1372-
y_trans = self._get_scale_transform(self.yaxis)
1373-
z_trans = self._get_scale_transform(self.zaxis)
1374-
1375-
xmin, xmax = self.get_xlim3d()
1376-
ymin, ymax = self.get_ylim3d()
1377-
zmin, zmax = self.get_zlim3d()
1378-
1379-
return (
1380-
self._transform_limit_to_scale(xmin, x_trans),
1381-
self._transform_limit_to_scale(xmax, x_trans),
1382-
self._transform_limit_to_scale(ymin, y_trans),
1383-
self._transform_limit_to_scale(ymax, y_trans),
1384-
self._transform_limit_to_scale(zmin, z_trans),
1385-
self._transform_limit_to_scale(zmax, z_trans),
1386-
)
1342+
xmin, xmax = self.xaxis.get_transform().transform(self.get_xlim3d())
1343+
ymin, ymax = self.yaxis.get_transform().transform(self.get_ylim3d())
1344+
zmin, zmax = self.zaxis.get_transform().transform(self.get_zlim3d())
1345+
return xmin, xmax, ymin, ymax, zmin, zmax
13871346

13881347
def _inverse_scale_transform(self, x, y, z):
13891348
"""
1390-
Apply inverse scale transforms to coordinates.
1349+
Apply inverse scale transforms to a point.
13911350
13921351
Converts from scaled space back to data space.
1393-
1394-
Parameters
1395-
----------
1396-
x, y, z : float or array-like
1397-
Coordinates in scaled space.
1398-
1399-
Returns
1400-
-------
1401-
x_data, y_data, z_data : float or ndarray
1402-
Coordinates in data space.
14031352
"""
1404-
x_inv = self.xaxis.get_transform().inverted()
1405-
y_inv = self.yaxis.get_transform().inverted()
1406-
z_inv = self.zaxis.get_transform().inverted()
1407-
1408-
x_arr = np.atleast_1d(x)
1409-
y_arr = np.atleast_1d(y)
1410-
z_arr = np.atleast_1d(z)
1411-
1412-
x_data = x_inv.transform(x_arr.ravel()).reshape(x_arr.shape)
1413-
y_data = y_inv.transform(y_arr.ravel()).reshape(y_arr.shape)
1414-
z_data = z_inv.transform(z_arr.ravel()).reshape(z_arr.shape)
1415-
1416-
# Return scalars if input was scalar (check original input)
1417-
if np.ndim(x) == 0:
1418-
x_data = float(x_data.flat[0])
1419-
if np.ndim(y) == 0:
1420-
y_data = float(y_data.flat[0])
1421-
if np.ndim(z) == 0:
1422-
z_data = float(z_data.flat[0])
1423-
1353+
x_data = self.xaxis.get_transform().inverted().transform([x])[0]
1354+
y_data = self.yaxis.get_transform().inverted().transform([y])[0]
1355+
z_data = self.zaxis.get_transform().inverted().transform([z])[0]
14241356
return x_data, y_data, z_data
14251357

14261358
def _set_lims_from_scaled(self, xmin_s, xmax_s, ymin_s, ymax_s,

lib/mpl_toolkits/mplot3d/axis3d.py

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -268,36 +268,14 @@ def get_rotate_label(self, text):
268268
return len(text) > 4
269269

270270
def _get_coord_info(self):
271-
# Get the data-space bounds
272-
data_mins, data_maxs = np.array([
273-
self.axes.get_xbound(),
274-
self.axes.get_ybound(),
275-
self.axes.get_zbound(),
276-
]).T
277-
278-
# Transform to scaled space for proper positioning with non-linear scales
279-
x_trans = self.axes.xaxis.get_transform()
280-
y_trans = self.axes.yaxis.get_transform()
281-
z_trans = self.axes.zaxis.get_transform()
282-
283-
mins = np.array([
284-
x_trans.transform([data_mins[0]])[0],
285-
y_trans.transform([data_mins[1]])[0],
286-
z_trans.transform([data_mins[2]])[0],
287-
])
288-
maxs = np.array([
289-
x_trans.transform([data_maxs[0]])[0],
290-
y_trans.transform([data_maxs[1]])[0],
291-
z_trans.transform([data_maxs[2]])[0],
292-
])
293-
294-
# Project the bounds along the current position of the cube:
295-
# Note: _transformed_cube expects data-space bounds and transforms them
296-
# internally
297-
bounds = (data_mins[0], data_maxs[0],
298-
data_mins[1], data_maxs[1],
299-
data_mins[2], data_maxs[2]
300-
)
271+
# Get scaled limits directly from the axes helper
272+
xmin, xmax, ymin, ymax, zmin, zmax = self.axes._get_scaled_limits()
273+
mins = np.array([xmin, ymin, zmin])
274+
maxs = np.array([xmax, ymax, zmax])
275+
276+
# Get data-space bounds for _transformed_cube
277+
bounds = (*self.axes.get_xbound(), *self.axes.get_ybound(),
278+
*self.axes.get_zbound())
301279
bounds_proj = self.axes._transformed_cube(bounds)
302280

303281
# Determine which one of the parallel planes are higher up:
@@ -484,25 +462,20 @@ def _draw_ticks(self, renderer, edgep1, centers, deltas, highs,
484462

485463
default_label_offset = 8. # A rough estimate
486464
points = deltas_per_point * deltas
465+
# All coordinates below are in scaled space for proper projection
487466
for tick in ticks:
488467
# Get tick line positions
489-
# edgep1 is already in scaled space (from _get_coord_info)
490468
pos = edgep1.copy()
491-
# Transform tick location from data space to scaled space
492-
tick_loc_scaled = axis_trans.transform([tick.get_loc()])[0]
493-
pos[index] = tick_loc_scaled
469+
pos[index] = axis_trans.transform([tick.get_loc()])[0]
494470
pos[tickdir] = out_tickdir
495-
# pos is already in scaled space, project directly
496471
x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M)
497472
pos[tickdir] = in_tickdir
498473
x2, y2, z2 = proj3d.proj_transform(*pos, self.axes.M)
499474

500475
# Get position of label
501476
labeldeltas = (tick.get_pad() + default_label_offset) * points
502-
503477
pos[tickdir] = edgep1_tickdir
504478
pos = _move_from_center(pos, centers, labeldeltas, self._axmask())
505-
# pos is already in scaled space, project directly
506479
lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M)
507480

508481
_tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
@@ -528,7 +501,6 @@ def _draw_offset_text(self, renderer, edgep1, edgep2, labeldeltas, centers,
528501

529502
pos = _move_from_center(outeredgep, centers, labeldeltas,
530503
self._axmask())
531-
# pos is already in scaled space, project directly
532504
olx, oly, olz = proj3d.proj_transform(*pos, self.axes.M)
533505
self.offsetText.set_text(self.major.formatter.get_offset())
534506
self.offsetText.set_position((olx, oly))
@@ -553,7 +525,6 @@ def _draw_offset_text(self, renderer, edgep1, edgep2, labeldeltas, centers,
553525
# Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
554526
# from the variable 'highs'.
555527
# ---------------------------------------------------------------------
556-
# centers is already in scaled space (from _get_coord_info)
557528
centpt = proj3d.proj_transform(*centers, self.axes.M)
558529
if centpt[tickdir] > pep[tickdir, outerindex]:
559530
# if FT and if highs has an even number of Trues
@@ -586,10 +557,8 @@ def _draw_labels(self, renderer, edgep1, edgep2, labeldeltas, centers, dx, dy):
586557
label = self._axinfo["label"]
587558

588559
# Draw labels
589-
# edgep1, edgep2, and centers are already in scaled space
590560
lxyz = 0.5 * (edgep1 + edgep2)
591561
lxyz = _move_from_center(lxyz, centers, labeldeltas, self._axmask())
592-
# lxyz is in scaled space, project directly
593562
tlx, tly, tlz = proj3d.proj_transform(*lxyz, self.axes.M)
594563
self.label.set_position((tlx, tly))
595564
if self.get_rotate_label(self.label.get_text()):
@@ -625,7 +594,7 @@ def draw(self, renderer):
625594

626595
for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
627596
minmax, maxmin, self._tick_position)):
628-
# Edge points are already in scaled space (from _get_coord_info)
597+
# Project the edge points along the current position
629598
pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
630599
pep = np.asarray(pep)
631600

@@ -653,7 +622,7 @@ def draw(self, renderer):
653622

654623
for edgep1, edgep2, pos in zip(*self._get_all_axis_line_edge_points(
655624
minmax, maxmin, self._label_position)):
656-
# Edge points are already in scaled space (from _get_coord_info)
625+
# See comments above
657626
pep = proj3d._proj_trans_points([edgep1, edgep2], self.axes.M)
658627
pep = np.asarray(pep)
659628
dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
@@ -678,17 +647,13 @@ def draw_grid(self, renderer):
678647
info = self._axinfo
679648
index = info["i"]
680649

681-
# For grid lines, we need data-space bounds since Line3DCollection
682-
# applies scale transforms in do_3d_projection
683-
data_mins, data_maxs = np.array([
684-
self.axes.get_xbound(),
685-
self.axes.get_ybound(),
686-
self.axes.get_zbound(),
687-
]).T
688-
689-
# Get highs from the scaled-space projection
650+
# Grid lines use data-space bounds (Line3DCollection applies transforms)
690651
mins, maxs, tc, highs = self._get_coord_info()
691-
652+
xlim, ylim, zlim = (self.axes.get_xbound(),
653+
self.axes.get_ybound(),
654+
self.axes.get_zbound())
655+
data_mins = np.array([xlim[0], ylim[0], zlim[0]])
656+
data_maxs = np.array([xlim[1], ylim[1], zlim[1]])
692657
minmax = np.where(highs, data_maxs, data_mins)
693658
maxmin = np.where(~highs, data_maxs, data_mins)
694659

0 commit comments

Comments
 (0)