Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[SpeechLM] Deepspeed trainer#6278

Merged
jctian98 merged 8 commits intoespnet:masterfrom
jctian98:deepspeed_trainer
Oct 31, 2025
Merged

[SpeechLM] Deepspeed trainer#6278
jctian98 merged 8 commits intoespnet:masterfrom
jctian98:deepspeed_trainer

Conversation

@jctian98
Copy link
Collaborator

This PR adds the two files that support DeepSpeed trainer logics:
(1) train.py: the overall training launcher script, which initializes distributed training, logging, model, data loader, etc.
(2) deepspeed_trainer.py: the deepspeed trainer wrapper with given model and data loader objects.

Prior PRs: #6257 , #6258 , #6260

Request review: @Masao-Someki @wanchichen @siddhu001

@dosubot dosubot bot added size:L This PR changes 100-499 lines, ignoring generated files. ESPnet2 New Features labels Oct 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeepSpeed trainer for SpeechLM, including the training script and the trainer implementation. The code is generally well-structured. However, I've identified critical issues concerning metrics calculation in a distributed training environment. Both training and validation statistics are not aggregated across all GPUs before logging, which can result in incorrect and misleading metrics. The validation metrics are particularly concerning as they are computed on only a fraction of the validation dataset. I have provided specific code suggestions to rectify these issues by properly aggregating statistics using torch.distributed.all_reduce.

Comment on lines +167 to +186
all_stats = {}

with torch.no_grad():
for batch in iterator:
batch = to_device(batch, "cuda", dtype=self.dtype)
out = self.model_engine(**batch)

stats = {k: float(v) for k, v in out["stats"].items()}
for key, value in stats.items():
if key not in all_stats:
all_stats[key] = []
all_stats[key].append(value)

# Compute averages and log (should be outside the batch loop)
all_stats = {
f"val/{name}/{key}": sum(value) / len(value)
for key, value in all_stats.items()
}
wandb.log(all_stats, step=self.global_step)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The validation metrics are calculated incorrectly in a distributed setting. The current implementation computes the average statistics on each rank independently, based only on its local shard of the validation data. The results are not aggregated across all ranks. This means the logged validation metrics are based on a fraction of the full validation set, which is incorrect and misleading. You should aggregate the statistics (e.g., sums and counts) from all ranks before computing the final average on rank 0.

            from collections import defaultdict

            # Collect sums and counts for each metric on the local rank
            local_stats_sum = defaultdict(float)
            local_stats_count = defaultdict(int)

            with torch.no_grad():
                for batch in iterator:
                    batch = to_device(batch, "cuda", dtype=self.dtype)
                    out = self.model_engine(**batch)

                    stats = {k: v.item() for k, v in out["stats"].items()}
                    for key, value in stats.items():
                        local_stats_sum[key] += value
                        local_stats_count[key] += 1

            stat_keys = sorted(local_stats_sum.keys())
            if not stat_keys:
                continue

            # Aggregate stats across all GPUs
            local_sums = torch.tensor(
                [local_stats_sum[k] for k in stat_keys], device="cuda"
            )
            local_counts = torch.tensor(
                [local_stats_count[k] for k in stat_keys],
                device="cuda",
                dtype=torch.long,
            )

            torch.distributed.all_reduce(local_sums, op=torch.distributed.ReduceOp.SUM)
            torch.distributed.all_reduce(
                local_counts, op=torch.distributed.ReduceOp.SUM
            )

            if torch.distributed.get_rank() == 0:
                final_stats = {}
                for i, key in enumerate(stat_keys):
                    total_sum = local_sums[i].item()
                    total_count = local_counts[i].item()
                    if total_count > 0:
                        final_stats[f"val/{name}/{key}"] = total_sum / total_count
                if final_stats:
                    wandb.log(final_stats, step=self.global_step)

@codecov
Copy link

codecov bot commented Oct 27, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 56.49%. Comparing base (c3f4fe0) to head (5bbc8c7).
⚠️ Report is 43 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6278   +/-   ##
=======================================
  Coverage   56.49%   56.49%           
=======================================
  Files         896      896           
  Lines       84814    84814           
=======================================
  Hits        47914    47914           
  Misses      36900    36900           
Flag Coverage Δ
test_integration_espnet2 46.81% <ø> (ø)
test_integration_espnetez 36.93% <ø> (ø)
test_python_espnet2 50.93% <ø> (ø)
test_python_espnetez 12.73% <ø> (ø)
test_utils 18.77% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jctian98
Copy link
Collaborator Author

I've reflected Gemini's review comments and added sync ops for stats.
The PR is ready to review now.

"""Execute one training epoch."""
self.model_engine.train()

iterator = self.train_data_factory.get_iterator(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this to build_iter? Also I think it's good to inherite espnet2.iterators.abs_iter_factory.AbsIterFactory
ESPnet-3 is based on the espnet2's AbsIterFactory, so fixing this part will also nice for espnet-3!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this name, but currently I want to have zero dependency on code outside espnet2/speechlm, so can we address this later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also change the "get_iterator" -> "build_iter" in #6280

@Fhrozen Fhrozen added this to the v.202512 milestone Oct 29, 2025
@dosubot dosubot bot added size:XXL This PR changes 1000+ lines, ignoring generated files. and removed size:L This PR changes 100-499 lines, ignoring generated files. labels Oct 30, 2025
@dosubot dosubot bot added size:L This PR changes 100-499 lines, ignoring generated files. and removed size:XXL This PR changes 1000+ lines, ignoring generated files. labels Oct 30, 2025
@jctian98 jctian98 merged commit 07a304f into espnet:master Oct 31, 2025
32 checks passed
@jctian98 jctian98 deleted the deepspeed_trainer branch November 3, 2025 00:55
@Fhrozen Fhrozen modified the milestones: v.202512, v.202511 Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ESPnet2 New Features size:L This PR changes 100-499 lines, ignoring generated files.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants