Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2489035

Browse files
glemaitrerth
authored andcommitted
TST: replace ignore_warnings with specific filterwarning in SAG (#11606)
1 parent 4d0a262 commit 2489035

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

sklearn/linear_model/tests/test_sag.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: BSD 3 clause
55

66
import math
7+
import pytest
78
import numpy as np
89
import scipy.sparse as sp
910

@@ -20,7 +21,6 @@
2021
from sklearn.utils.testing import assert_allclose
2122
from sklearn.utils.testing import assert_greater
2223
from sklearn.utils.testing import assert_raise_message
23-
from sklearn.utils.testing import ignore_warnings
2424
from sklearn.utils import compute_class_weight
2525
from sklearn.utils import check_random_state
2626
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
@@ -231,7 +231,6 @@ def get_step_size(X, alpha, fit_intercept, classification=True):
231231
return 1.0 / (np.max(np.sum(X * X, axis=1)) + fit_intercept + alpha)
232232

233233

234-
@ignore_warnings
235234
def test_classifier_matching():
236235
n_samples = 20
237236
X, y = make_blobs(n_samples=n_samples, centers=2, random_state=0,
@@ -301,7 +300,7 @@ def test_regressor_matching():
301300
assert_allclose(intercept2, clf.intercept_)
302301

303302

304-
@ignore_warnings
303+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
305304
def test_sag_pobj_matches_logistic_regression():
306305
"""tests if the sag pobj matches log reg"""
307306
n_samples = 100
@@ -331,7 +330,7 @@ def test_sag_pobj_matches_logistic_regression():
331330
assert_array_almost_equal(pobj3, pobj1, decimal=4)
332331

333332

334-
@ignore_warnings
333+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
335334
def test_sag_pobj_matches_ridge_regression():
336335
"""tests if the sag pobj matches ridge reg"""
337336
n_samples = 100
@@ -363,7 +362,7 @@ def test_sag_pobj_matches_ridge_regression():
363362
assert_array_almost_equal(pobj3, pobj2, decimal=4)
364363

365364

366-
@ignore_warnings
365+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
367366
def test_sag_regressor_computed_correctly():
368367
"""tests if the sag regressor is computed correctly"""
369368
alpha = .1
@@ -407,7 +406,6 @@ def test_sag_regressor_computed_correctly():
407406
# assert_almost_equal(clf2.intercept_, spintercept2, decimal=1)'''
408407

409408

410-
@ignore_warnings
411409
def test_get_auto_step_size():
412410
X = np.array([[1, 2, 3], [2, 3, 4], [2, 3, 2]], dtype=np.float64)
413411
alpha = 1.2
@@ -452,7 +450,7 @@ def test_get_auto_step_size():
452450
max_squared_sum_, alpha, "wrong", fit_intercept)
453451

454452

455-
@ignore_warnings
453+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
456454
def test_sag_regressor():
457455
"""tests if the sag regressor performs well"""
458456
xmin, xmax = -5, 5
@@ -491,7 +489,7 @@ def test_sag_regressor():
491489
assert_greater(score2, 0.5)
492490

493491

494-
@ignore_warnings
492+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
495493
def test_sag_classifier_computed_correctly():
496494
"""tests if the binary classifier is computed correctly"""
497495
alpha = .1
@@ -534,7 +532,7 @@ def test_sag_classifier_computed_correctly():
534532
assert_almost_equal(clf2.intercept_, spintercept2, decimal=1)
535533

536534

537-
@ignore_warnings
535+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
538536
def test_sag_multiclass_computed_correctly():
539537
"""tests if the multiclass classifier is computed correctly"""
540538
alpha = .1
@@ -593,7 +591,6 @@ def test_sag_multiclass_computed_correctly():
593591
assert_almost_equal(clf2.intercept_[i], intercept2[i], decimal=1)
594592

595593

596-
@ignore_warnings
597594
def test_classifier_results():
598595
"""tests if classifier results match target"""
599596
alpha = .1
@@ -618,7 +615,7 @@ def test_classifier_results():
618615
assert_almost_equal(pred2, y, decimal=12)
619616

620617

621-
@ignore_warnings
618+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
622619
def test_binary_classifier_class_weight():
623620
"""tests binary classifier with classweights for each class"""
624621
alpha = .1
@@ -668,7 +665,7 @@ def test_binary_classifier_class_weight():
668665
assert_almost_equal(clf2.intercept_, spintercept2, decimal=1)
669666

670667

671-
@ignore_warnings
668+
@pytest.mark.filterwarnings('ignore:The max_iter was reached')
672669
def test_multiclass_classifier_class_weight():
673670
"""tests multiclass with classweights for each class"""
674671
alpha = .1

0 commit comments

Comments
 (0)