Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Throwing more specific errors for CrossEntropyLoss weights being on a different device than the input/target #122757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
saurabhmahra91 opened this issue Mar 27, 2024 · 0 comments
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@saurabhmahra91
Copy link

saurabhmahra91 commented Mar 27, 2024

πŸš€ The feature, motivation and pitch

While calculating CrossEntropyLoss, both the model_output and target are on the same device (cuda), but the weights were on the cpu.

loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(input=model_output, target=labels)

Currently, we would get the following:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

But how? Wasn't both the input and target for the loss_fn on the same device (cuda)?
Answer: Because someone forgets to put the weights on cuda.

This problem becomes significant if the loss_fn were defined somewhere else rather than defining it just before calculating the loss, as it's not so apparent.

Can we please have an explicit error that describes that the weights (instance attributes) were on a different device than the input and target (function arguments)?

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@cpuhrsch cpuhrsch added module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 28, 2024
@github-project-automation github-project-automation bot moved this to To pick up in torch.nn/optim Apr 4, 2024
@github-project-automation github-project-automation bot moved this from To pick up to Done in torch.nn/optim May 6, 2025
pytorchmergebot pushed a commit that referenced this issue May 8, 2025
Fixes #122757

## Test Result

```python
import torch

model_output = torch.randn(10, 5).cuda()
labels = torch.randint(0, 5, (10,)).cuda()
weights = torch.randn(5)

loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(input=model_output, target=labels)
print(loss)

Traceback (most recent call last):
  File "/home/zong/code/pytorch/../loss2.py", line 17, in <module>
    loss = loss_fn(input=model_output, target=labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/modules/loss.py", line 1297, in forward
    return F.cross_entropy(
           ^^^^^^^^^^^^^^^^
  File "/home/zong/code/pytorch/torch/nn/functional.py", line 3494, in cross_entropy
    return torch._C._nn.cross_entropy_loss(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but got weight is on cpu, different from other tensors on cuda:0 (when checking argument in method wrapper_CUDA_nll_loss_forward)

```
Pull Request resolved: #150750
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

2 participants