Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d70af67

Browse files
committed
Merge pull request numpy#6653 from charris/fix-ma-dot
Fix ma dot
2 parents 694f628 + 3e82108 commit d70af67

File tree

4 files changed

+262
-184
lines changed

4 files changed

+262
-184
lines changed

numpy/ma/core.py

Lines changed: 223 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@
3232
from numpy import ndarray, amax, amin, iscomplexobj, bool_, _NoValue
3333
from numpy import array as narray
3434
from numpy.lib.function_base import angle
35-
from numpy.compat import getargspec, formatargspec, long, basestring, unicode, bytes, sixu
35+
from numpy.compat import (
36+
getargspec, formatargspec, long, basestring, unicode, bytes, sixu
37+
)
3638
from numpy import expand_dims as n_expand_dims
3739

40+
3841
if sys.version_info[0] >= 3:
3942
import pickle
4043
else:
@@ -4651,24 +4654,44 @@ def trace(self, offset=0, axis1=0, axis2=1, dtype=None, out=None):
46514654
return D.astype(dtype).filled(0).sum(axis=None, out=out)
46524655
trace.__doc__ = ndarray.trace.__doc__
46534656

4654-
def dot(self, other, out=None):
4655-
am = ~getmaskarray(self)
4656-
bm = ~getmaskarray(other)
4657-
if out is None:
4658-
d = np.dot(filled(self, 0), filled(other, 0))
4659-
m = ~np.dot(am, bm)
4660-
if d.ndim == 0:
4661-
d = np.asarray(d)
4662-
r = d.view(get_masked_subclass(self, other))
4663-
r.__setmask__(m)
4664-
return r
4665-
d = self.filled(0).dot(other.filled(0), out._data)
4666-
if out.mask.shape != d.shape:
4667-
out._mask = np.empty(d.shape, MaskType)
4668-
np.dot(am, bm, out._mask)
4669-
np.logical_not(out._mask, out._mask)
4670-
return out
4671-
dot.__doc__ = ndarray.dot.__doc__
4657+
def dot(self, b, out=None, strict=False):
4658+
"""
4659+
a.dot(b, out=None)
4660+
4661+
Masked dot product of two arrays. Note that `out` and `strict` are
4662+
located in different positions than in `ma.dot`. In order to
4663+
maintain compatibility with the functional version, it is
4664+
recommended that the optional arguments be treated as keyword only.
4665+
At some point that may be mandatory.
4666+
4667+
.. versionadded:: 1.10.0
4668+
4669+
Parameters
4670+
----------
4671+
b : masked_array_like
4672+
Inputs array.
4673+
out : masked_array, optional
4674+
Output argument. This must have the exact kind that would be
4675+
returned if it was not used. In particular, it must have the
4676+
right type, must be C-contiguous, and its dtype must be the
4677+
dtype that would be returned for `ma.dot(a,b)`. This is a
4678+
performance feature. Therefore, if these conditions are not
4679+
met, an exception is raised, instead of attempting to be
4680+
flexible.
4681+
strict : bool, optional
4682+
Whether masked data are propagated (True) or set to 0 (False)
4683+
for the computation. Default is False. Propagating the mask
4684+
means that if a masked value appears in a row or column, the
4685+
whole row or column is considered masked.
4686+
4687+
.. versionadded:: 1.10.2
4688+
4689+
See Also
4690+
--------
4691+
numpy.ma.dot : equivalent function
4692+
4693+
"""
4694+
return dot(self, b, out=out, strict=strict)
46724695

46734696
def sum(self, axis=None, dtype=None, out=None):
46744697
"""
@@ -5884,7 +5907,7 @@ def filled(self, fill_value=None):
58845907
--------
58855908
MaskedArray.filled
58865909
5887-
"""
5910+
"""
58885911
return asarray(self).filled(fill_value)[()]
58895912

58905913
def tolist(self):
@@ -7021,6 +7044,186 @@ def round_(a, decimals=0, out=None):
70217044
round = round_
70227045

70237046

7047+
# Needed by dot, so move here from extras.py. It will still be exported
7048+
# from extras.py for compatibility.
7049+
def mask_rowcols(a, axis=None):
7050+
"""
7051+
Mask rows and/or columns of a 2D array that contain masked values.
7052+
7053+
Mask whole rows and/or columns of a 2D array that contain
7054+
masked values. The masking behavior is selected using the
7055+
`axis` parameter.
7056+
7057+
- If `axis` is None, rows *and* columns are masked.
7058+
- If `axis` is 0, only rows are masked.
7059+
- If `axis` is 1 or -1, only columns are masked.
7060+
7061+
Parameters
7062+
----------
7063+
a : array_like, MaskedArray
7064+
The array to mask. If not a MaskedArray instance (or if no array
7065+
elements are masked). The result is a MaskedArray with `mask` set
7066+
to `nomask` (False). Must be a 2D array.
7067+
axis : int, optional
7068+
Axis along which to perform the operation. If None, applies to a
7069+
flattened version of the array.
7070+
7071+
Returns
7072+
-------
7073+
a : MaskedArray
7074+
A modified version of the input array, masked depending on the value
7075+
of the `axis` parameter.
7076+
7077+
Raises
7078+
------
7079+
NotImplementedError
7080+
If input array `a` is not 2D.
7081+
7082+
See Also
7083+
--------
7084+
mask_rows : Mask rows of a 2D array that contain masked values.
7085+
mask_cols : Mask cols of a 2D array that contain masked values.
7086+
masked_where : Mask where a condition is met.
7087+
7088+
Notes
7089+
-----
7090+
The input array's mask is modified by this function.
7091+
7092+
Examples
7093+
--------
7094+
>>> import numpy.ma as ma
7095+
>>> a = np.zeros((3, 3), dtype=np.int)
7096+
>>> a[1, 1] = 1
7097+
>>> a
7098+
array([[0, 0, 0],
7099+
[0, 1, 0],
7100+
[0, 0, 0]])
7101+
>>> a = ma.masked_equal(a, 1)
7102+
>>> a
7103+
masked_array(data =
7104+
[[0 0 0]
7105+
[0 -- 0]
7106+
[0 0 0]],
7107+
mask =
7108+
[[False False False]
7109+
[False True False]
7110+
[False False False]],
7111+
fill_value=999999)
7112+
>>> ma.mask_rowcols(a)
7113+
masked_array(data =
7114+
[[0 -- 0]
7115+
[-- -- --]
7116+
[0 -- 0]],
7117+
mask =
7118+
[[False True False]
7119+
[ True True True]
7120+
[False True False]],
7121+
fill_value=999999)
7122+
7123+
"""
7124+
a = array(a, subok=False)
7125+
if a.ndim != 2:
7126+
raise NotImplementedError("mask_rowcols works for 2D arrays only.")
7127+
m = getmask(a)
7128+
# Nothing is masked: return a
7129+
if m is nomask or not m.any():
7130+
return a
7131+
maskedval = m.nonzero()
7132+
a._mask = a._mask.copy()
7133+
if not axis:
7134+
a[np.unique(maskedval[0])] = masked
7135+
if axis in [None, 1, -1]:
7136+
a[:, np.unique(maskedval[1])] = masked
7137+
return a
7138+
7139+
7140+
# Include masked dot here to avoid import problems in getting it from
7141+
# extras.py. Note that it is not included in __all__, but rather exported
7142+
# from extras in order to avoid backward compatibility problems.
7143+
def dot(a, b, strict=False, out=None):
7144+
"""
7145+
Return the dot product of two arrays.
7146+
7147+
This function is the equivalent of `numpy.dot` that takes masked values
7148+
into account. Note that `strict` and `out` are in different position
7149+
than in the method version. In order to maintain compatibility with the
7150+
corresponding method, it is recommended that the optional arguments be
7151+
treated as keyword only. At some point that may be mandatory.
7152+
7153+
.. note::
7154+
Works only with 2-D arrays at the moment.
7155+
7156+
7157+
Parameters
7158+
----------
7159+
a, b : masked_array_like
7160+
Inputs arrays.
7161+
strict : bool, optional
7162+
Whether masked data are propagated (True) or set to 0 (False) for
7163+
the computation. Default is False. Propagating the mask means that
7164+
if a masked value appears in a row or column, the whole row or
7165+
column is considered masked.
7166+
out : masked_array, optional
7167+
Output argument. This must have the exact kind that would be returned
7168+
if it was not used. In particular, it must have the right type, must be
7169+
C-contiguous, and its dtype must be the dtype that would be returned
7170+
for `dot(a,b)`. This is a performance feature. Therefore, if these
7171+
conditions are not met, an exception is raised, instead of attempting
7172+
to be flexible.
7173+
7174+
.. versionadded:: 1.10.2
7175+
7176+
See Also
7177+
--------
7178+
numpy.dot : Equivalent function for ndarrays.
7179+
7180+
Examples
7181+
--------
7182+
>>> a = ma.array([[1, 2, 3], [4, 5, 6]], mask=[[1, 0, 0], [0, 0, 0]])
7183+
>>> b = ma.array([[1, 2], [3, 4], [5, 6]], mask=[[1, 0], [0, 0], [0, 0]])
7184+
>>> np.ma.dot(a, b)
7185+
masked_array(data =
7186+
[[21 26]
7187+
[45 64]],
7188+
mask =
7189+
[[False False]
7190+
[False False]],
7191+
fill_value = 999999)
7192+
>>> np.ma.dot(a, b, strict=True)
7193+
masked_array(data =
7194+
[[-- --]
7195+
[-- 64]],
7196+
mask =
7197+
[[ True True]
7198+
[ True False]],
7199+
fill_value = 999999)
7200+
7201+
"""
7202+
# !!!: Works only with 2D arrays. There should be a way to get it to run
7203+
# with higher dimension
7204+
if strict and (a.ndim == 2) and (b.ndim == 2):
7205+
a = mask_rowcols(a, 0)
7206+
b = mask_rowcols(b, 1)
7207+
am = ~getmaskarray(a)
7208+
bm = ~getmaskarray(b)
7209+
7210+
if out is None:
7211+
d = np.dot(filled(a, 0), filled(b, 0))
7212+
m = ~np.dot(am, bm)
7213+
if d.ndim == 0:
7214+
d = np.asarray(d)
7215+
r = d.view(get_masked_subclass(a, b))
7216+
r.__setmask__(m)
7217+
return r
7218+
else:
7219+
d = np.dot(filled(a, 0), filled(b, 0), out._data)
7220+
if out.mask.shape != d.shape:
7221+
out._mask = np.empty(d.shape, MaskType)
7222+
np.dot(am, bm, out._mask)
7223+
np.logical_not(out._mask, out._mask)
7224+
return out
7225+
7226+
70247227
def inner(a, b):
70257228
"""
70267229
Returns the inner product of a and b for arrays of floating point types.

0 commit comments

Comments
 (0)