Fix precision issue with expansion that prefers 'probs' over 'logits'#18614
Fix precision issue with expansion that prefers 'probs' over 'logits'#18614ahmadsalim wants to merge 3 commits into
Conversation
| if 'logits' in self.__dict__: | ||
| new.logits = self.logits.expand(batch_shape) | ||
| new._param = new.logits | ||
| else: |
There was a problem hiding this comment.
If both the parametrizations have already been computed, we should copy both of them, rather than having to recompute them again later, so I would suggest changing this to:
if 'probs' in self.__dict__:
... # expand and set
if 'logits' in self.__dict__:
... # expand and setThat will also address your original issue since it will use the initially specified logits for any log pdf computation, in addition to caching the logit to prob or prob to logit computation if its already done.
neerajprad
left a comment
There was a problem hiding this comment.
Thanks for the PR! Could you also change this for the remaining distributions - binomial, multinomial, relaxed_bernoulli, negative_binomial, geometric?
|
Thanks for the comments! I will do the adjustments as you suggested for all the distributions 😄 |
|
LGTM! The lint check fails are not from this PR. |
|
the CI is broken today, but I am putting it in the land queue. |
soumith
left a comment
There was a problem hiding this comment.
@pytorchbot merge this please
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…pytorch#18614) Summary: I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected. Pull Request resolved: pytorch#18614 Differential Revision: D14793486 Pulled By: ezyang fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
…pytorch#18614) Summary: I have experienced that sometimes both were in `__dict__`, but it chose to copy `probs` which loses precision over `logits`. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected. Pull Request resolved: pytorch#18614 Differential Revision: D14793486 Pulled By: ezyang fbshipit-source-id: d4ff5e34fbb4021ea9de9f58af09a7de00d80a63
I have experienced that sometimes both were in
__dict__, but it chose to copyprobswhich loses precision overlogits. This is especially important when training (bayesian) neural networks or doing other type of optimization, since the loss is heavily affected.