diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index b4d519230ca2d..6313bda07a160 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -27,6 +27,9 @@ random sampling procedures. - :class:`linear_model.LogisticRegression` and :class:`linear_model.LogisticRegressionCV` with 'saga' solver. |Fix| - :class:`ensemble.GradientBoostingClassifier` |Fix| +- :class:`sklearn.feature_extraction.text.HashingVectorizer`, + :class:`sklearn.feature_extraction.text.TfidfVectorizer`, and + :class:`sklearn.feature_extraction.text.CountVectorizer` |API| - :class:`neural_network.MLPClassifier` |Fix| - :func:`svm.SVC.decision_function` and :func:`multiclass.OneVsOneClassifier.decision_function`. |Fix| @@ -265,6 +268,17 @@ Support for Python 3.4 and below has been officially dropped. - |API| Deprecated :mod:`externals.six` since we have dropped support for Python 2.7. :issue:`12916` by :user:`Hanmin Qin `. +:mod:`sklearn.feature_extraction` +................................. + +- |API| If ``input='file'`` or ``input='filename'``, and a callable is given + as the ``analyzer``, :class:`sklearn.feature_extraction.text.HashingVectorizer`, + :class:`sklearn.feature_extraction.text.TfidfVectorizer`, and + :class:`sklearn.feature_extraction.text.CountVectorizer` now read the data + from the file(s) and then pass it to the given ``analyzer``, instead of + passing the file name(s) or the file object(s) to the analyzer. + :issue:`13641` by `Adrin Jalali`_. + :mod:`sklearn.impute` ..................... diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index b02aa2aea46af..bfd9f5f2f4ffe 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -29,6 +29,7 @@ from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_equal from sklearn.utils import IS_PYPY +from sklearn.exceptions import ChangedBehaviorWarning from sklearn.utils.testing import (assert_equal, assert_not_equal, assert_almost_equal, assert_in, assert_less, assert_greater, @@ -1196,3 +1197,47 @@ def build_preprocessor(self): .findall(doc), stop_words=['and']) assert _check_stop_words_consistency(vec) is True + + +@pytest.mark.parametrize('Estimator', + [CountVectorizer, TfidfVectorizer, HashingVectorizer]) +@pytest.mark.parametrize( + 'input_type, err_type, err_msg', + [('filename', FileNotFoundError, ''), + ('file', AttributeError, "'str' object has no attribute 'read'")] +) +def test_callable_analyzer_error(Estimator, input_type, err_type, err_msg): + data = ['this is text, not file or filename'] + with pytest.raises(err_type, match=err_msg): + Estimator(analyzer=lambda x: x.split(), + input=input_type).fit_transform(data) + + +@pytest.mark.parametrize('Estimator', + [CountVectorizer, TfidfVectorizer, HashingVectorizer]) +@pytest.mark.parametrize( + 'analyzer', [lambda doc: open(doc, 'r'), lambda doc: doc.read()] +) +@pytest.mark.parametrize('input_type', ['file', 'filename']) +def test_callable_analyzer_change_behavior(Estimator, analyzer, input_type): + data = ['this is text, not file or filename'] + warn_msg = 'Since v0.21, vectorizer' + with pytest.raises((FileNotFoundError, AttributeError)): + with pytest.warns(ChangedBehaviorWarning, match=warn_msg) as records: + Estimator(analyzer=analyzer, input=input_type).fit_transform(data) + assert len(records) == 1 + assert warn_msg in str(records[0]) + + +@pytest.mark.parametrize('Estimator', + [CountVectorizer, TfidfVectorizer, HashingVectorizer]) +def test_callable_analyzer_reraise_error(tmpdir, Estimator): + # check if a custom exception from the analyzer is shown to the user + def analyzer(doc): + raise Exception("testing") + + f = tmpdir.join("file.txt") + f.write("sample content\n") + + with pytest.raises(Exception, match="testing"): + Estimator(analyzer=analyzer, input='file').fit_transform([f]) diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index 9cdbace6224aa..007e158f3a449 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -31,6 +31,7 @@ from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES from ..utils import _IS_32BIT from ..utils.fixes import _astype_copy_false +from ..exceptions import ChangedBehaviorWarning __all__ = ['HashingVectorizer', @@ -304,10 +305,34 @@ def _check_stop_words_consistency(self, stop_words, preprocess, tokenize): self._stop_words_id = id(self.stop_words) return 'error' + def _validate_custom_analyzer(self): + # This is to check if the given custom analyzer expects file or a + # filename instead of data. + # Behavior changed in v0.21, function could be removed in v0.23 + import tempfile + with tempfile.NamedTemporaryFile() as f: + fname = f.name + # now we're sure fname doesn't exist + + msg = ("Since v0.21, vectorizers pass the data to the custom analyzer " + "and not the file names or the file objects. This warning " + "will be removed in v0.23.") + try: + self.analyzer(fname) + except FileNotFoundError: + warnings.warn(msg, ChangedBehaviorWarning) + except AttributeError as e: + if str(e) == "'str' object has no attribute 'read'": + warnings.warn(msg, ChangedBehaviorWarning) + except Exception: + pass + def build_analyzer(self): """Return a callable that handles preprocessing and tokenization""" if callable(self.analyzer): - return self.analyzer + if self.input in ['file', 'filename']: + self._validate_custom_analyzer() + return lambda doc: self.analyzer(self.decode(doc)) preprocess = self.build_preprocessor() @@ -490,6 +515,11 @@ class HashingVectorizer(BaseEstimator, VectorizerMixin, TransformerMixin): If a callable is passed it is used to extract the sequence of features out of the raw, unprocessed input. + .. versionchanged:: 0.21 + Since v0.21, if ``input`` is ``filename`` or ``file``, the data is + first read from the file and then passed to the given callable + analyzer. + n_features : integer, default=(2 ** 20) The number of features (columns) in the output matrices. Small numbers of features are likely to cause hash collisions, but large numbers @@ -745,6 +775,11 @@ class CountVectorizer(BaseEstimator, VectorizerMixin): If a callable is passed it is used to extract the sequence of features out of the raw, unprocessed input. + .. versionchanged:: 0.21 + Since v0.21, if ``input`` is ``filename`` or ``file``, the data is + first read from the file and then passed to the given callable + analyzer. + max_df : float in range [0.0, 1.0] or int, default=1.0 When building the vocabulary ignore terms that have a document frequency strictly higher than the given threshold (corpus-specific @@ -1369,6 +1404,11 @@ class TfidfVectorizer(CountVectorizer): If a callable is passed it is used to extract the sequence of features out of the raw, unprocessed input. + .. versionchanged:: 0.21 + Since v0.21, if ``input`` is ``filename`` or ``file``, the data is + first read from the file and then passed to the given callable + analyzer. + stop_words : string {'english'}, list, or None (default=None) If a string, it is passed to _check_stop_list and the appropriate stop list is returned. 'english' is currently the only supported string