Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3torch dev

* fix: `LearnerTorchModel` can now be parallelized and trained with
encapsulation activated.

# mlr3torch 0.2.0

## Breaking Changes
Expand All @@ -15,9 +18,9 @@

* Optimizers now use the faster ('ignite') version of the optimizers,
which leads to considerable speed improvements.
* The `jit_trace` parameter was added to `LearnerTorch`, which when set to
* The `jit_trace` parameter was added to `LearnerTorch`, which when set to
`TRUE` can lead to significant speedups.
This should only be enabled for 'static' models, see the
This should only be enabled for 'static' models, see the
[torch tutorial](https://torch.mlverse.org/docs/articles/torchscript)
for more information.
* Added parameter `num_interop_threads` to `LearnerTorch`.
Expand Down
20 changes: 8 additions & 12 deletions R/LearnerTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(network = NULL, ingress_tokens = NULL, task_type, properties = NULL, optimizer = NULL, loss = NULL,
callbacks = list(), packages = character(0), feature_types = NULL) {
# TODO: What about the learner properties?
if (!is.null(network)) self$network_stored = network
# we need to serialize here as otherwise encapsulation and parallelization fails
if (!is.null(network)) private$.network_stored = torch_serialize(assert_class(network, "nn_module"))
if (!is.null(ingress_tokens)) self$ingress_tokens = ingress_tokens
if (is.null(feature_types)) {
feature_types = unname(mlr_reflections$task_feature_types)
Expand Down Expand Up @@ -89,15 +89,6 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
}
),
active = list(
#' @field network_stored (`nn_module` or `NULL`)\cr
#' The network that will be trained.
#' After calling `$train()`, this is `NULL`.
network_stored = function(rhs) {
if (!missing(rhs)) {
private$.network_stored = assert_class(rhs, "nn_module")
}
private$.network_stored
},
#' @field ingress_tokens (named `list()` with `TorchIngressToken` or `NULL`)\cr
#' The ingress tokens. Must be non-`NULL` when calling `$train()`.
ingress_tokens = function(rhs) {
Expand All @@ -121,7 +112,12 @@ LearnerTorchModel = R6Class("LearnerTorchModel",
if (is.null(private$.network_stored)) {
stopf("No network stored, did you already train learner '%s' or did not specify a model?", self$id)
}
network = private$.network_stored
network = if (test_class(private$.network_stored, "nn_module")) {
# optimization for PipeOpTorchModel, where we control the construction of LearnerTorchModel
private$.network_stored
} else {
torch_load(private$.network_stored)
}
private$.network_stored = NULL
network
},
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpTorchModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ PipeOpTorchModel = R6Class("PipeOpTorchModel",
output_pointers = list(md$pointer),
list_output = FALSE
)
private$.learner$network_stored = network
# Because we control the creation of the LearnerTorchModel, we know that it's fitted in the same
# process as the current .train function, hence, we can avoid the serialization round-trip
get_private(private$.learner, ".network_stored") = network
private$.learner$ingress_tokens = md$ingress

if (is.null(md$loss)) {
Expand Down
2 changes: 2 additions & 0 deletions man/mlr_learners.mlp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners.tab_resnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/mlr_learners.torch_featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners.torchvision.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners_torch.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_learners_torch_image.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading