Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Should make the doc of nn.CrossEntropyLoss() more clear #134853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
hyperkai opened this issue Aug 30, 2024 · 1 comment
Closed

Should make the doc of nn.CrossEntropyLoss() more clear #134853

hyperkai opened this issue Aug 30, 2024 · 1 comment
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hyperkai
Copy link

hyperkai commented Aug 30, 2024

πŸ“š The doc issue

The doc of nn.CrossEntropyLoss() explains about target tensor in a complex way as shown below. *It's difficult to understand:

Screenshot 2024-08-30 195003

So from my understanding and experiments, these simple explanations below should be added to the doc above. *It's easy to understand:

  • The target tensor whose size is different from input tensor is treated as class indices.
  • The target tensor whose size is same as input tensor is the class probabilities which should be between [0, 1].

And from what the doc says below and my experiments, when target tensor is treated as class indices, softmax() is used both for input and target tensor internally:

The target that this criterion expects should contain either:

  • Class indices in the range ...
    ...
    Note that this case is equivalent to applying LogSoftmax on an input, followed by NLLLoss.

But when target tensor is treated as class probabilities, softmax() is used only for input tensor internally, that's why the example of target tensor as class indices in the doc doesn't use softmax() externally while the example of target tensor as class probabilities in the doc uses softmax() externally as shown below:

Screenshot 2024-08-30 201043

So, the doc also should say something like as shown below. *You also use the words class indices mode and class probabilities mode :

  • softmax() is used internally for input tensor, both when target tensor is treated as class indices and class probabilities so you don't need to use softmax() externally.
  • softmax() is used internally for target tensor only when target tensor is treated as class indices so you should use softmax() externally for target tensor when target tensor is treated as class probabilities.

Suggest a potential alternative/fix

No response

cc @svekars @brycebortree @tstatler @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@jbschlosser jbschlosser added module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 30, 2024
@mikaylagawarecki
Copy link
Contributor

mikaylagawarecki commented Aug 30, 2024

I don't think it is the case that softmax is applied internally on target when target is class indices. Please correct me if I missed this though.

Secondly, not sure it is the case that you should use softmax externally for the target tensor when target tensor is treated as class probabilities...the word class probabilities implies a probability distribution, softmax a way to generate a probability distribution (and is probably the most common one indeed)

That said, I agree that the docs might not be that intuitive and happy to review attempts to improve it

@svekars svekars closed this as completed May 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: docs Related to our documentation, both in docs/ and docblocks module: loss Problem is related to loss function module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants