From 338d047c3d17b815b2a5ea22cc96bfb12a4c013d Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Sat, 12 Apr 2014 12:25:25 -0400 Subject: [PATCH] ENH: allow subclass overrides by removing explicit ndarray methods --- doc/release/1.10.0-notes.rst | 5 +-- numpy/ma/core.py | 66 ++++++++++++++---------------------- 2 files changed, 29 insertions(+), 42 deletions(-) diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst index cb78b4e71703..ce86224bbe8c 100644 --- a/doc/release/1.10.0-notes.rst +++ b/doc/release/1.10.0-notes.rst @@ -248,8 +248,9 @@ object arrays that were generated on Python 2. MaskedArray support for more complicated base classes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Built-in assumptions that the baseclass behaved like a plain array are being -removed. In particalur, setting and getting elements and ranges will respect -baseclass overrides of ``__setitem__`` and ``__getitem__``. +removed. In particular, setting and getting elements and ranges will respect +baseclass overrides of ``__setitem__`` and ``__getitem__``, and arithmetic +will respect overrides of ``__add__``, ``__sub__``, etc. Changes ======= diff --git a/numpy/ma/core.py b/numpy/ma/core.py index 5df928a6da5b..6a46f4757523 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -3179,8 +3179,8 @@ def __setmask__(self, mask, copy=False): """Set the mask. """ - idtype = ndarray.__getattribute__(self, 'dtype') - current_mask = ndarray.__getattribute__(self, '_mask') + idtype = self.dtype + current_mask = self._mask if mask is masked: mask = True # Make sure the mask is set @@ -3258,7 +3258,7 @@ def _get_recordmask(self): A record is masked when all the fields are masked. """ - _mask = ndarray.__getattribute__(self, '_mask').view(ndarray) + _mask = self._mask.view(ndarray) if _mask.dtype.names is None: return _mask return np.all(flatten_structured_array(_mask), axis= -1) @@ -3685,7 +3685,7 @@ def __eq__(self, other): return masked omask = getattr(other, '_mask', nomask) if omask is nomask: - check = ndarray.__eq__(self.filled(0), other) + check = self.filled(0).__eq__(other) try: check = check.view(type(self)) check._mask = self._mask @@ -3694,7 +3694,7 @@ def __eq__(self, other): return check else: odata = filled(other, 0) - check = ndarray.__eq__(self.filled(0), odata).view(type(self)) + check = self.filled(0).__eq__(odata).view(type(self)) if self._mask is nomask: check._mask = omask else: @@ -3718,7 +3718,7 @@ def __ne__(self, other): return masked omask = getattr(other, '_mask', nomask) if omask is nomask: - check = ndarray.__ne__(self.filled(0), other) + check = self.filled(0).__ne__(other) try: check = check.view(type(self)) check._mask = self._mask @@ -3727,7 +3727,7 @@ def __ne__(self, other): return check else: odata = filled(other, 0) - check = ndarray.__ne__(self.filled(0), odata).view(type(self)) + check = self.filled(0).__ne__(odata).view(type(self)) if self._mask is nomask: check._mask = omask else: @@ -3807,10 +3807,8 @@ def __iadd__(self, other): else: if m is not nomask: self._mask += m - ndarray.__iadd__( - self._data, - np.where(self._mask, self.dtype.type(0), getdata(other)) - ) + self._data.__iadd__(np.where(self._mask, self.dtype.type(0), + getdata(other))) return self #.... def __isub__(self, other): @@ -3822,10 +3820,8 @@ def __isub__(self, other): self._mask += m elif m is not nomask: self._mask += m - ndarray.__isub__( - self._data, - np.where(self._mask, self.dtype.type(0), getdata(other)) - ) + self._data.__isub__(np.where(self._mask, self.dtype.type(0), + getdata(other))) return self #.... def __imul__(self, other): @@ -3837,10 +3833,8 @@ def __imul__(self, other): self._mask += m elif m is not nomask: self._mask += m - ndarray.__imul__( - self._data, - np.where(self._mask, self.dtype.type(1), getdata(other)) - ) + self._data.__imul__(np.where(self._mask, self.dtype.type(1), + getdata(other))) return self #.... def __idiv__(self, other): @@ -3855,10 +3849,8 @@ def __idiv__(self, other): other_data = np.where(dom_mask, fval, other_data) # self._mask = mask_or(self._mask, new_mask) self._mask |= new_mask - ndarray.__idiv__( - self._data, - np.where(self._mask, self.dtype.type(1), other_data) - ) + self._data.__idiv__(np.where(self._mask, self.dtype.type(1), + other_data)) return self #.... def __ifloordiv__(self, other): @@ -3873,10 +3865,8 @@ def __ifloordiv__(self, other): other_data = np.where(dom_mask, fval, other_data) # self._mask = mask_or(self._mask, new_mask) self._mask |= new_mask - ndarray.__ifloordiv__( - self._data, - np.where(self._mask, self.dtype.type(1), other_data) - ) + self._data.__ifloordiv__(np.where(self._mask, self.dtype.type(1), + other_data)) return self #.... def __itruediv__(self, other): @@ -3891,10 +3881,8 @@ def __itruediv__(self, other): other_data = np.where(dom_mask, fval, other_data) # self._mask = mask_or(self._mask, new_mask) self._mask |= new_mask - ndarray.__itruediv__( - self._data, - np.where(self._mask, self.dtype.type(1), other_data) - ) + self._data.__itruediv__(np.where(self._mask, self.dtype.type(1), + other_data)) return self #... def __ipow__(self, other): @@ -3902,10 +3890,8 @@ def __ipow__(self, other): other_data = getdata(other) other_mask = getmask(other) with np.errstate(divide='ignore', invalid='ignore'): - ndarray.__ipow__( - self._data, - np.where(self._mask, self.dtype.type(1), other_data) - ) + self._data.__ipow__(np.where(self._mask, self.dtype.type(1), + other_data)) invalid = np.logical_not(np.isfinite(self._data)) if invalid.any(): if self._mask is not nomask: @@ -4590,7 +4576,7 @@ def sum(self, axis=None, dtype=None, out=None): """ - _mask = ndarray.__getattribute__(self, '_mask') + _mask = self._mask newmask = _check_mask_axis(_mask, axis) # No explicit output if out is None: @@ -4718,7 +4704,7 @@ def prod(self, axis=None, dtype=None, out=None): array([ 2., 12.]) """ - _mask = ndarray.__getattribute__(self, '_mask') + _mask = self._mask newmask = _check_mask_axis(_mask, axis) # No explicit output if out is None: @@ -5234,7 +5220,7 @@ def min(self, axis=None, out=None, fill_value=None): Returns the minimum filling value for a given datatype. """ - _mask = ndarray.__getattribute__(self, '_mask') + _mask = self._mask newmask = _check_mask_axis(_mask, axis) if fill_value is None: fill_value = minimum_fill_value(self) @@ -5333,7 +5319,7 @@ def max(self, axis=None, out=None, fill_value=None): Returns the maximum filling value for a given datatype. """ - _mask = ndarray.__getattribute__(self, '_mask') + _mask = self._mask newmask = _check_mask_axis(_mask, axis) if fill_value is None: fill_value = maximum_fill_value(self) @@ -5658,7 +5644,7 @@ def __setstate__(self, state): """ (_, shp, typ, isf, raw, msk, flv) = state - ndarray.__setstate__(self, (shp, typ, isf, raw)) + super(MaskedArray, self).__setstate__((shp, typ, isf, raw)) self._mask.__setstate__((shp, make_mask_descr(typ), isf, msk)) self.fill_value = flv #