Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e010e4f

Browse files
MAINT make AdditiveChi2Sampler stateless and check that stateless Transformers don't raise NotFittedError (#25190)
* MAINT make AdditiveChi2Sampler stateless * apply feedbacks * fix tests * apply new suggestions * typo * iterate on tests * typo in docstring * fix docstring * update changelog * improve test coverage * add test_common * improve coverage by removing from test_common * move stateless check into _yield_transformer_checks * apply suggestions * Apply suggestions from code review Co-authored-by: Julien Jerphanion <[email protected]> * apply suggestions * Update doc/whats_new/v1.3.rst Co-authored-by: Julien Jerphanion <[email protected]> Co-authored-by: Julien Jerphanion <[email protected]>
1 parent 134c492 commit e010e4f

File tree

4 files changed

+137
-50
lines changed

4 files changed

+137
-50
lines changed

doc/whats_new/v1.3.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ Changelog
174174
:pr:`24935` by :user:`Seladus <seladus>`, :user:`Guillaume Lemaitre <glemaitre>`, and
175175
:user:`Dea María Léon <deamarialeon>`, :pr:`25257` by :user:`Gleb Levitski <glevv>`.
176176

177+
- |Fix| :class:`AdditiveChi2Sampler` is now stateless.
178+
The `sample_interval_` attribute is deprecated and will be removed in 1.5.
179+
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
180+
181+
:mod:`sklearn.utils`
182+
....................
183+
184+
- |API| :func:`estimator_checks.check_transformers_unfitted_stateless` has been
185+
introduced to ensure stateless transformers don't raise `NotFittedError`
186+
during `transform` with no prior call to `fit` or `fit_transform`.
187+
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
188+
177189
Code and Documentation Contributors
178190
-----------------------------------
179191

sklearn/kernel_approximation.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .base import TransformerMixin
2525
from .base import ClassNamePrefixFeaturesOutMixin
2626
from .utils import check_random_state
27+
from .utils import deprecated
2728
from .utils.extmath import safe_sparse_dot
2829
from .utils.validation import check_is_fitted
2930
from .utils.validation import _check_feature_names_in
@@ -600,6 +601,9 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
600601
Stored sampling interval. Specified as a parameter if `sample_steps`
601602
not in {1,2,3}.
602603
604+
.. deprecated:: 1.3
605+
`sample_interval_` serves internal purposes only and will be removed in 1.5.
606+
603607
n_features_in_ : int
604608
Number of features seen during :term:`fit`.
605609
@@ -626,6 +630,10 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
626630
This estimator approximates a slightly different version of the additive
627631
chi squared kernel then ``metric.additive_chi2`` computes.
628632
633+
This estimator is stateless and does not need to be fitted. However, we
634+
recommend to call :meth:`fit_transform` instead of :meth:`transform`, as
635+
parameter validation is only performed in :meth:`fit`.
636+
629637
References
630638
----------
631639
See `"Efficient additive kernels via explicit feature maps"
@@ -658,7 +666,10 @@ def __init__(self, *, sample_steps=2, sample_interval=None):
658666
self.sample_interval = sample_interval
659667

660668
def fit(self, X, y=None):
661-
"""Set the parameters.
669+
"""Only validates estimator's parameters.
670+
671+
This method allows to: (i) validate the estimator's parameters and
672+
(ii) be consistent with the scikit-learn transformer API.
662673
663674
Parameters
664675
----------
@@ -676,27 +687,40 @@ def fit(self, X, y=None):
676687
Returns the transformer.
677688
"""
678689
self._validate_params()
679-
680690
X = self._validate_data(X, accept_sparse="csr")
681691
check_non_negative(X, "X in AdditiveChi2Sampler.fit")
682692

693+
# TODO(1.5): remove the setting of _sample_interval from fit
683694
if self.sample_interval is None:
684-
# See reference, figure 2 c)
695+
# See figure 2 c) of "Efficient additive kernels via explicit feature maps"
696+
# <http://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
697+
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence,
698+
# 2011
685699
if self.sample_steps == 1:
686-
self.sample_interval_ = 0.8
700+
self._sample_interval = 0.8
687701
elif self.sample_steps == 2:
688-
self.sample_interval_ = 0.5
702+
self._sample_interval = 0.5
689703
elif self.sample_steps == 3:
690-
self.sample_interval_ = 0.4
704+
self._sample_interval = 0.4
691705
else:
692706
raise ValueError(
693707
"If sample_steps is not in [1, 2, 3],"
694708
" you need to provide sample_interval"
695709
)
696710
else:
697-
self.sample_interval_ = self.sample_interval
711+
self._sample_interval = self.sample_interval
712+
698713
return self
699714

715+
# TODO(1.5): remove
716+
@deprecated( # type: ignore
717+
"The ``sample_interval_`` attribute was deprecated in version 1.3 and "
718+
"will be removed 1.5."
719+
)
720+
@property
721+
def sample_interval_(self):
722+
return self._sample_interval
723+
700724
def transform(self, X):
701725
"""Apply approximate feature map to X.
702726
@@ -713,22 +737,39 @@ def transform(self, X):
713737
Whether the return value is an array or sparse matrix depends on
714738
the type of the input X.
715739
"""
716-
msg = (
717-
"%(name)s is not fitted. Call fit to set the parameters before"
718-
" calling transform"
719-
)
720-
check_is_fitted(self, msg=msg)
721-
722740
X = self._validate_data(X, accept_sparse="csr", reset=False)
723741
check_non_negative(X, "X in AdditiveChi2Sampler.transform")
724742
sparse = sp.issparse(X)
725743

744+
if hasattr(self, "_sample_interval"):
745+
# TODO(1.5): remove this branch
746+
sample_interval = self._sample_interval
747+
748+
else:
749+
if self.sample_interval is None:
750+
# See figure 2 c) of "Efficient additive kernels via explicit feature maps" # noqa
751+
# <http://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
752+
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence, # noqa
753+
# 2011
754+
if self.sample_steps == 1:
755+
sample_interval = 0.8
756+
elif self.sample_steps == 2:
757+
sample_interval = 0.5
758+
elif self.sample_steps == 3:
759+
sample_interval = 0.4
760+
else:
761+
raise ValueError(
762+
"If sample_steps is not in [1, 2, 3],"
763+
" you need to provide sample_interval"
764+
)
765+
else:
766+
sample_interval = self.sample_interval
767+
726768
# zeroth component
727769
# 1/cosh = sech
728770
# cosh(0) = 1.0
729-
730771
transf = self._transform_sparse if sparse else self._transform_dense
731-
return transf(X)
772+
return transf(X, self.sample_steps, sample_interval)
732773

733774
def get_feature_names_out(self, input_features=None):
734775
"""Get output feature names for transformation.
@@ -758,20 +799,21 @@ def get_feature_names_out(self, input_features=None):
758799

759800
return np.asarray(names_list, dtype=object)
760801

761-
def _transform_dense(self, X):
802+
@staticmethod
803+
def _transform_dense(X, sample_steps, sample_interval):
762804
non_zero = X != 0.0
763805
X_nz = X[non_zero]
764806

765807
X_step = np.zeros_like(X)
766-
X_step[non_zero] = np.sqrt(X_nz * self.sample_interval_)
808+
X_step[non_zero] = np.sqrt(X_nz * sample_interval)
767809

768810
X_new = [X_step]
769811

770-
log_step_nz = self.sample_interval_ * np.log(X_nz)
771-
step_nz = 2 * X_nz * self.sample_interval_
812+
log_step_nz = sample_interval * np.log(X_nz)
813+
step_nz = 2 * X_nz * sample_interval
772814

773-
for j in range(1, self.sample_steps):
774-
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * self.sample_interval_))
815+
for j in range(1, sample_steps):
816+
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * sample_interval))
775817

776818
X_step = np.zeros_like(X)
777819
X_step[non_zero] = factor_nz * np.cos(j * log_step_nz)
@@ -783,21 +825,22 @@ def _transform_dense(self, X):
783825

784826
return np.hstack(X_new)
785827

786-
def _transform_sparse(self, X):
828+
@staticmethod
829+
def _transform_sparse(X, sample_steps, sample_interval):
787830
indices = X.indices.copy()
788831
indptr = X.indptr.copy()
789832

790-
data_step = np.sqrt(X.data * self.sample_interval_)
833+
data_step = np.sqrt(X.data * sample_interval)
791834
X_step = sp.csr_matrix(
792835
(data_step, indices, indptr), shape=X.shape, dtype=X.dtype, copy=False
793836
)
794837
X_new = [X_step]
795838

796-
log_step_nz = self.sample_interval_ * np.log(X.data)
797-
step_nz = 2 * X.data * self.sample_interval_
839+
log_step_nz = sample_interval * np.log(X.data)
840+
step_nz = 2 * X.data * sample_interval
798841

799-
for j in range(1, self.sample_steps):
800-
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * self.sample_interval_))
842+
for j in range(1, sample_steps):
843+
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * sample_interval))
801844

802845
data_step = factor_nz * np.cos(j * log_step_nz)
803846
X_step = sp.csr_matrix(

sklearn/tests/test_kernel_approximation.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,34 +114,49 @@ def test_additive_chi2_sampler():
114114
Y_neg[0, 0] = -1
115115
msg = "Negative values in data passed to"
116116
with pytest.raises(ValueError, match=msg):
117-
transform.transform(Y_neg)
117+
transform.fit(Y_neg)
118+
119+
120+
@pytest.mark.parametrize("method", ["fit", "fit_transform", "transform"])
121+
@pytest.mark.parametrize("sample_steps", range(1, 4))
122+
def test_additive_chi2_sampler_sample_steps(method, sample_steps):
123+
"""Check that the input sample step doesn't raise an error
124+
and that sample interval doesn't change after fit.
125+
"""
126+
transformer = AdditiveChi2Sampler(sample_steps=sample_steps)
127+
getattr(transformer, method)(X)
128+
129+
sample_interval = 0.5
130+
transformer = AdditiveChi2Sampler(
131+
sample_steps=sample_steps,
132+
sample_interval=sample_interval,
133+
)
134+
getattr(transformer, method)(X)
135+
transformer.sample_interval == sample_interval
136+
137+
138+
# TODO(1.5): remove
139+
def test_additive_chi2_sampler_future_warnings():
140+
"""Check that we raise a FutureWarning when accessing to `sample_interval_`."""
141+
transformer = AdditiveChi2Sampler()
142+
transformer.fit(X)
143+
msg = re.escape(
144+
"The ``sample_interval_`` attribute was deprecated in version 1.3 and "
145+
"will be removed 1.5."
146+
)
147+
with pytest.warns(FutureWarning, match=msg):
148+
assert transformer.sample_interval_ is not None
149+
118150

119-
# test error on invalid sample_steps
120-
transform = AdditiveChi2Sampler(sample_steps=4)
151+
@pytest.mark.parametrize("method", ["fit", "fit_transform", "transform"])
152+
def test_additive_chi2_sampler_wrong_sample_steps(method):
153+
"""Check that we raise a ValueError on invalid sample_steps"""
154+
transformer = AdditiveChi2Sampler(sample_steps=4)
121155
msg = re.escape(
122156
"If sample_steps is not in [1, 2, 3], you need to provide sample_interval"
123157
)
124158
with pytest.raises(ValueError, match=msg):
125-
transform.fit(X)
126-
127-
# test that the sample interval is set correctly
128-
sample_steps_available = [1, 2, 3]
129-
for sample_steps in sample_steps_available:
130-
131-
# test that the sample_interval is initialized correctly
132-
transform = AdditiveChi2Sampler(sample_steps=sample_steps)
133-
assert transform.sample_interval is None
134-
135-
# test that the sample_interval is changed in the fit method
136-
transform.fit(X)
137-
assert transform.sample_interval_ is not None
138-
139-
# test that the sample_interval is set correctly
140-
sample_interval = 0.3
141-
transform = AdditiveChi2Sampler(sample_steps=4, sample_interval=sample_interval)
142-
assert transform.sample_interval == sample_interval
143-
transform.fit(X)
144-
assert transform.sample_interval_ == sample_interval
159+
getattr(transformer, method)(X)
145160

146161

147162
def test_skewed_chi2_sampler():

sklearn/utils/estimator_checks.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ def _yield_transformer_checks(transformer):
243243
yield partial(check_transformer_general, readonly_memmap=True)
244244
if not _safe_tags(transformer, key="stateless"):
245245
yield check_transformers_unfitted
246+
else:
247+
yield check_transformers_unfitted_stateless
246248
# Dependent on external solvers and hence accessing the iter
247249
# param is non-trivial.
248250
external_solver = [
@@ -1554,6 +1556,21 @@ def check_transformers_unfitted(name, transformer):
15541556
transformer.transform(X)
15551557

15561558

1559+
@ignore_warnings(category=FutureWarning)
1560+
def check_transformers_unfitted_stateless(name, transformer):
1561+
"""Check that using transform without prior fitting
1562+
doesn't raise a NotFittedError for stateless transformers.
1563+
"""
1564+
rng = np.random.RandomState(0)
1565+
X = rng.uniform(size=(20, 5))
1566+
X = _enforce_estimator_tags_X(transformer, X)
1567+
1568+
transformer = clone(transformer)
1569+
X_trans = transformer.transform(X)
1570+
1571+
assert X_trans.shape[0] == X.shape[0]
1572+
1573+
15571574
def _check_transformer(name, transformer_orig, X, y):
15581575
n_samples, n_features = np.asarray(X).shape
15591576
transformer = clone(transformer_orig)

0 commit comments

Comments
 (0)