@@ -1666,23 +1666,21 @@ def fit(self, X, y=None):
16661666 )
16671667 if not sp .issparse (X ):
16681668 X = sp .csr_matrix (X )
1669- dtype = X .dtype if X .dtype in FLOAT_DTYPES else np .float64
1669+ dtype = X .dtype if X .dtype in ( np . float64 , np . float32 ) else np .float64
16701670
16711671 if self .use_idf :
1672- n_samples , n_features = X .shape
1672+ n_samples , _ = X .shape
16731673 df = _document_frequency (X )
16741674 df = df .astype (dtype , copy = False )
16751675
16761676 # perform idf smoothing if required
1677- df += int (self .smooth_idf )
1677+ df += float (self .smooth_idf )
16781678 n_samples += int (self .smooth_idf )
16791679
16801680 # log+1 instead of log makes sure terms with zero idf don't get
16811681 # suppressed entirely.
1682+ # `np.log` preserves the dtype of `df` and thus `dtype`.
16821683 self .idf_ = np .log (n_samples / df ) + 1.0
1683- # FIXME: for backward compatibility, we force idf_ to be np.float64
1684- # In the future, we should preserve the `dtype` of `X`.
1685- self .idf_ = self .idf_ .astype (np .float64 , copy = False )
16861684
16871685 return self
16881686
@@ -1705,14 +1703,18 @@ def transform(self, X, copy=True):
17051703 """
17061704 check_is_fitted (self )
17071705 X = self ._validate_data (
1708- X , accept_sparse = "csr" , dtype = FLOAT_DTYPES , copy = copy , reset = False
1706+ X ,
1707+ accept_sparse = "csr" ,
1708+ dtype = [np .float64 , np .float32 ],
1709+ copy = copy ,
1710+ reset = False ,
17091711 )
17101712 if not sp .issparse (X ):
1711- X = sp .csr_matrix (X , dtype = np . float64 )
1713+ X = sp .csr_matrix (X , dtype = X . dtype )
17121714
17131715 if self .sublinear_tf :
17141716 np .log (X .data , X .data )
1715- X .data += 1
1717+ X .data += 1.0
17161718
17171719 if hasattr (self , "idf_" ):
17181720 # the columns of X (CSR matrix) can be accessed with `X.indices `and
@@ -1725,7 +1727,12 @@ def transform(self, X, copy=True):
17251727 return X
17261728
17271729 def _more_tags (self ):
1728- return {"X_types" : ["2darray" , "sparse" ]}
1730+ return {
1731+ "X_types" : ["2darray" , "sparse" ],
1732+ # FIXME: np.float16 could be preserved if _inplace_csr_row_normalize_l2
1733+ # accepted it.
1734+ "preserves_dtype" : [np .float64 , np .float32 ],
1735+ }
17291736
17301737
17311738class TfidfVectorizer (CountVectorizer ):
0 commit comments