The goal of {mlr3torch} is to connect {mlr3} with {torch}.
It is in the very early stages of development and it’s future and scope are yet to be determined.
remotes::install_github("mlr-org/mlr3torch")Using the {tabnet} learner for classification:
This API is heavily inspired by:
- Keras
library(mlr3)
library(mlr3viz)
library(mlr3torch)
task <- tsk("german_credit")
# Set up the learner
lrn_tabnet <- lrn("classif.tabnet", epochs = 5)
# Train and Predict
lrn_tabnet$train(task, row_ids = 1:900)
preds <- lrn_tabnet$predict(task, row_ids = 901:1000)
# Investigate predictions
preds$confusion
preds$score(msr("classif.acc"))
# Predict probabilities instead
lrn_tabnet$predict_type <- "prob"
preds_prob <- lrn_tabnet$predict(task)
autoplot(preds_prob, type = "roc")
# Examine variable importance scores
lrn_tabnet$importance()