1- import os
2- from functools import wraps
3-
41import random
2+
53import torch
64import torch .distributed as dist
75import torch .distributed .autograd as dist_autograd
86import torch .distributed .rpc as rpc
9- from torch .distributed .rpc import TensorPipeRpcBackendOptions
107import torch .multiprocessing as mp
118import torch .optim as optim
9+ from torch .distributed .nn import RemoteModule
1210from torch .distributed .optim import DistributedOptimizer
1311from torch .distributed .rpc import RRef
12+ from torch .distributed .rpc import TensorPipeRpcBackendOptions
1413from torch .nn .parallel import DistributedDataParallel as DDP
1514
1615NUM_EMBEDDINGS = 100
1918# BEGIN hybrid_model
2019class HybridModel (torch .nn .Module ):
2120 r"""
22- The model consists of a sparse part and a dense part. The dense part is an
23- nn.Linear module that is replicated across all trainers using
24- DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
25- stored on the parameter server.
26-
27- The model holds a Remote Reference to the embedding table on the parameter
28- server.
21+ The model consists of a sparse part and a dense part.
22+ 1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
23+ 2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
24+ This remote model can get a Remote Reference to the embedding table on the parameter server.
2925 """
3026
31- def __init__ (self , emb_rref , device ):
27+ def __init__ (self , remote_emb_module , device ):
3228 super (HybridModel , self ).__init__ ()
33- self .emb_rref = emb_rref
29+ self .remote_emb_module = remote_emb_module
3430 self .fc = DDP (torch .nn .Linear (16 , 8 ).cuda (device ), device_ids = [device ])
3531 self .device = device
3632
3733 def forward (self , indices , offsets ):
38- emb_lookup = self .emb_rref . rpc_sync () .forward (indices , offsets )
34+ emb_lookup = self .remote_emb_module .forward (indices , offsets )
3935 return self .fc (emb_lookup .cuda (self .device ))
4036# END hybrid_model
4137
4238# BEGIN setup_trainer
43- def _retrieve_embedding_parameters (emb_rref ):
44- param_rrefs = []
45- for param in emb_rref .local_value ().parameters ():
46- param_rrefs .append (RRef (param ))
47- return param_rrefs
48-
49-
50- def _run_trainer (emb_rref , rank ):
39+ def _run_trainer (remote_emb_module , rank ):
5140 r"""
5241 Each trainer runs a forward pass which involves an embedding lookup on the
5342 parameter server and running nn.Linear locally. During the backward pass,
@@ -57,16 +46,18 @@ def _run_trainer(emb_rref, rank):
5746 """
5847
5948 # Setup the model.
60- model = HybridModel (emb_rref , rank )
49+ model = HybridModel (remote_emb_module , rank )
6150
6251 # Retrieve all model parameters as rrefs for DistributedOptimizer.
6352
6453 # Retrieve parameters for embedding table.
65- model_parameter_rrefs = rpc .rpc_sync (
66- "ps" , _retrieve_embedding_parameters , args = (emb_rref ,))
54+ model_parameter_rrefs = model .remote_emb_module .remote_parameters ()
6755
68- # model.parameters() only includes local parameters.
69- for param in model .parameters ():
56+ # model.fc.parameters() only includes local parameters.
57+ # NOTE: Cannot call model.parameters() here,
58+ # because this will call remote_emb_module.parameters(),
59+ # which supports remote_parameters() but not parameters().
60+ for param in model .fc .parameters ():
7061 model_parameter_rrefs .append (RRef (param ))
7162
7263 # Setup distributed optimizer
@@ -115,43 +106,43 @@ def get_next_batch(rank):
115106 # Not necessary to zero grads as each iteration creates a different
116107 # distributed autograd context which hosts different grads
117108 print ("Training done for epoch {}" .format (epoch ))
118- # END run_trainer
119-
109+ # END run_trainer
120110
121111# BEGIN run_worker
122112def run_worker (rank , world_size ):
123113 r"""
124114 A wrapper function that initializes RPC, calls the function, and shuts down
125115 RPC.
126116 """
127- os .environ ['MASTER_ADDR' ] = 'localhost'
128- os .environ ['MASTER_PORT' ] = '29500'
129-
130117
118+ # We need to use different port numbers in TCP init_method for init_rpc and
119+ # init_process_group to avoid port conflicts.
131120 rpc_backend_options = TensorPipeRpcBackendOptions ()
132- rpc_backend_options .init_method = ' tcp://localhost:29501'
121+ rpc_backend_options .init_method = " tcp://localhost:29501"
133122
134123 # Rank 2 is master, 3 is ps and 0 and 1 are trainers.
135124 if rank == 2 :
136125 rpc .init_rpc (
137- "master" ,
138- rank = rank ,
139- world_size = world_size ,
140- rpc_backend_options = rpc_backend_options )
141-
142- # Build the embedding table on the ps.
143- emb_rref = rpc .remote (
144- "ps" ,
145- torch .nn .EmbeddingBag ,
146- args = (NUM_EMBEDDINGS , EMBEDDING_DIM ),
147- kwargs = {"mode" : "sum" })
126+ "master" ,
127+ rank = rank ,
128+ world_size = world_size ,
129+ rpc_backend_options = rpc_backend_options ,
130+ )
131+
132+ remote_emb_module = RemoteModule (
133+ "ps" ,
134+ torch .nn .EmbeddingBag ,
135+ args = (NUM_EMBEDDINGS , EMBEDDING_DIM ),
136+ kwargs = {"mode" : "sum" },
137+ )
148138
149139 # Run the training loop on trainers.
150140 futs = []
151141 for trainer_rank in [0 , 1 ]:
152142 trainer_name = "trainer{}" .format (trainer_rank )
153143 fut = rpc .rpc_async (
154- trainer_name , _run_trainer , args = (emb_rref , rank ))
144+ trainer_name , _run_trainer , args = (remote_emb_module , rank )
145+ )
155146 futs .append (fut )
156147
157148 # Wait for all training to finish.
@@ -160,32 +151,35 @@ def run_worker(rank, world_size):
160151 elif rank <= 1 :
161152 # Initialize process group for Distributed DataParallel on trainers.
162153 dist .init_process_group (
163- backend = "gloo" , rank = rank , world_size = 2 )
154+ backend = "gloo" , rank = rank , world_size = 2 , init_method = "tcp://localhost:29500"
155+ )
164156
165157 # Initialize RPC.
166158 trainer_name = "trainer{}" .format (rank )
167159 rpc .init_rpc (
168- trainer_name ,
169- rank = rank ,
170- world_size = world_size ,
171- rpc_backend_options = rpc_backend_options )
160+ trainer_name ,
161+ rank = rank ,
162+ world_size = world_size ,
163+ rpc_backend_options = rpc_backend_options ,
164+ )
172165
173166 # Trainer just waits for RPCs from master.
174167 else :
175168 rpc .init_rpc (
176- "ps" ,
177- rank = rank ,
178- world_size = world_size ,
179- rpc_backend_options = rpc_backend_options )
169+ "ps" ,
170+ rank = rank ,
171+ world_size = world_size ,
172+ rpc_backend_options = rpc_backend_options ,
173+ )
180174 # parameter server do nothing
181175 pass
182176
183177 # block until all rpcs finish
184178 rpc .shutdown ()
185179
186180
187- if __name__ == "__main__" :
181+ if __name__ == "__main__" :
188182 # 2 trainers, 1 parameter server, 1 master.
189183 world_size = 4
190- mp .spawn (run_worker , args = (world_size , ), nprocs = world_size , join = True )
184+ mp .spawn (run_worker , args = (world_size ,), nprocs = world_size , join = True )
191185# END run_worker
0 commit comments