Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 9, 2025

Summary

  • Migrate from PyBind to TORCH_LIBRARY API, which is python version agnostic
  • Update setup.py to use python limited (version agnostic) tag
  • Use pattern seen in torchao/ops.py for other CUDA C++ extensions
    • build torchao._C_mxfp8 so file (lands in build/ dir) instead of separate torchao.prototype.mxpf8_cuda`extension (landed in torchao/prototype)
    • Define new op schema for 2d quantization and 3d quantization kernels
    • Update custom ops, meta functions, and custom sharding registrations to wrap new custom op

Context

While doing the 0.15.0 torchao release and testing the test build for cuda 12.8, and i found the "torchao.prototype.mxfp8_cuda" c++ extension cannot be found (import error, module not found). we only build the extension for cuda 12.8+, so i checked the logs and i see logs indicating it was built: https://github.com/pytorch/ao/actions/runs/20046209265/job/57498462190

so then i checked the local installation itself, and i do see a .so file for the extension in the torchao/prototype dir, so it is definitely being built.

i tried asking claude about this and it says the build for python3.10 must match the python version in the conda env due to ABI incompatibility (i'm using python 3.12). as a test, i tried a fresh conda env with python 3.10, and instead of module not found, i get an undefined symbol error, so that does seem to indicate some python ABI issue.

asking @drisspg he said we should be building with a py agnostic flag, so i looked into this and we are doing this for other c++ extensions but not mxfp8_cuda, so I am fairly certain this is the root cause and this PR will fix the issue.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3471

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 39db553 with merge base 08e5e20 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 9, 2025
@danielvegamyhre danielvegamyhre added topic: bug fix Use this tag for PRs that fix bugs and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Dec 9, 2025
Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@drisspg
Copy link
Contributor

drisspg commented Dec 9, 2025

spoke offline, we need to be using torchlib and not use pybind

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 9, 2025
@danielvegamyhre danielvegamyhre changed the title use python version agnostic c++ extension for mxfp8_cuda use python version agnostic c++ binding for mxfp8 cuda kernels Dec 10, 2025
@danielvegamyhre danielvegamyhre changed the title use python version agnostic c++ binding for mxfp8 cuda kernels use python version agnostic python binding for mxfp8 cuda kernels Dec 10, 2025
@danielvegamyhre danielvegamyhre changed the title use python version agnostic python binding for mxfp8 cuda kernels use python version agnostic binding for mxfp8 cuda kernels Dec 10, 2025
not is_cuda_version_at_least(12, 8),
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
def test_cuda_mx_dim1_invalid_block_size():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleting this test since block_size of 32 is hard coded in the python wrapper for the kernel now, since we always use this for mxfp8

@drisspg
Copy link
Contributor

drisspg commented Dec 10, 2025

you should be able to use nm -D to also investigate the symbols and ensure there are none from python

@danielvegamyhre
Copy link
Contributor Author

current CI failures will be resolved once this rollback in upstream pytorch is included in the next torch nightly: pytorch/pytorch#169985

@danielvegamyhre
Copy link
Contributor Author

other upstream pytorch issue causing CI issues has now been resolved: pytorch/pytorch#170184

@danielvegamyhre
Copy link
Contributor Author

mac-os test failure is unrelated

@danielvegamyhre danielvegamyhre merged commit b9e5780 into main Dec 16, 2025
38 of 39 checks passed
danielvegamyhre added a commit that referenced this pull request Dec 16, 2025
* use py agnostic c++ extension for mxfp8_cuda

* refactor mxfp8 cuda from pybind to torch_library api

* put schema def inside guard
namgyu-youn pushed a commit to namgyu-youn/ao that referenced this pull request Dec 19, 2025
)

* use py agnostic c++ extension for mxfp8_cuda

* refactor mxfp8 cuda from pybind to torch_library api

* put schema def inside guard
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants