diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 4c307d9e54250..0240ad706b6ca 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -503,6 +503,9 @@ Changelog messages when optimizers produce non-finite parameter weights. :pr:`22150` by :user:`Christian Ritter ` and :user:`Norbert Preining `. +- |Enhancement| Adds :term:`get_feature_names_out` to + :class:`neural_network.BernoulliRBM`. :pr:`22248` by `Thomas Fan`_. + :mod:`sklearn.pipeline` ....................... diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index 6a6cb67f17de0..aac92c3108787 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -15,6 +15,7 @@ from ..base import BaseEstimator from ..base import TransformerMixin +from ..base import _ClassNamePrefixFeaturesOutMixin from ..utils import check_random_state from ..utils import gen_even_slices from ..utils.extmath import safe_sparse_dot @@ -22,7 +23,7 @@ from ..utils.validation import check_is_fitted -class BernoulliRBM(TransformerMixin, BaseEstimator): +class BernoulliRBM(_ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator): """Bernoulli Restricted Boltzmann Machine (RBM). A Restricted Boltzmann Machine with binary visible units and @@ -284,6 +285,7 @@ def partial_fit(self, X, y=None): self.random_state_.normal(0, 0.01, (self.n_components, X.shape[1])), order="F", ) + self._n_features_out = self.components_.shape[0] if not hasattr(self, "intercept_hidden_"): self.intercept_hidden_ = np.zeros( self.n_components, @@ -389,6 +391,7 @@ def fit(self, X, y=None): order="F", dtype=X.dtype, ) + self._n_features_out = self.components_.shape[0] self.intercept_hidden_ = np.zeros(self.n_components, dtype=X.dtype) self.intercept_visible_ = np.zeros(X.shape[1], dtype=X.dtype) self.h_samples_ = np.zeros((self.batch_size, self.n_components), dtype=X.dtype) diff --git a/sklearn/neural_network/tests/test_rbm.py b/sklearn/neural_network/tests/test_rbm.py index aadae44479ad5..d36fa6b0bd11f 100644 --- a/sklearn/neural_network/tests/test_rbm.py +++ b/sklearn/neural_network/tests/test_rbm.py @@ -238,3 +238,15 @@ def test_convergence_dtype_consistency(): ) assert_allclose(rbm_64.components_, rbm_32.components_, rtol=1e-03, atol=0) assert_allclose(rbm_64.h_samples_, rbm_32.h_samples_) + + +@pytest.mark.parametrize("method", ["fit", "partial_fit"]) +def test_feature_names_out(method): + """Check `get_feature_names_out` for `BernoulliRBM`.""" + n_components = 10 + rbm = BernoulliRBM(n_components=n_components) + getattr(rbm, method)(Xdigits) + + names = rbm.get_feature_names_out() + expected_names = [f"bernoullirbm{i}" for i in range(n_components)] + assert_array_equal(expected_names, names) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index fb4de8942f131..be26202d458d1 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -383,7 +383,6 @@ def test_pandas_column_name_consistency(estimator): "ensemble", "kernel_approximation", "preprocessing", - "neural_network", ]