Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 98357ec

Browse files
jnothmanqinhanmin2014
authored andcommitted
FIX (0.21) make count_nonzero dtype invariant wrt axis (#12341)
1 parent 0ad8736 commit 98357ec

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

sklearn/utils/sparsefuncs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,8 @@ def count_nonzero(X, axis=None, sample_weight=None):
467467
elif axis == 1:
468468
out = np.diff(X.indptr)
469469
if sample_weight is None:
470-
return out
470+
# astype here is for consistency with axis=0 dtype
471+
return out.astype('intp')
471472
return out * sample_weight
472473
elif axis == 0:
473474
if sample_weight is None:

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,19 @@ def test_count_nonzero():
443443
assert_raises(TypeError, count_nonzero, X_csc)
444444
assert_raises(ValueError, count_nonzero, X_csr, axis=2)
445445

446+
assert (count_nonzero(X_csr, axis=0).dtype ==
447+
count_nonzero(X_csr, axis=1).dtype)
448+
assert (count_nonzero(X_csr, axis=0, sample_weight=sample_weight).dtype ==
449+
count_nonzero(X_csr, axis=1, sample_weight=sample_weight).dtype)
450+
451+
# Check dtypes with large sparse matrices too
452+
X_csr.indices = X_csr.indices.astype(np.int64)
453+
X_csr.indptr = X_csr.indptr.astype(np.int64)
454+
assert (count_nonzero(X_csr, axis=0).dtype ==
455+
count_nonzero(X_csr, axis=1).dtype)
456+
assert (count_nonzero(X_csr, axis=0, sample_weight=sample_weight).dtype ==
457+
count_nonzero(X_csr, axis=1, sample_weight=sample_weight).dtype)
458+
446459

447460
def test_csc_row_median():
448461
# Test csc_row_median actually calculates the median.

0 commit comments

Comments
 (0)