-
-
Notifications
You must be signed in to change notification settings - Fork 26.6k
Closed
Description
Description
The user passed parameter n_components on the LDA object may be silently overwritten. Provide warning for better user experience. #6355
Steps/Code to Reproduce
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
n = ... # some int
X = [....] # q x r
Y = [....] # 1 x q, st. num_classes less than n
model = LinearDiscriminantAnalysis(n_components=n)
Xp = model.fit(X, Y).transform(X)
# Try to use Xp, and what!?!? Xp is of shape q x (num_classes - 1) not q x n as expected
raise ConfusionError()Fix in discriminant_analysis.py
import warnings
from .exceptions import ChangedBehaviorWarning
...
# https://github.com/scikit-learn/scikit-learn/blob/45dc891c96eebdb3b81bf14c2737d8f6540fabfe/sklearn/discriminant_analysis.py#L447
self._max_components = self.n_components
potential_components = len(self.classes_) - 1
if self._max_components is None:
self._max_components = potential_components
elif self._max_components < potential_components:
warnings.warn("Using a component size of %d due to invalid n_component" % potential_components, ChangedBehaviorWarning)
self._max_components = potential_componentsMetadata
Metadata
Assignees
Labels
No labels