-
Notifications
You must be signed in to change notification settings - Fork 83
Add native serialization support #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
By the way, I didn't really implement nor test Nevertheless, @edaxberger feel free to implement serialization on |
Though I must say that a quick test on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for tackling this, will be very useful!
I think currently the assumption is that model, likelihood, subset_of_weights, and hessian_structure of la and la2 (in the README example) match. It might make sense to explicitly check for this in the load_state_dict method and throw descriptive errors in case of a mismatch. The necessary information for this has to be also stored (see my comment below).
Additionally, it would be nice to allow for the creation of a new Laplace class instance just based on the saved state dict. A classmethod would make sense for this, but Laplace is a function and not a class, making this option impossible (without other changes). So alternatively, we could add a function like LaplaceFromStateDict that takes the state dict as an argument and returns a Laplace class instance which is (hopefully completely) equivalent to the previously saved one. I could add this functionality in a follow up PR though.
|
@runame updated. Please check and merge if everything's good. I decided to not include |
|
I also ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will review the update soon!
I should have elaborated on this, but there is a very concrete use case where the backend and backend_kwargs are needed, i.e. whenever you want to set fit(train_loader, override=False). For example, this is used to implement continual learning, see our experiment for the Redux paper. (I guess some mismatch between backends is fine, since only the shape (or both Kron) of self.H matters.)
I'm fine with not including model.state_dict() for now, this will only be important for the LaplaceFromStateDict functionality that I might add in a follow up PR.
|
I think the current checks are sufficient to handle continual learning: This will check whether the Laplace/laplace/baselaplace.py Lines 792 to 796 in 30e28f2
This will check the network (torch model) is correct: Laplace/laplace/baselaplace.py Lines 797 to 802 in 30e28f2
As for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok makes sense, only have one comment left.
|
Added a test to check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, last questions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for incorporating all feedback!
Addressing #45. Very useful for large models like LLMs where even doing forward passes over training data (for
fit()) is expensive.The API basically follows PyTorch.