Thanks to visit codestin.com
Credit goes to github.com

Skip to content

DOC Improve narrative of plot_roc_crossval example #24710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 47 additions & 34 deletions examples/model_selection/plot_roc_crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,77 +3,88 @@
Receiver Operating Characteristic (ROC) with cross validation
=============================================================

Example of Receiver Operating Characteristic (ROC) metric to evaluate
classifier output quality using cross-validation.
This example presents how to estimate and visualize the variance of the Receiver
Operating Characteristic (ROC) metric using cross-validation.

ROC curves typically feature true positive rate on the Y axis, and false
positive rate on the X axis. This means that the top left corner of the plot is
the "ideal" point - a false positive rate of zero, and a true positive rate of
one. This is not very realistic, but it does mean that a larger area under the
curve (AUC) is usually better.

The "steepness" of ROC curves is also important, since it is ideal to maximize
the true positive rate while minimizing the false positive rate.
ROC curves typically feature true positive rate (TPR) on the Y axis, and false
positive rate (FPR) on the X axis. This means that the top left corner of the
plot is the "ideal" point - a FPR of zero, and a TPR of one. This is not very
realistic, but it does mean that a larger Area Under the Curve (AUC) is usually
better. The "steepness" of ROC curves is also important, since it is ideal to
maximize the TPR while minimizing the FPR.

This example shows the ROC response of different datasets, created from K-fold
cross-validation. Taking all of these curves, it is possible to calculate the
mean area under curve, and see the variance of the curve when the
mean AUC, and see the variance of the curve when the
training set is split into different subsets. This roughly shows how the
classifier output is affected by changes in the training data, and how
different the splits generated by K-fold cross-validation are from one another.
classifier output is affected by changes in the training data, and how different
the splits generated by K-fold cross-validation are from one another.

.. note::

See also :func:`sklearn.metrics.roc_auc_score`,
:func:`sklearn.model_selection.cross_val_score`,
:ref:`sphx_glr_auto_examples_model_selection_plot_roc.py`,

See :ref:`sphx_glr_auto_examples_model_selection_plot_roc.py` for a
complement of the present example explaining the averaging strategies to
generalize the metrics for multiclass classifiers.
"""

# %%
# Data IO and generation
# ----------------------
import numpy as np
# Load and prepare data
# =====================
#
# We import the :ref:`iris_dataset` which contains 3 classes, each one
# corresponding to a type of iris plant. One class is linearly separable from
# the other 2; the latter are **not** linearly separable from each other.
#
# In the following we binarize the dataset by dropping the "virginica" class
# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is
# regarded as the positive class and "setosa" as the negative class
# (`class_id=0`).

from sklearn import datasets
import numpy as np
from sklearn.datasets import load_iris

# Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape

# Add noisy features
# %%
# We also add noisy features to make the problem harder.
random_state = np.random.RandomState(0)
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)

# %%
# Classification and ROC analysis
# -------------------------------
#
# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and
# plot the ROC curves fold-wise. Notice that the baseline to define the chance
# level (dashed ROC curve) is a classifier that would always predict the most
# frequent class.

import matplotlib.pyplot as plt

from sklearn import svm
from sklearn.metrics import auc
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import StratifiedKFold

# Run classifier with cross-validation and plot ROC curves
cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)

tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)

fig, ax = plt.subplots()
for i, (train, test) in enumerate(cv.split(X, y)):
fig, ax = plt.subplots(figsize=(6, 6))
for fold, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = RocCurveDisplay.from_estimator(
classifier,
X[test],
y[test],
name="ROC fold {}".format(i),
name=f"ROC fold {fold}",
alpha=0.3,
lw=1,
ax=ax,
Expand All @@ -82,8 +93,7 @@
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.roc_auc)

ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8)
ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
Expand Down Expand Up @@ -113,7 +123,10 @@
ax.set(
xlim=[-0.05, 1.05],
ylim=[-0.05, 1.05],
title="Receiver operating characteristic example",
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
)
ax.axis("square")
ax.legend(loc="lower right")
plt.show()