Description
According to the standard, the documentation of sum
states for the dtype
parameter:
If
None
, the returned array must have the same data type asx
, unlessx
has an integer data type supporting a smaller range of values than the default integer data type... In those latter cases: ... ifx
has an unsigned integer data type, the returned array must have an unsigned integer data type having the same number of bits as the default integer data type.
If I understand correctly, then the sums for unsigned dtype below should have uint64
dtype:
from array_api_compat import torch as xp
for dtype in [xp.int8, xp.int16, xp.int32, xp.int64,
xp.uint8, xp.uint16, xp.uint32, xp.uint64,
xp.float32, xp.float64, xp.complex32, xp.complex64]:
x = xp.asarray([1, 2, 3], dtype=dtype)
try:
print(xp.sum(x).dtype, dtype)
except RuntimeError as e:
print(e)
But the output is:
torch.int64 torch.int8
torch.int64 torch.int16
torch.int64 torch.int32
torch.int64 torch.int64
torch.int64 torch.uint8
torch.int64 torch.uint16
torch.int64 torch.uint32
torch.int64 torch.uint64
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64
I think this is at least partially fixable within array-api-compat
.
Also, torch
doesn't seem to natively support sum
for most uint
dtypes or complex32
. If we change xp.sum(x)
to xp.sum(x, dtype=dtype)
in the code above, the output is:
torch.int8 torch.int8
torch.int16 torch.int16
torch.int32 torch.int32
torch.int64 torch.int64
torch.uint8 torch.uint8
"sum_cpu" not implemented for 'UInt16'
"sum_cpu" not implemented for 'UInt32'
"sum_cpu" not implemented for 'UInt64'
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64
It would be helpful if array-api-compat
would implement sum
for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow with int64
than with uint64
, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)
Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?