Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[espnet3-7](2) Add Callbacks#6249

Merged
Masao-Someki merged 11 commits intoespnet:espnet3from
Masao-Someki:espnet3/callback-2
Oct 2, 2025
Merged

[espnet3-7](2) Add Callbacks#6249
Masao-Someki merged 11 commits intoespnet:espnet3from
Masao-Someki:espnet3/callback-2

Conversation

@Masao-Someki
Copy link
Contributor

What did you change?

  • Added a new module: espnet3/trainer/callbacks.py, which includes:

    • AverageCheckpointsCallback: a custom Lightning callback for averaging top-K model checkpoints.
    • get_default_callbacks(): a utility to create a standard set of callbacks including checkpointing, progress bar, and LR monitoring.
  • Integrated get_default_callbacks into the training loop in espnet3/trainer/trainer.py.

  • Created a new test module test/espnet3/test_callback.py:


Why did you make this change?

  • To support checkpoint ensembling by averaging top-K models.
  • To standardize callback configuration in the trainer, ensuring consistency and reusability.
  • To ensure correctness and robustness of the new logic via a comprehensive test suite.

Is your PR small enough?

Yes. 3 files changed, ~500 additions.


Additional Context

@dosubot dosubot bot added size:L This PR changes 100-499 lines, ignoring generated files. ESPnet3 New Features labels Sep 26, 2025
@Masao-Someki
Copy link
Contributor Author

Details on test cases is listed here: #6180 (comment)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new callbacks for PyTorch Lightning, including a callback to average the best K checkpoints. The implementation is mostly solid, and it comes with a comprehensive test suite.

I've found a few issues that should be addressed:

  • In espnet3/trainer/callbacks.py, there's a potential TypeError if no checkpoints are available for averaging. Also, the data type check for averaging is a bit fragile and could be made more robust.
  • In espnet3/trainer/trainer.py, a change from _del_config_key to pop could cause a regression when using argparse.Namespace for configuration.

Additionally, it seems that the new callbacks are not actually passed to the lightning.Trainer instance in espnet3/trainer/trainer.py, as the callbacks argument is still commented out. This would prevent the new functionality from working.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new callback functionality for checkpoint averaging and standardizes callback creation. My review focuses on improving security, robustness, and correctness. I've identified a critical security vulnerability in how checkpoints are loaded, a potential crash when no checkpoints are available for averaging, and a bug in configuration handling that could affect different configuration object types. I've provided suggestions to fix these issues and also recommended adding a test case for an important edge case.

@sw005320
Copy link
Contributor

I think this is a great idea to implement an average checkpoints (or other monitoring values) via callback.

@Emrys365, can you check this PR?

Copy link
Collaborator

@Emrys365 Emrys365 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general.

Comment on lines +98 to +99
for callback in self.config.callbacks:
callbacks.append(instantiate(callback))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it happen that the defined callbacks in self.config are duplicates of the default callbacks? In that case, would it better to detect such cases to display warnings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, the current implementation does not account for cases where callbacks defined in get_default_callbacks overlap with those specified in the config. For example, LearningRateMonitor may be registered twice, which can result in duplicated log outputs..

In the current behavior of Lightning this is not a critical issue, so it does not immediately cause an error.

@dosubot dosubot bot added the lgtm This PR has been approved by a maintainer label Sep 26, 2025
@dosubot dosubot bot added size:XL This PR changes 500-999 lines, ignoring generated files. and removed size:L This PR changes 100-499 lines, ignoring generated files. labels Sep 26, 2025
class AverageCheckpointsCallback(Callback):
"""
A custom PyTorch Lightning callback that performs weight averaging over the top-K
checkpoints (according to specified metrics) at the end of training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does model averaging only happen at the end of training?
This would be an issue when people try to get the intermediate results with the checkpoint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Right now it's only hooked into on_fit_end, so it fires once at the very end. I will switch to on_validation_end to trigge after every validation round.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

@codecov
Copy link

codecov bot commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 91.66667% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.98%. Comparing base (f5f1fd1) to head (2ae769e).
⚠️ Report is 13 commits behind head on espnet3.

Files with missing lines Patch % Lines
espnet3/trainer/trainer.py 50.00% 4 Missing ⚠️
espnet3/trainer/callbacks.py 98.07% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           espnet3    #6249      +/-   ##
===========================================
+ Coverage    68.96%   68.98%   +0.01%     
===========================================
  Files          750      751       +1     
  Lines        68915    68974      +59     
===========================================
+ Hits         47530    47584      +54     
- Misses       21385    21390       +5     
Flag Coverage Δ
test_integration_espnet2 47.23% <ø> (ø)
test_python_espnet2 61.95% <0.00%> (-0.06%) ⬇️
test_python_espnet3 16.00% <91.66%> (+0.06%) ⬆️
test_utils 61.95% <0.00%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Fhrozen Fhrozen added this to the v.202512 milestone Sep 28, 2025
@Masao-Someki Masao-Someki merged commit 4531bcc into espnet:espnet3 Oct 2, 2025
98 of 131 checks passed
@Masao-Someki Masao-Someki mentioned this pull request Oct 9, 2025
52 tasks
@Fhrozen Fhrozen modified the milestones: v.202512, v.202511 Nov 14, 2025
@Masao-Someki Masao-Someki deleted the espnet3/callback-2 branch November 26, 2025 18:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ESPnet3 lgtm This PR has been approved by a maintainer New Features size:XL This PR changes 500-999 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants