Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions sklearn/tests/test_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
from scipy import linalg

from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_array_equal, assert_no_warnings
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_raises
from sklearn.utils._testing import assert_raise_message
from sklearn.utils._testing import assert_warns
from sklearn.utils._testing import ignore_warnings

from sklearn.datasets import make_blobs
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
Expand Down Expand Up @@ -89,15 +85,22 @@ def test_lda_predict():

# Test invalid shrinkages
clf = LinearDiscriminantAnalysis(solver="lsqr", shrinkage=-0.2231)
assert_raises(ValueError, clf.fit, X, y)
with pytest.raises(ValueError):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="eigen", shrinkage="dummy")
assert_raises(ValueError, clf.fit, X, y)
with pytest.raises(ValueError):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="svd", shrinkage="auto")
assert_raises(NotImplementedError, clf.fit, X, y)
with pytest.raises(NotImplementedError):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="lsqr", shrinkage=np.array([1, 2]))
with pytest.raises(TypeError,
match="shrinkage must be a float or a string"):
clf.fit(X, y)

clf = LinearDiscriminantAnalysis(solver="lsqr",
shrinkage=0.1,
covariance_estimator=ShrunkCovariance())
Expand All @@ -106,9 +109,11 @@ def test_lda_predict():
"parameters are not None. "
"Only one of the two can be set.")):
clf.fit(X, y)

# Test unknown solver
clf = LinearDiscriminantAnalysis(solver="dummy")
assert_raises(ValueError, clf.fit, X, y)
with pytest.raises(ValueError):
clf.fit(X, y)

# test bad solver with covariance_estimator
clf = LinearDiscriminantAnalysis(solver="svd",
Expand Down Expand Up @@ -199,7 +204,9 @@ def test_lda_priors():
priors = np.array([0.5, -0.5])
clf = LinearDiscriminantAnalysis(priors=priors)
msg = "priors must be non-negative"
assert_raise_message(ValueError, msg, clf.fit, X, y)

with pytest.raises(ValueError, match=msg):
clf.fit(X, y)

# Test that priors passed as a list are correctly handled (run to see if
# failure)
Expand All @@ -210,7 +217,10 @@ def test_lda_priors():
priors = np.array([0.5, 0.6])
prior_norm = np.array([0.45, 0.55])
clf = LinearDiscriminantAnalysis(priors=priors)
assert_warns(UserWarning, clf.fit, X, y)

with pytest.warns(UserWarning):
clf.fit(X, y)

assert_array_almost_equal(clf.priors_, prior_norm, 2)


Expand Down Expand Up @@ -247,7 +257,9 @@ def test_lda_transform():
clf = LinearDiscriminantAnalysis(solver="lsqr", n_components=1)
clf.fit(X, y)
msg = "transform not implemented for 'lsqr'"
assert_raise_message(NotImplementedError, msg, clf.transform, X)

with pytest.raises(NotImplementedError, match=msg):
clf.transform(X)


def test_lda_explained_variance_ratio():
Expand Down Expand Up @@ -424,7 +436,8 @@ def test_lda_dimension_warning(n_classes, n_features):
for n_components in [max_components - 1, None, max_components]:
# if n_components <= min(n_classes - 1, n_features), no warning
lda = LinearDiscriminantAnalysis(n_components=n_components)
assert_no_warnings(lda.fit, X, y)
with pytest.warns(None):
lda.fit(X, y)

for n_components in [max_components + 1,
max(n_features, n_classes - 1) + 1]:
Expand Down Expand Up @@ -486,7 +499,8 @@ def test_qda():
assert np.any(y_pred3 != y7)

# Classes should have at least 2 elements
assert_raises(ValueError, clf.fit, X6, y4)
with pytest.raises(ValueError):
clf.fit(X6, y4)


def test_qda_priors():
Expand Down Expand Up @@ -523,23 +537,36 @@ def test_qda_store_covariance():


def test_qda_regularization():
# the default is reg_param=0. and will cause issues
# when there is a constant variable
# The default is reg_param=0. and will cause issues when there is a
# constant variable.

# Fitting on data with constant variable triggers an UserWarning.
collinear_msg = "Variables are collinear"
clf = QuadraticDiscriminantAnalysis()
with ignore_warnings():
y_pred = clf.fit(X2, y6).predict(X2)
with pytest.warns(UserWarning, match=collinear_msg):
y_pred = clf.fit(X2, y6)

# XXX: RuntimeWarning is also raised at predict time because of divisions
# by zero when the model is fit with a constant feature and without
# regularization: should this be considered a bug? Either by the fit-time
# message more informative, raising and exception instead of a warning in
# this case or somehow changing predict to avoid division by zero.
with pytest.warns(RuntimeWarning, match="divide by zero"):
y_pred = clf.predict(X2)
assert np.any(y_pred != y6)

# adding a little regularization fixes the problem
# Adding a little regularization fixes the division by zero at predict
# time. But UserWarning will persist at fit time.
clf = QuadraticDiscriminantAnalysis(reg_param=0.01)
with ignore_warnings():
with pytest.warns(UserWarning, match=collinear_msg):
clf.fit(X2, y6)
y_pred = clf.predict(X2)
assert_array_equal(y_pred, y6)

# Case n_samples_in_a_class < n_features
# UserWarning should also be there for the n_samples_in_a_class <
# n_features case.
clf = QuadraticDiscriminantAnalysis(reg_param=0.1)
with ignore_warnings():
with pytest.warns(UserWarning, match=collinear_msg):
clf.fit(X5, y5)
y_pred5 = clf.predict(X5)
assert_array_equal(y_pred5, y5)
Expand Down