[espnet3-7](2) Add Callbacks#6249
Conversation
|
Details on test cases is listed here: #6180 (comment) |
There was a problem hiding this comment.
Code Review
This pull request introduces new callbacks for PyTorch Lightning, including a callback to average the best K checkpoints. The implementation is mostly solid, and it comes with a comprehensive test suite.
I've found a few issues that should be addressed:
- In
espnet3/trainer/callbacks.py, there's a potentialTypeErrorif no checkpoints are available for averaging. Also, the data type check for averaging is a bit fragile and could be made more robust. - In
espnet3/trainer/trainer.py, a change from_del_config_keytopopcould cause a regression when usingargparse.Namespacefor configuration.
Additionally, it seems that the new callbacks are not actually passed to the lightning.Trainer instance in espnet3/trainer/trainer.py, as the callbacks argument is still commented out. This would prevent the new functionality from working.
There was a problem hiding this comment.
Code Review
This pull request introduces new callback functionality for checkpoint averaging and standardizes callback creation. My review focuses on improving security, robustness, and correctness. I've identified a critical security vulnerability in how checkpoints are loaded, a potential crash when no checkpoints are available for averaging, and a bug in configuration handling that could affect different configuration object types. I've provided suggestions to fix these issues and also recommended adding a test case for an important edge case.
|
I think this is a great idea to implement an average checkpoints (or other monitoring values) via callback. @Emrys365, can you check this PR? |
| for callback in self.config.callbacks: | ||
| callbacks.append(instantiate(callback)) |
There was a problem hiding this comment.
Can it happen that the defined callbacks in self.config are duplicates of the default callbacks? In that case, would it better to detect such cases to display warnings?
There was a problem hiding this comment.
Thank you, the current implementation does not account for cases where callbacks defined in get_default_callbacks overlap with those specified in the config. For example, LearningRateMonitor may be registered twice, which can result in duplicated log outputs..
In the current behavior of Lightning this is not a critical issue, so it does not immediately cause an error.
| class AverageCheckpointsCallback(Callback): | ||
| """ | ||
| A custom PyTorch Lightning callback that performs weight averaging over the top-K | ||
| checkpoints (according to specified metrics) at the end of training. |
There was a problem hiding this comment.
Does model averaging only happen at the end of training?
This would be an issue when people try to get the intermediate results with the checkpoint.
There was a problem hiding this comment.
Thank you! Right now it's only hooked into on_fit_end, so it fires once at the very end. I will switch to on_validation_end to trigge after every validation round.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## espnet3 #6249 +/- ##
===========================================
+ Coverage 68.96% 68.98% +0.01%
===========================================
Files 750 751 +1
Lines 68915 68974 +59
===========================================
+ Hits 47530 47584 +54
- Misses 21385 21390 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
What did you change?
Added a new module:
espnet3/trainer/callbacks.py, which includes:AverageCheckpointsCallback: a custom Lightning callback for averaging top-K model checkpoints.get_default_callbacks(): a utility to create a standard set of callbacks including checkpointing, progress bar, and LR monitoring.Integrated
get_default_callbacksinto the training loop inespnet3/trainer/trainer.py.Created a new test module
test/espnet3/test_callback.py:Why did you make this change?
Is your PR small enough?
Yes. 3 files changed, ~500 additions.
Additional Context