dtype option for softmax#11719
Conversation
ssnl
left a comment
There was a problem hiding this comment.
Didn't review the kernels. But how about also adding the option to cross entropy loss? :)
|
cross_entropy calls soft_max https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L1645 so it would require couple line python change. |
|
Yep, but you are already changing |
|
@ssnl, Yes, I can do that. but I'm not sure what's the preferred fix should be. FWIW, some operators already have overloads that are not supported from python, e.g. so log_softmax erroring out with similar message is not necessarily a big problem (?) . |
|
cc @apaszke @jamesr66a on JIT test |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@ngimel Looking at this more closely I would advise updating the error message here. |
In the jit tests on in the |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
3e039a6 to
94c6400
Compare
apaszke
left a comment
There was a problem hiding this comment.
I'm not super happy with the upconvert flag. It doesn't really specify the destination type. Should it be float? Should it be double? The context is probably dependent on the device, and this seems to overfit the CUDA context. Can't we apply a simple modification to our kernels, or simply have a _log_softmax_half_to_float implemented only for CUDA, and dispatch to _log_softmax(...).to(dtype) otherwise?
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
That's precisely the problem. It's a very specific flag, with a very specific meaning, which it not at all implied by its name/function name/function signature. I don't understand why the dispatch would be a problem. Can't you just declare the derivatives for the top-level native function |
|
If derivatives are declared for |
56a9546 to
fe802a7
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Test failures seem unrelated. |
Summary: Add dtype argument to softmax/log_softmax functions. Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it. For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called. Pull Request resolved: pytorch/pytorch#11719 Reviewed By: ezyang Differential Revision: D10175514 Pulled By: zou3519 fbshipit-source-id: 06d285af91a0b659932236d41ad63b787eeed243
Summary: Add dtype argument to softmax/log_softmax functions. Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it. For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called. Pull Request resolved: pytorch#11719 Reviewed By: ezyang Differential Revision: D10175514 Pulled By: zou3519 fbshipit-source-id: 06d285af91a0b659932236d41ad63b787eeed243
Add dtype argument to softmax/log_softmax functions.
Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it.
For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called.