Thanks to visit codestin.com
Credit goes to github.com

Skip to content

support broadcasting in _kl_categorical_categorical#10533

Closed
joh4n wants to merge 5 commits into
pytorch:masterfrom
joh4n:broadcast_in_kl_categorical_categorical
Closed

support broadcasting in _kl_categorical_categorical#10533
joh4n wants to merge 5 commits into
pytorch:masterfrom
joh4n:broadcast_in_kl_categorical_categorical

Conversation

@joh4n
Copy link
Copy Markdown
Contributor

@joh4n joh4n commented Aug 15, 2018

Support broadcasting in _kl_categorical_categorical

this makes it possible to do:

import torch.distributions as dist
import torch
p_dist = dist.Categorical(torch.ones(1,10))
q_dist = dist.Categorical(torch.ones(100,10))
dist.kl_divergence(p_dist, q_dist)

@vishwakftw
Copy link
Copy Markdown
Contributor

Could you please add a test in test_distributions.py?

@joh4n
Copy link
Copy Markdown
Contributor Author

joh4n commented Aug 15, 2018

I added a basic test for it

Comment thread test/test_distributions.py Outdated
(binomial30, binomial30),
(binomial_vectorized_count, binomial_vectorized_count),
(categorical, categorical),
(Categorical(torch.ones(1, 10)), Categorical(torch.ones(3, 10))),

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@vishwakftw vishwakftw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for fixing this.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Support broadcasting in _kl_categorical_categorical

this makes it possible to do:
```
import torch.distributions as dist
import torch
p_dist = dist.Categorical(torch.ones(1,10))
q_dist = dist.Categorical(torch.ones(100,10))
dist.kl_divergence(p_dist, q_dist)
```
Pull Request resolved: pytorch#10533

Differential Revision: D9341252

Pulled By: soumith

fbshipit-source-id: 34575b30160b43b6c9e4c3070dd7ef07c00ff5d7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants