diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 29667ff13922..c75ac103886b 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -545,9 +545,11 @@ def set_offsets(self, offsets): offsets = np.asanyarray(offsets) if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else. offsets = offsets[None, :] - self._offsets = np.column_stack( - (np.asarray(self.convert_xunits(offsets[:, 0]), float), - np.asarray(self.convert_yunits(offsets[:, 1]), float))) + cstack = (np.ma.column_stack if isinstance(offsets, np.ma.MaskedArray) + else np.column_stack) + self._offsets = cstack( + (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), + np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) self.stale = True def get_offsets(self): diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 782df21c5985..445249fae525 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1149,3 +1149,36 @@ def test_check_masked_offsets(): fig, ax = plt.subplots() ax.scatter(unmasked_x, masked_y) + + +@check_figures_equal(extensions=["png"]) +def test_masked_set_offsets(fig_ref, fig_test): + x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0]) + y = np.arange(1, 6) + + ax_test = fig_test.add_subplot() + scat = ax_test.scatter(x, y) + scat.set_offsets(np.ma.column_stack([x, y])) + ax_test.set_xticks([]) + ax_test.set_yticks([]) + + ax_ref = fig_ref.add_subplot() + ax_ref.scatter([1, 2, 5], [1, 2, 5]) + ax_ref.set_xticks([]) + ax_ref.set_yticks([]) + + +def test_check_offsets_dtype(): + # Check that setting offsets doesn't change dtype + x = np.ma.array([1, 2, 3, 4, 5], mask=[0, 0, 1, 1, 0]) + y = np.arange(1, 6) + + fig, ax = plt.subplots() + scat = ax.scatter(x, y) + masked_offsets = np.ma.column_stack([x, y]) + scat.set_offsets(masked_offsets) + assert isinstance(scat.get_offsets(), type(masked_offsets)) + + unmasked_offsets = np.column_stack([x, y]) + scat.set_offsets(unmasked_offsets) + assert isinstance(scat.get_offsets(), type(unmasked_offsets))