Thanks to visit codestin.com
Credit goes to github.com

Skip to content

mlr-org/mlr3torch

 
 

Repository files navigation

mlr3torch

Lifecycle: experimental R-CMD-check CRAN status

The goal of {mlr3torch} is to connect {mlr3} with {torch}.

It is in the very early stages of development and it’s future and scope are yet to be determined.

Installation

remotes::install_github("mlr-org/mlr3torch")

tabnet Example

Using the {tabnet} learner for classification:

Credit

This API is heavily inspired by:

  • Keras
library(mlr3)
library(mlr3viz)
library(mlr3torch)

task <- tsk("german_credit")

# Set up the learner
lrn_tabnet <- lrn("classif.tabnet", epochs = 5)

# Train and Predict
lrn_tabnet$train(task, row_ids = 1:900)

preds <- lrn_tabnet$predict(task, row_ids = 901:1000)

# Investigate predictions
preds$confusion
preds$score(msr("classif.acc"))

# Predict probabilities instead
lrn_tabnet$predict_type <- "prob"
preds_prob <- lrn_tabnet$predict(task)

autoplot(preds_prob, type = "roc")

# Examine variable importance scores
lrn_tabnet$importance()

About

Deep learning framework for the mlr3 ecosystem based on torch

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Sponsor this project

 

Packages

No packages published

Contributors 9

Languages