pow scalar exponent / base autodiff, fusion#19324
Conversation
|
So one of the problems seems to be from Scalar<->float magic in symbolic_script, which seems to break when Scalar is an int. Unfortunately, the only thing I could think of is a gross hack adding torch._float to make a float from int/float IValues passed into a scalar. |
| exponent: float): | ||
| def backward(grad_output): | ||
| grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1)) | ||
| if torch._float(exponent) == 0.0: |
There was a problem hiding this comment.
What's preventing float(exponent) instead?
There was a problem hiding this comment.
So what I think is going on: When creating the pow_0 operator for Tensor self, Scalar exponent, symbolic_script replaces Scalar by float. The float(exponent) cast gets eliminated because the JIT "knows" it is a float. What then happens is that unpacking Scalar IValues that the JIT thinks must be float but that are, in fact, int fails.
In a way it comes down to
// 2. to make sure the input of any graph node does not contain scalar type
// in its argument, all scalar arg should already be passed with float
// value since scalar/int aren't differentiable either way.
not being the complete picture because (as here) the scalar might not be the thing we want to differentiate for, but a parameter (in the mathematical sense as opposed to the variable) of the function we want to differentiate.
So I'm changing the patch to do the following:
If Scalar -> float conversion happened, I change back the input type of the graph to Scalar, and insert a conversion (prim::Float) as the first thing.
It'll be troublesome for use when we get operations that actually rely on the difference between float and int in Scalar ops, but currently we don't as far as I know.
I think this is a clean-up of the Scalar->float conversion as it ensures that the graph inputs actually match the schema. @ailzhang does that seem reasonable?
There was a problem hiding this comment.
So it turns out that re-Scalarizing breaks something (the double backward?). 😓
There was a problem hiding this comment.
cc: @wanchaol added the scalar to float conversion.
There was a problem hiding this comment.
Also, the scalar to float conversion only happens on the second pass, when allow_conversions is true, if there is an op defined for Scalar than that will be matched and there won't be a conversion.
There was a problem hiding this comment.
@t-vi Actually as @eellison pointed out offline, we actually can get rid of the scalar -> float conversion entirely.
For example this works for me.
def pow_0(self,
exponent: number):
Note that we don't expose number in torchscript but we CAN compile it! :D With this we can easily get rid of current symbolic_variable.h and c10::ReplaceAll(schema_str, "Scalar", "float");.
Huge thanks to @eellison who pointed this out!
Let us know if this fixes your problem :D
There was a problem hiding this comment.
Ha. That works (I think)! And I've been poking around much too long. Awesome.
Thanks so much @ailzhang and @eellison !
So my understanding is that with this, one would only have definitions that match the operator schemas. Could we actually check that? For now I left a fallback in but converted those replacements that were hit by a test_jit.py run.
|
|
||
| auto sym_script_it = schema_to_graphs.find(schema_str); | ||
|
|
||
| if (sym_script_it == schema_to_graphs.end()) { |
There was a problem hiding this comment.
Yea this if should be dropped before merging.
There was a problem hiding this comment.
So I removed it after double checking by means of calling sig(schema_string) and seeing whether it throws here:
pytorch/torch/csrc/jit/symbolic_script.cpp
Lines 1351 to 1353 in 3e0b46b
ailzhang
left a comment
There was a problem hiding this comment.
LGTM! Let me know when you are done with changing and want to merge it.
|
I think its good to merge.
Am 18. April 2019 20:27:33 MESZ schrieb Ailing <[email protected]>:
…ailzhang approved this pull request.
LGTM! Let me know when you are done with changing and want to merge it.
--
You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub:
#19324 (review)
|
facebook-github-bot
left a comment
There was a problem hiding this comment.
@ailzhang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes: pytorch#19253 Fixing pow(Tensor, float) is straightforward. The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs `torch.log` (`math.log` didn't work) from the newly merged pytorch#19115 (Thanks ngimel for pointing out this has landed.) Pull Request resolved: pytorch#19324 Differential Revision: D15003531 Pulled By: ailzhang fbshipit-source-id: 8b22138fa27a43806b82886fb3a7b557bbb5a865
Summary: Fixes: pytorch#19253 Fixing pow(Tensor, float) is straightforward. The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs `torch.log` (`math.log` didn't work) from the newly merged pytorch#19115 (Thanks ngimel for pointing out this has landed.) Pull Request resolved: pytorch#19324 Differential Revision: D15003531 Pulled By: ailzhang fbshipit-source-id: 8b22138fa27a43806b82886fb3a7b557bbb5a865
Fixes: #19253
Fixing pow(Tensor, float) is straightforward.
The breakage for pow(float, Tensor) is a bit more subtle to trigger, and fixing needs
torch.log(math.logdidn't work) from the newly merged #19115 (Thanks @ngimel for pointing out this has landed.)