-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[FEATURE] Support Merging LoRA Weights Into Base Model (Issue-3603) #3649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…to be able to use the base model as the standalone model, rather than having to first load the base model and then loading the fine-tuned weights.
Unit Test Results 6 files ±0 6 suites ±0 58m 24s ⏱️ + 4m 34s For more details on these failures, see this check. Results for commit c4f2185. ± Comparison against base commit 1286123. ♻️ This comment has been updated with latest results. |
|
Hi @alexsherstinsky! Great work on this PR and thanks for your contribution - I will try and review by EOW! |
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
Hi @arnavgarg1! Thank you so much for your help and support. Could you please wait a little bit -- I want to try to add another test, specific to this particular new feature -- I will let you know once it is ready (hopefully later today). Thanks a lot! |
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is an awesome addition, and I think the way it's integrated into the config is quite clean.
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
| # For a full explanation of this 8-bit workaround, see https://github.com/ludwig-ai/ludwig/pull/3606 | ||
| def filter_for_weight_format(i): | ||
| """Remove bitsandbytes metadata keys added on state dict creation. | ||
| # def filter_for_weight_format(i): | ||
| # """Remove bitsandbytes metadata keys added on state dict creation. | ||
| # | ||
| # 8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict. | ||
| # These contain strings that are used to reshape quantized tensors, however these have no impact until the state | ||
| # dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in | ||
| # the evaluation. | ||
| # """ | ||
| # return "weight_format" not in i[0] | ||
|
|
||
| # model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items()) | ||
| # model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items()) | ||
|
|
||
| 8-bit quantized models that have been put on gpu will have a set of `weight_format` keys in their state dict. | ||
| These contain strings that are used to reshape quantized tensors, however these have no impact until the state | ||
| dict is loaded into a model. These keys were causing `torch.equal` to raise an exception, so we skip them in the | ||
| evaluation. | ||
| """ | ||
| return "weight_format" not in i[0] | ||
| # Source: https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351/6 | ||
|
|
||
| model_1_filtered_state_dict = filter(filter_for_weight_format, model_1.state_dict().items()) | ||
| model_2_filtered_state_dict = filter(filter_for_weight_format, model_2.state_dict().items()) | ||
| if model_1.__class__.__name__ != model_2.__class__.__name__: | ||
| return False | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with commenting out some of these lines of code for now since we still don't have GPU tests setup for Ludwig (currently a work in progress), but this check for filtering the state dict for weight formats with 8 bit quantization is necessary to make sure the tests for comparing models works correctly when we test fine-tuning with 8-bit quantization!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd maybe create a comment with a TODO here to re-enable those lines of code when GPU tests are enabled? That should be fine for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@arnavgarg1 The issue is that the filter_for_weight_format() method (which focuses on 8-bit quantization) is not currently used, and it was causing linter errors. Could you please suggest what we should do? Should we keep it commented out for now (I added a TODO in the updated PR), or try to re-enable it and use it? Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be okay to keep it commented out for now since we have the TODO - maybe I can take a look at that in a follow-up PR, but for now I don't want to block merging this awesome change into Ludwig because of it since it's causing no harm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc: @jeffkinnison just a quick FYI so you're not surprised by this change in the test. I will work on adding it back when GPU tests get enabled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice work on this! I especially love the way you've written out your tests for making sure it is working as expected - thanks for being so thorough and for this awesome contribution!
I left some minor comments, but none of them are blocking merging this PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice work on this! I especially love the way you've written out your tests for making sure it is working as expected - thanks for being so thorough and for this awesome contribution!
I left some minor comments, but none of them are blocking merging this PR!
…prove language of raised exceptions for LoRA model usage (based on PR feedback).
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
…ing/support_merging_lora_weights_into_base_model-2023_09_13-0
Description
This contribution provides an implementation of
LLM.merge_and_unload(), which is triggered fromLudwigModel.train(), based on configuration. The "merge and unload" behavior merges the fine-tuned LoRA weights into the base model so that the users can load one complete model (e.g., from HuggingFace) in a singleAutoModelForCausalLM.from_pretrained()call, rather than using two calls (first loading the base model withAutoModelForCausalLM.from_pretrained()and then loading the fine-tuned weights withPeftModelForCausalLM.from_pretrained()). This capability facilitates portability of the inference function between Ludwig (LudwigModel.load()followed bymodel.predict()for inference) and others (e.g., HuggingFace usingAutoModelForCausalLM.from_pretrained()followed bytransformers.pipeline()for inference).The configuration consists of extending the
adaptersection with the optionalpostprocessorsection as follows:(If the
merge_adapter_into_base_modelis kept and set tofalse, then theprogressbardirective can be omitted.)Code Pull Requests
Please provide the following:
Documentation Pull Requests
Note that the documentation HTML files are in
docs/while the Markdown sources are inmkdocs/docs.If you are proposing a modification to the documentation you should change only the Markdown files.
api.mdis automatically generated from the docstrings in the code, so if you want to change something in that file, first modifyludwig/api.pydocstring, then runmkdocs/code_docs_autogen.py, which will createmkdocs/docs/api.md.