Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix masking for Axes3D.plot() #26906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions lib/matplotlib/cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,44 @@ def _combine_masks(*args):
return margs


def _broadcast_with_masks(*args, compress=False):
"""
Broadcast inputs, combining all masked arrays.

Parameters
----------
*args : array-like
The inputs to broadcast.
compress : bool, default: False
Whether to compress the masked arrays. If False, the masked values
are replaced by NaNs.

Returns
-------
list of array-like
The broadcasted and masked inputs.
"""
# extract the masks, if any
masks = [k.mask for k in args if isinstance(k, np.ma.MaskedArray)]
# broadcast to match the shape
bcast = np.broadcast_arrays(*args, *masks)
inputs = bcast[:len(args)]
masks = bcast[len(args):]
if masks:
# combine the masks into one
mask = np.logical_or.reduce(masks)
# put mask on and compress
if compress:
inputs = [np.ma.array(k, mask=mask).compressed()
for k in inputs]
else:
inputs = [np.ma.array(k, mask=mask, dtype=float).filled(np.nan).ravel()
for k in inputs]
else:
inputs = [np.ravel(k) for k in inputs]
return inputs


def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
autorange=False):
r"""
Expand Down
1 change: 1 addition & 0 deletions lib/matplotlib/cbook.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class GrouperView(Generic[_T]):

def simple_linear_interpolation(a: ArrayLike, steps: int) -> np.ndarray: ...
def delete_masked_points(*args): ...
def _broadcast_with_masks(*args: ArrayLike, compress: bool = ...) -> list[ArrayLike]: ...
def boxplot_stats(
X: ArrayLike,
whis: float | tuple[float, float] = ...,
Expand Down
28 changes: 5 additions & 23 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""

from collections import defaultdict
import functools
import itertools
import math
import textwrap
Expand Down Expand Up @@ -1909,8 +1908,7 @@ def plot(self, xs, ys, *args, zdir='z', **kwargs):
else:
zs = kwargs.pop('zs', 0)

# Match length
zs = np.broadcast_to(zs, np.shape(xs))
xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs)

lines = super().plot(xs, ys, *args, **kwargs)
for line in lines:
Expand Down Expand Up @@ -2665,8 +2663,7 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,
had_data = self.has_data()
zs_orig = zs

xs, ys, zs = np.broadcast_arrays(
*[np.ravel(np.ma.filled(t, np.nan)) for t in [xs, ys, zs]])
xs, ys, zs = cbook._broadcast_with_masks(xs, ys, zs)
s = np.ma.ravel(s) # This doesn't have to match x, y in size.

xs, ys, zs, s, c, color = cbook.delete_masked_points(
Expand Down Expand Up @@ -2722,7 +2719,7 @@ def bar(self, left, height, zs=0, zdir='z', *args, **kwargs):

patches = super().bar(left, height, *args, **kwargs)

zs = np.broadcast_to(zs, len(left))
zs = np.broadcast_to(zs, len(left), subok=True)

verts = []
verts_zs = []
Expand Down Expand Up @@ -2988,23 +2985,8 @@ def calc_arrows(UVW):

had_data = self.has_data()

input_args = [X, Y, Z, U, V, W]

# extract the masks, if any
masks = [k.mask for k in input_args
if isinstance(k, np.ma.MaskedArray)]
# broadcast to match the shape
bcast = np.broadcast_arrays(*input_args, *masks)
input_args = bcast[:6]
masks = bcast[6:]
if masks:
# combine the masks into one
mask = functools.reduce(np.logical_or, masks)
# put mask on and compress
input_args = [np.ma.array(k, mask=mask).compressed()
for k in input_args]
else:
input_args = [np.ravel(k) for k in input_args]
input_args = cbook._broadcast_with_masks(X, Y, Z, U, V, W,
compress=True)

if any(len(v) == 0 for v in input_args):
# No quivers, so just make an empty collection and return early
Expand Down
18 changes: 18 additions & 0 deletions lib/mpl_toolkits/mplot3d/tests/test_axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,24 @@ def test_surface3d_masked():
ax.view_init(30, -80, 0)


@check_figures_equal(extensions=["png"])
def test_plot_scatter_masks(fig_test, fig_ref):
x = np.linspace(0, 10, 100)
y = np.linspace(0, 10, 100)
z = np.sin(x) * np.cos(y)
mask = z > 0

z_masked = np.ma.array(z, mask=mask)
ax_test = fig_test.add_subplot(projection='3d')
ax_test.scatter(x, y, z_masked)
ax_test.plot(x, y, z_masked)

x[mask] = y[mask] = z[mask] = np.nan
ax_ref = fig_ref.add_subplot(projection='3d')
ax_ref.scatter(x, y, z)
ax_ref.plot(x, y, z)


@check_figures_equal(extensions=["png"])
def test_plot_surface_None_arg(fig_test, fig_ref):
x, y = np.meshgrid(np.arange(5), np.arange(5))
Expand Down