Thanks to visit codestin.com
Credit goes to github.com

Skip to content

fix segfault when grad to a hook fn is None#12028

Closed
weiyangfb wants to merge 1 commit into
pytorch:masterfrom
weiyangfb:register_hook
Closed

fix segfault when grad to a hook fn is None#12028
weiyangfb wants to merge 1 commit into
pytorch:masterfrom
weiyangfb:register_hook

Conversation

@weiyangfb
Copy link
Copy Markdown
Contributor

@weiyangfb weiyangfb commented Sep 24, 2018

pre-fix

>>> a = torch.randn(5, requires_grad=True)
>>> a_list = a.unbind()

>>> a0 = a_list[0]
>>> @a0.register_hook
...:def hook(grad):
...:    print(grad)

>>> a_list[0].backward()
tensor(1.)

>>> print('a_list[0]', a_list[0].grad, a.grad)
('a_list[0]', None, tensor([1., 0., 0., 0., 0.]))

>>> a_list[1].backward() # segfault

post-fix

>>> a = torch.randn(5, requires_grad=True)
>>> a_list = a.unbind()

>>> a0 = a_list[0]
>>> @a0.register_hook
... :def hook(grad):
... :    print(grad)

>>> a_list[0].backward()
tensor(1.)

>>> print(a_list[0].grad, a.grad)
(None, tensor([1., 0., 0., 0., 0.]))

>>> a_list[1].backward()
None

>>> print(a_list[1].grad, a.grad)
(None, tensor([1., 1., 0., 0., 0.]))

@weiyangfb
Copy link
Copy Markdown
Contributor Author

@ezyang @apaszke Can I get a review on this? Not sure if this is a correct / complete fix.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Sep 25, 2018

Looks reasonable enough. Are there any other places where you think we might have made a similar mistake?

@weiyangfb
Copy link
Copy Markdown
Contributor Author

@ezyang Let me double check on files under torch/csrc/autograd

@weiyangfb
Copy link
Copy Markdown
Contributor Author

Other places look good to me

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Comment thread torch/csrc/autograd/python_hook.cpp Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

petrex pushed a commit to petrex/pytorch that referenced this pull request Sep 26, 2018
* upstream/master: (117 commits)
  Add full namespace resolution in CAFFE_DURATION (pytorch#12065)
  T33898723: Simple put operators for caffe2 stats (pytorch#12057)
  add narrow() support for sparse tensors re: pytorch#8853 (pytorch#11342)
  Fix ONNX bug, add symbolic for full
  Enable tracing of tensor factories with an out argument
  Fix warnings emitted when testing distributions (pytorch#12038)
  Unify versions across setup.py, libtorch, and libcaffe2 (pytorch#12053)
  add autodiff expressions for common operations (pytorch#11832)
  Blob doesn't allow access to destroyCall anymore (pytorch#11548)
  IValue can store Blob (pytorch#11414)
  Move Blob to ATen/core (pytorch#11924)
  Use tempfile during serialized test comparison (pytorch#12021)
  fix segfault when grad to a hook fn is None (pytorch#12028)
  Fallback CreateMutex/AtomicIter operators for mkl-dnn
  Unify all *_EXPORT and *_IMPORT macros across c++ backend (pytorch#12019)
  Add safety asserts for methods on TensorImpl which don't work on Variable. (pytorch#12058)
  Make USE_IDEEP work again (pytorch#12026)
  Fix "identifier following the 'template' keyword does not refer to a template" (pytorch#12037)
  Delete some unused variables. (pytorch#12059)
  Support TypeIdentifier::name() (pytorch#12036)
  ...
@ezyang ezyang added the merged label Jun 26, 2019
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
- fixes pytorch#11751 by checking if a grad is a Python None object before getting cdata from it
- behaviors:

pre-fix
```
>>> a = torch.randn(5, requires_grad=True)
>>> a_list = a.unbind()

>>> a0 = a_list[0]
>>> a0.register_hook
...:    def hook(grad):
...:        print(grad)

>>> a_list[0].backward()
tensor(1.)

>>> print('a_list[0]', a_list[0].grad, a.grad)
('a_list[0]', None, tensor([1., 0., 0., 0., 0.]))

>>> a_list[1].backward() # segfault
```

post-fix
```
>>> a = torch.randn(5, requires_grad=True)
>>> a_list = a.unbind()

>>> a0 = a_list[0]
>>> a0.register_hook
... :   def hook(grad):
... :       print(grad)

>>> a_list[0].backward()
tensor(1.)

>>> print(a_list[0].grad, a.grad)
(None, tensor([1., 0., 0., 0., 0.]))

>>> a_list[1].backward()
None

>>> print(a_list[1].grad, a.grad)
(None, tensor([1., 1., 0., 0., 0.]))
```
Pull Request resolved: pytorch#12028

Differential Revision: D10034094

Pulled By: weiyangfb

fbshipit-source-id: 3f2135325fa7d338b920f57752057e4f6a6c0b1d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Segfault in autograd using hook

4 participants