Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add DeepSpeed trainer for large-scale training#5856

Merged
sw005320 merged 25 commits intoespnet:masterfrom
jctian98:deepspeed
Aug 26, 2024
Merged

Add DeepSpeed trainer for large-scale training#5856
sw005320 merged 25 commits intoespnet:masterfrom
jctian98:deepspeed

Conversation

@jctian98
Copy link
Collaborator

@jctian98 jctian98 commented Aug 2, 2024

What?

This PR adds another trainer object that warp the DeepSpeed so that it will automatically handle many trainer-related things, especially some advanced features:

  • Tensor partition, with zero-{1,2,3} strategy
  • Activation checkpointing
  • Parameter offloading
  • Distributed checkpoint saving / loading
  • full BF16 training
  • on-cpu optimizer
  • etc.

(1) This DeepSpeed trainer is based on the data parallelism and will allow us to train model as large as ~13B.
(2) It can be smoothly switched from previous ESPnet trainer.
(3) Unlike model parallelism (will be needed when >13B), this trainer doesn't pose any requirement to the model architecture.

To use it, simply change add these lines in training config:
use_deepspeed: true
deepspeed_config: <path-to-config>.json
Most trainer-related options with be moved to deepspeed_config, so that the training config will only need to define things like model arch and data loader.

Discussion:
what is the good place to add a README.md file? or is it needed?

Will request @wanchichen to take a look.

Thanks

@sw005320 sw005320 added this to the v.202405 milestone Aug 2, 2024
@sw005320
Copy link
Contributor

sw005320 commented Aug 2, 2024

@wanchichen, can you review this PR?

@codecov
Copy link

codecov bot commented Aug 5, 2024

Codecov Report

Attention: Patch coverage is 5.02793% with 170 lines in your changes missing coverage. Please review.

Project coverage is 43.06%. Comparing base (5acc3c8) to head (7e5f289).
Report is 64 commits behind head on master.

Files Patch % Lines
espnet2/train/deepspeed_trainer.py 0.00% 145 Missing ⚠️
espnet2/train/distributed_utils.py 14.28% 12 Missing ⚠️
espnet2/tasks/abs_task.py 30.00% 7 Missing ⚠️
espnet2/torch_utils/device_funcs.py 16.66% 5 Missing ⚠️
espnet2/asr/ctc.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #5856       +/-   ##
===========================================
+ Coverage        0   43.06%   +43.06%     
===========================================
  Files           0      819      +819     
  Lines           0    75193    +75193     
===========================================
+ Hits            0    32384    +32384     
- Misses          0    42809    +42809     
Flag Coverage Δ
test_integration_espnet1 62.62% <ø> (?)
test_integration_espnet2 47.88% <26.47%> (?)
test_integration_espnetez 27.87% <26.47%> (?)
test_python_espnetez 13.43% <2.23%> (?)
test_utils 20.61% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wanchichen
Copy link
Contributor

Great work! The implementation is very clean, so I do not have many comments. Only two questions:

  • Does full bf16 training work with models that use conv. layers?
  • Is zero-3 tested? We might also be able to remove the logging all reduce (
    stats, weight = recursive_average(stats, weight, True)
    ), since I don't think all devices will have unique metrics in the zero3 case, which makes the operation very wasteful.


with reporter.measure_time("step_time"):
# (0) ensure all ranks have not finished.
dist.all_reduce(iterator_stop, ReduceOp.SUM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all-reduce is not necessary (same with the one in valid_one_epoch). We can remove lines 185-187

Copy link

@xingchensong xingchensong Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that forced synchronization is necessary, as it functions like a join operation, allowing all ranks to start forwarding simultaneously and avoiding timeouts. Torch_DDP does not require this synchronization operation, partly because DDP has its own join function, and partly because, for distributed training like DeepSpeed, the communication between machines is more complex and timeouts are more likely to occur. Therefore, sacrificing some waiting time (in fact, I believe the proportion of waiting time will not be very high) in exchange for training stability is necessary. cc @jctian98

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the both comments :)
Sice the ZeRO will always have extensive communication, maybe this force sync is ok to use. I'll post more observation if I find these lines are the performance bottleneck.

@jctian98
Copy link
Collaborator Author

jctian98 commented Aug 6, 2024

@wanchichen Thanks for the review!

  • With BF16 and the zero2 config I posted, it can train a conformer model without errors / warnings.
  • I have tried zero3 and it can work - but slower than zero2 maybe be due to heavier cpu-gpu communication. We can further explore zero3 config, but maybe not need to change the code.
  • Could you explain a bit why I don't think all devices will have unique metrics in the zero3 case? is there anything that is different from standard DDP (in terms of computing results)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this?
you removed this in the other part of the latest commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto
do we need this?
you removed this in the other part in the latest commit.

@sw005320
Copy link
Contributor

sw005320 commented Aug 6, 2024

  • Do you have some example usages (e.g., config)?
  • It's better to mention the DeepSpeed support in the top README.md

@mergify mergify bot added the README label Aug 6, 2024
@sw005320
Copy link
Contributor

sw005320 commented Aug 6, 2024

Sounds good.
Is it everything, @jctian98?
If it is ready, I'll merge this PR.

@jctian98
Copy link
Collaborator Author

jctian98 commented Aug 6, 2024

Sounds good. Is it everything, @jctian98? If it is ready, I'll merge this PR.

We may still discuss with @wanchichen about if we should remove the all_reduces in the training loop to avoid communication overhaed. But we can merge for now, and revise it later if necessary.

Comment on lines +8 to +23
"zero_optimization": {
"stage": 2,
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
},
"zero_optimization": {
"stage": 2,
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
},

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant configuration?

@@ -0,0 +1,47 @@
{
"train_batch_size": 32,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_batch_size (32 in this case) might be conflict with train.yaml::batch_size (64 in this case)

@jctian98
Copy link
Collaborator Author

jctian98 commented Aug 9, 2024

@xingchensong Many thanks for the review!!! I have updated the example config.

Current example is simple a toy. Later when we run some experiments at large scale we can share our deepspeed config together with the recipes.

@sw005320 if CI is ok, I think this PR is ready to merge.

@sw005320 sw005320 merged commit b54ea65 into espnet:master Aug 26, 2024
@sw005320
Copy link
Contributor

Thanks a lot, @jctian98!

@jctian98 jctian98 deleted the deepspeed branch November 3, 2025 00:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants