Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f2aee23

Browse files
Yi Wangholly1238
andauthored
Replace RRef with RemoteModule (#1513)
* Replace RRef with RemoteModule Replace `emb_rref` with `remote_emb_module`. Copied from: https://github.com/pytorch/examples/blob/fba20d59beabc144666123db10c907f88e4744b8/distributed/rpc/ddp_rpc/main.py * Update main.py Add some marks needed by https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html For example: ``` .. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py :language: py :start-after: BEGIN hybrid_model :end-before: END hybrid_model ``` This requires the comments of `BEGIN hybrid_model` and `END hybrid_model` Co-authored-by: Holly Sweeney <[email protected]>
1 parent 6efe1c5 commit f2aee23

1 file changed

Lines changed: 50 additions & 56 deletions

File tree

  • advanced_source/rpc_ddp_tutorial
Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
import os
2-
from functools import wraps
3-
41
import random
2+
53
import torch
64
import torch.distributed as dist
75
import torch.distributed.autograd as dist_autograd
86
import torch.distributed.rpc as rpc
9-
from torch.distributed.rpc import TensorPipeRpcBackendOptions
107
import torch.multiprocessing as mp
118
import torch.optim as optim
9+
from torch.distributed.nn import RemoteModule
1210
from torch.distributed.optim import DistributedOptimizer
1311
from torch.distributed.rpc import RRef
12+
from torch.distributed.rpc import TensorPipeRpcBackendOptions
1413
from torch.nn.parallel import DistributedDataParallel as DDP
1514

1615
NUM_EMBEDDINGS = 100
@@ -19,35 +18,25 @@
1918
# BEGIN hybrid_model
2019
class HybridModel(torch.nn.Module):
2120
r"""
22-
The model consists of a sparse part and a dense part. The dense part is an
23-
nn.Linear module that is replicated across all trainers using
24-
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
25-
stored on the parameter server.
26-
27-
The model holds a Remote Reference to the embedding table on the parameter
28-
server.
21+
The model consists of a sparse part and a dense part.
22+
1) The dense part is an nn.Linear module that is replicated across all trainers using DistributedDataParallel.
23+
2) The sparse part is a Remote Module that holds an nn.EmbeddingBag on the parameter server.
24+
This remote model can get a Remote Reference to the embedding table on the parameter server.
2925
"""
3026

31-
def __init__(self, emb_rref, device):
27+
def __init__(self, remote_emb_module, device):
3228
super(HybridModel, self).__init__()
33-
self.emb_rref = emb_rref
29+
self.remote_emb_module = remote_emb_module
3430
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
3531
self.device = device
3632

3733
def forward(self, indices, offsets):
38-
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
34+
emb_lookup = self.remote_emb_module.forward(indices, offsets)
3935
return self.fc(emb_lookup.cuda(self.device))
4036
# END hybrid_model
4137

4238
# BEGIN setup_trainer
43-
def _retrieve_embedding_parameters(emb_rref):
44-
param_rrefs = []
45-
for param in emb_rref.local_value().parameters():
46-
param_rrefs.append(RRef(param))
47-
return param_rrefs
48-
49-
50-
def _run_trainer(emb_rref, rank):
39+
def _run_trainer(remote_emb_module, rank):
5140
r"""
5241
Each trainer runs a forward pass which involves an embedding lookup on the
5342
parameter server and running nn.Linear locally. During the backward pass,
@@ -57,16 +46,18 @@ def _run_trainer(emb_rref, rank):
5746
"""
5847

5948
# Setup the model.
60-
model = HybridModel(emb_rref, rank)
49+
model = HybridModel(remote_emb_module, rank)
6150

6251
# Retrieve all model parameters as rrefs for DistributedOptimizer.
6352

6453
# Retrieve parameters for embedding table.
65-
model_parameter_rrefs = rpc.rpc_sync(
66-
"ps", _retrieve_embedding_parameters, args=(emb_rref,))
54+
model_parameter_rrefs = model.remote_emb_module.remote_parameters()
6755

68-
# model.parameters() only includes local parameters.
69-
for param in model.parameters():
56+
# model.fc.parameters() only includes local parameters.
57+
# NOTE: Cannot call model.parameters() here,
58+
# because this will call remote_emb_module.parameters(),
59+
# which supports remote_parameters() but not parameters().
60+
for param in model.fc.parameters():
7061
model_parameter_rrefs.append(RRef(param))
7162

7263
# Setup distributed optimizer
@@ -115,43 +106,43 @@ def get_next_batch(rank):
115106
# Not necessary to zero grads as each iteration creates a different
116107
# distributed autograd context which hosts different grads
117108
print("Training done for epoch {}".format(epoch))
118-
# END run_trainer
119-
109+
# END run_trainer
120110

121111
# BEGIN run_worker
122112
def run_worker(rank, world_size):
123113
r"""
124114
A wrapper function that initializes RPC, calls the function, and shuts down
125115
RPC.
126116
"""
127-
os.environ['MASTER_ADDR'] = 'localhost'
128-
os.environ['MASTER_PORT'] = '29500'
129-
130117

118+
# We need to use different port numbers in TCP init_method for init_rpc and
119+
# init_process_group to avoid port conflicts.
131120
rpc_backend_options = TensorPipeRpcBackendOptions()
132-
rpc_backend_options.init_method='tcp://localhost:29501'
121+
rpc_backend_options.init_method = "tcp://localhost:29501"
133122

134123
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
135124
if rank == 2:
136125
rpc.init_rpc(
137-
"master",
138-
rank=rank,
139-
world_size=world_size,
140-
rpc_backend_options=rpc_backend_options)
141-
142-
# Build the embedding table on the ps.
143-
emb_rref = rpc.remote(
144-
"ps",
145-
torch.nn.EmbeddingBag,
146-
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
147-
kwargs={"mode": "sum"})
126+
"master",
127+
rank=rank,
128+
world_size=world_size,
129+
rpc_backend_options=rpc_backend_options,
130+
)
131+
132+
remote_emb_module = RemoteModule(
133+
"ps",
134+
torch.nn.EmbeddingBag,
135+
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
136+
kwargs={"mode": "sum"},
137+
)
148138

149139
# Run the training loop on trainers.
150140
futs = []
151141
for trainer_rank in [0, 1]:
152142
trainer_name = "trainer{}".format(trainer_rank)
153143
fut = rpc.rpc_async(
154-
trainer_name, _run_trainer, args=(emb_rref, rank))
144+
trainer_name, _run_trainer, args=(remote_emb_module, rank)
145+
)
155146
futs.append(fut)
156147

157148
# Wait for all training to finish.
@@ -160,32 +151,35 @@ def run_worker(rank, world_size):
160151
elif rank <= 1:
161152
# Initialize process group for Distributed DataParallel on trainers.
162153
dist.init_process_group(
163-
backend="gloo", rank=rank, world_size=2)
154+
backend="gloo", rank=rank, world_size=2, init_method="tcp://localhost:29500"
155+
)
164156

165157
# Initialize RPC.
166158
trainer_name = "trainer{}".format(rank)
167159
rpc.init_rpc(
168-
trainer_name,
169-
rank=rank,
170-
world_size=world_size,
171-
rpc_backend_options=rpc_backend_options)
160+
trainer_name,
161+
rank=rank,
162+
world_size=world_size,
163+
rpc_backend_options=rpc_backend_options,
164+
)
172165

173166
# Trainer just waits for RPCs from master.
174167
else:
175168
rpc.init_rpc(
176-
"ps",
177-
rank=rank,
178-
world_size=world_size,
179-
rpc_backend_options=rpc_backend_options)
169+
"ps",
170+
rank=rank,
171+
world_size=world_size,
172+
rpc_backend_options=rpc_backend_options,
173+
)
180174
# parameter server do nothing
181175
pass
182176

183177
# block until all rpcs finish
184178
rpc.shutdown()
185179

186180

187-
if __name__=="__main__":
181+
if __name__ == "__main__":
188182
# 2 trainers, 1 parameter server, 1 master.
189183
world_size = 4
190-
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
184+
mp.spawn(run_worker, args=(world_size,), nprocs=world_size, join=True)
191185
# END run_worker

0 commit comments

Comments
 (0)