-
Notifications
You must be signed in to change notification settings - Fork 5.9k
fix tensordot #72139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix tensordot #72139
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
| y = y.sum(dim_y, dtype=y.dtype).reshape(shape_y) | ||
| elif sy == 1: | ||
| shape_x[dim_x] = 1 | ||
| x = x.sum(dim_x).reshape(shape_x) | ||
| x = x.sum(dim_x, dtype=x.dtype).reshape(shape_x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的逻辑应该是中间结果用int64存,然后最终结果转成int32,还是保持中间结果也是int32?如果输入中含有接近int32最大值的场景,哪一种比较合适呢?可以跟numpy或者torch对比一下这种场景下的数值溢出情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个修改的意图是明确指定 sum 操作的输出数据类型与其输入数据类型保持一致。溢出这一块应该是由开发者处理,框架应该保证正确性(未指定dtype,输入int32,输出int32)
HydrogenSulfate
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* update * add test
PR Category
User Experience
PR Types
Bug fixes
Description
当输入都为int32时,输出会是int64。