-
Notifications
You must be signed in to change notification settings - Fork 89
[example] fused_linear_cross_entropy kernel #1268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Hi @tyler-romero! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
Timings on an RTX 4090, "quick" autotuning, fp32 inputs, "fwd" mode. Tritonbench defaults to fp32 inputs/weights, which is unrealistic. When running in bf16, other implementations don't follow common practice (at least for llm pretraining) of upcasting logits from bf16 to fp32. Tritonbench also defaults to "fwd" mode, which hurts liger (precomputes gradients in the fwd pass). Updating the benchmark args to better defaults. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tyler-romero thanks a lot for the PR! it looks great overall, only some nit comments below:
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tyler-romero thanks!
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw we might still need to fix the lint errors (please let me know if you need help with them), thanks!
|
Thanks! Lint issues should be addressed now. Waiting on final benchmark. |
Fused linear ce layer that avoids ever materializing logits.
The forward pass works by using a tiled matmul and online softmax fn to compute lse and the target logits. The lse values are stored for the backward pass.
The backward pass works using a similar tiled matmul and online softmax to recompute the logits (expensive). In the same online stream,
grad_inputis computed. An atomic add is used to accumulategrad_weight.I noticed after writing this that there is an existing flce example in an old open PR (#342). This one is interesting/different because it doesn't follow liger's chunking and gradient precomputation strategy and so has a different performance profile (eg lower memory usage, faster fwdbwd in many scenarios, forward pass is very fast)