From d864048367057dae5f22c3adc4b7845821e3fb87 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 10 Oct 2018 13:27:56 +1100 Subject: [PATCH 1/2] FIX make count_nonzero dtype invariant wrt axis --- sklearn/utils/sparsefuncs.py | 3 ++- sklearn/utils/tests/test_sparsefuncs.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/sparsefuncs.py b/sklearn/utils/sparsefuncs.py index ccaa6eeb28e65..918f32e6da3e5 100644 --- a/sklearn/utils/sparsefuncs.py +++ b/sklearn/utils/sparsefuncs.py @@ -467,7 +467,8 @@ def count_nonzero(X, axis=None, sample_weight=None): elif axis == 1: out = np.diff(X.indptr) if sample_weight is None: - return out + # astype here is for consistency with axis=0 dtype + return out.astype('intp') return out * sample_weight elif axis == 0: if sample_weight is None: diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index 838435a0deab9..862b1e6db8e0b 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -443,6 +443,19 @@ def test_count_nonzero(): assert_raises(TypeError, count_nonzero, X_csc) assert_raises(ValueError, count_nonzero, X_csr, axis=2) + assert (count_nonzero(X_csr, axis=0).dtype == + count_nonzero(X_csr, axis=1).dtype) + assert (count_nonzero(X_csr, axis=0, sample_weight=sample_weight).dtype == + count_nonzero(X_csr, axis=1, sample_weight=sample_weight).dtype) + + # Check dtypes with large sparse matrices too + X_csr = sp.csr_matrix((X_csr.data, X_csr.indices.astype(np.int64), + X_csr.indptr.astype(np.int64)), shape=X_csr.shape) + assert (count_nonzero(X_csr, axis=0).dtype == + count_nonzero(X_csr, axis=1).dtype) + assert (count_nonzero(X_csr, axis=0, sample_weight=sample_weight).dtype == + count_nonzero(X_csr, axis=1, sample_weight=sample_weight).dtype) + def test_csc_row_median(): # Test csc_row_median actually calculates the median. From e2c0bf723a3ee9d012e82b7aad40ad51fa449910 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sun, 14 Oct 2018 22:59:47 +1100 Subject: [PATCH 2/2] Correctly test scipy int64 sparse indices --- sklearn/utils/tests/test_sparsefuncs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index 862b1e6db8e0b..03c0c717d3174 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -449,8 +449,8 @@ def test_count_nonzero(): count_nonzero(X_csr, axis=1, sample_weight=sample_weight).dtype) # Check dtypes with large sparse matrices too - X_csr = sp.csr_matrix((X_csr.data, X_csr.indices.astype(np.int64), - X_csr.indptr.astype(np.int64)), shape=X_csr.shape) + X_csr.indices = X_csr.indices.astype(np.int64) + X_csr.indptr = X_csr.indptr.astype(np.int64) assert (count_nonzero(X_csr, axis=0).dtype == count_nonzero(X_csr, axis=1).dtype) assert (count_nonzero(X_csr, axis=0, sample_weight=sample_weight).dtype ==