Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9d9a965

Browse files
committed
add back file
1 parent e3c82bb commit 9d9a965

File tree

1 file changed

+220
-0
lines changed

1 file changed

+220
-0
lines changed

r/2021-07-26-ml-roc-pr.Rmd

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
---
2+
description: Interpret the results of your classification using Receiver Operating Characteristics (ROC) and Precision-Recall (PR) Curves in R with Plotly.
3+
display_as: ai_ml
4+
language: r
5+
layout: base
6+
name: ROC and PR Curves
7+
order: 3
8+
output:
9+
html_document:
10+
keep_md: true
11+
permalink: r/roc-and-pr-curves/
12+
thumbnail: thumbnail/ml-roc-pr.png
13+
---
14+
15+
```{r, echo = FALSE, message=FALSE}
16+
knitr::opts_chunk$set(message = FALSE, warning = FALSE)
17+
```
18+
## ROC and PR Curves in R
19+
20+
Interpret the results of your classification using Receiver Operating Characteristics (ROC) and Precision-Recall (PR) Curves in R with Plotly.
21+
22+
## Preliminary plots
23+
24+
Before diving into the receiver operating characteristic (ROC) curve, we will look at two plots that will give some context to the thresholds mechanism behind the ROC and PR curves.
25+
26+
In the histogram, we observe that the score spread such that most of the positive labels are binned near 1, and a lot of the negative labels are close to 0. When we set a threshold on the score, all of the bins to its left will be classified as 0's, and everything to the right will be 1's. There are obviously a few outliers, such as **negative** samples that our model gave a high score, and *positive* samples with a low score. If we set a threshold right in the middle, those outliers will respectively become **false positives** and *false negatives*.
27+
28+
As we adjust thresholds, the number of false positives will increase or decrease, and at the same time the number of true positives will also change; this is shown in the second plot. As you can see, the model seems to perform fairly well, because the true positive rate and the false positive rate decreases sharply as we increase the threshold. Those two lines each represent a dimension of the ROC curve.
29+
30+
31+
```{r}
32+
library(plotly)
33+
library(tidymodels)
34+
set.seed(0)
35+
X <- matrix(rnorm(10000),nrow=500)
36+
y <- sample(0:1, 500, replace=TRUE)
37+
data <- data.frame(X,y)
38+
data$y <- as.factor(data$y)
39+
X <- subset(data,select = -c(y))
40+
logistic_glm <-
41+
logistic_reg() %>%
42+
set_engine("glm") %>%
43+
set_mode("classification") %>%
44+
fit(y ~ ., data = data)
45+
46+
y_scores <- logistic_glm %>%
47+
predict(X, type = 'prob')
48+
49+
y_score <- y_scores$.pred_1
50+
db <- data.frame(data$y, y_score)
51+
52+
z <- roc_curve(data = db, 'data.y', 'y_score')
53+
z$specificity <- 1 - z$specificity
54+
colnames(z) <- c('threshold', 'tpr', 'fpr')
55+
56+
fig1 <- plot_ly(x= y_score, color = data$y, colors = c('blue', 'red'), type = 'histogram', alpha = 0.5, nbinsx = 50) %>%
57+
layout(barmode = "overlay")
58+
fig1
59+
60+
fig2 <- plot_ly(data = z, x = ~threshold) %>%
61+
add_trace(y = ~fpr, mode = 'lines', name = 'False Positive Rate', type = 'scatter')%>%
62+
add_trace(y = ~tpr, mode = 'lines', name = 'True Positive Rate', type = 'scatter')%>%
63+
layout(title = 'TPR and FPR at every threshold')
64+
fig2 <- fig2 %>% layout(legend=list(title=list(text='<b> Rate </b>')))
65+
fig2
66+
```
67+
68+
## Multiclass ROC Curve
69+
70+
When you have more than 2 classes, you will need to plot the ROC curve for each class separately. Make sure that you use a [one-versus-rest](https://cran.r-project.org/web/packages/multiclassPairs/vignettes/Tutorial.html) model, or make sure that your problem has a multi-label format; otherwise, your ROC curve might not return the expected results.
71+
72+
```{r}
73+
library(plotly)
74+
library(tidymodels)
75+
library(fastDummies)
76+
77+
# Artificially add noise to make task harder
78+
data(iris)
79+
ind <- sample.int(150, 50)
80+
samples <- sample(x = iris$Species, size = 50)
81+
iris[ind,'Species'] = samples
82+
83+
# Define the inputs and outputs
84+
X <- subset(iris, select = -c(Species))
85+
iris$Species <- as.factor(iris$Species)
86+
87+
# Fit the model
88+
logistic <-
89+
multinom_reg() %>%
90+
set_engine("nnet") %>%
91+
set_mode("classification") %>%
92+
fit(Species ~ ., data = iris)
93+
94+
y_scores <- logistic %>%
95+
predict(X, type = 'prob')
96+
97+
# One hot encode the labels in order to plot them
98+
y_onehot <- dummy_cols(iris$Species)
99+
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
100+
y_onehot <- subset(y_onehot, select = -c(drop))
101+
102+
z = cbind(y_scores, y_onehot)
103+
104+
z$setosa <- as.factor(z$setosa)
105+
roc_setosa <- roc_curve(data = z, setosa, .pred_setosa)
106+
roc_setosa$specificity <- 1 - roc_setosa$specificity
107+
colnames(roc_setosa) <- c('threshold', 'tpr', 'fpr')
108+
auc_setosa <- roc_auc(data = z, setosa, .pred_setosa)
109+
auc_setosa <- auc_setosa$.estimate
110+
setosa <- paste('setosa (AUC=',toString(round(1-auc_setosa,2)),')',sep = '')
111+
112+
z$versicolor <- as.factor(z$versicolor)
113+
roc_versicolor <- roc_curve(data = z, versicolor, .pred_versicolor)
114+
roc_versicolor$specificity <- 1 - roc_versicolor$specificity
115+
colnames(roc_versicolor) <- c('threshold', 'tpr', 'fpr')
116+
auc_versicolor <- roc_auc(data = z, versicolor, .pred_versicolor)
117+
auc_versicolor <- auc_versicolor$.estimate
118+
versicolor <- paste('versicolor (AUC=',toString(round(1-auc_versicolor,2)),')', sep = '')
119+
120+
z$virginica <- as.factor(z$virginica)
121+
roc_virginica <- roc_curve(data = z, virginica, .pred_virginica)
122+
roc_virginica$specificity <- 1 - roc_virginica$specificity
123+
colnames(roc_virginica) <- c('threshold', 'tpr', 'fpr')
124+
auc_virginica <- roc_auc(data = z, virginica, .pred_virginica)
125+
auc_virginica <- auc_virginica$.estimate
126+
virginica <- paste('virginica (AUC=',toString(round(1-auc_virginica,2)),')',sep = '')
127+
128+
# Create an empty figure, and iteratively add a line for each class
129+
fig <- plot_ly()%>%
130+
add_segments(x = 0, xend = 1, y = 0, yend = 1, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
131+
add_trace(data = roc_setosa,x = ~fpr, y = ~tpr, mode = 'lines', name = setosa, type = 'scatter')%>%
132+
add_trace(data = roc_versicolor,x = ~fpr, y = ~tpr, mode = 'lines', name = versicolor, type = 'scatter')%>%
133+
add_trace(data = roc_virginica,x = ~fpr, y = ~tpr, mode = 'lines', name = virginica, type = 'scatter')%>%
134+
layout(xaxis = list(
135+
title = "False Positive Rate"
136+
), yaxis = list(
137+
title = "True Positive Rate"
138+
),legend = list(x = 100, y = 0.5))
139+
fig
140+
141+
```
142+
143+
In this example, we use the average precision metric, which is an alternative scoring method to the area under the PR curve.
144+
145+
```{r}
146+
library(plotly)
147+
library(tidymodels)
148+
library(fastDummies)
149+
150+
# Artificially add noise to make task harder
151+
data(iris)
152+
ind <- sample.int(150, 50)
153+
samples <- sample(x = iris$Species, size = 50)
154+
iris[ind,'Species'] = samples
155+
156+
# Define the inputs and outputs
157+
X <- subset(iris, select = -c(Species))
158+
iris$Species <- as.factor(iris$Species)
159+
160+
# Fit the model
161+
logistic <-
162+
multinom_reg() %>%
163+
set_engine("nnet") %>%
164+
set_mode("classification") %>%
165+
fit(Species ~ ., data = iris)
166+
167+
y_scores <- logistic %>%
168+
predict(X, type = 'prob')
169+
170+
y_onehot <- dummy_cols(iris$Species)
171+
colnames(y_onehot) <- c('drop', 'setosa', 'versicolor', 'virginica')
172+
y_onehot <- subset(y_onehot, select = -c(drop))
173+
174+
z = cbind(y_scores, y_onehot)
175+
176+
z$setosa <- as.factor(z$setosa)
177+
pr_setosa <- pr_curve(data = z, setosa, .pred_setosa)
178+
aps_setosa <- mean(pr_setosa$precision)
179+
setosa <- paste('setosa (AP =',toString(round(aps_setosa,2)),')',sep = '')
180+
181+
182+
z$versicolor <- as.factor(z$versicolor)
183+
pr_versicolor <- pr_curve(data = z, versicolor, .pred_versicolor)
184+
aps_versicolor <- mean(pr_versicolor$precision)
185+
versicolor <- paste('versicolor (AP = ',toString(round(aps_versicolor,2)),')',sep = '')
186+
187+
z$virginica <- as.factor(z$virginica)
188+
pr_virginica <- pr_curve(data = z, virginica, .pred_virginica)
189+
aps_virginica <- mean(pr_virginica$precision)
190+
virginica <- paste('virginica (AP = ',toString(round(aps_virginica,2)),')',sep = '')
191+
192+
# Create an empty figure, and add a new line for each class
193+
fig <- plot_ly()%>%
194+
add_segments(x = 0, xend = 1, y = 1, yend = 0, line = list(dash = "dash", color = 'black'), showlegend = FALSE) %>%
195+
add_trace(data = pr_setosa,x = ~recall, y = ~precision, mode = 'lines', name = setosa, type = 'scatter')%>%
196+
add_trace(data = pr_versicolor,x = ~recall, y = ~precision, mode = 'lines', name = versicolor, type = 'scatter')%>%
197+
add_trace(data = pr_virginica,x = ~recall, y = ~precision, mode = 'lines', name = virginica, type = 'scatter')%>%
198+
layout(xaxis = list(
199+
title = "Recall"
200+
), yaxis = list(
201+
title = "Precision"
202+
),legend = list(x = 100, y = 0.5))
203+
fig
204+
```
205+
206+
207+
## References
208+
209+
210+
Learn more about histograms, filled area plots and line charts:
211+
212+
* https://plot.ly/r/histograms/
213+
214+
* https://plot.ly/r/filled-area-plots/
215+
216+
* https://plot.ly/r/line-charts/
217+
218+
219+
220+

0 commit comments

Comments
 (0)