Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7141f40

Browse files
committed
Merge pull request #7001 from shoyer/NaT-comparison
API: make all comparisons with NaT false
2 parents 8fa6e3b + 53ad26a commit 7141f40

File tree

8 files changed

+135
-24
lines changed

8 files changed

+135
-24
lines changed

numpy/core/arrayprint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,8 @@ def __call__(self, x):
739739
class TimedeltaFormat(object):
740740
def __init__(self, data):
741741
if data.dtype.kind == 'm':
742-
nat_value = array(['NaT'], dtype=data.dtype)[0]
743-
v = data[not_equal(data, nat_value)].view('i8')
742+
# select non-NaT elements
743+
v = data[data == data].view('i8')
744744
if len(v) > 0:
745745
# Max str length of non-NaT elements
746746
max_str_len = max(len(str(maximum.reduce(v))),
@@ -754,7 +754,7 @@ def __init__(self, data):
754754
self._nat = "'NaT'".rjust(max_str_len)
755755

756756
def __call__(self, x):
757-
if x + 1 == x:
757+
if x != x:
758758
return self._nat
759759
else:
760760
return self.format % x.astype('i8')

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,7 @@ voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
16731673
* However, as a special case, void-scalar assignment broadcasts
16741674
* differently from ndarrays when assigning to an object field: Assignment
16751675
* to an ndarray object field broadcasts, but assignment to a void-scalar
1676-
* object-field should not, in order to allow nested ndarrays.
1676+
* object-field should not, in order to allow nested ndarrays.
16771677
* These lines should then behave identically:
16781678
*
16791679
* b = np.zeros(1, dtype=[('x', 'O')])

numpy/core/src/umath/loops.c.src

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,20 +1117,40 @@ NPY_NO_EXPORT void
11171117
}
11181118

11191119
/**begin repeat1
1120-
* #kind = equal, not_equal, greater, greater_equal, less, less_equal#
1121-
* #OP = ==, !=, >, >=, <, <=#
1120+
* #kind = equal, greater, greater_equal, less, less_equal#
1121+
* #OP = ==, >, >=, <, <=#
11221122
*/
11231123
NPY_NO_EXPORT void
11241124
@TYPE@_@kind@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
11251125
{
11261126
BINARY_LOOP {
11271127
const @type@ in1 = *(@type@ *)ip1;
11281128
const @type@ in2 = *(@type@ *)ip2;
1129-
*((npy_bool *)op1) = in1 @OP@ in2;
1129+
if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
1130+
*((npy_bool *)op1) = NPY_FALSE;
1131+
}
1132+
else {
1133+
*((npy_bool *)op1) = in1 @OP@ in2;
1134+
}
11301135
}
11311136
}
11321137
/**end repeat1**/
11331138

1139+
NPY_NO_EXPORT void
1140+
@TYPE@_not_equal(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
1141+
{
1142+
BINARY_LOOP {
1143+
const @type@ in1 = *(@type@ *)ip1;
1144+
const @type@ in2 = *(@type@ *)ip2;
1145+
if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
1146+
*((npy_bool *)op1) = NPY_TRUE;
1147+
}
1148+
else {
1149+
*((npy_bool *)op1) = in1 != in2;
1150+
}
1151+
}
1152+
}
1153+
11341154
/**begin repeat1
11351155
* #kind = maximum, minimum#
11361156
* #OP = >, <#

numpy/core/tests/test_datetime.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,11 @@ def test_compare_generic_nat(self):
130130
# regression tests for GH6452
131131
assert_equal(np.datetime64('NaT'),
132132
np.datetime64('2000') + np.timedelta64('NaT'))
133-
# nb. we may want to make NaT != NaT true in the future; this test
134-
# verifies the existing behavior (and that it should not warn)
135-
assert_(np.datetime64('NaT') == np.datetime64('NaT', 'us'))
136-
assert_(np.datetime64('NaT', 'us') == np.datetime64('NaT'))
133+
assert_equal(np.datetime64('NaT'), np.datetime64('NaT', 'us'))
134+
assert_equal(np.timedelta64('NaT'), np.timedelta64('NaT', 'us'))
135+
# neither of these should issue a warning
136+
assert_(np.datetime64('NaT') != np.datetime64('NaT', 'us'))
137+
assert_(np.datetime64('NaT', 'us') != np.datetime64('NaT'))
137138

138139
def test_datetime_scalar_construction(self):
139140
# Construct with different units
@@ -552,6 +553,9 @@ def test_datetime_array_str(self):
552553
"'%s'" % np.datetime_as_string(x, timezone='UTC')}),
553554
"['2011-03-16T13:55Z', '1920-01-01T03:12Z']")
554555

556+
a = np.array(['NaT', 'NaT'], dtype='datetime64[ns]')
557+
assert_equal(str(a), "['NaT' 'NaT']")
558+
555559
# Check that one NaT doesn't corrupt subsequent entries
556560
a = np.array(['2010', 'NaT', '2030']).astype('M')
557561
assert_equal(str(a), "['2010' 'NaT' '2030']")
@@ -658,7 +662,7 @@ def test_pyobject_roundtrip(self):
658662
b[8] = 'NaT'
659663

660664
assert_equal(b.astype(object).astype(unit), b,
661-
"Error roundtripping unit %s" % unit)
665+
"Error roundtripping unit %s" % unit)
662666
# With time units
663667
for unit in ['M8[as]', 'M8[16fs]', 'M8[ps]', 'M8[us]',
664668
'M8[300as]', 'M8[20us]']:
@@ -674,7 +678,7 @@ def test_pyobject_roundtrip(self):
674678
b[8] = 'NaT'
675679

676680
assert_equal(b.astype(object).astype(unit), b,
677-
"Error roundtripping unit %s" % unit)
681+
"Error roundtripping unit %s" % unit)
678682

679683
def test_month_truncation(self):
680684
# Make sure that months are truncating correctly
@@ -1081,6 +1085,26 @@ def test_datetime_compare(self):
10811085
assert_equal(np.greater(a, b), [0, 1, 0, 1, 0])
10821086
assert_equal(np.greater_equal(a, b), [1, 1, 0, 1, 0])
10831087

1088+
def test_datetime_compare_nat(self):
1089+
dt_nat = np.datetime64('NaT', 'D')
1090+
dt_other = np.datetime64('2000-01-01')
1091+
td_nat = np.timedelta64('NaT', 'h')
1092+
td_other = np.timedelta64(1, 'h')
1093+
for op in [np.equal, np.less, np.less_equal,
1094+
np.greater, np.greater_equal]:
1095+
assert_(not op(dt_nat, dt_nat))
1096+
assert_(not op(dt_nat, dt_other))
1097+
assert_(not op(dt_other, dt_nat))
1098+
assert_(not op(td_nat, td_nat))
1099+
assert_(not op(td_nat, td_other))
1100+
assert_(not op(td_other, td_nat))
1101+
assert_(np.not_equal(dt_nat, dt_nat))
1102+
assert_(np.not_equal(dt_nat, dt_other))
1103+
assert_(np.not_equal(dt_other, dt_nat))
1104+
assert_(np.not_equal(td_nat, td_nat))
1105+
assert_(np.not_equal(td_nat, td_other))
1106+
assert_(np.not_equal(td_other, td_nat))
1107+
10841108
def test_datetime_minmax(self):
10851109
# The metadata of the result should become the GCD
10861110
# of the operand metadata

numpy/ma/tests/test_extras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_testAverage1(self):
154154
ott = ott.reshape(2, 2)
155155
ott[:, 1] = masked
156156
assert_equal(average(ott, axis=0), [2.0, 0.0])
157-
assert_equal(average(ott, axis=1).mask[0], [True])
157+
assert_equal(average(ott, axis=1).mask[0], True)
158158
assert_equal([2., 0.], average(ott, axis=0))
159159
result, wts = average(ott, axis=0, returned=1)
160160
assert_equal(wts, [1., 0.])

numpy/ma/testutils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@ def assert_equal(actual, desired, err_msg=''):
125125
if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)):
126126
return _assert_equal_on_sequences(actual, desired, err_msg='')
127127
if not (isinstance(actual, ndarray) or isinstance(desired, ndarray)):
128-
msg = build_err_msg([actual, desired], err_msg,)
129-
if not desired == actual:
130-
raise AssertionError(msg)
131-
return
128+
return utils.assert_equal(actual, desired)
132129
# Case #4. arrays or equivalent
133130
if ((actual is masked) and not (desired is masked)) or \
134131
((desired is masked) and not (actual is masked)):

numpy/testing/tests/test_utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal,
1212
assert_array_almost_equal_nulp, assert_array_max_ulp,
1313
clear_and_catch_warnings, run_module_suite,
14-
assert_string_equal, assert_, tempdir, temppath,
14+
assert_string_equal, assert_, tempdir, temppath,
1515
)
1616
import unittest
1717

@@ -119,6 +119,25 @@ def test_nan_array(self):
119119
c = np.array([1, 2, 3])
120120
self._test_not_equal(c, b)
121121

122+
def test_nat_array_datetime(self):
123+
a = np.array([np.datetime64('2000-01'), np.datetime64('NaT')])
124+
b = np.array([np.datetime64('2000-01'), np.datetime64('NaT')])
125+
self._test_equal(a, b)
126+
127+
c = np.array([np.datetime64('NaT'), np.datetime64('NaT')])
128+
self._test_not_equal(c, b)
129+
130+
def test_nat_array_timedelta(self):
131+
a = np.array([np.timedelta64(1, 'h'), np.timedelta64('NaT')])
132+
b = np.array([np.timedelta64(1, 'h'), np.timedelta64('NaT')])
133+
self._test_equal(a, b)
134+
135+
c = np.array([np.timedelta64('NaT'), np.timedelta64('NaT')])
136+
self._test_not_equal(c, b)
137+
138+
d = np.array([np.datetime64('NaT'), np.datetime64('NaT')])
139+
self._test_not_equal(c, d)
140+
122141
def test_string_arrays(self):
123142
"""Test two arrays with different shapes are found not equal."""
124143
a = np.array(['floupi', 'floupa'])
@@ -227,6 +246,16 @@ def test_complex(self):
227246
self._assert_func(x, x)
228247
self._test_not_equal(x, y)
229248

249+
def test_nat(self):
250+
dt = np.datetime64('2000-01-01')
251+
dt_nat = np.datetime64('NaT')
252+
td_nat = np.timedelta64('NaT')
253+
self._assert_func(dt_nat, dt_nat)
254+
self._assert_func(td_nat, td_nat)
255+
self._test_not_equal(dt_nat, td_nat)
256+
self._test_not_equal(dt, td_nat)
257+
self._test_not_equal(dt, dt_nat)
258+
230259

231260
class TestArrayAlmostEqual(_GenericTest, unittest.TestCase):
232261

@@ -457,7 +486,7 @@ def f():
457486

458487

459488
class TestAssertAllclose(unittest.TestCase):
460-
489+
461490
def test_simple(self):
462491
x = 1e-3
463492
y = 1e-9

numpy/testing/utils.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tempfile import mkdtemp, mkstemp
1616

1717
from .nosetester import import_nose
18-
from numpy.core import float32, empty, arange, array_repr, ndarray
18+
from numpy.core import float32, empty, arange, array_repr, ndarray, dtype
1919
from numpy.lib.utils import deprecate
2020

2121
if sys.version_info[0] >= 3:
@@ -343,16 +343,31 @@ def assert_equal(actual,desired,err_msg='',verbose=True):
343343
except AssertionError:
344344
raise AssertionError(msg)
345345

346+
def isnat(x):
347+
return (hasattr(x, 'dtype')
348+
and getattr(x.dtype, 'kind', '_') in 'mM'
349+
and x != x)
350+
346351
# Inf/nan/negative zero handling
347352
try:
348353
# isscalar test to check cases such as [np.nan] != np.nan
349-
if isscalar(desired) != isscalar(actual):
354+
# dtypes compare equal to strings, but unlike strings aren't scalars,
355+
# so we need to exclude them from this check
356+
if (isscalar(desired) != isscalar(actual)
357+
and not (isinstance(desired, dtype)
358+
or isinstance(actual, dtype))):
350359
raise AssertionError(msg)
351360

361+
# check NaT before NaN, because isfinite errors on datetime dtypes
362+
if isnat(desired) and isnat(actual):
363+
if desired.dtype.kind != actual.dtype.kind:
364+
# datetime64 and timedelta64 NaT should not be comparable
365+
raise AssertionError(msg)
366+
return
352367
# If one of desired/actual is not finite, handle it specially here:
353368
# check that both are nan if any is a nan, and test for equality
354369
# otherwise
355-
if not (gisfinite(desired) and gisfinite(actual)):
370+
elif not (gisfinite(desired) and gisfinite(actual)):
356371
isdesnan = gisnan(desired)
357372
isactnan = gisnan(actual)
358373
if isdesnan or isactnan:
@@ -663,6 +678,9 @@ def safe_comparison(*args, **kwargs):
663678
def isnumber(x):
664679
return x.dtype.char in '?bhilqpBHILQPefdgFDG'
665680

681+
def isdatetime(x):
682+
return x.dtype.char in 'mM'
683+
666684
def chk_same_position(x_id, y_id, hasval='nan'):
667685
"""Handling nan/inf: check that x and y have the nan/inf at the same
668686
locations."""
@@ -675,6 +693,15 @@ def chk_same_position(x_id, y_id, hasval='nan'):
675693
names=('x', 'y'), precision=precision)
676694
raise AssertionError(msg)
677695

696+
def chk_same_dtype(x_dt, y_dt):
697+
try:
698+
assert_equal(x_dt, y_dt)
699+
except AssertionError:
700+
msg = build_err_msg([x, y], err_msg + '\nx and y dtype mismatch',
701+
verbose=verbose, header=header,
702+
names=('x', 'y'), precision=precision)
703+
raise AssertionError(msg)
704+
678705
try:
679706
cond = (x.shape == () or y.shape == ()) or x.shape == y.shape
680707
if not cond:
@@ -712,6 +739,20 @@ def chk_same_position(x_id, y_id, hasval='nan'):
712739
val = safe_comparison(x[~x_id], y[~y_id])
713740
else:
714741
val = safe_comparison(x, y)
742+
elif isdatetime(x) and isdatetime(y):
743+
x_isnat, y_isnat = (x != x), (y != y)
744+
745+
if any(x_isnat) or any(y_isnat):
746+
# cannot mix timedelta64/datetime64 NaT
747+
chk_same_dtype(x.dtype, y.dtype)
748+
chk_same_position(x_isnat, y_isnat, hasval='nat')
749+
750+
if all(x_isnat):
751+
return
752+
if any(x_isnat):
753+
val = safe_comparison(x[~x_isnat], y[~y_isnat])
754+
else:
755+
val = safe_comparison(x, y)
715756
else:
716757
val = safe_comparison(x, y)
717758

@@ -1826,7 +1867,7 @@ def temppath(*args, **kwargs):
18261867
parameters are the same as for tempfile.mkstemp and are passed directly
18271868
to that function. The underlying file is removed when the context is
18281869
exited, so it should be closed at that time.
1829-
1870+
18301871
Windows does not allow a temporary file to be opened if it is already
18311872
open, so the underlying file must be closed after opening before it
18321873
can be opened again.

0 commit comments

Comments
 (0)