Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c36e668

Browse files
wz337Svetlana Karslioglu
andauthored
[DCP] Add Distributed Checkpoint tutorial (#2565)
* add DCP tutorial --------------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 8e37b67 commit c36e668

5 files changed

Lines changed: 290 additions & 2 deletions

File tree

22.6 KB
Loading
34.9 KB
Loading

index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ What's new in PyTorch tutorials?
265265
:tags: Text
266266

267267
.. customcarditem::
268-
:header: Pre-process custom text dataset using Torchtext
268+
:header: Pre-process custom text dataset using Torchtext
269269
:card_description: Learn how to use torchtext to prepare a custom dataset
270270
:image: _static/img/thumbnails/cropped/torch_text_logo.png
271271
:link: beginner/torchtext_custom_dataset_tutorial.html
@@ -592,7 +592,7 @@ What's new in PyTorch tutorials?
592592
:image: _static/img/thumbnails/cropped/pytorch-logo.png
593593
:link: intermediate/scaled_dot_product_attention_tutorial.html
594594
:tags: Model-Optimization,Attention,Transformer
595-
595+
596596
.. customcarditem::
597597
:header: Knowledge Distillation in Convolutional Neural Networks
598598
:card_description: Learn how to improve the accuracy of lightweight models using more powerful models as teachers.
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
Getting Started with Distributed Checkpoint (DCP)
2+
=====================================================
3+
4+
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__
5+
6+
.. note::
7+
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__.
8+
9+
10+
Prerequisites:
11+
12+
- `FullyShardedDataParallel API documents <https://pytorch.org/docs/master/fsdp.html>`__
13+
- `torch.load API documents <https://pytorch.org/docs/stable/generated/torch.load.html>`__
14+
15+
16+
Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training.
17+
Pytorch Distributed Checkpointing (DCP) can help make this process easier.
18+
19+
In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
20+
21+
22+
How DCP works
23+
--------------
24+
25+
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel.
26+
In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies.
27+
28+
DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways:
29+
30+
* It produces multiple files per checkpoint, with at least one per rank.
31+
* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
32+
33+
.. note::
34+
The code in this tutorial runs on an 8-GPU server, but it can be easily
35+
generalized to other environments.
36+
37+
How to use DCP
38+
--------------
39+
40+
Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing.
41+
42+
Saving
43+
~~~~~~
44+
45+
Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
46+
47+
.. code-block:: python
48+
49+
import os
50+
51+
import torch
52+
import torch.distributed as dist
53+
import torch.distributed.checkpoint as DCP
54+
import torch.multiprocessing as mp
55+
import torch.nn as nn
56+
57+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
58+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
59+
60+
CHECKPOINT_DIR = "checkpoint"
61+
62+
63+
class ToyModel(nn.Module):
64+
def __init__(self):
65+
super(ToyModel, self).__init__()
66+
self.net1 = nn.Linear(16, 16)
67+
self.relu = nn.ReLU()
68+
self.net2 = nn.Linear(16, 8)
69+
70+
def forward(self, x):
71+
return self.net2(self.relu(self.net1(x)))
72+
73+
74+
def setup(rank, world_size):
75+
os.environ["MASTER_ADDR"] = "localhost"
76+
os.environ["MASTER_PORT"] = "12355 "
77+
78+
# initialize the process group
79+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
80+
torch.cuda.set_device(rank)
81+
82+
83+
def cleanup():
84+
dist.destroy_process_group()
85+
86+
87+
def run_fsdp_checkpoint_save_example(rank, world_size):
88+
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
89+
setup(rank, world_size)
90+
91+
# create a model and move it to GPU with id rank
92+
model = ToyModel().to(rank)
93+
model = FSDP(model)
94+
95+
loss_fn = nn.MSELoss()
96+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
97+
98+
optimizer.zero_grad()
99+
model(torch.rand(8, 16, device="cuda")).sum().backward()
100+
optimizer.step()
101+
102+
# set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict
103+
# note that we do not support FSDP StateDictType.LOCAL_STATE_DICT
104+
FSDP.set_state_dict_type(
105+
model,
106+
StateDictType.SHARDED_STATE_DICT,
107+
)
108+
state_dict = {
109+
"model": model.state_dict(),
110+
}
111+
112+
DCP.save_state_dict(
113+
state_dict=state_dict,
114+
storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
115+
)
116+
117+
cleanup()
118+
119+
120+
if __name__ == "__main__":
121+
world_size = torch.cuda.device_count()
122+
print(f"Running fsdp checkpoint example on {world_size} devices.")
123+
mp.spawn(
124+
run_fsdp_checkpoint_save_example,
125+
args=(world_size,),
126+
nprocs=world_size,
127+
join=True,
128+
)
129+
130+
Please go ahead and check the `checkpoint` directory. You should see 8 checkpoint files as shown below.
131+
132+
.. figure:: /_static/img/distributed/distributed_checkpoint_generated_files.png
133+
:width: 100%
134+
:align: center
135+
:alt: Distributed Checkpoint
136+
137+
Loading
138+
~~~~~~~
139+
140+
After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size.
141+
142+
Please note that you will have to call :func:`model.state_dict` prior to loading and pass it to DCP's :func:`load_state_dict` API.
143+
This is fundamentally different from :func:`torch.load`, as :func:`torch.load` simply requires the path to the checkpoint prior for loading.
144+
The reason that we need the ``state_dict`` prior to loading is:
145+
146+
* DCP uses the pre-allocated storage from model state_dict to load from the checkpoint directory. During loading, the state_dict passed in will be updated in place.
147+
* DCP requires the sharding information from the model prior to loading to support resharding.
148+
149+
.. code-block:: python
150+
151+
import os
152+
153+
import torch
154+
import torch.distributed as dist
155+
import torch.distributed.checkpoint as DCP
156+
import torch.multiprocessing as mp
157+
import torch.nn as nn
158+
159+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
160+
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
161+
162+
CHECKPOINT_DIR = "checkpoint"
163+
164+
165+
class ToyModel(nn.Module):
166+
def __init__(self):
167+
super(ToyModel, self).__init__()
168+
self.net1 = nn.Linear(16, 16)
169+
self.relu = nn.ReLU()
170+
self.net2 = nn.Linear(16, 8)
171+
172+
def forward(self, x):
173+
return self.net2(self.relu(self.net1(x)))
174+
175+
176+
def setup(rank, world_size):
177+
os.environ["MASTER_ADDR"] = "localhost"
178+
os.environ["MASTER_PORT"] = "12355 "
179+
180+
# initialize the process group
181+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
182+
torch.cuda.set_device(rank)
183+
184+
185+
def cleanup():
186+
dist.destroy_process_group()
187+
188+
189+
def run_fsdp_checkpoint_load_example(rank, world_size):
190+
print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
191+
setup(rank, world_size)
192+
193+
# create a model and move it to GPU with id rank
194+
model = ToyModel().to(rank)
195+
model = FSDP(model)
196+
197+
FSDP.set_state_dict_type(
198+
model,
199+
StateDictType.SHARDED_STATE_DICT,
200+
)
201+
# different from ``torch.load()``, DCP requires model state_dict prior to loading to get
202+
# the allocated storage and sharding information.
203+
state_dict = {
204+
"model": model.state_dict(),
205+
}
206+
207+
DCP.load_state_dict(
208+
state_dict=state_dict,
209+
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
210+
)
211+
model.load_state_dict(state_dict["model"])
212+
213+
cleanup()
214+
215+
216+
if __name__ == "__main__":
217+
world_size = torch.cuda.device_count()
218+
print(f"Running fsdp checkpoint example on {world_size} devices.")
219+
mp.spawn(
220+
run_fsdp_checkpoint_load_example,
221+
args=(world_size,),
222+
nprocs=world_size,
223+
join=True,
224+
)
225+
226+
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP.
227+
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP.
228+
229+
.. note::
230+
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
231+
232+
.. code-block:: python
233+
import os
234+
235+
import torch
236+
import torch.distributed.checkpoint as DCP
237+
import torch.nn as nn
238+
239+
240+
CHECKPOINT_DIR = "checkpoint"
241+
242+
243+
class ToyModel(nn.Module):
244+
def __init__(self):
245+
super(ToyModel, self).__init__()
246+
self.net1 = nn.Linear(16, 16)
247+
self.relu = nn.ReLU()
248+
self.net2 = nn.Linear(16, 8)
249+
250+
def forward(self, x):
251+
return self.net2(self.relu(self.net1(x)))
252+
253+
254+
def run_checkpoint_load_example():
255+
# create the non FSDP-wrapped toy model
256+
model = ToyModel()
257+
state_dict = {
258+
"model": model.state_dict(),
259+
}
260+
261+
# turn no_dist to be true to load in non-distributed setting
262+
DCP.load_state_dict(
263+
state_dict=state_dict,
264+
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
265+
no_dist=True,
266+
)
267+
model.load_state_dict(state_dict["model"])
268+
269+
if __name__ == "__main__":
270+
print(f"Running basic DCP checkpoint loading example.")
271+
run_checkpoint_load_example()
272+
273+
274+
Conclusion
275+
----------
276+
In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`.
277+
278+
For more information, please see the following:
279+
280+
- `Saving and loading models tutorial <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__
281+
- `Getting started with FullyShardedDataParallel tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__

recipes_source/recipes_index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
310310
:link: ../recipes/distributed_optim_torchscript.html
311311
:tags: Distributed-Training,TorchScript
312312

313+
.. customcarditem::
314+
:header: Getting Started with Distributed Checkpoint (DCP)
315+
:card_description: Learn how to checkpoint distributed models with Distributed Checkpoint package.
316+
:image: ../_static/img/thumbnails/cropped/Getting-Started-with-DCP.png
317+
:link: ../recipes/DCP_tutorial.html
318+
:tags: Distributed-Training
319+
313320
.. End of tutorial card section
314321
315322
.. raw:: html

0 commit comments

Comments
 (0)