diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index bb245aa466152..f5cee5c28258f 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -499,6 +499,15 @@ Changelog during `transform` with no prior call to `fit` or `fit_transform`. :pr:`25190` by :user:`Vincent Maladière `. +- |Enhancement| :class:`preprocessing.PolynomialFeatures` now calculates the + number of expanded terms a-priori when dealing with sparse `csr` matrices + in order to optimize the choice of `dtype` for `indices` and `indptr`. It + can now output `csr` matrices with `np.int32` `indices/indptr` components + when there are few enough elements, and will automatically use `np.int64` + for sufficiently large matrices. + :pr:`20524` by :user:`niuk-a ` and + :pr:`23731` by :user:`Meekail Zain ` + - |API| A `FutureWarning` is now raised when instantiating a class which inherits from a deprecated base class (i.e. decorated by :class:`utils.deprecated`) and which overrides the `__init__` method. diff --git a/setup.py b/setup.py index 89227169a05e9..d1e191190fd09 100755 --- a/setup.py +++ b/setup.py @@ -293,7 +293,7 @@ def check_package_status(package, min_version): }, ], "preprocessing": [ - {"sources": ["_csr_polynomial_expansion.pyx"], "include_np": True}, + {"sources": ["_csr_polynomial_expansion.pyx"]}, { "sources": ["_target_encoder_fast.pyx"], "include_np": True, diff --git a/sklearn/preprocessing/_csr_polynomial_expansion.pyx b/sklearn/preprocessing/_csr_polynomial_expansion.pyx index e2cff65f07972..90f81c0399a6e 100644 --- a/sklearn/preprocessing/_csr_polynomial_expansion.pyx +++ b/sklearn/preprocessing/_csr_polynomial_expansion.pyx @@ -1,73 +1,178 @@ -# Author: Andrew nystrom +# Authors: Andrew nystrom +# Meekail Zain +from ..utils._typedefs cimport uint8_t, int64_t, intp_t -from scipy.sparse import csr_matrix -cimport numpy as cnp -import numpy as np +ctypedef uint8_t FLAG_t + +# We use the following verbatim block to determine whether the current +# platform's compiler supports 128-bit integer values intrinsically. +# This should work for GCC and CLANG on 64-bit architectures, but doesn't for +# MSVC on any architecture. We prefer to use 128-bit integers when possible +# because the intermediate calculations have a non-trivial risk of overflow. It +# is, however, very unlikely to come up on an average use case, hence 64-bit +# integers (i.e. `long long`) are "good enough" for most common cases. There is +# not much we can do to efficiently mitigate the overflow risk on the Windows +# platform at this time. Consider this a "best effort" design decision that +# could be revisited later in case someone comes up with a safer option that +# does not hurt the performance of the common cases. +# See `test_sizeof_LARGEST_INT_t()`for more information on exact type expectations. +cdef extern from *: + """ + #ifdef __SIZEOF_INT128__ + typedef __int128 LARGEST_INT_t; + #elif (__clang__ || __EMSCRIPTEN__) && !__i386__ + typedef _BitInt(128) LARGEST_INT_t; + #else + typedef long long LARGEST_INT_t; + #endif + """ + ctypedef long long LARGEST_INT_t + + +# Determine the size of `LARGEST_INT_t` at runtime. +# Used in `test_sizeof_LARGEST_INT_t`. +def _get_sizeof_LARGEST_INT_t(): + return sizeof(LARGEST_INT_t) -cnp.import_array() -# TODO: use `cnp.{int,float}{32,64}` when cython#5230 is resolved: +# TODO: use `{int,float}{32,64}_t` when cython#5230 is resolved: # https://github.com/cython/cython/issues/5230 -ctypedef fused DATA_T: +ctypedef fused DATA_t: float double int - long + long long +# INDEX_{A,B}_t are defined to generate a proper Cartesian product +# of types through Cython fused-type expansion. +ctypedef fused INDEX_A_t: + signed int + signed long long +ctypedef fused INDEX_B_t: + signed int + signed long long - -cdef inline cnp.int32_t _deg2_column( - cnp.int32_t d, - cnp.int32_t i, - cnp.int32_t j, - cnp.int32_t interaction_only, -) noexcept nogil: +cdef inline int64_t _deg2_column( + LARGEST_INT_t n_features, + LARGEST_INT_t i, + LARGEST_INT_t j, + FLAG_t interaction_only +) nogil: """Compute the index of the column for a degree 2 expansion - d is the dimensionality of the input data, i and j are the indices + n_features is the dimensionality of the input data, i and j are the indices for the columns involved in the expansion. """ if interaction_only: - return d * i - (i**2 + 3 * i) / 2 - 1 + j + return n_features * i - i * (i + 3) / 2 - 1 + j else: - return d * i - (i**2 + i) / 2 + j + return n_features * i - i* (i + 1) / 2 + j -cdef inline cnp.int32_t _deg3_column( - cnp.int32_t d, - cnp.int32_t i, - cnp.int32_t j, - cnp.int32_t k, - cnp.int32_t interaction_only -) noexcept nogil: +cdef inline int64_t _deg3_column( + LARGEST_INT_t n_features, + LARGEST_INT_t i, + LARGEST_INT_t j, + LARGEST_INT_t k, + FLAG_t interaction_only +) nogil: """Compute the index of the column for a degree 3 expansion - d is the dimensionality of the input data, i, j and k are the indices + n_features is the dimensionality of the input data, i, j and k are the indices for the columns involved in the expansion. """ if interaction_only: - return ((3 * d**2 * i - 3 * d * i**2 + i**3 - + 11 * i - 3 * j**2 - 9 * j) / 6 - + i**2 - 2 * d * i + d * j - d + k) + return ( + ( + (3 * n_features) * (n_features * i - i**2) + + i * (i**2 + 11) - (3 * j) * (j + 3) + ) / 6 + i**2 + n_features * (j - 1 - 2 * i) + k + ) + else: + return ( + ( + (3 * n_features) * (n_features * i - i**2) + + i ** 3 - i - (3 * j) * (j + 1) + ) / 6 + n_features * j + k + ) + + +def py_calc_expanded_nnz_deg2(n, interaction_only): + return n * (n + 1) // 2 - interaction_only * n + + +def py_calc_expanded_nnz_deg3(n, interaction_only): + return n * (n**2 + 3 * n + 2) // 6 - interaction_only * n**2 + + +cpdef int64_t _calc_expanded_nnz( + LARGEST_INT_t n, + FLAG_t interaction_only, + LARGEST_INT_t degree +): + """ + Calculates the number of non-zero interaction terms generated by the + non-zero elements of a single row. + """ + # This is the maximum value before the intermediate computation + # d**2 + d overflows + # Solution to d**2 + d = maxint64 + # SymPy: solve(x**2 + x - int64_max, x) + cdef int64_t MAX_SAFE_INDEX_CALC_DEG2 = 3037000499 + + # This is the maximum value before the intermediate computation + # d**3 + 3 * d**2 + 2*d overflows + # Solution to d**3 + 3 * d**2 + 2*d = maxint64 + # SymPy: solve(x * (x**2 + 3 * x + 2) - int64_max, x) + cdef int64_t MAX_SAFE_INDEX_CALC_DEG3 = 2097151 + + if degree == 2: + # Only need to check when not using 128-bit integers + if sizeof(LARGEST_INT_t) < 16 and n <= MAX_SAFE_INDEX_CALC_DEG2: + return n * (n + 1) / 2 - interaction_only * n + return py_calc_expanded_nnz_deg2(n, interaction_only) else: - return ((3 * d**2 * i - 3 * d * i**2 + i ** 3 - i - - 3 * j**2 - 3 * j) / 6 - + d * j + k) - - -def _csr_polynomial_expansion( - const DATA_T[:] data, - const cnp.int32_t[:] indices, - const cnp.int32_t[:] indptr, - cnp.int32_t d, - cnp.int32_t interaction_only, - cnp.int32_t degree + # Only need to check when not using 128-bit integers + if sizeof(LARGEST_INT_t) < 16 and n <= MAX_SAFE_INDEX_CALC_DEG3: + return n * (n**2 + 3 * n + 2) / 6 - interaction_only * n**2 + return py_calc_expanded_nnz_deg3(n, interaction_only) + +cpdef int64_t _calc_total_nnz( + INDEX_A_t[:] indptr, + FLAG_t interaction_only, + int64_t degree, ): """ - Perform a second-degree polynomial or interaction expansion on a scipy + Calculates the number of non-zero interaction terms generated by the + non-zero elements across all rows for a single degree. + """ + cdef int64_t total_nnz=0 + cdef intp_t row_idx + for row_idx in range(len(indptr) - 1): + total_nnz += _calc_expanded_nnz( + indptr[row_idx + 1] - indptr[row_idx], + interaction_only, + degree + ) + return total_nnz + + +cpdef void _csr_polynomial_expansion( + const DATA_t[:] data, # IN READ-ONLY + const INDEX_A_t[:] indices, # IN READ-ONLY + const INDEX_A_t[:] indptr, # IN READ-ONLY + INDEX_A_t n_features, + DATA_t[:] result_data, # OUT + INDEX_B_t[:] result_indices, # OUT + INDEX_B_t[:] result_indptr, # OUT + FLAG_t interaction_only, + FLAG_t degree +) nogil: + """ + Perform a second or third degree polynomial or interaction expansion on a compressed sparse row (CSR) matrix. The method used only takes products of - non-zero features. For a matrix with density d, this results in a speedup - on the order of d^k where k is the degree of the expansion, assuming all - rows are of similar density. + non-zero features. For a matrix with density :math:`d`, this results in a + speedup on the order of :math:`(1/d)^k` where :math:`k` is the degree of + the expansion, assuming all rows are of similar density. Parameters ---------- @@ -80,9 +185,21 @@ def _csr_polynomial_expansion( indptr : memory view on nd-array The "indptr" attribute of the input CSR matrix. - d : int + n_features : int The dimensionality of the input CSR matrix. + result_data : nd-array + The output CSR matrix's "data" attribute. + It is modified by this routine. + + result_indices : nd-array + The output CSR matrix's "indices" attribute. + It is modified by this routine. + + result_indptr : nd-array + The output CSR matrix's "indptr" attribute. + It is modified by this routine. + interaction_only : int 0 for a polynomial expansion, 1 for an interaction expansion. @@ -95,47 +212,11 @@ def _csr_polynomial_expansion( Matrices Using K-Simplex Numbers" by Andrew Nystrom and John Hughes. """ - assert degree in (2, 3) - - if degree == 2: - expanded_dimensionality = int((d**2 + d) / 2 - interaction_only*d) - else: - expanded_dimensionality = int((d**3 + 3*d**2 + 2*d) / 6 - - interaction_only*d**2) - if expanded_dimensionality == 0: - return None - assert expanded_dimensionality > 0 - - cdef cnp.int32_t total_nnz = 0, row_i, nnz - - # Count how many nonzero elements the expanded matrix will contain. - for row_i in range(indptr.shape[0]-1): - # nnz is the number of nonzero elements in this row. - nnz = indptr[row_i + 1] - indptr[row_i] - if degree == 2: - total_nnz += (nnz ** 2 + nnz) / 2 - interaction_only * nnz - else: - total_nnz += ((nnz ** 3 + 3 * nnz ** 2 + 2 * nnz) / 6 - - interaction_only * nnz ** 2) - # Make the arrays that will form the CSR matrix of the expansion. - cdef: - DATA_T[:] expanded_data = np.empty( - shape=total_nnz, dtype=data.base.dtype - ) - cnp.int32_t[:] expanded_indices = np.empty( - shape=total_nnz, dtype=np.int32 - ) - cnp.int32_t num_rows = indptr.shape[0] - 1 - cnp.int32_t[:] expanded_indptr = np.empty( - shape=num_rows + 1, dtype=np.int32 - ) - - cnp.int32_t expanded_index = 0, row_starts, row_ends - cnp.int32_t i, j, k, i_ptr, j_ptr, k_ptr, num_cols_in_row - + cdef INDEX_A_t row_i, row_starts, row_ends, i, j, k, i_ptr, j_ptr, k_ptr + cdef INDEX_B_t expanded_index=0, num_cols_in_row, col with nogil: - expanded_indptr[0] = indptr[0] + result_indptr[0] = indptr[0] for row_i in range(indptr.shape[0]-1): row_starts = indptr[row_i] row_ends = indptr[row_i + 1] @@ -145,24 +226,32 @@ def _csr_polynomial_expansion( for j_ptr in range(i_ptr + interaction_only, row_ends): j = indices[j_ptr] if degree == 2: - col = _deg2_column(d, i, j, interaction_only) - expanded_indices[expanded_index] = col - expanded_data[expanded_index] = ( - data[i_ptr] * data[j_ptr]) + col = _deg2_column( + n_features, + i, j, + interaction_only + ) + result_indices[expanded_index] = col + result_data[expanded_index] = ( + data[i_ptr] * data[j_ptr] + ) expanded_index += 1 num_cols_in_row += 1 else: # degree == 3 for k_ptr in range(j_ptr + interaction_only, row_ends): k = indices[k_ptr] - col = _deg3_column(d, i, j, k, interaction_only) - expanded_indices[expanded_index] = col - expanded_data[expanded_index] = ( - data[i_ptr] * data[j_ptr] * data[k_ptr]) + col = _deg3_column( + n_features, + i, j, k, + interaction_only + ) + result_indices[expanded_index] = col + result_data[expanded_index] = ( + data[i_ptr] * data[j_ptr] * data[k_ptr] + ) expanded_index += 1 num_cols_in_row += 1 - expanded_indptr[row_i+1] = expanded_indptr[row_i] + num_cols_in_row - - return csr_matrix((expanded_data, expanded_indices, expanded_indptr), - shape=(num_rows, expanded_dimensionality)) + result_indptr[row_i+1] = result_indptr[row_i] + num_cols_in_row + return diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py index 83ebbf786d8fc..64ecb9864fae0 100644 --- a/sklearn/preprocessing/_polynomial.py +++ b/sklearn/preprocessing/_polynomial.py @@ -17,8 +17,13 @@ from ..utils.validation import _check_feature_names_in from ..utils._param_validation import Interval, StrOptions from ..utils.stats import _weighted_percentile +from ..utils.fixes import sp_version, parse_version -from ._csr_polynomial_expansion import _csr_polynomial_expansion +from ._csr_polynomial_expansion import ( + _csr_polynomial_expansion, + _calc_expanded_nnz, + _calc_total_nnz, +) __all__ = [ @@ -27,6 +32,67 @@ ] +def _create_expansion(X, interaction_only, deg, n_features, cumulative_size=0): + """Helper function for creating and appending sparse expansion matrices""" + + total_nnz = _calc_total_nnz(X.indptr, interaction_only, deg) + expanded_col = _calc_expanded_nnz(n_features, interaction_only, deg) + + if expanded_col == 0: + return None + # This only checks whether each block needs 64bit integers upon + # expansion. We prefer to keep int32 indexing where we can, + # since currently SciPy's CSR construction downcasts when possible, + # so we prefer to avoid an unnecessary cast. The dtype may still + # change in the concatenation process if needed. + # See: https://github.com/scipy/scipy/issues/16569 + max_indices = expanded_col - 1 + max_indptr = total_nnz + max_int32 = np.iinfo(np.int32).max + needs_int64 = max(max_indices, max_indptr) > max_int32 + index_dtype = np.int64 if needs_int64 else np.int32 + + # This is a pretty specific bug that is hard to work around by a user, + # hence we do not detail the entire bug and all possible avoidance + # mechnasisms. Instead we recommend upgrading scipy or shrinking their data. + cumulative_size += expanded_col + if ( + sp_version < parse_version("1.8.0") + and cumulative_size - 1 > max_int32 + and not needs_int64 + ): + raise ValueError( + "In scipy versions `<1.8.0`, the function `scipy.sparse.hstack`" + " sometimes produces negative columns when the output shape contains" + " `n_cols` too large to be represented by a 32bit signed" + " integer. To avoid this error, either use a version" + " of scipy `>=1.8.0` or alter the `PolynomialFeatures`" + " transformer to produce fewer than 2^31 output features." + ) + + # Result of the expansion, modified in place by the + # `_csr_polynomial_expansion` routine. + expanded_data = np.empty(shape=total_nnz, dtype=X.data.dtype) + expanded_indices = np.empty(shape=total_nnz, dtype=index_dtype) + expanded_indptr = np.empty(shape=X.indptr.shape[0], dtype=index_dtype) + _csr_polynomial_expansion( + X.data, + X.indices, + X.indptr, + X.shape[1], + expanded_data, + expanded_indices, + expanded_indptr, + interaction_only, + deg, + ) + return sparse.csr_matrix( + (expanded_data, expanded_indices, expanded_indptr), + shape=(X.indptr.shape[0] - 1, expanded_col), + dtype=X.dtype, + ) + + class PolynomialFeatures(TransformerMixin, BaseEstimator): """Generate polynomial and interaction features. @@ -297,6 +363,27 @@ def fit(self, X, y=None): interaction_only=self.interaction_only, include_bias=self.include_bias, ) + if self.n_output_features_ > np.iinfo(np.intp).max: + msg = ( + "The output that would result from the current configuration would" + f" have {self.n_output_features_} features which is too large to be" + f" indexed by {np.intp().dtype.name}. Please change some or all of the" + " following:\n- The number of features in the input, currently" + f" {n_features=}\n- The range of degrees to calculate, currently" + f" [{self._min_degree}, {self._max_degree}]\n- Whether to include only" + f" interaction terms, currently {self.interaction_only}\n- Whether to" + f" include a bias term, currently {self.include_bias}." + ) + if ( + np.intp == np.int32 + and self.n_output_features_ <= np.iinfo(np.int64).max + ): # pragma: nocover + msg += ( + "\nNote that the current Python runtime has a limited 32 bit " + "address space and that this configuration would have been " + "admissible if run on a 64 bit Python runtime." + ) + raise ValueError(msg) # We also record the number of output features for # _max_degree = 0 self._n_out_full = self._num_combinations( @@ -345,29 +432,52 @@ def transform(self, X): ) n_samples, n_features = X.shape - + max_int32 = np.iinfo(np.int32).max if sparse.isspmatrix_csr(X): if self._max_degree > 3: return self.transform(X.tocsc()).tocsr() to_stack = [] if self.include_bias: to_stack.append( - sparse.csc_matrix(np.ones(shape=(n_samples, 1), dtype=X.dtype)) + sparse.csr_matrix(np.ones(shape=(n_samples, 1), dtype=X.dtype)) ) if self._min_degree <= 1 and self._max_degree > 0: to_stack.append(X) + + cumulative_size = sum(mat.shape[1] for mat in to_stack) for deg in range(max(2, self._min_degree), self._max_degree + 1): - Xp_next = _csr_polynomial_expansion( - X.data, X.indices, X.indptr, X.shape[1], self.interaction_only, deg + expanded = _create_expansion( + X=X, + interaction_only=self.interaction_only, + deg=deg, + n_features=n_features, + cumulative_size=cumulative_size, ) - if Xp_next is None: - break - to_stack.append(Xp_next) + if expanded is not None: + to_stack.append(expanded) + cumulative_size += expanded.shape[1] if len(to_stack) == 0: # edge case: deal with empty matrix XP = sparse.csr_matrix((n_samples, 0), dtype=X.dtype) else: - XP = sparse.hstack(to_stack, format="csr") + # `scipy.sparse.hstack` breaks in scipy<1.9.2 + # when `n_output_features_ > max_int32` + all_int32 = all(mat.indices.dtype == np.int32 for mat in to_stack) + if ( + sp_version < parse_version("1.9.2") + and self.n_output_features_ > max_int32 + and all_int32 + ): + raise ValueError( # pragma: no cover + "In scipy versions `<1.9.2`, the function `scipy.sparse.hstack`" + " produces negative columns when:\n1. The output shape contains" + " `n_cols` too large to be represented by a 32bit signed" + " integer.\n2. All sub-matrices to be stacked have indices of" + " dtype `np.int32`.\nTo avoid this error, either use a version" + " of scipy `>=1.9.2` or alter the `PolynomialFeatures`" + " transformer to produce fewer than 2^31 output features" + ) + XP = sparse.hstack(to_stack, dtype=X.dtype, format="csr") elif sparse.isspmatrix_csc(X) and self._max_degree < 4: return self.transform(X.tocsr()).tocsc() elif sparse.isspmatrix(X): diff --git a/sklearn/preprocessing/tests/test_polynomial.py b/sklearn/preprocessing/tests/test_polynomial.py index f21c37fb694fa..727b31b793b1d 100644 --- a/sklearn/preprocessing/tests/test_polynomial.py +++ b/sklearn/preprocessing/tests/test_polynomial.py @@ -1,8 +1,10 @@ import numpy as np import pytest +import sys from scipy import sparse from scipy.sparse import random as sparse_random from sklearn.utils._testing import assert_array_almost_equal +from sklearn.utils.fixes import sp_version, parse_version from numpy.testing import assert_allclose, assert_array_equal from scipy.interpolate import BSpline @@ -13,6 +15,11 @@ PolynomialFeatures, SplineTransformer, ) +from sklearn.preprocessing._csr_polynomial_expansion import ( + _calc_total_nnz, + _calc_expanded_nnz, + _get_sizeof_LARGEST_INT_t, +) @pytest.mark.parametrize("est", (PolynomialFeatures, SplineTransformer)) @@ -789,6 +796,262 @@ def test_polynomial_features_csr_X_dim_edges(deg, dim, interaction_only): assert_array_almost_equal(Xt_csr.A, Xt_dense) +@pytest.mark.parametrize("interaction_only", [True, False]) +@pytest.mark.parametrize("include_bias", [True, False]) +def test_csr_polynomial_expansion_index_overflow_non_regression( + interaction_only, include_bias +): + """Check the automatic index dtype promotion to `np.int64` when needed. + + This ensures that sufficiently large input configurations get + properly promoted to use `np.int64` for index and indptr representation + while preserving data integrity. Non-regression test for gh-16803. + + Note that this is only possible for Python runtimes with a 64 bit address + space. On 32 bit platforms, a `ValueError` is raised instead. + """ + + def degree_2_calc(d, i, j): + if interaction_only: + return d * i - (i**2 + 3 * i) // 2 - 1 + j + else: + return d * i - (i**2 + i) // 2 + j + + n_samples = 13 + n_features = 120001 + data_dtype = np.float32 + data = np.arange(1, 5, dtype=np.int64) + row = np.array([n_samples - 2, n_samples - 2, n_samples - 1, n_samples - 1]) + # An int64 dtype is required to avoid overflow error on Windows within the + # `degree_2_calc` function. + col = np.array( + [n_features - 2, n_features - 1, n_features - 2, n_features - 1], dtype=np.int64 + ) + X = sparse.csr_matrix( + (data, (row, col)), + shape=(n_samples, n_features), + dtype=data_dtype, + ) + pf = PolynomialFeatures( + interaction_only=interaction_only, include_bias=include_bias, degree=2 + ) + + # Calculate the number of combinations a-priori, and if needed check for + # the correct ValueError and terminate the test early. + num_combinations = pf._num_combinations( + n_features=n_features, + min_degree=0, + max_degree=2, + interaction_only=pf.interaction_only, + include_bias=pf.include_bias, + ) + if num_combinations > np.iinfo(np.intp).max: + msg = ( + r"The output that would result from the current configuration would have" + r" \d* features which is too large to be indexed" + ) + with pytest.raises(ValueError, match=msg): + pf.fit(X) + return + X_trans = pf.fit_transform(X) + row_nonzero, col_nonzero = X_trans.nonzero() + n_degree_1_features_out = n_features + include_bias + max_degree_2_idx = ( + degree_2_calc(n_features, col[int(not interaction_only)], col[1]) + + n_degree_1_features_out + ) + + # Account for bias of all samples except last one which will be handled + # separately since there are distinct data values before it + data_target = [1] * (n_samples - 2) if include_bias else [] + col_nonzero_target = [0] * (n_samples - 2) if include_bias else [] + + for i in range(2): + x = data[2 * i] + y = data[2 * i + 1] + x_idx = col[2 * i] + y_idx = col[2 * i + 1] + if include_bias: + data_target.append(1) + col_nonzero_target.append(0) + data_target.extend([x, y]) + col_nonzero_target.extend( + [x_idx + int(include_bias), y_idx + int(include_bias)] + ) + if not interaction_only: + data_target.extend([x * x, x * y, y * y]) + col_nonzero_target.extend( + [ + degree_2_calc(n_features, x_idx, x_idx) + n_degree_1_features_out, + degree_2_calc(n_features, x_idx, y_idx) + n_degree_1_features_out, + degree_2_calc(n_features, y_idx, y_idx) + n_degree_1_features_out, + ] + ) + else: + data_target.extend([x * y]) + col_nonzero_target.append( + degree_2_calc(n_features, x_idx, y_idx) + n_degree_1_features_out + ) + + nnz_per_row = int(include_bias) + 3 + 2 * int(not interaction_only) + + assert pf.n_output_features_ == max_degree_2_idx + 1 + assert X_trans.dtype == data_dtype + assert X_trans.shape == (n_samples, max_degree_2_idx + 1) + assert X_trans.indptr.dtype == X_trans.indices.dtype == np.int64 + # Ensure that dtype promotion was actually required: + assert X_trans.indices.max() > np.iinfo(np.int32).max + + row_nonzero_target = list(range(n_samples - 2)) if include_bias else [] + row_nonzero_target.extend( + [n_samples - 2] * nnz_per_row + [n_samples - 1] * nnz_per_row + ) + + assert_allclose(X_trans.data, data_target) + assert_array_equal(row_nonzero, row_nonzero_target) + assert_array_equal(col_nonzero, col_nonzero_target) + + +@pytest.mark.parametrize( + "degree, n_features", + [ + # Needs promotion to int64 when interaction_only=False + (2, 65535), + (3, 2344), + # This guarantees that the intermediate operation when calculating + # output columns would overflow a C-long, hence checks that python- + # longs are being used. + (2, int(np.sqrt(np.iinfo(np.int64).max) + 1)), + (3, 65535), + # This case tests the second clause of the overflow check which + # takes into account the value of `n_features` itself. + (2, int(np.sqrt(np.iinfo(np.int64).max))), + ], +) +@pytest.mark.parametrize("interaction_only", [True, False]) +@pytest.mark.parametrize("include_bias", [True, False]) +def test_csr_polynomial_expansion_index_overflow( + degree, n_features, interaction_only, include_bias +): + """Tests known edge-cases to the dtype promotion strategy and custom + Cython code, including a current bug in the upstream + `scipy.sparse.hstack`. + """ + data = [1.0] + row = [0] + col = [n_features - 1] + + # First degree index + expected_indices = [ + n_features - 1 + int(include_bias), + ] + # Second degree index + expected_indices.append(n_features * (n_features + 1) // 2 + expected_indices[0]) + # Third degree index + expected_indices.append( + n_features * (n_features + 1) * (n_features + 2) // 6 + expected_indices[1] + ) + + X = sparse.csr_matrix((data, (row, col))) + pf = PolynomialFeatures( + interaction_only=interaction_only, include_bias=include_bias, degree=degree + ) + + # Calculate the number of combinations a-priori, and if needed check for + # the correct ValueError and terminate the test early. + num_combinations = pf._num_combinations( + n_features=n_features, + min_degree=0, + max_degree=degree, + interaction_only=pf.interaction_only, + include_bias=pf.include_bias, + ) + if num_combinations > np.iinfo(np.intp).max: + msg = ( + r"The output that would result from the current configuration would have" + r" \d* features which is too large to be indexed" + ) + with pytest.raises(ValueError, match=msg): + pf.fit(X) + return + + # In SciPy < 1.8, a bug occurs when an intermediate matrix in + # `to_stack` in `hstack` fits within int32 however would require int64 when + # combined with all previous matrices in `to_stack`. + if sp_version < parse_version("1.8.0"): + has_bug = False + max_int32 = np.iinfo(np.int32).max + cumulative_size = n_features + include_bias + for deg in range(2, degree + 1): + max_indptr = _calc_total_nnz(X.indptr, interaction_only, deg) + max_indices = _calc_expanded_nnz(n_features, interaction_only, deg) - 1 + cumulative_size += max_indices + 1 + needs_int64 = max(max_indices, max_indptr) > max_int32 + has_bug |= not needs_int64 and cumulative_size > max_int32 + if has_bug: + msg = r"In scipy versions `<1.8.0`, the function `scipy.sparse.hstack`" + with pytest.raises(ValueError, match=msg): + X_trans = pf.fit_transform(X) + return + + # When `n_features>=65535`, `scipy.sparse.hstack` may not use the right + # dtype for representing indices and indptr if `n_features` is still + # small enough so that each block matrix's indices and indptr arrays + # can be represented with `np.int32`. We test `n_features==65535` + # since it is guaranteed to run into this bug. + if ( + sp_version < parse_version("1.9.2") + and n_features == 65535 + and degree == 2 + and not interaction_only + ): # pragma: no cover + msg = r"In scipy versions `<1.9.2`, the function `scipy.sparse.hstack`" + with pytest.raises(ValueError, match=msg): + X_trans = pf.fit_transform(X) + return + X_trans = pf.fit_transform(X) + + expected_dtype = np.int64 if num_combinations > np.iinfo(np.int32).max else np.int32 + # Terms higher than first degree + non_bias_terms = 1 + (degree - 1) * int(not interaction_only) + expected_nnz = int(include_bias) + non_bias_terms + assert X_trans.dtype == X.dtype + assert X_trans.shape == (1, pf.n_output_features_) + assert X_trans.indptr.dtype == X_trans.indices.dtype == expected_dtype + assert X_trans.nnz == expected_nnz + + if include_bias: + assert X_trans[0, 0] == pytest.approx(1.0) + for idx in range(non_bias_terms): + assert X_trans[0, expected_indices[idx]] == pytest.approx(1.0) + + offset = interaction_only * n_features + if degree == 3: + offset *= 1 + n_features + assert pf.n_output_features_ == expected_indices[degree - 1] + 1 - offset + + +@pytest.mark.parametrize("interaction_only", [True, False]) +@pytest.mark.parametrize("include_bias", [True, False]) +def test_csr_polynomial_expansion_too_large_to_index(interaction_only, include_bias): + n_features = np.iinfo(np.int64).max // 2 + data = [1.0] + row = [0] + col = [n_features - 1] + X = sparse.csr_matrix((data, (row, col))) + pf = PolynomialFeatures( + interaction_only=interaction_only, include_bias=include_bias, degree=(2, 2) + ) + msg = ( + r"The output that would result from the current configuration would have \d*" + r" features which is too large to be indexed" + ) + with pytest.raises(ValueError, match=msg): + pf.fit(X) + with pytest.raises(ValueError, match=msg): + pf.fit_transform(X) + + def test_polynomial_features_behaviour_on_zero_degree(): """Check that PolynomialFeatures raises error when degree=0 and include_bias=False, and output a single constant column when include_bias=True @@ -817,3 +1080,62 @@ def test_polynomial_features_behaviour_on_zero_degree(): if sparse.issparse(output): output = output.toarray() assert_array_equal(output, np.ones((X.shape[0], 1))) + + +def test_sizeof_LARGEST_INT_t(): + # On Windows, scikit-learn is typically compiled with MSVC that + # does not support int128 arithmetic (at the time of writing): + # https://stackoverflow.com/a/6761962/163740 + if sys.platform == "win32" or ( + sys.maxsize <= 2**32 and sys.platform != "emscripten" + ): + expected_size = 8 + else: + expected_size = 16 + + assert _get_sizeof_LARGEST_INT_t() == expected_size + + +@pytest.mark.xfail( + sys.platform == "win32", + reason=( + "On Windows, scikit-learn is typically compiled with MSVC that does not support" + " int128 arithmetic (at the time of writing)" + ), + run=True, +) +def test_csr_polynomial_expansion_windows_fail(): + # Minimum needed to ensure integer overflow occurs while guaranteeing an + # int64-indexable output. + n_features = int(np.iinfo(np.int64).max ** (1 / 3) + 3) + data = [1.0] + row = [0] + col = [n_features - 1] + + # First degree index + expected_indices = [ + n_features - 1, + ] + # Second degree index + expected_indices.append( + int(n_features * (n_features + 1) // 2 + expected_indices[0]) + ) + # Third degree index + expected_indices.append( + int(n_features * (n_features + 1) * (n_features + 2) // 6 + expected_indices[1]) + ) + + X = sparse.csr_matrix((data, (row, col))) + pf = PolynomialFeatures(interaction_only=False, include_bias=False, degree=3) + if sys.maxsize <= 2**32: + msg = ( + r"The output that would result from the current configuration would" + r" have \d*" + r" features which is too large to be indexed" + ) + with pytest.raises(ValueError, match=msg): + pf.fit_transform(X) + else: + X_trans = pf.fit_transform(X) + for idx in range(3): + assert X_trans[0, expected_indices[idx]] == pytest.approx(1.0)