Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Output of torch.sum with unsigned input should be unsigned #242

Open
@mdhaber

Description

@mdhaber

According to the standard, the documentation of sum states for the dtype parameter:

If None, the returned array must have the same data type as x, unless x has an integer data type supporting a smaller range of values than the default integer data type... In those latter cases: ... if x has an unsigned integer data type, the returned array must have an unsigned integer data type having the same number of bits as the default integer data type.

If I understand correctly, then the sums for unsigned dtype below should have uint64 dtype:

from array_api_compat import torch as xp
for dtype in [xp.int8, xp.int16, xp.int32, xp.int64,
              xp.uint8, xp.uint16, xp.uint32, xp.uint64,
              xp.float32, xp.float64, xp.complex32, xp.complex64]:
    x = xp.asarray([1, 2, 3], dtype=dtype)
    try:
        print(xp.sum(x).dtype, dtype)
    except RuntimeError as e:
        print(e)

But the output is:

torch.int64 torch.int8
torch.int64 torch.int16
torch.int64 torch.int32
torch.int64 torch.int64
torch.int64 torch.uint8
torch.int64 torch.uint16
torch.int64 torch.uint32
torch.int64 torch.uint64
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64

I think this is at least partially fixable within array-api-compat.

Also, torch doesn't seem to natively support sum for most uint dtypes or complex32. If we change xp.sum(x) to xp.sum(x, dtype=dtype) in the code above, the output is:

torch.int8 torch.int8
torch.int16 torch.int16
torch.int32 torch.int32
torch.int64 torch.int64
torch.uint8 torch.uint8
"sum_cpu" not implemented for 'UInt16'
"sum_cpu" not implemented for 'UInt32'
"sum_cpu" not implemented for 'UInt64'
torch.float32 torch.float32
torch.float64 torch.float64
"sum_cpu" not implemented for 'ComplexHalf'
torch.complex64 torch.complex64

It would be helpful if array-api-compat would implement sum for these types even if that means upcasting to a supported type before summing and then downcasting. (There is a slightly larger chance of overflow with int64 than with uint64, and it's possible that the conversion will not be safe, so it's up for discussion what should happen in those cases.)

Does array-api-compat have a mechanism for reporting the shortcomings it has to patch to the underlying libraries? If not, should I report this to PyTorch (if it is not already reported)?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions