[sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity#2242
[sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity#2242
Conversation
…sity Summary: We have this gemm already in torchao, but for weight sparsity. For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode. Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2242
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 4 PendingAs of commit e17ebfd with merge base f0f976c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
danielvegamyhre
left a comment
There was a problem hiding this comment.
lgtm, left a couple minor comments
| using ElementOut = cutlass::bfloat16_t; | ||
| using ElementAccumulator = float; | ||
|
|
||
| using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>; |
There was a problem hiding this comment.
how was this tile shape selected?
There was a problem hiding this comment.
This is the default I copied over, planning on adding some tuning in a subsequent PR for more perf.
| cutlass::arch::OpClassSparseTensorOp, | ||
| ElementA, | ||
| cutlass::layout::RowMajor, | ||
| 32, |
There was a problem hiding this comment.
nit: would help with readability to define give these constant args variable names IMO
There was a problem hiding this comment.
yeah good point, will address these nits when I add in the tile config tuning just want to get unblocked for now.
| device_guard.emplace(tensor_a.device()); | ||
| } | ||
|
|
||
| using K = SparseRowwiseKernel<cutlass::float_e4m3_t>; |
There was a problem hiding this comment.
nit: more descriptive variable name would be helpful
…sity (#2242) * [sparse] Add fp8 sparse gemm with rowwise scaling for activation sparsity Summary: We have this gemm already in torchao, but for weight sparsity. For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode. Test Plan: Reviewers: Subscribers: Tasks: Tags: * remove cutlass compression * ruff fix * one more ruff fix * don't build for CUDA 11.8 * fix formatting * ifdef to avoid issues
Summary:
We have this gemm already in torchao, but for weight sparsity, which assumes the weights are in row-major formats and are sparse
For activation sparsity, we need the weights to be stored in column-major format to allow for us to use the selective weight loading kernel for decode.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: