Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@tdhock
Copy link
Contributor

@tdhock tdhock commented Jul 29, 2025

closes #417

In current master, the code below gives an error. Using this branch, it works (no error).

remotes::install_github("tdhock/mlr3torch@fix-sampler")
rev_sampler_class <- torch::sampler(
  "MySampler",
  initialize = function(data_source) {
    print('init')
    self$data_source <- data_source
  },
  .iter = function() {
    count <<- 0L
    function() {
      if (count < length(self$data_source)) {
        idx <- length(self$data_source)-count
        count <<- count + 1L
        return(idx)
      }
      coro::exhausted()
    }
  },
  .length = function() {
    length(self$data_source)
  }
)
sonar_task <- mlr3::tsk("sonar")
mlp_learner <- mlr3torch::LearnerTorchMLP$new(task_type="classif")
mlp_learner$param_set$set_values(
  epochs=10,
  batch_size=20,
  sampler=rev_sampler_class)
mlp_learner$train(sonar_task)

@sebffischer can you please review and tell me if this change would be acceptable?
If so I can add documentation and tests.

I wonder if we should do the same for param batch_sampler ? What is the difference with param sampler ?

@tdhock
Copy link
Contributor Author

tdhock commented Jul 29, 2025

I wonder if we should do the same for param batch_sampler ? What is the difference with param sampler ?

  • sampler .iter() method returns one index (one data point), and can be used in combination with batch_size arg. can also define .iter_batch(batch_size) for efficiency.
  • batch_sampler .iter() method returns a vector of indices (batch of data), and can not be used in combination with batch_size arg.

@sebffischer
Copy link
Member

Hey toby, thanks for opening this PR!
I think it's a good idea to also do this for batch sampler. It would be great if you could add the documentation and tests, then I will merge it.

Copy link
Member

@sebffischer sebffischer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add a test please?

"worker_packages"
)
args = param_vals[names(param_vals) %in% dl_args]
for(param_name in c("sampler", "batch_sampler")){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that only one is present

Co-authored-by: Sebastian Fischer <[email protected]>
@sebffischer sebffischer merged commit 8a6ad29 into mlr-org:main Aug 1, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TorchLearner: support / example for custom sampler

2 participants