Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Re-initialize autograd engine in child processes#4158

Merged
soumith merged 2 commits into
pytorch:masterfrom
colesbury:fork_backwards
Dec 18, 2017
Merged

Re-initialize autograd engine in child processes#4158
soumith merged 2 commits into
pytorch:masterfrom
colesbury:fork_backwards

Conversation

@colesbury
Copy link
Copy Markdown
Member

The autograd engine uses threads for backwards. These don't exist after
forks and they were not being re-initialized because the
Engine::start_threads_flag was already set. This re-initializes the
engine in child processes, which will cause it to re-create threads when
backwards() is called in the child process.

Note that we only attempt to handle the common case where fork() is
called while the backwards threads are idle.

Fixes #3966

Comment thread torch/csrc/autograd/python_engine.cpp Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

The autograd engine uses threads for backwards. These don't exist after
forks and they were not being re-initialized because the
Engine::start_threads_flag was already set. This re-initializes the
engine in child processes, which will cause it to re-create threads when
backwards() is called in the child process.

Note that we only attempt to handle the common case where fork() is
called while the backwards threads are idle.

Fixes pytorch#3966
static bool _reinitialize_engine = false;

static void _maybe_reinitialize_engine_after_fork() {
// This is "probably" thread-safe because the flag is set in a fork handler

This comment was marked as off-topic.

@soumith soumith merged commit b79d74a into pytorch:master Dec 18, 2017
@colesbury colesbury deleted the fork_backwards branch December 18, 2017 18:05
@soumith soumith added the 0.3.1 label Feb 4, 2018
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
* Re-initialize autograd engine in child processes

The autograd engine uses threads for backwards. These don't exist after
forks and they were not being re-initialized because the
Engine::start_threads_flag was already set. This re-initializes the
engine in child processes, which will cause it to re-create threads when
backwards() is called in the child process.

Note that we only attempt to handle the common case where fork() is
called while the backwards threads are idle.

Fixes pytorch#3966

* Avoid non-async-signal-safe functions in fork handler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants