24
24
from .base import TransformerMixin
25
25
from .base import ClassNamePrefixFeaturesOutMixin
26
26
from .utils import check_random_state
27
+ from .utils import deprecated
27
28
from .utils .extmath import safe_sparse_dot
28
29
from .utils .validation import check_is_fitted
29
30
from .utils .validation import _check_feature_names_in
@@ -600,6 +601,9 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
600
601
Stored sampling interval. Specified as a parameter if `sample_steps`
601
602
not in {1,2,3}.
602
603
604
+ .. deprecated:: 1.3
605
+ `sample_interval_` serves internal purposes only and will be removed in 1.5.
606
+
603
607
n_features_in_ : int
604
608
Number of features seen during :term:`fit`.
605
609
@@ -626,6 +630,10 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
626
630
This estimator approximates a slightly different version of the additive
627
631
chi squared kernel then ``metric.additive_chi2`` computes.
628
632
633
+ This estimator is stateless and does not need to be fitted. However, we
634
+ recommend to call :meth:`fit_transform` instead of :meth:`transform`, as
635
+ parameter validation is only performed in :meth:`fit`.
636
+
629
637
References
630
638
----------
631
639
See `"Efficient additive kernels via explicit feature maps"
@@ -658,7 +666,10 @@ def __init__(self, *, sample_steps=2, sample_interval=None):
658
666
self .sample_interval = sample_interval
659
667
660
668
def fit (self , X , y = None ):
661
- """Set the parameters.
669
+ """Only validates estimator's parameters.
670
+
671
+ This method allows to: (i) validate the estimator's parameters and
672
+ (ii) be consistent with the scikit-learn transformer API.
662
673
663
674
Parameters
664
675
----------
@@ -676,27 +687,40 @@ def fit(self, X, y=None):
676
687
Returns the transformer.
677
688
"""
678
689
self ._validate_params ()
679
-
680
690
X = self ._validate_data (X , accept_sparse = "csr" )
681
691
check_non_negative (X , "X in AdditiveChi2Sampler.fit" )
682
692
693
+ # TODO(1.5): remove the setting of _sample_interval from fit
683
694
if self .sample_interval is None :
684
- # See reference, figure 2 c)
695
+ # See figure 2 c) of "Efficient additive kernels via explicit feature maps"
696
+ # <http://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
697
+ # A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence,
698
+ # 2011
685
699
if self .sample_steps == 1 :
686
- self .sample_interval_ = 0.8
700
+ self ._sample_interval = 0.8
687
701
elif self .sample_steps == 2 :
688
- self .sample_interval_ = 0.5
702
+ self ._sample_interval = 0.5
689
703
elif self .sample_steps == 3 :
690
- self .sample_interval_ = 0.4
704
+ self ._sample_interval = 0.4
691
705
else :
692
706
raise ValueError (
693
707
"If sample_steps is not in [1, 2, 3],"
694
708
" you need to provide sample_interval"
695
709
)
696
710
else :
697
- self .sample_interval_ = self .sample_interval
711
+ self ._sample_interval = self .sample_interval
712
+
698
713
return self
699
714
715
+ # TODO(1.5): remove
716
+ @deprecated ( # type: ignore
717
+ "The ``sample_interval_`` attribute was deprecated in version 1.3 and "
718
+ "will be removed 1.5."
719
+ )
720
+ @property
721
+ def sample_interval_ (self ):
722
+ return self ._sample_interval
723
+
700
724
def transform (self , X ):
701
725
"""Apply approximate feature map to X.
702
726
@@ -713,22 +737,39 @@ def transform(self, X):
713
737
Whether the return value is an array or sparse matrix depends on
714
738
the type of the input X.
715
739
"""
716
- msg = (
717
- "%(name)s is not fitted. Call fit to set the parameters before"
718
- " calling transform"
719
- )
720
- check_is_fitted (self , msg = msg )
721
-
722
740
X = self ._validate_data (X , accept_sparse = "csr" , reset = False )
723
741
check_non_negative (X , "X in AdditiveChi2Sampler.transform" )
724
742
sparse = sp .issparse (X )
725
743
744
+ if hasattr (self , "_sample_interval" ):
745
+ # TODO(1.5): remove this branch
746
+ sample_interval = self ._sample_interval
747
+
748
+ else :
749
+ if self .sample_interval is None :
750
+ # See figure 2 c) of "Efficient additive kernels via explicit feature maps" # noqa
751
+ # <http://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
752
+ # A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence, # noqa
753
+ # 2011
754
+ if self .sample_steps == 1 :
755
+ sample_interval = 0.8
756
+ elif self .sample_steps == 2 :
757
+ sample_interval = 0.5
758
+ elif self .sample_steps == 3 :
759
+ sample_interval = 0.4
760
+ else :
761
+ raise ValueError (
762
+ "If sample_steps is not in [1, 2, 3],"
763
+ " you need to provide sample_interval"
764
+ )
765
+ else :
766
+ sample_interval = self .sample_interval
767
+
726
768
# zeroth component
727
769
# 1/cosh = sech
728
770
# cosh(0) = 1.0
729
-
730
771
transf = self ._transform_sparse if sparse else self ._transform_dense
731
- return transf (X )
772
+ return transf (X , self . sample_steps , sample_interval )
732
773
733
774
def get_feature_names_out (self , input_features = None ):
734
775
"""Get output feature names for transformation.
@@ -758,20 +799,21 @@ def get_feature_names_out(self, input_features=None):
758
799
759
800
return np .asarray (names_list , dtype = object )
760
801
761
- def _transform_dense (self , X ):
802
+ @staticmethod
803
+ def _transform_dense (X , sample_steps , sample_interval ):
762
804
non_zero = X != 0.0
763
805
X_nz = X [non_zero ]
764
806
765
807
X_step = np .zeros_like (X )
766
- X_step [non_zero ] = np .sqrt (X_nz * self . sample_interval_ )
808
+ X_step [non_zero ] = np .sqrt (X_nz * sample_interval )
767
809
768
810
X_new = [X_step ]
769
811
770
- log_step_nz = self . sample_interval_ * np .log (X_nz )
771
- step_nz = 2 * X_nz * self . sample_interval_
812
+ log_step_nz = sample_interval * np .log (X_nz )
813
+ step_nz = 2 * X_nz * sample_interval
772
814
773
- for j in range (1 , self . sample_steps ):
774
- factor_nz = np .sqrt (step_nz / np .cosh (np .pi * j * self . sample_interval_ ))
815
+ for j in range (1 , sample_steps ):
816
+ factor_nz = np .sqrt (step_nz / np .cosh (np .pi * j * sample_interval ))
775
817
776
818
X_step = np .zeros_like (X )
777
819
X_step [non_zero ] = factor_nz * np .cos (j * log_step_nz )
@@ -783,21 +825,22 @@ def _transform_dense(self, X):
783
825
784
826
return np .hstack (X_new )
785
827
786
- def _transform_sparse (self , X ):
828
+ @staticmethod
829
+ def _transform_sparse (X , sample_steps , sample_interval ):
787
830
indices = X .indices .copy ()
788
831
indptr = X .indptr .copy ()
789
832
790
- data_step = np .sqrt (X .data * self . sample_interval_ )
833
+ data_step = np .sqrt (X .data * sample_interval )
791
834
X_step = sp .csr_matrix (
792
835
(data_step , indices , indptr ), shape = X .shape , dtype = X .dtype , copy = False
793
836
)
794
837
X_new = [X_step ]
795
838
796
- log_step_nz = self . sample_interval_ * np .log (X .data )
797
- step_nz = 2 * X .data * self . sample_interval_
839
+ log_step_nz = sample_interval * np .log (X .data )
840
+ step_nz = 2 * X .data * sample_interval
798
841
799
- for j in range (1 , self . sample_steps ):
800
- factor_nz = np .sqrt (step_nz / np .cosh (np .pi * j * self . sample_interval_ ))
842
+ for j in range (1 , sample_steps ):
843
+ factor_nz = np .sqrt (step_nz / np .cosh (np .pi * j * sample_interval ))
801
844
802
845
data_step = factor_nz * np .cos (j * log_step_nz )
803
846
X_step = sp .csr_matrix (
0 commit comments