Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@mobicham
Copy link
Contributor

@mobicham mobicham commented Jun 13, 2025

Fixes generate.py benchmarking with gemlite.

Also, the current code gives OOM on smaller gpus. By putting the weights on cpu first, we can avoid this issue.

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2372

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 862e0cb with merge base 6243040 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 13, 2025
@jerryzh168 jerryzh168 added the topic: bug fix Use this tag for PRs that fix bugs label Jun 16, 2025
@jerryzh168 jerryzh168 merged commit 7a846d5 into pytorch:main Jun 24, 2025
18 of 20 checks passed
def _load_model(checkpoint_path, device, precision):
checkpoint = torch.load(
str(checkpoint_path), mmap=True, weights_only=True, map_location=device
str(checkpoint_path), mmap=True, weights_only=True, map_location="cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually is the map_location change correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, without this, you can't load a Llama3-8B fp16 on a 24GB gpu, you just get OOM.
This is the part I mentioned:

Also, the current code gives OOM on smaller gpus. By putting the weights on cpu first, we can avoid this issue.

Copy link
Contributor

@jerryzh168 jerryzh168 Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh OK, but do we need to change it somewhere or it has to run in cpu

I think if user request to run on some device but it fails, we should probably not silently change the device, but instead ask user to use a different device instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it still runs on gpu, there's a model.to(device): https://github.com/mobicham/ao/blob/862e0cbd90f8cc5f992ae66e779022912fa4d93a/torchao/_models/llama/generate.py#L245-L256

The issue is that, loading the weights via map_location='cuda' + doing model.to(device) for some reason uses more vram leading to oom. Maybe there's a cleaner way of doing it.

Copy link
Contributor

@jerryzh168 jerryzh168 Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh OK, makes sense, this should be OK, some quantized format does not support loading in CPU but for full/half precision model it should be fine

liangel-02 pushed a commit that referenced this pull request Aug 25, 2025
* fix get_plain() with FMA mode

* update

* fix in_features/out_feature meta-data mismatch

* update gemlite slice test

* add packing_bitwidth support

* add packing_bitwidth support and cleanup

* update default gemlite layout

* cleanup

* fix symmetric use-case and relax _same_meta_data

* _copy() meta data

* fix (4,) in autoquant

* Add dynamic mode in gemlite layout

* mode explanation

Signed-off-by: mobicham <[email protected]>

* use weights_only instead of static

* generate fix

Signed-off-by: mobicham <[email protected]>

* remove set_packing_bitwidth

---------

Signed-off-by: mobicham <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants