Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a40a891

Browse files
authored
Merge pull request #26906 from scottshambaugh/3d_masking
Fix masking for Axes3D.plot()
2 parents 5814cf3 + 0762613 commit a40a891

File tree

4 files changed

+62
-23
lines changed

4 files changed

+62
-23
lines changed

lib/matplotlib/cbook.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,44 @@ def _combine_masks(*args):
10881088
return margs
10891089

10901090

1091+
def _broadcast_with_masks(*args, compress=False):
1092+
"""
1093+
Broadcast inputs, combining all masked arrays.
1094+
1095+
Parameters
1096+
----------
1097+
*args : array-like
1098+
The inputs to broadcast.
1099+
compress : bool, default: False
1100+
Whether to compress the masked arrays. If False, the masked values
1101+
are replaced by NaNs.
1102+
1103+
Returns
1104+
-------
1105+
list of array-like
1106+
The broadcasted and masked inputs.
1107+
"""
1108+
# extract the masks, if any
1109+
masks = [k.mask for k in args if isinstance(k, np.ma.MaskedArray)]
1110+
# broadcast to match the shape
1111+
bcast = np.broadcast_arrays(*args, *masks)
1112+
inputs = bcast[:len(args)]
1113+
masks = bcast[len(args):]
1114+
if masks:
1115+
# combine the masks into one
1116+
mask = np.logical_or.reduce(masks)
1117+
# put mask on and compress
1118+
if compress:
1119+
inputs = [np.ma.array(k, mask=mask).compressed()
1120+
for k in inputs]
1121+
else:
1122+
inputs = [np.ma.array(k, mask=mask, dtype=float).filled(np.nan).ravel()
1123+
for k in inputs]
1124+
else:
1125+
inputs = [np.ravel(k) for k in inputs]
1126+
return inputs
1127+
1128+
10911129
def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
10921130
autorange=False):
10931131
r"""

lib/matplotlib/cbook.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class GrouperView(Generic[_T]):
130130

131131
def simple_linear_interpolation(a: ArrayLike, steps: int) -> np.ndarray: ...
132132
def delete_masked_points(*args): ...
133+
def _broadcast_with_masks(*args: ArrayLike, compress: bool = ...) -> list[ArrayLike]: ...
133134
def boxplot_stats(
134135
X: ArrayLike,
135136
whis: float | tuple[float, float] = ...,

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
"""
1212

1313
from collections import defaultdict
14-
import functools
1514
import itertools
1615
import math
1716
import textwrap
@@ -1909,8 +1908,7 @@ def plot(self, xs, ys, *args, zdir='z', **kwargs):
19091908
else:
19101909
zs = kwargs.pop('zs', 0)
19111910

1912-
# Match length
1913-
zs = np.broadcast_to(zs, np.shape(xs))
1911+
xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs)
19141912

19151913
lines = super().plot(xs, ys, *args, **kwargs)
19161914
for line in lines:
@@ -2665,8 +2663,7 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,
26652663
had_data = self.has_data()
26662664
zs_orig = zs
26672665

2668-
xs, ys, zs = np.broadcast_arrays(
2669-
*[np.ravel(np.ma.filled(t, np.nan)) for t in [xs, ys, zs]])
2666+
xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs)
26702667
s = np.ma.ravel(s) # This doesn't have to match x, y in size.
26712668

26722669
xs, ys, zs, s, c, color = cbook.delete_masked_points(
@@ -2722,7 +2719,7 @@ def bar(self, left, height, zs=0, zdir='z', *args, **kwargs):
27222719

27232720
patches = super().bar(left, height, *args, **kwargs)
27242721

2725-
zs = np.broadcast_to(zs, len(left))
2722+
zs = np.broadcast_to(zs, len(left), subok=True)
27262723

27272724
verts = []
27282725
verts_zs = []
@@ -2988,23 +2985,8 @@ def calc_arrows(UVW):
29882985

29892986
had_data = self.has_data()
29902987

2991-
input_args = [X, Y, Z, U, V, W]
2992-
2993-
# extract the masks, if any
2994-
masks = [k.mask for k in input_args
2995-
if isinstance(k, np.ma.MaskedArray)]
2996-
# broadcast to match the shape
2997-
bcast = np.broadcast_arrays(*input_args, *masks)
2998-
input_args = bcast[:6]
2999-
masks = bcast[6:]
3000-
if masks:
3001-
# combine the masks into one
3002-
mask = functools.reduce(np.logical_or, masks)
3003-
# put mask on and compress
3004-
input_args = [np.ma.array(k, mask=mask).compressed()
3005-
for k in input_args]
3006-
else:
3007-
input_args = [np.ravel(k) for k in input_args]
2988+
input_args = cbook._broadcast_with_masks(X, Y, Z, U, V, W,
2989+
compress=True)
30082990

30092991
if any(len(v) == 0 for v in input_args):
30102992
# No quivers, so just make an empty collection and return early

lib/mpl_toolkits/mplot3d/tests/test_axes3d.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,24 @@ def test_surface3d_masked():
668668
ax.view_init(30, -80, 0)
669669

670670

671+
@check_figures_equal(extensions=["png"])
672+
def test_plot_scatter_masks(fig_test, fig_ref):
673+
x = np.linspace(0, 10, 100)
674+
y = np.linspace(0, 10, 100)
675+
z = np.sin(x) * np.cos(y)
676+
mask = z > 0
677+
678+
z_masked = np.ma.array(z, mask=mask)
679+
ax_test = fig_test.add_subplot(projection='3d')
680+
ax_test.scatter(x, y, z_masked)
681+
ax_test.plot(x, y, z_masked)
682+
683+
x[mask] = y[mask] = z[mask] = np.nan
684+
ax_ref = fig_ref.add_subplot(projection='3d')
685+
ax_ref.scatter(x, y, z)
686+
ax_ref.plot(x, y, z)
687+
688+
671689
@check_figures_equal(extensions=["png"])
672690
def test_plot_surface_None_arg(fig_test, fig_ref):
673691
x, y = np.meshgrid(np.arange(5), np.arange(5))

0 commit comments

Comments
 (0)