Thanks to visit codestin.com
Credit goes to github.com

Skip to content
/ tabnet Public

❗ This is a read-only mirror of the CRAN R package repository. tabnet — Fit 'TabNet' Models for Classification and Regression. Homepage: https://mlverse.github.io/tabnet/https://github.com/mlverse/tabnet Report bugs for this package: https://github.com/mlverse/tabnet/issues

License

Notifications You must be signed in to change notification settings

cran/tabnet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tabnet

R build status Lifecycle: experimental

An R implementation of: TabNet: Attentive Interpretable Tabular Learning. The code in this repository is an R port of dreamquark-ai/tabnet PyTorch’s implementation using the torch package.

Installation

You can install the development version from GitHub with:

# install.packages("devtools")
devtools::install_github("mlverse/tabnet")

Example

library(tabnet)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step
library(yardstick)
#> For binary classification, the first factor level is assumed to be the event.
#> Use the argument `event_level = "second"` to alter this as needed.
set.seed(1)

data("attrition", package = "modeldata")
test_idx <- sample.int(nrow(attrition), size = 0.2 * nrow(attrition))

train <- attrition[-test_idx,]
test <- attrition[test_idx,]

rec <- recipe(Attrition ~ ., data = train) %>% 
  step_normalize(all_numeric(), -all_outcomes())

fit <- tabnet_fit(rec, train, epochs = 30)

metrics <- metric_set(accuracy, precision, recall)
cbind(test, predict(fit, test)) %>% 
  metrics(Attrition, estimate = .pred_class)
#> # A tibble: 3 x 3
#>   .metric   .estimator .estimate
#>   <chr>     <chr>          <dbl>
#> 1 accuracy  binary         0.847
#> 2 precision binary         0.848
#> 3 recall    binary         0.996
  
cbind(test, predict(fit, test, type = "prob")) %>% 
  roc_auc(Attrition, .pred_No)
#> # A tibble: 1 x 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 roc_auc binary         0.705

About

❗ This is a read-only mirror of the CRAN R package repository. tabnet — Fit 'TabNet' Models for Classification and Regression. Homepage: https://mlverse.github.io/tabnet/https://github.com/mlverse/tabnet Report bugs for this package: https://github.com/mlverse/tabnet/issues

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •