Throwing more specific errors for CrossEntropyLoss weights being on a different device than the input/target #122757
Labels
module: nn
Related to torch.nn
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
π The feature, motivation and pitch
While calculating CrossEntropyLoss, both the model_output and target are on the same device (cuda), but the weights were on the cpu.
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fn(input=model_output, target=labels)
Currently, we would get the following:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
But how? Wasn't both the input and target for the loss_fn on the same device (cuda)?
Answer: Because someone forgets to put the weights on cuda.
This problem becomes significant if the loss_fn were defined somewhere else rather than defining it just before calculating the loss, as it's not so apparent.
Can we please have an explicit error that describes that the weights (instance attributes) were on a different device than the input and target (function arguments)?
Alternatives
No response
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki
The text was updated successfully, but these errors were encountered: