From e51b0da65e768a38cb675b43e31ad491c4dfc966 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Wed, 13 Sep 2017 14:34:04 +0200 Subject: [PATCH 01/16] sgd for one class svm --- benchmarks/bench_online_ocsvm.py | 279 +++++++++++ doc/modules/classes.rst | 1 + doc/modules/outlier_detection.rst | 32 +- doc/modules/sgd.rst | 51 ++ .../linear_model/plot_sgdocsvm_vs_ocsvm.py | 135 +++++ .../miscellaneous/plot_anomaly_comparison.py | 29 +- sklearn/linear_model/__init__.py | 3 +- sklearn/linear_model/_sgd_fast.pyx | 16 +- sklearn/linear_model/_stochastic_gradient.py | 472 +++++++++++++++++- sklearn/linear_model/tests/test_sgd.py | 374 +++++++++++++- 10 files changed, 1351 insertions(+), 41 deletions(-) create mode 100644 benchmarks/bench_online_ocsvm.py create mode 100644 examples/linear_model/plot_sgdocsvm_vs_ocsvm.py diff --git a/benchmarks/bench_online_ocsvm.py b/benchmarks/bench_online_ocsvm.py new file mode 100644 index 0000000000000..33262e8fcb690 --- /dev/null +++ b/benchmarks/bench_online_ocsvm.py @@ -0,0 +1,279 @@ +""" +===================================== +SGDOneClassSVM benchmark +===================================== +This benchmark compares the :class:`SGDOneClassSVM` with :class:`OneClassSVM`. +The former is an online One-Class SVM implemented with a Stochastic Gradient +Descent (SGD). The latter is based on the LibSVM implementation. The +complexity of :class:`SGDOneClassSVM` is linear in the number of samples +whereas the one of :class:`OneClassSVM` is at best quadratic in the number of +samples. We here compare the performance in terms of AUC and training time on +classical anomaly detection datasets. + +The :class:`OneClassSVM` is applied with a Gaussian kernel and we therefore +use a kernel approximation prior to the application of :class:`SGDOneClassSVM`. +""" + +from time import time +import numpy as np + +from scipy.interpolate import interp1d + +from sklearn.metrics import roc_curve, auc +from sklearn.datasets import fetch_kddcup99, fetch_covtype +from sklearn.preprocessing import LabelBinarizer, StandardScaler +from sklearn.pipeline import make_pipeline +from sklearn.utils import shuffle +from sklearn.kernel_approximation import Nystroem +from sklearn.svm import OneClassSVM +from sklearn.linear_model import SGDOneClassSVM + +import matplotlib.pyplot as plt +import matplotlib + +font = {'weight': 'normal', + 'size': 15} + +matplotlib.rc('font', **font) + +print(__doc__) + + +def print_outlier_ratio(y): + """ + Helper function to show the distinct value count of element in the target. + Useful indicator for the datasets used in bench_isolation_forest.py. + """ + uniq, cnt = np.unique(y, return_counts=True) + print("----- Target count values: ") + for u, c in zip(uniq, cnt): + print("------ %s -> %d occurrences" % (str(u), c)) + print("----- Outlier ratio: %.5f" % (np.min(cnt) / len(y))) + + +# for roc curve computation +n_axis = 1000 +x_axis = np.linspace(0, 1, n_axis) + +datasets = ['http', 'smtp', 'SA', 'SF', 'forestcover'] + +novelty_detection = False # if False, training set polluted by outliers + +random_states = [42] +nu = 0.05 + +results_libsvm = np.empty((len(datasets), n_axis + 5)) +results_online = np.empty((len(datasets), n_axis + 5)) + +for dat, dataset_name in enumerate(datasets): + + print(dataset_name) + + # Loading datasets + if dataset_name in ['http', 'smtp', 'SA', 'SF']: + dataset = fetch_kddcup99(subset=dataset_name, shuffle=False, + percent10=False, random_state=88) + X = dataset.data + y = dataset.target + + if dataset_name == 'forestcover': + dataset = fetch_covtype(shuffle=False) + X = dataset.data + y = dataset.target + # normal data are those with attribute 2 + # abnormal those with attribute 4 + s = (y == 2) + (y == 4) + X = X[s, :] + y = y[s] + y = (y != 2).astype(int) + + # Vectorizing data + if dataset_name == 'SF': + # Casting type of X (object) as string is needed for string categorical + # features to apply LabelBinarizer + lb = LabelBinarizer() + x1 = lb.fit_transform(X[:, 1].astype(str)) + X = np.c_[X[:, :1], x1, X[:, 2:]] + y = (y != b'normal.').astype(int) + + if dataset_name == 'SA': + lb = LabelBinarizer() + # Casting type of X (object) as string is needed for string categorical + # features to apply LabelBinarizer + x1 = lb.fit_transform(X[:, 1].astype(str)) + x2 = lb.fit_transform(X[:, 2].astype(str)) + x3 = lb.fit_transform(X[:, 3].astype(str)) + X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]] + y = (y != b'normal.').astype(int) + + if dataset_name in ['http', 'smtp']: + y = (y != b'normal.').astype(int) + + print_outlier_ratio(y) + + n_samples, n_features = np.shape(X) + if dataset_name == 'SA': # LibSVM too long with n_samples // 2 + n_samples_train = n_samples // 20 + else: + n_samples_train = n_samples // 2 + + n_samples_test = n_samples - n_samples_train + print('n_train: ', n_samples_train) + print('n_features: ', n_features) + + tpr_libsvm = np.zeros(n_axis) + tpr_online = np.zeros(n_axis) + fit_time_libsvm = 0 + fit_time_online = 0 + predict_time_libsvm = 0 + predict_time_online = 0 + + X = X.astype(float) + + gamma = 1 / n_features # OCSVM default parameter + + for random_state in random_states: + + print('random state: %s' % random_state) + + X, y = shuffle(X, y, random_state=random_state) + X_train = X[:n_samples_train] + X_test = X[n_samples_train:] + y_train = y[:n_samples_train] + y_test = y[n_samples_train:] + + if novelty_detection: + X_train = X_train[y_train == 0] + y_train = y_train[y_train == 0] + + std = StandardScaler() + + print('----------- LibSVM OCSVM ------------') + ocsvm = OneClassSVM(kernel='rbf', gamma=gamma, nu=nu) + pipe_libsvm = make_pipeline(std, ocsvm) + + tstart = time() + pipe_libsvm.fit(X_train) + fit_time_libsvm += time() - tstart + + tstart = time() + # scoring such that the lower, the more normal + scoring = -pipe_libsvm.decision_function(X_test) + predict_time_libsvm += time() - tstart + fpr_libsvm_, tpr_libsvm_, _ = roc_curve(y_test, scoring) + + f_libsvm = interp1d(fpr_libsvm_, tpr_libsvm_) + tpr_libsvm += f_libsvm(x_axis) + + print('----------- Online OCSVM ------------') + nystroem = Nystroem(gamma=gamma, random_state=random_state) + online_ocsvm = SGDOneClassSVM(nu=nu, random_state=random_state) + pipe_online = make_pipeline(std, nystroem, online_ocsvm) + + tstart = time() + pipe_online.fit(X_train) + fit_time_online += time() - tstart + + tstart = time() + # scoring such that the lower, the more normal + scoring = -pipe_online.decision_function(X_test) + predict_time_online += time() - tstart + fpr_online_, tpr_online_, _ = roc_curve(y_test, scoring) + + f_online = interp1d(fpr_online_, tpr_online_) + tpr_online += f_online(x_axis) + + tpr_libsvm /= len(random_states) + tpr_libsvm[0] = 0. + fit_time_libsvm /= len(random_states) + predict_time_libsvm /= len(random_states) + auc_libsvm = auc(x_axis, tpr_libsvm) + + results_libsvm[dat] = ([fit_time_libsvm, predict_time_libsvm, + auc_libsvm, n_samples_train, + n_features] + list(tpr_libsvm)) + + tpr_online /= len(random_states) + tpr_online[0] = 0. + fit_time_online /= len(random_states) + predict_time_online /= len(random_states) + auc_online = auc(x_axis, tpr_online) + + results_online[dat] = ([fit_time_online, predict_time_online, + auc_online, n_samples_train, + n_features] + list(tpr_libsvm)) + + +# -------- Plotting bar charts ------------- +fit_time_libsvm_all = results_libsvm[:, 0] +predict_time_libsvm_all = results_libsvm[:, 1] +auc_libsvm_all = results_libsvm[:, 2] +n_train_all = results_libsvm[:, 3] +n_features_all = results_libsvm[:, 4] + +fit_time_online_all = results_online[:, 0] +predict_time_online_all = results_online[:, 1] +auc_online_all = results_online[:, 2] + + +width = 0.7 +ind = 2 * np.arange(len(datasets)) +x_tickslabels = [(name + '\n' + r'$n={:,d}$' + '\n' + r'$d={:d}$') + .format(int(n), int(d)) + for name, n, d in zip(datasets, n_train_all, n_features_all)] + + +def autolabel_auc(rects, ax): + """Attach a text label above each bar displaying its height.""" + for rect in rects: + height = rect.get_height() + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, + '%.3f' % height, ha='center', va='bottom') + + +def autolabel_time(rects, ax): + """Attach a text label above each bar displaying its height.""" + for rect in rects: + height = rect.get_height() + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, + '%.1f' % height, ha='center', va='bottom') + + +fig, ax = plt.subplots(figsize=(15, 8)) +ax.set_ylabel('AUC') +ax.set_ylim((0, 1.3)) +rect_libsvm = ax.bar(ind, auc_libsvm_all, width=width, color='r') +rect_online = ax.bar(ind + width, auc_online_all, width=width, color='y') +ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM')) +ax.set_xticks(ind + width / 2) +ax.set_xticklabels(x_tickslabels) +autolabel_auc(rect_libsvm, ax) +autolabel_auc(rect_online, ax) +plt.show() + + +fig, ax = plt.subplots(figsize=(15, 8)) +ax.set_ylabel('Training time (sec) - Log scale') +ax.set_yscale('log') +rect_libsvm = ax.bar(ind, fit_time_libsvm_all, color='r', width=width) +rect_online = ax.bar(ind + width, fit_time_online_all, color='y', width=width) +ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM')) +ax.set_xticks(ind + width / 2) +ax.set_xticklabels(x_tickslabels) +autolabel_time(rect_libsvm, ax) +autolabel_time(rect_online, ax) +plt.show() + + +fig, ax = plt.subplots(figsize=(15, 8)) +ax.set_ylabel('Testing time (sec) - Log scale') +ax.set_yscale('log') +rect_libsvm = ax.bar(ind, predict_time_libsvm_all, color='r', width=width) +rect_online = ax.bar(ind + width, predict_time_online_all, + color='y', width=width) +ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM')) +ax.set_xticks(ind + width / 2) +ax.set_xticklabels(x_tickslabels) +autolabel_time(rect_libsvm, ax) +autolabel_time(rect_online, ax) +plt.show() diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 84f8097cbbe9d..18cb97224285d 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -762,6 +762,7 @@ Linear classifiers linear_model.RidgeClassifier linear_model.RidgeClassifierCV linear_model.SGDClassifier + linear_model.SGDOneClassSVM Classical linear regressors --------------------------- diff --git a/doc/modules/outlier_detection.rst b/doc/modules/outlier_detection.rst index 5d2008f3c3f58..14495bc558dab 100644 --- a/doc/modules/outlier_detection.rst +++ b/doc/modules/outlier_detection.rst @@ -110,9 +110,14 @@ does not perform very well for outlier detection. That being said, outlier detection in high-dimension, or without any assumptions on the distribution of the inlying data is very challenging. :class:`svm.OneClassSVM` may still be used with outlier detection but requires fine-tuning of its hyperparameter -`nu` to handle outliers and prevent overfitting. Finally, -:class:`covariance.EllipticEnvelope` assumes the data is Gaussian and learns -an ellipse. For more details on the different estimators refer to the example +`nu` to handle outliers and prevent overfitting. +:class:`linear_model.SGDOneClassSVM` provides an implementation of a +linear One-Class SVM with a linear complexity in the number of samples. This +implementation is here used with a kernel approximation technique to obtain +results similar to :class:`svm.OneClassSVM` which uses a Gaussian kernel +by default. Finally, :class:`covariance.EllipticEnvelope` assumes the data is +Gaussian and learns an ellipse. For more details on the different estimators +refer to the example :ref:`sphx_glr_auto_examples_miscellaneous_plot_anomaly_comparison.py` and the sections hereunder. @@ -173,6 +178,23 @@ but regular, observation outside the frontier. :scale: 75% +Scaling up the One-Class SVM +---------------------------- + +An online linear version of the One-Class SVM is implemented in +:class:`linear_model.SGDOneClassSVM`. This implementation scales linearly with +the number of samples and can be used with a kernel approximation to +approximate the solution of a kernelized :class:`svm.OneClassSVM` whose +complexity is at best quadratic in the number of samples. See section +:ref:`sgd_online_one_class_svm` for more details. + +.. topic:: Examples: + + * See :ref:`sphx_glr_auto_examples_linear_model_plot_sgdocsvm_vs_ocsvm.py` + for an illustration of the approximation of a kernelized One-Class SVM + with the `linear_model.SGDOneClassSVM` combined with kernel approximation. + + Outlier Detection ================= @@ -278,8 +300,8 @@ allows you to add more trees to an already fitted model:: for a comparison of :class:`ensemble.IsolationForest` with :class:`neighbors.LocalOutlierFactor`, :class:`svm.OneClassSVM` (tuned to perform like an outlier detection - method) and a covariance-based outlier detection with - :class:`covariance.EllipticEnvelope`. + method), :class:`linear_model.SGDOneClassSVM`, and a covariance-based + outlier detection with :class:`covariance.EllipticEnvelope`. .. topic:: References: diff --git a/doc/modules/sgd.rst b/doc/modules/sgd.rst index 95a5111747509..df6eab3acb783 100644 --- a/doc/modules/sgd.rst +++ b/doc/modules/sgd.rst @@ -232,6 +232,57 @@ For regression with a squared loss and a l2 penalty, another variant of SGD with an averaging strategy is available with Stochastic Average Gradient (SAG) algorithm, available as a solver in :class:`Ridge`. +.. _sgd_online_one_class_svm: + +Online One-Class SVM +==================== + +The class :class:`sklearn.linear_model.SGDOneClassSVM` implements an online +linear version of the One-Class SVM using a stochastic gradient descent. +Combined with kernel approximation techniques, +:class:`sklearn.linear_model.SGDOneClassSVM` can be used to approximate the +solution of a kernelized One-Class SVM, implemented in +:class:`sklearn.svm.OneClassSVM`, with a linear complexity in the number of +samples. Note that the complexity of a kernelized One-Class SVM is at best +quadratic in the number of samples. +:class:`sklearn.linear_model.SGDOneClassSVM` is thus well suited for datasets +with a large number of training samples (> 10.000). + +Its implementation is based on the implementation of the stochastic +gradient descent. Indeed, the original optimization problem of the One-Class +SVM is given by + +.. math:: + + \begin{aligned} + \min_{w, \rho, \xi} & \quad \frac{1}{2}\Vert w \Vert^2 - \rho + \frac{1}{\nu n} \sum_{i=1}^n \xi_i \\ + \text{s.t.} & \quad \langle w, x_i \rangle \geq \rho - \xi_i \quad 1 \leq i \leq n \\ + & \quad \xi_i \geq 0 \quad 1 \leq i \leq n + \end{aligned} + +where :math:`\nu \in (0, 1]` is the user-specified parameter controlling the +proportion of outliers and the proportion of support vectors. Getting rid of +the slack variables :math:`\xi_i` this problem is equivalent to + +.. math:: + + \min_{w, \rho} \frac{1}{2}\Vert w \Vert^2 - \rho + \frac{1}{\nu n} \sum_{i=1}^n \max(0, \rho - \langle w, x_i \rangle) \, . + +Multiplying by the constant :math:`\nu` and introducing the intercept +:math:`b = 1 - \rho` we obtain the following equivalent optimization problem + +.. math:: + + \min_{w, b} \frac{\nu}{2}\Vert w \Vert^2 + b\nu + \frac{1}{n} \sum_{i=1}^n \max(0, 1 - (\langle w, x_i \rangle + b)) \, . + +This is similar to the optimization problems studied in section +:ref:`sgd_mathematical_formulation` with :math:`y_i = 1, 1 \leq i \leq n` and +:math:`\alpha = \nu/2`, :math:`L` being the hinge loss function and :math:`R` +being the L2 norm. We just need to add the term :math:`b\nu` in the +optimization loop. + +As :class:`SGDClassifier` and :class:`SGDRegressor`, :class:`SGDOneClassSVM` +supports averaged SGD. Averaging can be enabled by setting ``average=True``. Stochastic Gradient Descent for sparse data =========================================== diff --git a/examples/linear_model/plot_sgdocsvm_vs_ocsvm.py b/examples/linear_model/plot_sgdocsvm_vs_ocsvm.py new file mode 100644 index 0000000000000..e70694cdb1c1b --- /dev/null +++ b/examples/linear_model/plot_sgdocsvm_vs_ocsvm.py @@ -0,0 +1,135 @@ +""" +==================================================================== +One-Class SVM versus One-Class SVM using Stochastic Gradient Descent +==================================================================== + +This example shows how to approximate the solution of +:class:`sklearn.svm.OneClassSVM` in the case of an RBF kernel with +:class:`sklearn.linear_model.SGDOneClassSVM`, a Stochastic Gradient Descent +(SGD) version of the One-Class SVM. A kernel approximation is first used in +order to apply :class:`sklearn.linear_model.SGDOneClassSVM` which implements a +linear One-Class SVM using SGD. + +Note that :class:`sklearn.linear_model.SGDOneClassSVM` scales linearly with +the number of samples whereas the complexity of a kernelized +:class:`sklearn.svm.OneClassSVM` is at best quadratic with respect to the +number of samples. It is not the purpose of this example to illustrate the +benefits of such an approximation in terms of computation time but rather to +show that we obtain similar results on a toy dataset. +""" +print(__doc__) # noqa + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib +from sklearn.svm import OneClassSVM +from sklearn.linear_model import SGDOneClassSVM +from sklearn.kernel_approximation import Nystroem +from sklearn.pipeline import make_pipeline + +font = {'weight': 'normal', + 'size': 15} + +matplotlib.rc('font', **font) + +random_state = 42 +rng = np.random.RandomState(random_state) + +# Generate train data +X = 0.3 * rng.randn(500, 2) +X_train = np.r_[X + 2, X - 2] +# Generate some regular novel observations +X = 0.3 * rng.randn(20, 2) +X_test = np.r_[X + 2, X - 2] +# Generate some abnormal novel observations +X_outliers = rng.uniform(low=-4, high=4, size=(20, 2)) + +xx, yy = np.meshgrid(np.linspace(-4.5, 4.5, 50), np.linspace(-4.5, 4.5, 50)) + +# OCSVM hyperparameters +nu = 0.05 +gamma = 2. + +# Fit the One-Class SVM +clf = OneClassSVM(gamma=gamma, kernel='rbf', nu=nu) +clf.fit(X_train) +y_pred_train = clf.predict(X_train) +y_pred_test = clf.predict(X_test) +y_pred_outliers = clf.predict(X_outliers) +n_error_train = y_pred_train[y_pred_train == -1].size +n_error_test = y_pred_test[y_pred_test == -1].size +n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size + +Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) +Z = Z.reshape(xx.shape) + + +# Fit the One-Class SVM using a kernel approximation and SGD +transform = Nystroem(gamma=gamma, random_state=random_state) +clf_sgd = SGDOneClassSVM(nu=nu, shuffle=True, fit_intercept=True, + random_state=random_state, tol=1e-4) +pipe_sgd = make_pipeline(transform, clf_sgd) +pipe_sgd.fit(X_train) +y_pred_train_sgd = pipe_sgd.predict(X_train) +y_pred_test_sgd = pipe_sgd.predict(X_test) +y_pred_outliers_sgd = pipe_sgd.predict(X_outliers) +n_error_train_sgd = y_pred_train_sgd[y_pred_train_sgd == -1].size +n_error_test_sgd = y_pred_test_sgd[y_pred_test_sgd == -1].size +n_error_outliers_sgd = y_pred_outliers_sgd[y_pred_outliers_sgd == 1].size + +Z_sgd = pipe_sgd.decision_function(np.c_[xx.ravel(), yy.ravel()]) +Z_sgd = Z_sgd.reshape(xx.shape) + +# plot the level sets of the decision function +plt.figure(figsize=(9, 6)) +plt.title('One Class SVM') +plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu) +a = plt.contour(xx, yy, Z, levels=[0], linewidths=2, colors='darkred') +plt.contourf(xx, yy, Z, levels=[0, Z.max()], colors='palevioletred') + +s = 20 +b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c='white', s=s, edgecolors='k') +b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c='blueviolet', s=s, + edgecolors='k') +c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c='gold', s=s, + edgecolors='k') +plt.axis('tight') +plt.xlim((-4.5, 4.5)) +plt.ylim((-4.5, 4.5)) +plt.legend([a.collections[0], b1, b2, c], + ["learned frontier", "training observations", + "new regular observations", "new abnormal observations"], + loc="upper left") +plt.xlabel( + "error train: %d/%d; errors novel regular: %d/%d; " + "errors novel abnormal: %d/%d" + % (n_error_train, X_train.shape[0], n_error_test, X_test.shape[0], + n_error_outliers, X_outliers.shape[0])) +plt.show() + +plt.figure(figsize=(9, 6)) +plt.title('Online One-Class SVM') +plt.contourf(xx, yy, Z_sgd, levels=np.linspace(Z_sgd.min(), 0, 7), + cmap=plt.cm.PuBu) +a = plt.contour(xx, yy, Z_sgd, levels=[0], linewidths=2, colors='darkred') +plt.contourf(xx, yy, Z_sgd, levels=[0, Z_sgd.max()], colors='palevioletred') + +s = 20 +b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c='white', s=s, edgecolors='k') +b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c='blueviolet', s=s, + edgecolors='k') +c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c='gold', s=s, + edgecolors='k') +plt.axis('tight') +plt.xlim((-4.5, 4.5)) +plt.ylim((-4.5, 4.5)) +plt.legend([a.collections[0], b1, b2, c], + ["learned frontier", "training observations", + "new regular observations", "new abnormal observations"], + loc="upper left") +plt.xlabel( + "error train: %d/%d; errors novel regular: %d/%d; " + "errors novel abnormal: %d/%d" + % (n_error_train_sgd, X_train.shape[0], n_error_test_sgd, X_test.shape[0], + n_error_outliers_sgd, X_outliers.shape[0])) +plt.show() diff --git a/examples/miscellaneous/plot_anomaly_comparison.py b/examples/miscellaneous/plot_anomaly_comparison.py index b5ebd96bd8815..dc0d9f23f8315 100644 --- a/examples/miscellaneous/plot_anomaly_comparison.py +++ b/examples/miscellaneous/plot_anomaly_comparison.py @@ -22,7 +22,17 @@ One-class SVM might give useful results in these situations depending on the value of its hyperparameters. -:class:`~sklearn.covariance.EllipticEnvelope` assumes the data is Gaussian and +The :class:`sklearn.linear_model.SGDOneClassSVM` is an implementation of the +One-Class SVM based on stochastic gradient descent (SGD). Combined with kernel +approximation, this estimator can be used to approximate the solution +of a kernelized :class:`sklearn.svm.OneClassSVM`. We note that, although not +identical, the decision boundaries of the +:class:`sklearn.linear_model.SGDOneClassSVM` and the ones of +:class:`sklearn.svm.OneClassSVM` are very similar. The main advantage of using +:class:`sklearn.linear_model.SGDOneClassSVM` is that it scales linearly with +the number of samples. + +:class:`sklearn.covariance.EllipticEnvelope` assumes the data is Gaussian and learns an ellipse. It thus degrades when the data is not unimodal. Notice however that this estimator is robust to outliers. @@ -66,6 +76,9 @@ from sklearn.covariance import EllipticEnvelope from sklearn.ensemble import IsolationForest from sklearn.neighbors import LocalOutlierFactor +from sklearn.linear_model import SGDOneClassSVM +from sklearn.kernel_approximation import Nystroem +from sklearn.pipeline import make_pipeline print(__doc__) @@ -86,6 +99,14 @@ random_state=42)), ("Local Outlier Factor", LocalOutlierFactor( n_neighbors=35, contamination=outliers_fraction))] +# SGDOneClassSVM must be used with a kernel approximation to give similar +# results to the OneClassSVM +transform = Nystroem(gamma=0.1, random_state=42, n_components=150) +clf_sgd = SGDOneClassSVM(nu=outliers_fraction, shuffle=True, + fit_intercept=True, random_state=42, tol=1e-6) +pipe_sgd = make_pipeline(transform, clf_sgd) +anomaly_algorithms.insert(2, ("One-Class SVM (SGD)", pipe_sgd)) + # Define datasets blobs_params = dict(random_state=0, n_samples=n_inliers, n_features=2) @@ -104,7 +125,7 @@ xx, yy = np.meshgrid(np.linspace(-7, 7, 150), np.linspace(-7, 7, 150)) -plt.figure(figsize=(len(anomaly_algorithms) * 2 + 3, 12.5)) +plt.figure(figsize=(len(anomaly_algorithms) * 2 + 4, 12.5)) plt.subplots_adjust(left=.02, right=.98, bottom=.001, top=.96, wspace=.05, hspace=.01) @@ -113,8 +134,8 @@ for i_dataset, X in enumerate(datasets): # Add outliers - X = np.concatenate([X, rng.uniform(low=-6, high=6, - size=(n_outliers, 2))], axis=0) + X = np.concatenate([X, rng.uniform(low=-6, high=6, size=(n_outliers, 2))], + axis=0) for name, algorithm in anomaly_algorithms: t0 = time.time() diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index 110e0008bccc9..f715e30795961 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -18,7 +18,7 @@ GammaRegressor, TweedieRegressor) from ._huber import HuberRegressor from ._sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber -from ._stochastic_gradient import SGDClassifier, SGDRegressor +from ._stochastic_gradient import SGDClassifier, SGDRegressor, SGDOneClassSVM from ._ridge import (Ridge, RidgeCV, RidgeClassifier, RidgeClassifierCV, ridge_regression) from ._logistic import LogisticRegression, LogisticRegressionCV @@ -65,6 +65,7 @@ 'RidgeClassifierCV', 'SGDClassifier', 'SGDRegressor', + 'SGDOneClassSVM', 'SquaredLoss', 'TheilSenRegressor', 'enet_path', diff --git a/sklearn/linear_model/_sgd_fast.pyx b/sklearn/linear_model/_sgd_fast.pyx index 3940e5d873669..dab7b36b14d0e 100644 --- a/sklearn/linear_model/_sgd_fast.pyx +++ b/sklearn/linear_model/_sgd_fast.pyx @@ -55,7 +55,7 @@ cdef class LossFunction: Parameters ---------- p : double - The prediction, p = w^T x + The prediction, p = w^T x + intercept y : double The true value (aka target) @@ -358,6 +358,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, double weight_pos, double weight_neg, int learning_rate, double eta0, double power_t, + bint one_class, double t=1.0, double intercept_decay=1.0, int average=0): @@ -427,6 +428,8 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, The initial learning rate. power_t : double The exponent for inverse scaling learning rate. + one_class : boolean + Whether to solve the One-Class SVM optimization problem. t : double Initial state of the learning rate. This value is equal to the iteration count except when the learning rate is set to `optimal`. @@ -435,6 +438,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, The number of iterations before averaging starts. average=1 is equivalent to averaging for all iterations. + Returns ------- weights : array, shape=[n_features] @@ -468,6 +472,7 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, cdef double eta = 0.0 cdef double p = 0.0 cdef double update = 0.0 + cdef double intercept_update = 0.0 cdef double sumloss = 0.0 cdef double score = 0.0 cdef double best_loss = INFINITY @@ -574,10 +579,15 @@ def _plain_sgd(np.ndarray[double, ndim=1, mode='c'] weights, # do not scale to negative values when eta or alpha are too # big: instead set the weights to zero w.scale(max(0, 1.0 - ((1.0 - l1_ratio) * eta * alpha))) + if update != 0.0: w.add(x_data_ptr, x_ind_ptr, xnnz, update) - if fit_intercept == 1: - intercept += update * intercept_decay + if fit_intercept == 1: + intercept_update = update + if one_class: # specific for One-Class SVM + intercept_update -= 2. * eta * alpha + if intercept_update != 0: + intercept += intercept_update * intercept_decay if 0 < average <= t: # compute the average for the intercept and update the diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 948910e61b51c..aa9df78dda6cf 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -2,7 +2,9 @@ # Mathieu Blondel (partial_fit support) # # License: BSD 3 clause -"""Classification and regression using Stochastic Gradient Descent (SGD).""" +"""Classification, regression and One-Class SVM using Stochastic Gradient +Descent (SGD). +""" import numpy as np import warnings @@ -14,7 +16,7 @@ from ..base import clone, is_classifier from ._base import LinearClassifierMixin, SparseCoefMixin from ._base import make_dataset -from ..base import BaseEstimator, RegressorMixin +from ..base import BaseEstimator, RegressorMixin, OutlierMixin from ..utils import check_random_state from ..utils.extmath import safe_sparse_dot from ..utils.multiclass import _check_partial_fit_first_call @@ -134,7 +136,7 @@ def _validate_params(self, for_partial_fit=False): raise ValueError("max_iter must be > zero. Got %f" % self.max_iter) if not (0.0 <= self.l1_ratio <= 1.0): raise ValueError("l1_ratio must be in [0, 1]") - if self.alpha < 0.0: + if not isinstance(self, SGDOneClassSVM) and self.alpha < 0.0: raise ValueError("alpha must be >= 0") if self.n_iter_no_change < 1: raise ValueError("n_iter_no_change must be >= 1") @@ -182,7 +184,7 @@ def _get_penalty_type(self, penalty): raise ValueError("Penalty %s is not supported. " % penalty) from e def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, - intercept_init=None): + intercept_init=None, one_class=0): """Allocate mem for parameters; initialize if provided.""" if n_classes > 2: # allocate coef_ for multi-class @@ -207,7 +209,7 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, self.intercept_ = np.zeros(n_classes, dtype=np.float64, order="C") else: - # allocate coef_ for binary problem + # allocate coef_ if coef_init is not None: coef_init = np.asarray(coef_init, dtype=np.float64, order="C") @@ -221,26 +223,36 @@ def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None, dtype=np.float64, order="C") - # allocate intercept_ for binary problem + # allocate intercept_ if intercept_init is not None: intercept_init = np.asarray(intercept_init, dtype=np.float64) if intercept_init.shape != (1,) and intercept_init.shape != (): raise ValueError("Provided intercept_init " "does not match dataset.") - self.intercept_ = intercept_init.reshape(1,) + if one_class: + self.offset_ = intercept_init.reshape(1,) + else: + self.intercept_ = intercept_init.reshape(1,) else: - self.intercept_ = np.zeros(1, dtype=np.float64, order="C") + if one_class: + self.offset_ = np.zeros(1, dtype=np.float64, order="C") + else: + self.intercept_ = np.zeros(1, dtype=np.float64, order="C") # initialize average parameters if self.average > 0: self._standard_coef = self.coef_ - self._standard_intercept = self.intercept_ self._average_coef = np.zeros(self.coef_.shape, dtype=np.float64, order="C") - self._average_intercept = np.zeros(self._standard_intercept.shape, - dtype=np.float64, - order="C") + if one_class: + self._standard_intercept = 1 - self.offset_ + else: + self._standard_intercept = self.intercept_ + + self._average_intercept = np.zeros( + self._standard_intercept.shape, dtype=np.float64, + order="C") def _make_validation_split(self, y): """Split the dataset between training set and validation set. @@ -439,7 +451,7 @@ def fit_binary(est, i, X, y, alpha, C, learning_rate, max_iter, est.early_stopping, validation_score_cb, int(est.n_iter_no_change), max_iter, tol, int(est.fit_intercept), int(est.verbose), int(est.shuffle), seed, pos_weight, neg_weight, learning_rate_type, - est.eta0, est.power_t, est.t_, intercept_decay, est.average) + est.eta0, est.power_t, 0, est.t_, intercept_decay, est.average) if est.average: if len(est.classes_) == 2: @@ -1342,7 +1354,7 @@ def _fit_regressor(self, X, y, alpha, C, loss, learning_rate, seed, 1.0, 1.0, learning_rate_type, - self.eta0, self.power_t, self.t_, + self.eta0, self.power_t, 0, self.t_, intercept_decay, self.average) self.t_ += self.n_iter_ * X.shape[0] @@ -1596,3 +1608,435 @@ def _more_tags(self): 'zero sample_weight is not equivalent to removing samples', } } + + +class SGDOneClassSVM(BaseSGD, OutlierMixin): + """Solves linear One-Class SVM using Stochastic Gradient Descent. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.0 + + Parameters + ---------- + nu : float, optional + The nu parameter of the One Class SVM: an upper bound on the + fraction of training errors and a lower bound of the fraction of + support vectors. Should be in the interval (0, 1]. By default 0.5 + will be taken. + + fit_intercept : bool + Whether the intercept should be estimated or not. Defaults to True. + + max_iter : int, optional + The maximum number of passes over the training data (aka epochs). + It only impacts the behavior in the ``fit`` method, and not the + `partial_fit`. Defaults to 1000. + + tol : float or None, optional + The stopping criterion. If it is not None, the iterations will stop + when (loss > previous_loss - tol). Defaults to 1e-3. + + shuffle : bool, optional + Whether or not the training data should be shuffled after each epoch. + Defaults to True. + + verbose : integer, optional + The verbosity level + + random_state : int, RandomState instance or None, optional (default=None) + The seed of the pseudo random number generator to use when shuffling + the data. If int, random_state is the seed used by the random number + generator; If RandomState instance, random_state is the random number + generator; If None, the random number generator is the RandomState + instance used by `np.random`. + + learning_rate : string, optional + The learning rate schedule: + + 'constant': + eta = eta0 + 'optimal': [default] + eta = 1.0 / (alpha * (t + t0)) + where t0 is chosen by a heuristic proposed by Leon Bottou. + 'invscaling': + eta = eta0 / pow(t, power_t) + 'adaptive': + eta = eta0, as long as the training keeps decreasing. + Each time n_iter_no_change consecutive epochs fail to decrease the + training loss by tol or fail to increase validation score by tol if + early_stopping is True, the current learning rate is divided by 5. + + eta0 : double + The initial learning rate for the 'constant', 'invscaling' or + 'adaptive' schedules. The default value is 0.0 as eta0 is not used by + the default schedule 'optimal'. + + power_t : double + The exponent for inverse scaling learning rate [default 0.5]. + + warm_start : bool, optional + When set to True, reuse the solution of the previous call to fit as + initialization, otherwise, just erase the previous solution. + See :term:`the Glossary `. + + Repeatedly calling fit or partial_fit when warm_start is True can + result in a different solution than when calling fit a single time + because of the way the data is shuffled. + If a dynamic learning rate is used, the learning rate is adapted + depending on the number of samples already seen. Calling ``fit`` resets + this counter, while ``partial_fit`` will result in increasing the + existing counter. + + average : bool or int, optional + When set to True, computes the averaged SGD weights and stores the + result in the ``coef_`` attribute. If set to an int greater than 1, + averaging will begin once the total number of samples seen reaches + average. So ``average=10`` will begin averaging after seeing 10 + samples. + + Attributes + ---------- + coef_ : array, shape (1, n_features) + Weights assigned to the features. + + offset_ : array, shape (1,) + Offset used to define the decision function from the raw scores. + We have the relation: decision_function = score_samples - offset. + + n_iter_ : int + The actual number of iterations to reach the stopping criterion. + + t_ : int + Number of weight updates performed during training. + Same as ``(n_iter_ * n_samples)``. + + loss_function_ : concrete ``LossFunction`` + + Examples + -------- + >>> import numpy as np + >>> from sklearn import linear_model + >>> X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) + >>> clf = linear_model.SGDOneClassSVM(random_state=42) + >>> clf.fit(X) + SGDOneClassSVM(random_state=42) + + >>> print(clf.predict([[4, 4]])) + [1] + + See also + -------- + sklearn.svm.OneClassSVM + + """ + + loss_functions = {"hinge": (Hinge, 1.0)} + + def __init__(self, nu=0.5, fit_intercept=True, max_iter=1000, tol=1e-3, + shuffle=True, verbose=0, random_state=None, + learning_rate="optimal", eta0=0.0, power_t=0.5, + warm_start=False, average=False): + + alpha = nu / 2 + self.nu = nu + super(SGDOneClassSVM, self).__init__( + loss="hinge", penalty='l2', alpha=alpha, C=1.0, l1_ratio=0, + fit_intercept=fit_intercept, max_iter=max_iter, tol=tol, + shuffle=shuffle, verbose=verbose, epsilon=DEFAULT_EPSILON, + random_state=random_state, learning_rate=learning_rate, + eta0=eta0, power_t=power_t, early_stopping=False, + validation_fraction=0.1, n_iter_no_change=5, + warm_start=warm_start, average=average) + + def _validate_params(self, for_partial_fit=False): + """Validate input params. """ + if not(0 < self.nu <= 1): + raise ValueError("nu must be in (0, 1], got nu=%f" % self.nu) + + super(SGDOneClassSVM, self)._validate_params( + for_partial_fit=for_partial_fit) + + def _fit_one_class(self, X, alpha, C, sample_weight, + learning_rate, max_iter): + """Uses SGD implementation with X and y=np.ones(n_samples).""" + + # The One-Class SVM uses the SGD implementation with + # y=np.ones(n_samples). + n_samples = X.shape[0] + y = np.ones(n_samples, dtype=np.float64, order="C") + + dataset, offset_decay = make_dataset(X, y, sample_weight) + + penalty_type = self._get_penalty_type(self.penalty) + learning_rate_type = self._get_learning_rate_type(learning_rate) + + # early stopping is set to False for the One-Class SVM. thus + # validation_mask and validation_score_cb will be set to values + # associated to early_stopping=False in _make_validation_split and + # _make_validation_score_cb respectively. + validation_mask = self._make_validation_split(y) + validation_score_cb = self._make_validation_score_cb( + validation_mask, X, y, sample_weight) + + random_state = check_random_state(self.random_state) + # numpy mtrand expects a C long which is a signed 32 bit integer under + # Windows + seed = random_state.randint(0, np.iinfo(np.int32).max) + + tol = self.tol if self.tol is not None else -np.inf + + one_class = 1 + # There are no class weights for the One-Class SVM and they are + # therefore set to 1. + pos_weight = 1 + neg_weight = 1 + + if self.average: + coef = self._standard_coef + intercept = self._standard_intercept + average_coef = self._average_coef + average_intercept = self._average_intercept + else: + coef = self.coef_ + intercept = 1 - self.offset_ + average_coef = None # Not used + average_intercept = [0] # Not used + + coef, intercept, average_coef, average_intercept, self.n_iter_ = \ + _plain_sgd(coef, + intercept[0], + average_coef, + average_intercept[0], + self.loss_function_, + penalty_type, + alpha, C, + self.l1_ratio, + dataset, + validation_mask, self.early_stopping, + validation_score_cb, + int(self.n_iter_no_change), + max_iter, tol, + int(self.fit_intercept), + int(self.verbose), + int(self.shuffle), + seed, + neg_weight, pos_weight, + learning_rate_type, + self.eta0, self.power_t, + one_class, self.t_, + offset_decay, self.average) + + self.t_ += self.n_iter_ * n_samples + + if self.average > 0: + + self._average_intercept = np.atleast_1d(average_intercept) + self._standard_intercept = np.atleast_1d(intercept) + + if self.average <= self.t_ - 1.0: + # made enough updates for averaging to be taken into account + self.coef_ = average_coef + self.offset_ = 1 - np.atleast_1d(average_intercept) + else: + self.coef_ = coef + self.offset_ = 1 - np.atleast_1d(intercept) + + else: + self.offset_ = 1 - np.atleast_1d(intercept) + + def _partial_fit(self, X, alpha, C, loss, learning_rate, max_iter, + sample_weight, coef_init, offset_init): + first_call = getattr(self, "coef_", None) is None + X = self._validate_data( + X, None, accept_sparse='csr', dtype=np.float64, + order="C", accept_large_sparse=False, + reset=first_call) + + n_features = X.shape[1] + + # Allocate datastructures from input arguments + sample_weight = _check_sample_weight(sample_weight, X) + + # We use intercept = 1 - offset where intercept is the intercept of + # the SGD implementation and offset is the offset of the One-Class SVM + # optimization problem. + if getattr(self, "coef_", None) is None or coef_init is not None: + self._allocate_parameter_mem(1, n_features, + coef_init, offset_init, 1) + elif n_features != self.coef_.shape[-1]: + raise ValueError("Number of features %d does not match previous " + "data %d." % (n_features, self.coef_.shape[-1])) + + if self.average and getattr(self, "_average_coef", None) is None: + self._average_coef = np.zeros(n_features, dtype=np.float64, + order="C") + self._average_intercept = np.zeros(1, dtype=np.float64, order="C") + + self.loss_function_ = self._get_loss_function(loss) + if not hasattr(self, "t_"): + self.t_ = 1.0 + + # delegate to concrete training procedure + self._fit_one_class(X, alpha=alpha, C=C, + learning_rate=learning_rate, + sample_weight=sample_weight, + max_iter=max_iter) + + return self + + def partial_fit(self, X, y=None, sample_weight=None): + """Fit linear One-Class SVM with Stochastic Gradient Descent. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Subset of the training data + + sample_weight : array-like, shape (n_samples,), optional + Weights applied to individual samples. + If not provided, uniform weights are assumed. + + Returns + ------- + self : returns an instance of self. + """ + + alpha = self.nu / 2 + self._validate_params(for_partial_fit=True) + + return self._partial_fit(X, alpha, C=1.0, loss=self.loss, + learning_rate=self.learning_rate, + max_iter=1, + sample_weight=sample_weight, + coef_init=None, offset_init=None) + + def _fit(self, X, alpha, C, loss, learning_rate, coef_init=None, + offset_init=None, sample_weight=None): + self._validate_params() + + if self.warm_start and hasattr(self, "coef_"): + if coef_init is None: + coef_init = self.coef_ + if offset_init is None: + offset_init = self.offset_ + else: + self.coef_ = None + self.offset_ = None + + # Clear iteration count for multiple call to fit. + self.t_ = 1.0 + + self._partial_fit(X, alpha, C, loss, learning_rate, self.max_iter, + sample_weight, coef_init, offset_init) + + if (self.tol is not None and self.tol > -np.inf + and self.n_iter_ == self.max_iter): + warnings.warn("Maximum number of iteration reached before " + "convergence. Consider increasing max_iter to " + "improve the fit.", + ConvergenceWarning) + + return self + + def fit(self, X, y=None, coef_init=None, offset_init=None, + sample_weight=None): + """Fit linear One-Class SVM with Stochastic Gradient Descent. + + This solves an equivalent optimization problem of the + One-Class SVM primal optimization problem and returns a weight vector + w and an offset rho such that the decision function is given by + - rho. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Training data + + coef_init : array, shape (n_classes, n_features) + The initial coefficients to warm-start the optimization. + + offset_init : array, shape (n_classes,) + The initial offset to warm-start the optimization + + sample_weight : array-like, shape (n_samples,), optional + Weights applied to individual samples. + If not provided, uniform weights are assumed. These weights will + be multiplied with class_weight (passed through the + constructor) if class_weight is specified. + + Returns + ------- + self : returns an instance of self. + """ + + alpha = self.nu / 2 + self._fit(X, alpha=alpha, C=1.0, + loss=self.loss, learning_rate=self.learning_rate, + coef_init=coef_init, offset_init=offset_init, + sample_weight=sample_weight) + + return self + + def decision_function(self, X): + """Signed distance to the separating hyperplane. + + Signed distance is positive for an inlier and negative for an + outlier. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + dec : array-like, shape (n_samples,) + Decision function values of the samples. + """ + + check_is_fitted(self, "coef_") + + X = self._validate_data(X, accept_sparse='csr', reset=False) + decisions = safe_sparse_dot(X, self.coef_.T, + dense_output=True) - self.offset_ + + return decisions.ravel() + + def score_samples(self, X): + """Raw scoring function of the samples + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + score_samples : array-like, shape (n_samples,) + Unshiffted scoring function values of the samples. + """ + score_samples = self.decision_function(X) + self.offset_ + return score_samples + + def predict(self, X): + """Return labels (1 inlier, -1 outlier) of the samples + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + + Returns + ------- + y : array, shape (n_samples,) + Labels of the samples. + """ + y = (self.decision_function(X) >= 0).astype(int) + y[y == 0] = -1 # for consistency with outlier detectors + return y + + def _more_tags(self): + return { + '_xfail_checks': { + 'check_sample_weights_invariance': + 'zero sample_weight is not equivalent to removing samples', + } + } diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index d5063981ff9aa..ddffa9db563e9 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -16,8 +16,11 @@ from sklearn import linear_model, datasets, metrics from sklearn.base import clone, is_classifier +from sklearn.svm import OneClassSVM from sklearn.preprocessing import LabelEncoder, scale, MinMaxScaler from sklearn.preprocessing import StandardScaler +from sklearn.kernel_approximation import Nystroem +from sklearn.pipeline import make_pipeline from sklearn.exceptions import ConvergenceWarning from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit from sklearn.linear_model import _sgd_fast as sgd_fast @@ -68,6 +71,21 @@ def decision_function(self, X, *args, **kw): **kw) +class _SparseSGDOneClassSVM(linear_model.SGDOneClassSVM): + def fit(self, X, *args, **kw): + X = sp.csr_matrix(X) + return linear_model.SGDOneClassSVM.fit(self, X, *args, **kw) + + def partial_fit(self, X, *args, **kw): + X = sp.csr_matrix(X) + return linear_model.SGDOneClassSVM.partial_fit(self, X, *args, **kw) + + def decision_function(self, X, *args, **kw): + X = sp.csr_matrix(X) + return linear_model.SGDOneClassSVM.decision_function(self, X, *args, + **kw) + + def SGDClassifier(**kwargs): _update_kwargs(kwargs) return linear_model.SGDClassifier(**kwargs) @@ -78,6 +96,11 @@ def SGDRegressor(**kwargs): return linear_model.SGDRegressor(**kwargs) +def SGDOneClassSVM(**kwargs): + _update_kwargs(kwargs) + return linear_model.SGDOneClassSVM(**kwargs) + + def SparseSGDClassifier(**kwargs): _update_kwargs(kwargs) return _SparseSGDClassifier(**kwargs) @@ -88,6 +111,11 @@ def SparseSGDRegressor(**kwargs): return _SparseSGDRegressor(**kwargs) +def SparseSGDOneClassSVM(**kwargs): + _update_kwargs(kwargs) + return _SparseSGDOneClassSVM(**kwargs) + + # Test Data # test sample 1 @@ -250,7 +278,8 @@ def test_clone(klass): @pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, - SGDRegressor, SparseSGDRegressor]) + SGDRegressor, SparseSGDRegressor, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_plain_has_no_average_attr(klass): clf = klass(average=True, eta0=.01) clf.fit(X, Y) @@ -283,7 +312,8 @@ def test_sgd_deprecated_attr(klass): @pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, - SGDRegressor, SparseSGDRegressor]) + SGDRegressor, SparseSGDRegressor, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_late_onset_averaging_not_reached(klass): clf1 = klass(average=600) clf2 = klass() @@ -296,7 +326,11 @@ def test_late_onset_averaging_not_reached(klass): clf2.partial_fit(X, Y) assert_array_almost_equal(clf1.coef_, clf2.coef_, decimal=16) - assert_almost_equal(clf1.intercept_, clf2.intercept_, decimal=16) + if klass in [SGDClassifier, SparseSGDClassifier, SGDRegressor, + SparseSGDRegressor]: + assert_almost_equal(clf1.intercept_, clf2.intercept_, decimal=16) + elif klass in [SGDOneClassSVM, SparseSGDOneClassSVM]: + assert_almost_equal(clf1.offset_, clf2.offset_, decimal=16) @pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, @@ -441,26 +475,30 @@ def test_sgd_bad_l1_ratio(klass): assert_raises(ValueError, klass, l1_ratio=1.1) -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_sgd_bad_learning_rate_schedule(klass): # Check whether expected ValueError on bad learning_rate assert_raises(ValueError, klass, learning_rate="") -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_sgd_bad_eta0(klass): # Check whether expected ValueError on bad eta0 assert_raises(ValueError, klass, eta0=0, learning_rate="constant") -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_sgd_max_iter_param(klass): # Test parameter validity check assert_raises(ValueError, klass, max_iter=-10000) -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_sgd_shuffle_param(klass): # Test parameter validity check assert_raises(ValueError, klass, shuffle="false") @@ -484,14 +522,16 @@ def test_sgd_n_iter_no_change(klass): assert_raises(ValueError, klass, n_iter_no_change=0) -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_argument_coef(klass): # Checks coef_init not allowed as model argument (only fit) # Provided coef_ does not match dataset assert_raises(TypeError, klass, coef_init=np.zeros((3,))) -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_provide_coef(klass): # Checks coef_init shape for the warm starts # Provided coef_ does not match dataset. @@ -499,12 +539,17 @@ def test_provide_coef(klass): X, Y, coef_init=np.zeros((3,))) -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_set_intercept(klass): # Checks intercept_ shape for the warm starts # Provided intercept_ does not match dataset. - assert_raises(ValueError, klass().fit, - X, Y, intercept_init=np.zeros((3,))) + if klass in [SGDClassifier, SparseSGDClassifier]: + assert_raises(ValueError, klass().fit, + X, Y, intercept_init=np.zeros((3,))) + elif klass in [SGDOneClassSVM, SparseSGDOneClassSVM]: + assert_raises(ValueError, klass().fit, + X, Y, offset_init=np.zeros((3,))) @pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) @@ -931,10 +976,14 @@ def test_sample_weights(klass): assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([-1])) -@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) +@pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier, + SGDOneClassSVM, SparseSGDOneClassSVM]) def test_wrong_sample_weights(klass): # Test if ValueError is raised if sample_weight has wrong shape - clf = klass(alpha=0.1, max_iter=1000, fit_intercept=False) + if klass in [SGDClassifier, SparseSGDClassifier]: + clf = klass(alpha=0.1, max_iter=1000, fit_intercept=False) + elif klass in [SGDOneClassSVM, SparseSGDOneClassSVM]: + clf = klass(nu=0.1, max_iter=1000, fit_intercept=False) # provided sample_weight too long assert_raises(ValueError, clf.fit, X, Y, sample_weight=np.arange(7)) @@ -1323,6 +1372,303 @@ def test_loss_function_epsilon(klass): assert clf.loss_functions['huber'][1] == 0.1 +############################################################################### +# SGD One Class SVM Test Case + +# a simple implementation of ASGD to use for testing SGDOneClassSVM +def asgd_oneclass(klass, X, eta, nu, coef_init=None, offset_init=0.0): + if coef_init is None: + coef = np.zeros(X.shape[1]) + else: + coef = coef_init + + average_coef = np.zeros(X.shape[1]) + offset = offset_init + intercept = 1 - offset + average_intercept = 0.0 + decay = 1.0 + + # sparse data has a fixed decay of .01 + if klass == SparseSGDOneClassSVM: + decay = .01 + + for i, entry in enumerate(X): + p = np.dot(entry, coef) + p += intercept + if p <= 1.0: + gradient = -1 + else: + gradient = 0 + coef *= max(0, 1.0 - (eta * nu / 2)) + coef += -(eta * gradient * entry) + intercept += -(eta * (nu + gradient)) * decay + + average_coef *= i + average_coef += coef + average_coef /= i + 1.0 + + average_intercept *= i + average_intercept += intercept + average_intercept /= i + 1.0 + + return average_coef, 1 - average_intercept + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +@pytest.mark.parametrize('nu', [-0.5, 2]) +def test_bad_nu_values(klass, nu): + msg = r"nu must be in \(0, 1]" + assert_raises_regexp(ValueError, msg, klass, nu=nu) + + clf = klass(nu=0.05) + clf2 = clone(clf) + assert_raises_regexp(ValueError, msg, clf2.set_params, nu=nu) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def _test_warm_start_oneclass(klass, X, lr): + # Test that explicit warm restart... + clf = klass(nu=0.5, eta0=0.01, shuffle=False, + learning_rate=lr) + clf.fit(X) + + clf2 = klass(nu=0.1, eta0=0.01, shuffle=False, + learning_rate=lr) + clf2.fit(X, coef_init=clf.coef_.copy(), + offset_init=clf.offset_.copy()) + + # ... and implicit warm restart are equivalent. + clf3 = klass(nu=0.5, eta0=0.01, shuffle=False, + warm_start=True, learning_rate=lr) + clf3.fit(X) + + assert clf3.t_ == clf.t_ + assert_array_almost_equal(clf3.coef_, clf.coef_) + + clf3.set_params(nu=0.1) + clf3.fit(X) + + assert clf3.t_ == clf2.t_ + assert_array_almost_equal(clf3.coef_, clf2.coef_) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +@pytest.mark.parametrize('lr', + ["constant", "optimal", "invscaling", "adaptive"]) +def test_warm_start_oneclass(klass, lr): + _test_warm_start_oneclass(klass, X, lr) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def test_clone_oneclass(klass): + # Test whether clone works ok. + clf = klass(nu=0.5) + clf = clone(clf) + clf.set_params(nu=0.1) + clf.fit(X) + + clf2 = klass(nu=0.1) + clf2.fit(X) + + assert_array_equal(clf.coef_, clf2.coef_) + + +@ignore_warnings +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def test_partial_fit_oneclass(klass): + third = X.shape[0] // 3 + clf = klass(nu=0.1) + + clf.partial_fit(X[:third]) + assert clf.coef_.shape == (X.shape[1], ) + assert clf.offset_.shape == (1,) + assert clf.predict([[0, 0]]).shape == (1, ) + id1 = id(clf.coef_.data) + + clf.partial_fit(X[third:]) + id2 = id(clf.coef_.data) + # check that coef_ haven't been re-allocated + assert id1 == id2 + + # raises ValueError if number of features does not match previous data + assert_raises(ValueError, clf.partial_fit, X[:, 1]) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +@pytest.mark.parametrize('lr', + ["constant", "optimal", "invscaling", "adaptive"]) +def test_partial_fit_equal_fit_oneclass(klass, lr): + clf = klass(nu=0.05, max_iter=2, eta0=0.01, + learning_rate=lr, shuffle=False) + clf.fit(X) + y_scores = clf.decision_function(T) + t = clf.t_ + coef = clf.coef_ + offset = clf.offset_ + + clf = klass(nu=0.05, eta0=0.01, max_iter=1, + learning_rate=lr, shuffle=False) + for _ in range(2): + clf.partial_fit(X) + y_scores2 = clf.decision_function(T) + + assert clf.t_ == t + assert_array_almost_equal(y_scores, y_scores2) + assert_array_almost_equal(clf.coef_, coef) + assert_array_almost_equal(clf.offset_, offset) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def test_late_onset_averaging_reached_oneclass(klass): + # Test average + eta0 = .001 + nu = .05 + + # 2 passes over the training set but average only at second pass + clf1 = klass(average=7, learning_rate="constant", eta0=eta0, + nu=nu, max_iter=2, shuffle=False) + # 1 pass over the training set with no averaging + clf2 = klass(average=0, learning_rate="constant", eta0=eta0, + nu=nu, max_iter=1, shuffle=False) + + clf1.fit(X) + clf2.fit(X) + + # Start from clf2 solution, compute averaging using asgd function and + # compare with clf1 solution + average_coef, average_offset = \ + asgd_oneclass(klass, X, eta0, nu, + coef_init=clf2.coef_.ravel(), + offset_init=clf2.offset_) + + assert_array_almost_equal(clf1.coef_.ravel(), + average_coef.ravel(), + decimal=16) + assert_almost_equal(clf1.offset_, average_offset, decimal=15) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def test_sgd_averaged_computed_correctly_oneclass(klass): + # Tests the average SGD One-Class SVM matches the naive implementation + eta = .001 + nu = .05 + n_samples = 20 + n_features = 10 + rng = np.random.RandomState(0) + X = rng.normal(size=(n_samples, n_features)) + + clf = klass(learning_rate='constant', + eta0=eta, nu=nu, + fit_intercept=True, + max_iter=1, average=True, shuffle=False) + + clf.fit(X) + average_coef, average_offset = asgd_oneclass(klass, X, eta, nu) + + assert_array_almost_equal(clf.coef_, average_coef, decimal=16) + assert_almost_equal(clf.offset_, average_offset, decimal=15) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def test_sgd_averaged_partial_fit_oneclass(klass): + # Tests whether the partial fit yields the same average as the fit + eta = .001 + nu = .05 + n_samples = 20 + n_features = 10 + rng = np.random.RandomState(0) + X = rng.normal(size=(n_samples, n_features)) + + clf = klass(learning_rate='constant', + eta0=eta, nu=nu, + fit_intercept=True, + max_iter=1, average=True, shuffle=False) + + clf.partial_fit(X[:int(n_samples / 2)][:]) + clf.partial_fit(X[int(n_samples / 2):][:]) + average_coef, average_offset = asgd_oneclass(klass, X, eta, nu) + + assert_array_almost_equal(clf.coef_, average_coef, decimal=16) + assert_almost_equal(clf.offset_, average_offset, decimal=15) + + +@pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) +def test_average_sparse_oneclass(klass): + # Checks the average coef on data with 0s + eta = .001 + nu = .01 + clf = klass(learning_rate='constant', + eta0=eta, nu=nu, + fit_intercept=True, + max_iter=1, average=True, shuffle=False) + + n_samples = X3.shape[0] + + clf.partial_fit(X3[:int(n_samples / 2)]) + clf.partial_fit(X3[int(n_samples / 2):]) + average_coef, average_offset = asgd_oneclass(klass, X3, eta, nu) + + assert_array_almost_equal(clf.coef_, average_coef, decimal=16) + assert_almost_equal(clf.offset_, average_offset, decimal=15) + + +def test_sgd_oneclass(): + # Test fit, decision_function, predict and score_samples on a toy + # dataset + X_train = np.array([[-2, -1], [-1, -1], [1, 1]]) + X_test = np.array([[0.5, -2], [2, 2]]) + clf = SGDOneClassSVM(nu=0.5, eta0=1, learning_rate='constant', + shuffle=False, max_iter=1) + clf.fit(X_train) + assert_array_equal(clf.coef_, np.array([-0.125, 0.4375])) + assert clf.offset_[0] == -0.5 + + scores = clf.score_samples(X_test) + assert_array_equal(scores, np.array([-0.9375, 0.625])) + + dec = clf.score_samples(X_test) - clf.offset_ + assert_array_equal(clf.decision_function(X_test), dec) + + pred = clf.predict(X_test) + assert_array_equal(pred, np.array([-1, 1])) + + +def test_ocsvm_vs_sgdocsvm(): + # Checks SGDOneClass SVM gives a good approximation of kernelized + # One-Class SVM + nu = 0.05 + gamma = 2. + random_state = 42 + + # Generate train and test data + rng = np.random.RandomState(random_state) + X = 0.3 * rng.randn(500, 2) + X_train = np.r_[X + 2, X - 2] + X = 0.3 * rng.randn(100, 2) + X_test = np.r_[X + 2, X - 2] + + # One-Class SVM + clf = OneClassSVM(gamma=gamma, kernel='rbf', nu=nu) + clf.fit(X_train) + y_pred_ocsvm = clf.predict(X_test) + dec_ocsvm = clf.decision_function(X_test).reshape(1, -1) + + # SGDOneClassSVM using kernel approximation + max_iter = 15 + transform = Nystroem(gamma=gamma, random_state=random_state) + clf_sgd = SGDOneClassSVM(nu=nu, shuffle=True, fit_intercept=True, + max_iter=max_iter, random_state=random_state, + tol=-np.inf) + pipe_sgd = make_pipeline(transform, clf_sgd) + pipe_sgd.fit(X_train) + y_pred_sgdocsvm = pipe_sgd.predict(X_test) + dec_sgdocsvm = pipe_sgd.decision_function(X_test).reshape(1, -1) + + corrcoef = np.corrcoef(np.concatenate((dec_ocsvm, dec_sgdocsvm)))[0, 1] + assert np.mean(y_pred_sgdocsvm == y_pred_ocsvm) >= 0.99 + assert corrcoef >= 0.9 + + def test_l1_ratio(): # Test if l1 ratio extremes match L1 and L2 penalty settings. X, y = datasets.make_classification(n_samples=1000, From 7ca673e244e28e4e655ab33294b30f24434818f6 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Thu, 7 Jan 2021 22:39:44 +0100 Subject: [PATCH 02/16] trying test with almost equal --- sklearn/linear_model/tests/test_sgd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index ddffa9db563e9..fd9b5a6a5e7ce 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -1620,14 +1620,14 @@ def test_sgd_oneclass(): clf = SGDOneClassSVM(nu=0.5, eta0=1, learning_rate='constant', shuffle=False, max_iter=1) clf.fit(X_train) - assert_array_equal(clf.coef_, np.array([-0.125, 0.4375])) + assert_array_almost_equal(clf.coef_, np.array([-0.125, 0.4375])) assert clf.offset_[0] == -0.5 scores = clf.score_samples(X_test) - assert_array_equal(scores, np.array([-0.9375, 0.625])) + assert_array_almost_equal(scores, np.array([-0.9375, 0.625])) dec = clf.score_samples(X_test) - clf.offset_ - assert_array_equal(clf.decision_function(X_test), dec) + assert_array_almost_equal(clf.decision_function(X_test), dec) pred = clf.predict(X_test) assert_array_equal(pred, np.array([-1, 1])) From e1b0ffcf3dde365978c2f8be59353d4e8bc61268 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Fri, 8 Jan 2021 21:58:46 +0100 Subject: [PATCH 03/16] suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tom Dupré la Tour --- sklearn/linear_model/_stochastic_gradient.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index aa9df78dda6cf..8ce1c584f23a2 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -1891,7 +1891,7 @@ def partial_fit(self, X, y=None, sample_weight=None): Parameters ---------- X : {array-like, sparse matrix}, shape (n_samples, n_features) - Subset of the training data + Subset of the training data. sample_weight : array-like, shape (n_samples,), optional Weights applied to individual samples. @@ -1951,13 +1951,13 @@ def fit(self, X, y=None, coef_init=None, offset_init=None, Parameters ---------- X : {array-like, sparse matrix}, shape (n_samples, n_features) - Training data + Training data. coef_init : array, shape (n_classes, n_features) The initial coefficients to warm-start the optimization. offset_init : array, shape (n_classes,) - The initial offset to warm-start the optimization + The initial offset to warm-start the optimization. sample_weight : array-like, shape (n_samples,), optional Weights applied to individual samples. @@ -1986,7 +1986,8 @@ def decision_function(self, X): Parameters ---------- - X : array-like, shape (n_samples, n_features) + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Testing data. Returns ------- @@ -2003,11 +2004,12 @@ def decision_function(self, X): return decisions.ravel() def score_samples(self, X): - """Raw scoring function of the samples + """Raw scoring function of the samples. Parameters ---------- - X : array-like, shape (n_samples, n_features) + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Testing data. Returns ------- @@ -2018,11 +2020,12 @@ def score_samples(self, X): return score_samples def predict(self, X): - """Return labels (1 inlier, -1 outlier) of the samples + """Return labels (1 inlier, -1 outlier) of the samples. Parameters ---------- - X : array-like, shape (n_samples, n_features) + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Testing data. Returns ------- From 98fd4b3ed0117f3fc534e0ef7daf792279368028 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Fri, 8 Jan 2021 22:20:15 +0100 Subject: [PATCH 04/16] review --- sklearn/linear_model/_stochastic_gradient.py | 4 ++++ sklearn/svm/_classes.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 8ce1c584f23a2..3b97e88884270 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -1613,6 +1613,10 @@ def _more_tags(self): class SGDOneClassSVM(BaseSGD, OutlierMixin): """Solves linear One-Class SVM using Stochastic Gradient Descent. + This implementation is meant to be used with a kernel approximation + technique to obtain results similar to `sklearn.svm.OneClassSVM` which uses + a Gaussian kernel by default. + Read more in the :ref:`User Guide `. .. versionadded:: 1.0 diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index 908ece408bb1d..c402779f4eeb6 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -1334,6 +1334,10 @@ class OneClassSVM(OutlierMixin, BaseLibSVM): array([-1, 1, 1, 1, -1]) >>> clf.score_samples(X) array([1.7798..., 2.0547..., 2.0556..., 2.0561..., 1.7332...]) + + See also + -------- + sklearn.linear_model.SGDOneClassSVM """ _impl = 'one_class' From 3d03381685df87b2d5b0f0b2d5868ff26c4fc9a4 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 16 Mar 2021 20:00:47 +0100 Subject: [PATCH 05/16] Fix assert_raises => pytest.raises --- sklearn/linear_model/tests/test_sgd.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index d4ce88f7fd270..9ce500369d659 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -8,7 +8,6 @@ from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal -from sklearn.utils._testing import assert_raises_regexp from sklearn.utils._testing import ignore_warnings from sklearn.utils.fixes import parse_version @@ -634,10 +633,8 @@ def test_partial_fit_weight_class_balanced(klass): r"estimate the class frequency distributions\. " r"Pass the resulting weights as the class_weight " r"parameter\.") - assert_raises_regexp(ValueError, - regex, - klass(class_weight='balanced').partial_fit, - X, Y, classes=np.unique(Y)) + with pytest.raises(ValueError, match=regex): + klass(class_weight='balanced').partial_fit(X, Y, classes=np.unique(Y)) @pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier]) @@ -1433,11 +1430,13 @@ def asgd_oneclass(klass, X, eta, nu, coef_init=None, offset_init=0.0): @pytest.mark.parametrize('nu', [-0.5, 2]) def test_bad_nu_values(klass, nu): msg = r"nu must be in \(0, 1]" - assert_raises_regexp(ValueError, msg, klass, nu=nu) + with pytest.raises(ValueError, match=msg): + klass(nu=nu) clf = klass(nu=0.05) clf2 = clone(clf) - assert_raises_regexp(ValueError, msg, clf2.set_params, nu=nu) + with pytest.raises(ValueError, match=msg): + clf2.set_params(nu=nu) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1506,7 +1505,8 @@ def test_partial_fit_oneclass(klass): assert id1 == id2 # raises ValueError if number of features does not match previous data - assert_raises(ValueError, clf.partial_fit, X[:, 1]) + with pytest.raises(ValueError): + clf.partial_fit(X[:, 1]) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1739,7 +1739,8 @@ def test_underflow_or_overlow(): msg_regxp = (r"Floating-point under-/overflow occurred at epoch #.*" " Scaling input data with StandardScaler or MinMaxScaler" " might help.") - assert_raises_regexp(ValueError, msg_regxp, model.fit, X, y) + with pytest.raises(ValueError, match=msg_regxp): + model.fit(X, y) def test_numerical_stability_large_gradient(): From 6c4d42b6be5a29c7ae9305587c6e06898c24ce41 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Thu, 18 Mar 2021 22:41:26 +0100 Subject: [PATCH 06/16] Update doc/modules/sgd.rst Co-authored-by: Olivier Grisel --- doc/modules/sgd.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/modules/sgd.rst b/doc/modules/sgd.rst index df6eab3acb783..7050192b986ec 100644 --- a/doc/modules/sgd.rst +++ b/doc/modules/sgd.rst @@ -246,7 +246,8 @@ solution of a kernelized One-Class SVM, implemented in samples. Note that the complexity of a kernelized One-Class SVM is at best quadratic in the number of samples. :class:`sklearn.linear_model.SGDOneClassSVM` is thus well suited for datasets -with a large number of training samples (> 10.000). +with a large number of training samples (> 10,000) for which the SGD +variant can be several orders of magnitude faster. Its implementation is based on the implementation of the stochastic gradient descent. Indeed, the original optimization problem of the One-Class From 4ef8462a744562b6d4cf36488e2a9723f4f0d561 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Thu, 18 Mar 2021 22:47:54 +0100 Subject: [PATCH 07/16] Update sklearn/linear_model/tests/test_sgd.py Co-authored-by: Olivier Grisel --- sklearn/linear_model/tests/test_sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 9ce500369d659..bb35072d109ef 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -1679,8 +1679,8 @@ def test_ocsvm_vs_sgdocsvm(): y_pred_sgdocsvm = pipe_sgd.predict(X_test) dec_sgdocsvm = pipe_sgd.decision_function(X_test).reshape(1, -1) - corrcoef = np.corrcoef(np.concatenate((dec_ocsvm, dec_sgdocsvm)))[0, 1] assert np.mean(y_pred_sgdocsvm == y_pred_ocsvm) >= 0.99 + corrcoef = np.corrcoef(np.concatenate((dec_ocsvm, dec_sgdocsvm)))[0, 1] assert corrcoef >= 0.9 From d8038967735a29abb45de376e895e2bdc67a0830 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Thu, 18 Mar 2021 23:20:32 +0100 Subject: [PATCH 08/16] avoid insert in example --- .../miscellaneous/plot_anomaly_comparison.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/miscellaneous/plot_anomaly_comparison.py b/examples/miscellaneous/plot_anomaly_comparison.py index dc0d9f23f8315..c0c3a4f890923 100644 --- a/examples/miscellaneous/plot_anomaly_comparison.py +++ b/examples/miscellaneous/plot_anomaly_comparison.py @@ -90,23 +90,22 @@ n_outliers = int(outliers_fraction * n_samples) n_inliers = n_samples - n_outliers -# define outlier/anomaly detection methods to be compared +# define outlier/anomaly detection methods to be compared. +# the SGDOneClassSVM must be used in a pipeline with a kernel approximation +# to give similar results to the OneClassSVM anomaly_algorithms = [ ("Robust covariance", EllipticEnvelope(contamination=outliers_fraction)), ("One-Class SVM", svm.OneClassSVM(nu=outliers_fraction, kernel="rbf", gamma=0.1)), + ("One-Class SVM (SGD)", make_pipeline( + Nystroem(gamma=0.1, random_state=42, n_components=150), + SGDOneClassSVM(nu=outliers_fraction, shuffle=True, + fit_intercept=True, random_state=42, tol=1e-6) + )), ("Isolation Forest", IsolationForest(contamination=outliers_fraction, random_state=42)), ("Local Outlier Factor", LocalOutlierFactor( n_neighbors=35, contamination=outliers_fraction))] -# SGDOneClassSVM must be used with a kernel approximation to give similar -# results to the OneClassSVM -transform = Nystroem(gamma=0.1, random_state=42, n_components=150) -clf_sgd = SGDOneClassSVM(nu=outliers_fraction, shuffle=True, - fit_intercept=True, random_state=42, tol=1e-6) -pipe_sgd = make_pipeline(transform, clf_sgd) -anomaly_algorithms.insert(2, ("One-Class SVM (SGD)", pipe_sgd)) - # Define datasets blobs_params = dict(random_state=0, n_samples=n_inliers, n_features=2) From 731b6867fc59e812f41e2c460435ca66909ea75c Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Fri, 19 Mar 2021 09:28:47 +0100 Subject: [PATCH 09/16] whats_new entry --- doc/whats_new/v1.0.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 89280c7f01d0d..87a3951d990ab 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -113,6 +113,13 @@ Changelog :mod:`sklearn.linear_model` ........................... +- |Feature| The new :class:`linear_model.SGDOneClassSVM` provides an SGD + implementation of the linear One-Class SVM. Combined with kernel + approximation techniques, this implementation approximates the solution of + a kernelized One Class SVM while benefitting from a linear + complexity in the number of samples. + :pr:`10027` by :user:`Albert Thomas `. + - |Efficiency| The implementation of :class:`linear_model.LogisticRegression` has been optimised for dense matrices when using `solver='newton-cg'` and `multi_class!='multinomial'`. From 08d2721627af31c6131b1f744767a12839ee0705 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Sat, 20 Mar 2021 00:21:19 +0100 Subject: [PATCH 10/16] add note --- sklearn/linear_model/_stochastic_gradient.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 47358dbd37c07..f2143f7531ad2 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -1743,6 +1743,12 @@ class SGDOneClassSVM(BaseSGD, OutlierMixin): -------- sklearn.svm.OneClassSVM + Notes + ----- + This estimator has a linear complexity in the number of training samples + and is thus better suited than the `sklearn.svm.OneClassSVM` + implementation for datasets with a large number of training samples (say + > 10,000). """ loss_functions = {"hinge": (Hinge, 1.0)} From d3f1a7d8dc4ef91ece38bd97e86c09de72c4f59a Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Sat, 20 Mar 2021 00:26:18 +0100 Subject: [PATCH 11/16] int32 --- sklearn/linear_model/_stochastic_gradient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index f2143f7531ad2..fb13abf222595 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -2052,7 +2052,7 @@ def predict(self, X): y : array, shape (n_samples,) Labels of the samples. """ - y = (self.decision_function(X) >= 0).astype(int) + y = (self.decision_function(X) >= 0).astype(np.int32) y[y == 0] = -1 # for consistency with outlier detectors return y From 715e6b6228c22c117fdae80593a9c2bb88638acd Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Sun, 21 Mar 2021 22:51:07 +0100 Subject: [PATCH 12/16] explicit ddocstring --- sklearn/linear_model/_stochastic_gradient.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index fb13abf222595..981e48d2d160e 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -1624,8 +1624,9 @@ class SGDOneClassSVM(BaseSGD, OutlierMixin): """Solves linear One-Class SVM using Stochastic Gradient Descent. This implementation is meant to be used with a kernel approximation - technique to obtain results similar to `sklearn.svm.OneClassSVM` which uses - a Gaussian kernel by default. + technique (e.g. `sklearn.kernel_approximation.Nystroem`) to obtain results + similar to `sklearn.svm.OneClassSVM` which uses a Gaussian kernel by + default. Read more in the :ref:`User Guide `. From ba53ef236c8cbec46273a9e7b30e243ec01e813c Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Sun, 21 Mar 2021 23:11:45 +0100 Subject: [PATCH 13/16] rm useless ignore warnings in tests --- sklearn/linear_model/tests/test_sgd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index bb35072d109ef..e5d1ca15bfd7e 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -1487,7 +1487,6 @@ def test_clone_oneclass(klass): assert_array_equal(clf.coef_, clf2.coef_) -@ignore_warnings @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) def test_partial_fit_oneclass(klass): third = X.shape[0] // 3 From 47e5978397254e3dfb389bc56c130112ee224ae4 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Mon, 22 Mar 2021 14:48:10 +0100 Subject: [PATCH 14/16] use assert_allclose --- sklearn/linear_model/tests/test_sgd.py | 35 +++++++++++++------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index e5d1ca15bfd7e..5e0f778096aa2 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -5,6 +5,7 @@ import scipy.sparse as sp import joblib +from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal @@ -1457,13 +1458,13 @@ def _test_warm_start_oneclass(klass, X, lr): clf3.fit(X) assert clf3.t_ == clf.t_ - assert_array_almost_equal(clf3.coef_, clf.coef_) + assert_allclose(clf3.coef_, clf.coef_) clf3.set_params(nu=0.1) clf3.fit(X) assert clf3.t_ == clf2.t_ - assert_array_almost_equal(clf3.coef_, clf2.coef_) + assert_allclose(clf3.coef_, clf2.coef_) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1527,9 +1528,9 @@ def test_partial_fit_equal_fit_oneclass(klass, lr): y_scores2 = clf.decision_function(T) assert clf.t_ == t - assert_array_almost_equal(y_scores, y_scores2) - assert_array_almost_equal(clf.coef_, coef) - assert_array_almost_equal(clf.offset_, offset) + assert_allclose(y_scores, y_scores2) + assert_allclose(clf.coef_, coef) + assert_allclose(clf.offset_, offset) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1555,10 +1556,8 @@ def test_late_onset_averaging_reached_oneclass(klass): coef_init=clf2.coef_.ravel(), offset_init=clf2.offset_) - assert_array_almost_equal(clf1.coef_.ravel(), - average_coef.ravel(), - decimal=16) - assert_almost_equal(clf1.offset_, average_offset, decimal=15) + assert_allclose(clf1.coef_.ravel(), average_coef.ravel()) + assert_allclose(clf1.offset_, average_offset) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1579,8 +1578,8 @@ def test_sgd_averaged_computed_correctly_oneclass(klass): clf.fit(X) average_coef, average_offset = asgd_oneclass(klass, X, eta, nu) - assert_array_almost_equal(clf.coef_, average_coef, decimal=16) - assert_almost_equal(clf.offset_, average_offset, decimal=15) + assert_allclose(clf.coef_, average_coef) + assert_allclose(clf.offset_, average_offset) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1602,8 +1601,8 @@ def test_sgd_averaged_partial_fit_oneclass(klass): clf.partial_fit(X[int(n_samples / 2):][:]) average_coef, average_offset = asgd_oneclass(klass, X, eta, nu) - assert_array_almost_equal(clf.coef_, average_coef, decimal=16) - assert_almost_equal(clf.offset_, average_offset, decimal=15) + assert_allclose(clf.coef_, average_coef) + assert_allclose(clf.offset_, average_offset) @pytest.mark.parametrize('klass', [SGDOneClassSVM, SparseSGDOneClassSVM]) @@ -1622,8 +1621,8 @@ def test_average_sparse_oneclass(klass): clf.partial_fit(X3[int(n_samples / 2):]) average_coef, average_offset = asgd_oneclass(klass, X3, eta, nu) - assert_array_almost_equal(clf.coef_, average_coef, decimal=16) - assert_almost_equal(clf.offset_, average_offset, decimal=15) + assert_allclose(clf.coef_, average_coef) + assert_allclose(clf.offset_, average_offset) def test_sgd_oneclass(): @@ -1634,14 +1633,14 @@ def test_sgd_oneclass(): clf = SGDOneClassSVM(nu=0.5, eta0=1, learning_rate='constant', shuffle=False, max_iter=1) clf.fit(X_train) - assert_array_almost_equal(clf.coef_, np.array([-0.125, 0.4375])) + assert_allclose(clf.coef_, np.array([-0.125, 0.4375])) assert clf.offset_[0] == -0.5 scores = clf.score_samples(X_test) - assert_array_almost_equal(scores, np.array([-0.9375, 0.625])) + assert_allclose(scores, np.array([-0.9375, 0.625])) dec = clf.score_samples(X_test) - clf.offset_ - assert_array_almost_equal(clf.decision_function(X_test), dec) + assert_allclose(clf.decision_function(X_test), dec) pred = clf.predict(X_test) assert_array_equal(pred, np.array([-1, 1])) From a663430b934b7a696906295cd899fe1e66f8b9fb Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Mon, 22 Mar 2021 15:19:53 +0100 Subject: [PATCH 15/16] remove assert_allclose import --- sklearn/linear_model/tests/test_sgd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 728fb9656ff8f..ee23445ff2e6c 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -6,7 +6,6 @@ import scipy.sparse as sp import joblib -from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal From a1a70f729a05a6c23034a24795ce6fd06ae4cd70 Mon Sep 17 00:00:00 2001 From: Albert Thomas Date: Mon, 22 Mar 2021 15:24:17 +0100 Subject: [PATCH 16/16] remaining almost_equal --- sklearn/linear_model/tests/test_sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index ee23445ff2e6c..f943592c02005 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -331,7 +331,7 @@ def test_late_onset_averaging_not_reached(klass): SparseSGDRegressor]: assert_almost_equal(clf1.intercept_, clf2.intercept_, decimal=16) elif klass in [SGDOneClassSVM, SparseSGDOneClassSVM]: - assert_almost_equal(clf1.offset_, clf2.offset_, decimal=16) + assert_allclose(clf1.offset_, clf2.offset_) @pytest.mark.parametrize('klass', [SGDClassifier, SparseSGDClassifier,