diff --git a/doc/release/1.9.0-notes.rst b/doc/release/1.9.0-notes.rst index 5849c21299ab..691aae23910c 100644 --- a/doc/release/1.9.0-notes.rst +++ b/doc/release/1.9.0-notes.rst @@ -194,6 +194,11 @@ Several more functions now release the Global Interpreter Lock allowing more efficient parallization using the ``threading`` module. Most notably the GIL is now released for fancy indexing and ``np.where``. +MaskedArray support for more complicated base classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Built-in assumptions that the baseclass behaved like a plain array are being +removed. In particalur, ``repr`` and ``str`` should now work more reliably. + Changes ======= diff --git a/numpy/ma/core.py b/numpy/ma/core.py index faf9896a5201..d644f5b385a0 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3585,9 +3585,8 @@ def __str__(self): if m.dtype.names: m = m.view((bool, len(m.dtype))) if m.any(): - r = np.array(self._data.tolist(), dtype=object) - np.copyto(r, f, where=m) - return str(tuple(r)) + return str(tuple((f if _m else _d) for _d, _m in + zip(self._data.tolist(), m))) else: return str(self._data) elif m: @@ -3598,7 +3597,7 @@ def __str__(self): names = self.dtype.names if names is None: res = self._data.astype("O") - res[m] = f + res.view(ndarray)[m] = f else: rdtype = _recursive_make_descr(self.dtype, "O") res = self._data.astype(rdtype) @@ -3612,19 +3611,22 @@ def __repr__(self): """ n = len(self.shape) - name = repr(self._data).split('(')[0] + if self._baseclass is np.ndarray: + name = 'array' + else: + name = self._baseclass.__name__ + parameters = dict(name=name, nlen=" " * len(name), - data=str(self), mask=str(self._mask), - fill=str(self.fill_value), dtype=str(self.dtype)) + data=str(self), mask=str(self._mask), + fill=str(self.fill_value), dtype=str(self.dtype)) if self.dtype.names: if n <= 1: return _print_templates['short_flx'] % parameters - return _print_templates['long_flx'] % parameters + return _print_templates['long_flx'] % parameters elif n <= 1: return _print_templates['short_std'] % parameters return _print_templates['long_std'] % parameters - def __eq__(self, other): "Check whether other equals self elementwise" if self is masked: diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index 6bb8d5b2ee1f..7eade7f20631 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -369,6 +369,13 @@ def test_deepcopy(self): assert_equal(copied.mask, [0, 0, 0]) assert_equal(a.mask, [0, 1, 0]) + def test_str_repr(self): + a = array([0, 1, 2], mask=[False, True, False]) + assert_equal(str(a), '[0 -- 2]') + assert_equal(repr(a), 'masked_array(data = [0 -- 2],\n' + ' mask = [False True False],\n' + ' fill_value = 999999)\n') + def test_pickling(self): # Tests pickling a = arange(10) @@ -3624,7 +3631,7 @@ def test_append_masked_array(): def test_append_masked_array_along_axis(): a = np.ma.masked_equal([1,2,3], value=2) b = np.ma.masked_values([[4, 5, 6], [7, 8, 9]], 7) - + # When `axis` is specified, `values` must have the correct shape. assert_raises(ValueError, np.ma.append, a, b, axis=0) @@ -3634,7 +3641,7 @@ def test_append_masked_array_along_axis(): expected = expected.reshape((3,3)) assert_array_equal(result.data, expected.data) assert_array_equal(result.mask, expected.mask) - + ############################################################################### if __name__ == "__main__": diff --git a/numpy/ma/tests/test_subclassing.py b/numpy/ma/tests/test_subclassing.py index c2c9b8ec9ca2..ade5c59daebf 100644 --- a/numpy/ma/tests/test_subclassing.py +++ b/numpy/ma/tests/test_subclassing.py @@ -82,6 +82,24 @@ def _get_series(self): mmatrix = MMatrix +# also a subclass that overrides __str__, __repr__ and __setitem__, disallowing +# setting to non-class values (and thus np.ma.core.masked_print_option) +class ComplicatedSubArray(SubArray): + def __str__(self): + return 'myprefix {0} mypostfix'.format( + super(ComplicatedSubArray, self).__str__()) + + def __repr__(self): + # Return a repr that does not start with 'name(' + return '<{0} {1}>'.format(self.__class__.__name__, self) + + def __setitem__(self, item, value): + # this ensures direct assignment to masked_print_option will fail + if not isinstance(value, ComplicatedSubArray): + raise ValueError("Can only set to MySubArray values") + super(ComplicatedSubArray, self).__setitem__(item, value) + + class TestSubclassing(TestCase): # Test suite for masked subclasses of ndarray. @@ -187,6 +205,31 @@ def test_subclasspreservation(self): assert_equal(mxsub.info, xsub.info) assert_equal(mxsub._mask, m) + def test_subclass_repr(self): + """test that repr uses the name of the subclass + and 'array' for np.ndarray""" + x = np.arange(5) + mx = masked_array(x, mask=[True, False, True, False, False]) + self.assertTrue(repr(mx).startswith('masked_array')) + xsub = SubArray(x) + mxsub = masked_array(xsub, mask=[True, False, True, False, False]) + self.assertTrue(repr(mxsub).startswith( + 'masked_{0}(data = [-- 1 -- 3 4]'.format(SubArray.__name__))) + + def test_subclass_str(self): + """test str with subclass that has overridden str, setitem""" + # first without override + x = np.arange(5) + xsub = SubArray(x) + mxsub = masked_array(xsub, mask=[True, False, True, False, False]) + self.assertTrue(str(mxsub) == '[-- 1 -- 3 4]') + + xcsub = ComplicatedSubArray(x) + assert_raises(ValueError, xcsub.__setitem__, 0, + np.ma.core.masked_print_option) + mxcsub = masked_array(xcsub, mask=[True, False, True, False, False]) + self.assertTrue(str(mxcsub) == 'myprefix [-- 1 -- 3 4] mypostfix') + ############################################################################### if __name__ == '__main__':