|
| 1 | +Getting Started with Distributed Checkpoint (DCP) |
| 2 | +===================================================== |
| 3 | + |
| 4 | +**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__ |
| 5 | + |
| 6 | +.. note:: |
| 7 | + |edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__. |
| 8 | + |
| 9 | + |
| 10 | +Prerequisites: |
| 11 | + |
| 12 | +- `FullyShardedDataParallel API documents <https://pytorch.org/docs/master/fsdp.html>`__ |
| 13 | +- `torch.load API documents <https://pytorch.org/docs/stable/generated/torch.load.html>`__ |
| 14 | + |
| 15 | + |
| 16 | +Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training. |
| 17 | +Pytorch Distributed Checkpointing (DCP) can help make this process easier. |
| 18 | + |
| 19 | +In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model. |
| 20 | + |
| 21 | + |
| 22 | +How DCP works |
| 23 | +-------------- |
| 24 | + |
| 25 | +:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. |
| 26 | +In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies. |
| 27 | + |
| 28 | +DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways: |
| 29 | + |
| 30 | +* It produces multiple files per checkpoint, with at least one per rank. |
| 31 | +* It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. |
| 32 | + |
| 33 | +.. note:: |
| 34 | + The code in this tutorial runs on an 8-GPU server, but it can be easily |
| 35 | + generalized to other environments. |
| 36 | + |
| 37 | +How to use DCP |
| 38 | +-------------- |
| 39 | + |
| 40 | +Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing. |
| 41 | + |
| 42 | +Saving |
| 43 | +~~~~~~ |
| 44 | + |
| 45 | +Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. |
| 46 | + |
| 47 | +.. code-block:: python |
| 48 | +
|
| 49 | + import os |
| 50 | +
|
| 51 | + import torch |
| 52 | + import torch.distributed as dist |
| 53 | + import torch.distributed.checkpoint as DCP |
| 54 | + import torch.multiprocessing as mp |
| 55 | + import torch.nn as nn |
| 56 | +
|
| 57 | + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 58 | + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType |
| 59 | +
|
| 60 | + CHECKPOINT_DIR = "checkpoint" |
| 61 | +
|
| 62 | +
|
| 63 | + class ToyModel(nn.Module): |
| 64 | + def __init__(self): |
| 65 | + super(ToyModel, self).__init__() |
| 66 | + self.net1 = nn.Linear(16, 16) |
| 67 | + self.relu = nn.ReLU() |
| 68 | + self.net2 = nn.Linear(16, 8) |
| 69 | +
|
| 70 | + def forward(self, x): |
| 71 | + return self.net2(self.relu(self.net1(x))) |
| 72 | +
|
| 73 | +
|
| 74 | + def setup(rank, world_size): |
| 75 | + os.environ["MASTER_ADDR"] = "localhost" |
| 76 | + os.environ["MASTER_PORT"] = "12355 " |
| 77 | +
|
| 78 | + # initialize the process group |
| 79 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 80 | + torch.cuda.set_device(rank) |
| 81 | +
|
| 82 | +
|
| 83 | + def cleanup(): |
| 84 | + dist.destroy_process_group() |
| 85 | +
|
| 86 | +
|
| 87 | + def run_fsdp_checkpoint_save_example(rank, world_size): |
| 88 | + print(f"Running basic FSDP checkpoint saving example on rank {rank}.") |
| 89 | + setup(rank, world_size) |
| 90 | +
|
| 91 | + # create a model and move it to GPU with id rank |
| 92 | + model = ToyModel().to(rank) |
| 93 | + model = FSDP(model) |
| 94 | +
|
| 95 | + loss_fn = nn.MSELoss() |
| 96 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) |
| 97 | +
|
| 98 | + optimizer.zero_grad() |
| 99 | + model(torch.rand(8, 16, device="cuda")).sum().backward() |
| 100 | + optimizer.step() |
| 101 | +
|
| 102 | + # set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict |
| 103 | + # note that we do not support FSDP StateDictType.LOCAL_STATE_DICT |
| 104 | + FSDP.set_state_dict_type( |
| 105 | + model, |
| 106 | + StateDictType.SHARDED_STATE_DICT, |
| 107 | + ) |
| 108 | + state_dict = { |
| 109 | + "model": model.state_dict(), |
| 110 | + } |
| 111 | +
|
| 112 | + DCP.save_state_dict( |
| 113 | + state_dict=state_dict, |
| 114 | + storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR), |
| 115 | + ) |
| 116 | +
|
| 117 | + cleanup() |
| 118 | +
|
| 119 | +
|
| 120 | + if __name__ == "__main__": |
| 121 | + world_size = torch.cuda.device_count() |
| 122 | + print(f"Running fsdp checkpoint example on {world_size} devices.") |
| 123 | + mp.spawn( |
| 124 | + run_fsdp_checkpoint_save_example, |
| 125 | + args=(world_size,), |
| 126 | + nprocs=world_size, |
| 127 | + join=True, |
| 128 | + ) |
| 129 | +
|
| 130 | +Please go ahead and check the `checkpoint` directory. You should see 8 checkpoint files as shown below. |
| 131 | + |
| 132 | +.. figure:: /_static/img/distributed/distributed_checkpoint_generated_files.png |
| 133 | + :width: 100% |
| 134 | + :align: center |
| 135 | + :alt: Distributed Checkpoint |
| 136 | + |
| 137 | +Loading |
| 138 | +~~~~~~~ |
| 139 | + |
| 140 | +After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size. |
| 141 | + |
| 142 | +Please note that you will have to call :func:`model.state_dict` prior to loading and pass it to DCP's :func:`load_state_dict` API. |
| 143 | +This is fundamentally different from :func:`torch.load`, as :func:`torch.load` simply requires the path to the checkpoint prior for loading. |
| 144 | +The reason that we need the ``state_dict`` prior to loading is: |
| 145 | + |
| 146 | +* DCP uses the pre-allocated storage from model state_dict to load from the checkpoint directory. During loading, the state_dict passed in will be updated in place. |
| 147 | +* DCP requires the sharding information from the model prior to loading to support resharding. |
| 148 | + |
| 149 | +.. code-block:: python |
| 150 | +
|
| 151 | + import os |
| 152 | +
|
| 153 | + import torch |
| 154 | + import torch.distributed as dist |
| 155 | + import torch.distributed.checkpoint as DCP |
| 156 | + import torch.multiprocessing as mp |
| 157 | + import torch.nn as nn |
| 158 | +
|
| 159 | + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| 160 | + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType |
| 161 | +
|
| 162 | + CHECKPOINT_DIR = "checkpoint" |
| 163 | +
|
| 164 | +
|
| 165 | + class ToyModel(nn.Module): |
| 166 | + def __init__(self): |
| 167 | + super(ToyModel, self).__init__() |
| 168 | + self.net1 = nn.Linear(16, 16) |
| 169 | + self.relu = nn.ReLU() |
| 170 | + self.net2 = nn.Linear(16, 8) |
| 171 | +
|
| 172 | + def forward(self, x): |
| 173 | + return self.net2(self.relu(self.net1(x))) |
| 174 | +
|
| 175 | +
|
| 176 | + def setup(rank, world_size): |
| 177 | + os.environ["MASTER_ADDR"] = "localhost" |
| 178 | + os.environ["MASTER_PORT"] = "12355 " |
| 179 | +
|
| 180 | + # initialize the process group |
| 181 | + dist.init_process_group("nccl", rank=rank, world_size=world_size) |
| 182 | + torch.cuda.set_device(rank) |
| 183 | +
|
| 184 | +
|
| 185 | + def cleanup(): |
| 186 | + dist.destroy_process_group() |
| 187 | +
|
| 188 | +
|
| 189 | + def run_fsdp_checkpoint_load_example(rank, world_size): |
| 190 | + print(f"Running basic FSDP checkpoint loading example on rank {rank}.") |
| 191 | + setup(rank, world_size) |
| 192 | +
|
| 193 | + # create a model and move it to GPU with id rank |
| 194 | + model = ToyModel().to(rank) |
| 195 | + model = FSDP(model) |
| 196 | +
|
| 197 | + FSDP.set_state_dict_type( |
| 198 | + model, |
| 199 | + StateDictType.SHARDED_STATE_DICT, |
| 200 | + ) |
| 201 | + # different from ``torch.load()``, DCP requires model state_dict prior to loading to get |
| 202 | + # the allocated storage and sharding information. |
| 203 | + state_dict = { |
| 204 | + "model": model.state_dict(), |
| 205 | + } |
| 206 | +
|
| 207 | + DCP.load_state_dict( |
| 208 | + state_dict=state_dict, |
| 209 | + storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR), |
| 210 | + ) |
| 211 | + model.load_state_dict(state_dict["model"]) |
| 212 | +
|
| 213 | + cleanup() |
| 214 | +
|
| 215 | +
|
| 216 | + if __name__ == "__main__": |
| 217 | + world_size = torch.cuda.device_count() |
| 218 | + print(f"Running fsdp checkpoint example on {world_size} devices.") |
| 219 | + mp.spawn( |
| 220 | + run_fsdp_checkpoint_load_example, |
| 221 | + args=(world_size,), |
| 222 | + nprocs=world_size, |
| 223 | + join=True, |
| 224 | + ) |
| 225 | +
|
| 226 | +If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. |
| 227 | +By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP. |
| 228 | + |
| 229 | +.. note:: |
| 230 | + Distributed checkpoint support for Multi-Program Multi-Data is still under development. |
| 231 | + |
| 232 | +.. code-block:: python |
| 233 | + import os |
| 234 | +
|
| 235 | + import torch |
| 236 | + import torch.distributed.checkpoint as DCP |
| 237 | + import torch.nn as nn |
| 238 | +
|
| 239 | +
|
| 240 | + CHECKPOINT_DIR = "checkpoint" |
| 241 | +
|
| 242 | +
|
| 243 | + class ToyModel(nn.Module): |
| 244 | + def __init__(self): |
| 245 | + super(ToyModel, self).__init__() |
| 246 | + self.net1 = nn.Linear(16, 16) |
| 247 | + self.relu = nn.ReLU() |
| 248 | + self.net2 = nn.Linear(16, 8) |
| 249 | +
|
| 250 | + def forward(self, x): |
| 251 | + return self.net2(self.relu(self.net1(x))) |
| 252 | +
|
| 253 | +
|
| 254 | + def run_checkpoint_load_example(): |
| 255 | + # create the non FSDP-wrapped toy model |
| 256 | + model = ToyModel() |
| 257 | + state_dict = { |
| 258 | + "model": model.state_dict(), |
| 259 | + } |
| 260 | +
|
| 261 | + # turn no_dist to be true to load in non-distributed setting |
| 262 | + DCP.load_state_dict( |
| 263 | + state_dict=state_dict, |
| 264 | + storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR), |
| 265 | + no_dist=True, |
| 266 | + ) |
| 267 | + model.load_state_dict(state_dict["model"]) |
| 268 | +
|
| 269 | + if __name__ == "__main__": |
| 270 | + print(f"Running basic DCP checkpoint loading example.") |
| 271 | + run_checkpoint_load_example() |
| 272 | +
|
| 273 | +
|
| 274 | +Conclusion |
| 275 | +---------- |
| 276 | +In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`. |
| 277 | + |
| 278 | +For more information, please see the following: |
| 279 | + |
| 280 | +- `Saving and loading models tutorial <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`__ |
| 281 | +- `Getting started with FullyShardedDataParallel tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ |
0 commit comments