Open
Description
The spec requires that matmul
follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:
In [3]: import array_api_strict as xp
In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)
In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double
It's not immediately clear to me whether we want to paper over it in compat-
or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.