-
Notifications
You must be signed in to change notification settings - Fork 28.9k
Handle torch ver in flexattn #37400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle torch ver in flexattn #37400
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
@@ -66,7 +66,7 @@ def __init__(self, training): | |||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" | |||
# see https://github.com/pytorch/pytorch/issues/146260 for training | |||
self.training = training | |||
if _torch_version.split("+")[0] == "2.6.0" and training: | |||
if is_torch_greater_or_equal("2.6.0") and training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/pytorch#143299 should've fixed this issue so it makes more sense to directly look for 2.6.0 and not <=
I think it's fine to use from packaging import version ...
instead of creating another function in the utils 👀 but up to debate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vasqu not what you mean here 🤔 this will check ">= 2.6.0", is_torch_greater_or_equal
's already part of the package
I think it makes the check more futureproof
am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch version guard was introduced for torch==2.6.0
explicitly.
The PR I linked fixed some issues which should remove the need for this check, i.e. we don't need to compile with "max-autotune-no-cudagraphs"
. This means that future versions should also not need it which is why I suggested an == and not a 2.6.0<=.
Edit: the wording before was probably less than ideal :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok mb I got!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the PR to usd version
instead of manual str checking, lmk how it looks to you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Handle torch ver in flexattn * update
* Handle torch ver in flexattn * update
Follow up #37399
@ArthurZucker