-
Notifications
You must be signed in to change notification settings - Fork 539
AnyValue Torch Performance Improvement #10647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes me wonder if we shouldn't do this in more places? Could also be worth checking if it works for Jax and TF in that case. If we make it work for all of those and can make it work more broadly that would be very worth documenting and bragging about!
By I can manually check if this works for JAX and TF (it would be weird if it didn't) but adding the unit test requires a bit more infra since we don't have a separate test env and those are heavy packages to import |
Yes, I mean have the ability to convert anything with a If I don't miss-remember there should also be some way to appease mypy and friends by annotating a type hint of the style "anything that has this attribute" but not sure.
Yeah manual check seems fine. |
Confirmed it works for jax: elif version == "raw_jax":
timing = simple_loop(shape, raw_any, partial(jax.random.uniform, jax.random.key(123))) >>>python sample.py raw_jax
Loop 100
1
raw_jax:
0.0025s |
IIUC we can convert any 1d object to a pyarrow array with this. Anything greater than 1d either needs to be flattened which could still be fast or made a nested list of lists as a pyarrow limitation
Yes there is. We should be able to use protocols for that |
Tensorflow doesn't work yet because they haven't fully committed to the standard yet But on main it appears they might support |
Related
Closes #10556 and maybe #10344
What
Adds support for the
__dlpack__
interface so we can optimally convert into 1d arrays (ND-Array support still not available and probably can't do it no copy AnyValues should support ND arraysΒ #4572). This is mostly to supporttorch
but should work for a variety of things.Catches the edge cases:
Performance comparison
This is not a strict benchmark but should be sufficiently convincing things work nice.