Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

ntjohnson1
Copy link
Member

Related

Closes #10556 and maybe #10344

What

  • Adds support for the __dlpack__ interface so we can optimally convert into 1d arrays (ND-Array support still not available and probably can't do it no copy AnyValues should support ND arraysΒ #4572). This is mostly to support torch but should work for a variety of things.

  • Catches the edge cases:

    • Calling AnyValues without a kwarg should say SOMETHING, since the positional only value doesn't generate content for the AnyValue
    • Clarify the error message when we hit the fall through for numpy. Raises from the exception from arrow so we have the full traceback but tries to provide more guidance when we expect a value but get another.

Performance comparison

This is not a strict benchmark but should be sufficiently convincing things work nice.

>>> pixi run -e py python sample.py raw
Loop 100
1
raw:
       0.0024s
>>> pixi run -e py python sample.py to_numpy
Loop 100
1
to_numpy:
       0.0024s
>>> pixi run -e py python sample.py raw_numpy
Loop 100
1
to_numpy:
       0.0025s

@ntjohnson1 ntjohnson1 changed the title Nick/anyvalue AnyValue Torch Performance Improvement Jul 15, 2025
Copy link
Member

@nikolausWest nikolausWest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Member

@nikolausWest nikolausWest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes me wonder if we shouldn't do this in more places? Could also be worth checking if it works for Jax and TF in that case. If we make it work for all of those and can make it work more broadly that would be very worth documenting and bragging about!

@ntjohnson1
Copy link
Member Author

Makes me wonder if we shouldn't do this in more places? Could also be worth checking if it works for Jax and TF in that case. If we make it work for all of those and can make it work more broadly that would be very worth documenting and bragging about!

By do this you mean leverage dlpack?

I can manually check if this works for JAX and TF (it would be weird if it didn't) but adding the unit test requires a bit more infra since we don't have a separate test env and those are heavy packages to import

@nikolausWest
Copy link
Member

By do this you mean leverage dlpack?

Yes, I mean have the ability to convert anything with a __dlpack__ attribute to an arrow array. I think I may have added some numpy conversion logic that was supposed to work with pytorch earler but it may have been removed since.

If I don't miss-remember there should also be some way to appease mypy and friends by annotating a type hint of the style "anything that has this attribute" but not sure.

I can manually check if this works for JAX and TF (it would be weird if it didn't) but adding the unit test requires a bit more infra since we don't have a separate test env and those are heavy packages to import

Yeah manual check seems fine.

@ntjohnson1
Copy link
Member Author

Confirmed it works for jax:

    elif version == "raw_jax":
        timing = simple_loop(shape, raw_any, partial(jax.random.uniform, jax.random.key(123)))
>>>python sample.py raw_jax
Loop 100
1
raw_jax:
        0.0025s

@ntjohnson1
Copy link
Member Author

Yes, I mean have the ability to convert anything with a dlpack attribute to an arrow array. I think I may have added some numpy conversion logic that was supposed to work with pytorch earler but it may have been removed since.

IIUC we can convert any 1d object to a pyarrow array with this. Anything greater than 1d either needs to be flattened which could still be fast or made a nested list of lists as a pyarrow limitation

If I don't miss-remember there should also be some way to appease mypy and friends by annotating a type hint of the style "anything that has this attribute" but not sure.

Yes there is. We should be able to use protocols for that

@ntjohnson1
Copy link
Member Author

ntjohnson1 commented Jul 15, 2025

Tensorflow doesn't work yet because they haven't fully committed to the standard yet
https://www.tensorflow.org/api_docs/python/tf/experimental/dlpack/to_dlpack

But on main it appears they might support __dlpack__ which landed 2 days after the next last release. So on the next release it hopefully works. tensorflow/tensorflow#89079

@ntjohnson1 ntjohnson1 merged commit dedaf08 into main Jul 15, 2025
39 checks passed
@ntjohnson1 ntjohnson1 deleted the nick/anyvalue branch July 15, 2025 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AnyValues should handle torch tensors with optimal performance
2 participants