Description
On import, ml_dtypes
adds new entries to np.sctypeDict
so that e.g. np.dtype(“int4”)
returns an int4 dtype defined outside NumPy.
Since jax currently documents this behavior to users and relies on it internally, I don’t think we can reasonably break it without a deprecation story and a migration story.
For deprecating it, we would keep a list of all the strings that NumPy accepts out of the box and if any other string is passed in and somehow we get back a valid dtype, we should raise a deprecation warning. I don’t know if there are other ways of injecting a string dtype name into NumPy’s internals without manipulating sctypeDict
so this will catch any other shenanigans.
We should probably also deprecate np.sctypeDict
too?
In a few releases after adding the deprecation, we could make it so np.dtype
can only return dtype instances with a mapping defined out of the box in NumPy or via some as-yet unwritten mechanism to associate string names with dtypes, probably with some kind of support for namespacing.
As far as I know jax is the only downstream library that injects dtype names into the np.dtype("dtype_name")
mechanism.
The deprecation should not be added until we have a clear migration story for the jax library and any possible impacts on jax users are considered.
xref #24376 (comment) and the discussion that follows for context