Gradient can be backpropagated through only certain distributions #152703
Labels
module: autograd
Related to torch.autograd, and the autograd engine in general
module: distributions
Related to torch.distributions
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
π Describe the bug
Using Normal, I can avoid having to preserve gradients:
Using Categorical, this is not the case:
I'm not sure why this is the case, log-probs might be something like
torch.log(probs[one_hot(samples)].mean(axis=-1))
and it's completely differntiable, so should be able to do it without gradient in the distributionVersions
cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @albanD @gqchen @soulitzer @Varal7 @xmfan
The text was updated successfully, but these errors were encountered: