Thanks to visit codestin.com
Credit goes to github.com

Skip to content

torch matmul does not handle different dtypes #245

Open
@ev-br

Description

@ev-br

The spec requires that matmul follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:

In [3]: import array_api_strict as xp

In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)

In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)

RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double

It's not immediately clear to me whether we want to paper over it in compat- or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions