-
Notifications
You must be signed in to change notification settings - Fork 0
Quick sketch of bfloat16 support #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| if TYPE_CHECKING: | ||
| from collections.abc import Awaitable, Callable, Iterator | ||
|
|
||
| try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor query
I wasn't sure where to import this. Given that it hooks into numpy to extend its dtype support, it's likely in the same location where numpy is first imported by zarr.
Where is numpy first imported? 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question, I have no idea.
| DataType.float16: "f2", | ||
| DataType.float32: "f4", | ||
| DataType.float64: "f8", | ||
| DataType.bfloat16: "bfloat16", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will require some input.
The numpy kind codes are not extensible. In particular a lot of the dtypes added by ml_dtypes will not have unique kind codes and/or codes that numpy can interpret. See jax-ml/ml_dtypes#41 for a MRE.
Is there any special need to tie the logic here to numpy kind codes? As an example, numpy will also recognise np.dtype('int16') in the same way as np.dtype('i2').
I think this could be the biggest decision required before support for other dtypes is possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we are tied to the numpy kind codes here. That's a numpy thing, not a zarr thing. We just need to ensure that the zarr dtypes can be unambiguously interpreted by zarr-python to make numpy / cupy arrays that can represent the underlying data. But anchoring our string representations to numpy is useful, because it ties us to something ~standard-ish. so I agree that we should tread carefully here, and this will definitely require input from the broader community... but that shouldn't stop the implementation
| gpu = [ | ||
| "cupy-cuda12x", | ||
| ] | ||
| ml-dtypes = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of comments.
ml_dypes makes a bunch of extra types available: https://github.com/jax-ml/ml_dtypes#ml_dtypes.
I think some of these like bfloat16 are already widely used and probably fine to adopt immediately.
Others like int2 and int4 subject to certain limitations with numpy. Details are available in the README, but roughly, some or all of the bits are padded with zeros to make them a byte in memory. In practice, I don't think this will be a useful representation for a decent chunk of the community, and so I wouldn't necessarily want to make it the default representation.
Should just a subset of these dtypes be integrated in zarr-python?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
although we are pretty numpy-centric, zarr-python has been designed to be device agnostic -- we have infrastructure in place for returning gpu-backed arrays via cupy, for example. so if these narrow ints are represented sub-optimally in numpy, but they have a performant representation on the gpu, I think that's actually OK for us. In other words, "kind of useless in CPU memory, but correct" seems OK to me. we just need to make sure that the stored representation (what zarr-python stores) is what people will expect. So I would say let's be greedy and try to support everything
| "<c16": "complex128", | ||
| } | ||
| return DataType[dtype_to_data_type[dtype.str]] | ||
| elif 'v' not in dtype.str.lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as previous comment.
This is very ugly proxy for the new dtypes. If we were'nt tied to the numpy kind codes, then I think all of this would be a lot cleaner.
| "<c16": "complex128", | ||
| } | ||
| return DataType[dtype_to_data_type[dtype.str]] | ||
| return DataType[dtype.name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the new dtypes don't have unique kind codes, I am currently using their name (which numpy recognises) to instantiate the dtype.
[Description of PR]
TODO: