Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@tyler-romero
Copy link

@tyler-romero tyler-romero commented Dec 15, 2025

Fused linear ce layer that avoids ever materializing logits.

The forward pass works by using a tiled matmul and online softmax fn to compute lse and the target logits. The lse values are stored for the backward pass.

The backward pass works using a similar tiled matmul and online softmax to recompute the logits (expensive). In the same online stream, grad_input is computed. An atomic add is used to accumulate grad_weight.

I noticed after writing this that there is an existing flce example in an old open PR (#342). This one is interesting/different because it doesn't follow liger's chunking and gradient precomputation strategy and so has a different performance profile (eg lower memory usage, faster fwdbwd in many scenarios, forward pass is very fast)

@meta-cla
Copy link

meta-cla bot commented Dec 15, 2025

Hi @tyler-romero!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@meta-cla
Copy link

meta-cla bot commented Dec 15, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 15, 2025
@meta-cla
Copy link

meta-cla bot commented Dec 15, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@tyler-romero
Copy link
Author

tyler-romero commented Dec 15, 2025

Timings on an RTX 4090, "quick" autotuning, fp32 inputs, "fwd" mode.

uv run python benchmarks/run.py --metrics latency,gpu_peak_mem,speedup,accuracy --kernel fused_linear_cross_entropy
(B*T, H)    torch_lm_head_ce-latency    torch_lm_head_ce-gpu_peak_mem    liger_lm_head_ce-latency    liger_lm_head_ce-gpu_peak_mem    liger_lm_head_ce-speedup    liger_lm_head_ce-accuracy    torch_compile_fused_linear_cross_entropy-latency    torch_compile_fused_linear_cross_entropy-gpu_peak_mem    torch_compile_fused_linear_cross_entropy-speedup    torch_compile_fused_linear_cross_entropy-accuracy    helion_helion_fused_linear_cross_entropy_tritonbench-latency    helion_helion_fused_linear_cross_entropy_tritonbench-gpu_peak_mem    helion_helion_fused_linear_cross_entropy_tritonbench-speedup    helion_helion_fused_linear_cross_entropy_tritonbench-accuracy
-------------  --------------------------  -------------------------------  --------------------------  -------------------------------  --------------------------  ---------------------------  --------------------------------------------------  -------------------------------------------------------  --------------------------------------------------  ---------------------------------------------------  --------------------------------------------------------------  -------------------------------------------------------------------  --------------------------------------------------------------  ---------------------------------------------------------------
 (4096, 4096)          96.116737 (±0.00%)                           6.3797         490.204163 (±0.00%)                          6.5125                     0.196075                            1                                  91.594749 (±0.00%)                                                  4.27839                                             1.04937                                             1                                                     76.863487 (±0.00%)                                                              2.17703                                                         1.25049                                                                1
 (8192, 4096)         197.682175 (±0.00%)                          10.6495         692.483093 (±0.00%)                          6.71243                    0.285469                            1                                 191.182846 (±0.00%)                                                  6.44691                                             1.034                                               1                                                    115.140610 (±0.00%)                                                              2.24418                                                         1.71688                                                                1
(16384, 4096)         397.231110 (±0.00%)                          19.1892        1166.612427 (±0.00%)                          7.11231                    0.3405                              1                                 377.675781 (±0.00%)                                                 10.784                                               1.05178                                             0                                                    211.524612 (±0.00%)                                                              2.3785                                                          1.87794                                                                1
      average           230.3433405558268                          12.0728           783.0998942057291                          6.77908                    0.274014                            1                                   220.1511255900065                                                  7.16975                                             1.04505                                             0.666667                                               134.5095698038737                                                              2.26657                                                         1.6151                                                                 1

Tritonbench defaults to fp32 inputs/weights, which is unrealistic. When running in bf16, other implementations don't follow common practice (at least for llm pretraining) of upcasting logits from bf16 to fp32. Tritonbench also defaults to "fwd" mode, which hurts liger (precomputes gradients in the fwd pass).

Updating the benchmark args to better defaults.

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tyler-romero thanks a lot for the PR! it looks great overall, only some nit comments below:

@yf225 yf225 marked this pull request as ready for review December 15, 2025 22:00
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tyler-romero thanks!

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw we might still need to fix the lint errors (please let me know if you need help with them), thanks!

@tyler-romero
Copy link
Author

Thanks! Lint issues should be addressed now. Waiting on final benchmark.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants