From 096ed4ca455bb65a36e6a41a6948b2c7677d950e Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 7 Feb 2025 12:24:01 +0100 Subject: [PATCH 1/4] fix(learner): `LearnerTorchModel` now works in parallel settings Because `LearnerTorchModel` stored an instantiated `nn_module`, this caused issues when transferring the learner to the worker, e.g. when using callr encapsulation. This is now solved by serializing the network and unserializing it during training. --- NEWS.md | 3 +++ R/LearnerTorchModel.R | 15 +++------------ R/PipeOpTorchModel.R | 2 +- tests/testthat/test_LearnerTorchModel.R | 18 ++++++++++++++++++ 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/NEWS.md b/NEWS.md index c17af6761..f71e8a786 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # mlr3torch dev +* fix: `LearnerTorchModel` can now be parallelized and trained with + encapsulation activated. + # mlr3torch 0.2.0 ## Breaking Changes diff --git a/R/LearnerTorchModel.R b/R/LearnerTorchModel.R index 6dd43032a..dd5f36436 100644 --- a/R/LearnerTorchModel.R +++ b/R/LearnerTorchModel.R @@ -60,8 +60,8 @@ LearnerTorchModel = R6Class("LearnerTorchModel", #' Creates a new instance of this [R6][R6::R6Class] class. initialize = function(network = NULL, ingress_tokens = NULL, task_type, properties = NULL, optimizer = NULL, loss = NULL, callbacks = list(), packages = character(0), feature_types = NULL) { - # TODO: What about the learner properties? - if (!is.null(network)) self$network_stored = network + # we need to serialize here as otherwise encapsulation and parallelization fails + if (!is.null(network)) private$.network_stored = torch_serialize(assert_class(network, "nn_module")) if (!is.null(ingress_tokens)) self$ingress_tokens = ingress_tokens if (is.null(feature_types)) { feature_types = unname(mlr_reflections$task_feature_types) @@ -89,15 +89,6 @@ LearnerTorchModel = R6Class("LearnerTorchModel", } ), active = list( - #' @field network_stored (`nn_module` or `NULL`)\cr - #' The network that will be trained. - #' After calling `$train()`, this is `NULL`. - network_stored = function(rhs) { - if (!missing(rhs)) { - private$.network_stored = assert_class(rhs, "nn_module") - } - private$.network_stored - }, #' @field ingress_tokens (named `list()` with `TorchIngressToken` or `NULL`)\cr #' The ingress tokens. Must be non-`NULL` when calling `$train()`. ingress_tokens = function(rhs) { @@ -121,7 +112,7 @@ LearnerTorchModel = R6Class("LearnerTorchModel", if (is.null(private$.network_stored)) { stopf("No network stored, did you already train learner '%s' or did not specify a model?", self$id) } - network = private$.network_stored + network = torch_load(private$.network_stored) private$.network_stored = NULL network }, diff --git a/R/PipeOpTorchModel.R b/R/PipeOpTorchModel.R index 0c5828e9d..bd4d09390 100644 --- a/R/PipeOpTorchModel.R +++ b/R/PipeOpTorchModel.R @@ -69,7 +69,7 @@ PipeOpTorchModel = R6Class("PipeOpTorchModel", output_pointers = list(md$pointer), list_output = FALSE ) - private$.learner$network_stored = network + get_private(private$.learner, ".network_stored") = torch_serialize(network) private$.learner$ingress_tokens = md$ingress if (is.null(md$loss)) { diff --git a/tests/testthat/test_LearnerTorchModel.R b/tests/testthat/test_LearnerTorchModel.R index ac04062eb..2354142ea 100644 --- a/tests/testthat/test_LearnerTorchModel.R +++ b/tests/testthat/test_LearnerTorchModel.R @@ -66,3 +66,21 @@ test_that("marshaling works for graph learner", { learner$unmarshal() expect_class(learner$predict(task), "Prediction") }) + +test_that("LearnerTorchModel and marshaling", { + # there used to be a marshaling bug resulting from the fact that composed network + # is stored in the learner (not part of the model) + task = tsk("iris") + learner = LearnerTorchModel$new( + task_type = "classif", + network = testmodule_linear(task), + ingress_tokens = list(x = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4L))), + packages = "data.table", + ) + learner$configure( + batch_size = 50, + epochs = 1 + ) + learner$train(task) + expect_class(learner$model, "learner_torch_model") +}) From 9e07d252799e453b187d68ff21d2795f69881977 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 7 Feb 2025 12:34:06 +0100 Subject: [PATCH 2/4] fix test, optimization --- NEWS.md | 4 ++-- R/LearnerTorchModel.R | 7 ++++++- R/PipeOpTorchModel.R | 4 +++- tests/testthat/test_LearnerTorchModel.R | 1 + 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/NEWS.md b/NEWS.md index f71e8a786..c287c99bc 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,9 +18,9 @@ * Optimizers now use the faster ('ignite') version of the optimizers, which leads to considerable speed improvements. -* The `jit_trace` parameter was added to `LearnerTorch`, which when set to +* The `jit_trace` parameter was added to `LearnerTorch`, which when set to `TRUE` can lead to significant speedups. - This should only be enabled for 'static' models, see the + This should only be enabled for 'static' models, see the [torch tutorial](https://torch.mlverse.org/docs/articles/torchscript) for more information. * Added parameter `num_interop_threads` to `LearnerTorch`. diff --git a/R/LearnerTorchModel.R b/R/LearnerTorchModel.R index dd5f36436..84d4e2917 100644 --- a/R/LearnerTorchModel.R +++ b/R/LearnerTorchModel.R @@ -112,7 +112,12 @@ LearnerTorchModel = R6Class("LearnerTorchModel", if (is.null(private$.network_stored)) { stopf("No network stored, did you already train learner '%s' or did not specify a model?", self$id) } - network = torch_load(private$.network_stored) + network = if (test_clrss(private$.network_stored, "nn_module")) { + # optimization for PipeOpTorchModel, where we control the construction of LearnerTorchModel + private$.network_stored + } else { + torch_load(private$.network_stored) + } private$.network_stored = NULL network }, diff --git a/R/PipeOpTorchModel.R b/R/PipeOpTorchModel.R index bd4d09390..5042cc126 100644 --- a/R/PipeOpTorchModel.R +++ b/R/PipeOpTorchModel.R @@ -69,7 +69,9 @@ PipeOpTorchModel = R6Class("PipeOpTorchModel", output_pointers = list(md$pointer), list_output = FALSE ) - get_private(private$.learner, ".network_stored") = torch_serialize(network) + # Because we control the creation of the LearnerTorchModel, we know that it's fitted in the same + # process as the current .train function, hence, we can avoid the serialization round-trip + get_private(private$.learner, ".network_stored") = network private$.learner$ingress_tokens = md$ingress if (is.null(md$loss)) { diff --git a/tests/testthat/test_LearnerTorchModel.R b/tests/testthat/test_LearnerTorchModel.R index 2354142ea..40b7a0b6f 100644 --- a/tests/testthat/test_LearnerTorchModel.R +++ b/tests/testthat/test_LearnerTorchModel.R @@ -77,6 +77,7 @@ test_that("LearnerTorchModel and marshaling", { ingress_tokens = list(x = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4L))), packages = "data.table", ) + learner$encapsulate("callr", lrn("classif.featureless")) learner$configure( batch_size = 50, epochs = 1 From 1a1de759fc691a7e8169135515a3adbd77484dcf Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 7 Feb 2025 12:36:33 +0100 Subject: [PATCH 3/4] typo --- R/LearnerTorchModel.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/LearnerTorchModel.R b/R/LearnerTorchModel.R index 84d4e2917..b64b0ac47 100644 --- a/R/LearnerTorchModel.R +++ b/R/LearnerTorchModel.R @@ -112,7 +112,7 @@ LearnerTorchModel = R6Class("LearnerTorchModel", if (is.null(private$.network_stored)) { stopf("No network stored, did you already train learner '%s' or did not specify a model?", self$id) } - network = if (test_clrss(private$.network_stored, "nn_module")) { + network = if (test_class(private$.network_stored, "nn_module")) { # optimization for PipeOpTorchModel, where we control the construction of LearnerTorchModel private$.network_stored } else { From 344e77bbfd21caa25048f42bafa3361f00c85db3 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 7 Feb 2025 12:37:36 +0100 Subject: [PATCH 4/4] document [skip ci] --- man/mlr_learners.mlp.Rd | 2 ++ man/mlr_learners.tab_resnet.Rd | 2 ++ man/mlr_learners.torch_featureless.Rd | 4 +++- man/mlr_learners.torchvision.Rd | 2 ++ man/mlr_learners_torch.Rd | 2 ++ man/mlr_learners_torch_image.Rd | 2 ++ man/mlr_learners_torch_model.Rd | 6 ++---- 7 files changed, 15 insertions(+), 5 deletions(-) diff --git a/man/mlr_learners.mlp.Rd b/man/mlr_learners.mlp.Rd index 7f03f5c7c..362d1a0dd 100644 --- a/man/mlr_learners.mlp.Rd +++ b/man/mlr_learners.mlp.Rd @@ -107,11 +107,13 @@ Other Learner:
Inherited methods