Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Code Review
This pull request introduces a DeepSpeed trainer for SpeechLM, including the training script and the trainer implementation. The code is generally well-structured. However, I've identified critical issues concerning metrics calculation in a distributed training environment. Both training and validation statistics are not aggregated across all GPUs before logging, which can result in incorrect and misleading metrics. The validation metrics are particularly concerning as they are computed on only a fraction of the validation dataset. I have provided specific code suggestions to rectify these issues by properly aggregating statistics using torch.distributed.all_reduce.
| all_stats = {} | ||
|
|
||
| with torch.no_grad(): | ||
| for batch in iterator: | ||
| batch = to_device(batch, "cuda", dtype=self.dtype) | ||
| out = self.model_engine(**batch) | ||
|
|
||
| stats = {k: float(v) for k, v in out["stats"].items()} | ||
| for key, value in stats.items(): | ||
| if key not in all_stats: | ||
| all_stats[key] = [] | ||
| all_stats[key].append(value) | ||
|
|
||
| # Compute averages and log (should be outside the batch loop) | ||
| all_stats = { | ||
| f"val/{name}/{key}": sum(value) / len(value) | ||
| for key, value in all_stats.items() | ||
| } | ||
| wandb.log(all_stats, step=self.global_step) | ||
|
|
There was a problem hiding this comment.
The validation metrics are calculated incorrectly in a distributed setting. The current implementation computes the average statistics on each rank independently, based only on its local shard of the validation data. The results are not aggregated across all ranks. This means the logged validation metrics are based on a fraction of the full validation set, which is incorrect and misleading. You should aggregate the statistics (e.g., sums and counts) from all ranks before computing the final average on rank 0.
from collections import defaultdict
# Collect sums and counts for each metric on the local rank
local_stats_sum = defaultdict(float)
local_stats_count = defaultdict(int)
with torch.no_grad():
for batch in iterator:
batch = to_device(batch, "cuda", dtype=self.dtype)
out = self.model_engine(**batch)
stats = {k: v.item() for k, v in out["stats"].items()}
for key, value in stats.items():
local_stats_sum[key] += value
local_stats_count[key] += 1
stat_keys = sorted(local_stats_sum.keys())
if not stat_keys:
continue
# Aggregate stats across all GPUs
local_sums = torch.tensor(
[local_stats_sum[k] for k in stat_keys], device="cuda"
)
local_counts = torch.tensor(
[local_stats_count[k] for k in stat_keys],
device="cuda",
dtype=torch.long,
)
torch.distributed.all_reduce(local_sums, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(
local_counts, op=torch.distributed.ReduceOp.SUM
)
if torch.distributed.get_rank() == 0:
final_stats = {}
for i, key in enumerate(stat_keys):
total_sum = local_sums[i].item()
total_count = local_counts[i].item()
if total_count > 0:
final_stats[f"val/{name}/{key}"] = total_sum / total_count
if final_stats:
wandb.log(final_stats, step=self.global_step)
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #6278 +/- ##
=======================================
Coverage 56.49% 56.49%
=======================================
Files 896 896
Lines 84814 84814
=======================================
Hits 47914 47914
Misses 36900 36900
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…into deepspeed_trainer merge remote
|
I've reflected Gemini's review comments and added sync ops for stats. |
| """Execute one training epoch.""" | ||
| self.model_engine.train() | ||
|
|
||
| iterator = self.train_data_factory.get_iterator( |
There was a problem hiding this comment.
Can we change this to build_iter? Also I think it's good to inherite espnet2.iterators.abs_iter_factory.AbsIterFactory
ESPnet-3 is based on the espnet2's AbsIterFactory, so fixing this part will also nice for espnet-3!
There was a problem hiding this comment.
I changed this name, but currently I want to have zero dependency on code outside espnet2/speechlm, so can we address this later?
There was a problem hiding this comment.
Also change the "get_iterator" -> "build_iter" in #6280
merge tokenizer pr
This PR adds the two files that support DeepSpeed trainer logics:
(1) train.py: the overall training launcher script, which initializes distributed training, logging, model, data loader, etc.
(2) deepspeed_trainer.py: the deepspeed trainer wrapper with given model and data loader objects.
Prior PRs: #6257 , #6258 , #6260
Request review: @Masao-Someki @wanchichen @siddhu001