You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/rpc_param_server_tutorial.rst
+49-53Lines changed: 49 additions & 53 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -113,12 +113,12 @@ Next, we'll define our forward pass. Note that regardless of the device of the m
113
113
classParameterServer(nn.Module):
114
114
...
115
115
defforward(self, inp):
116
-
inp = inp.to(self.input_device)
117
-
out =self.model(inp)
118
-
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
119
-
# Tensors must be moved in and out of GPU memory due to this.
120
-
out = out.to("cpu")
121
-
return out
116
+
inp = inp.to(self.input_device)
117
+
out =self.model(inp)
118
+
# This output is forwarded over RPC, which as of 1.5.0 only accepts CPU tensors.
119
+
# Tensors must be moved in and out of GPU memory due to this.
120
+
out = out.to("cpu")
121
+
return out
122
122
Next, we'll define a few miscellaneous functions useful for training and verification purposes. The first, ``get_dist_gradients``\ , will take in a Distributed Autograd context ID and call into the ``dist_autograd.get_gradients`` API in order to retrieve gradients computed by distributed autograd. More information can be found in the `distributed autograd documentation <https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework>`_. Note that we also iterate through the resulting dictionary and convert each tensor to a CPU tensor, as the framework currently only supports sending tensors over RPC. Next, ``get_param_rrefs`` will iterate through our model parameters and wrap them as a (local) `RRef <https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.RRef>`_. This method will be invoked over RPC by trainer nodes and will return a list of the parameters to be optimized. This is required as input to the `Distributed Optimizer <https://pytorch.org/docs/stable/rpc.html#module-torch.distributed.optim>`_\ , which requires all parameters it must optimize as a list of ``RRef``\ s.
123
123
124
124
.. code-block:: python
@@ -151,28 +151,28 @@ Finally, we'll create methods to initialize our parameter server. Note that ther
151
151
152
152
153
153
defget_parameter_server(num_gpus=0):
154
-
"""
155
-
Returns a singleton parameter server to all trainer processes
156
-
"""
157
-
global param_server
158
-
# Ensure that we get only one handle to the ParameterServer.
159
-
with global_lock:
160
-
ifnot param_server:
161
-
# construct it once
162
-
param_server = ParameterServer(num_gpus=num_gpus)
163
-
return param_server
154
+
"""
155
+
Returns a singleton parameter server to all trainer processes
156
+
"""
157
+
global param_server
158
+
# Ensure that we get only one handle to the ParameterServer.
159
+
with global_lock:
160
+
ifnot param_server:
161
+
# construct it once
162
+
param_server = ParameterServer(num_gpus=num_gpus)
163
+
return param_server
164
164
165
165
defrun_parameter_server(rank, world_size):
166
-
# The parameter server just acts as a host for the model and responds to
167
-
# requests from trainers.
168
-
# rpc.shutdown() will wait for all workers to complete by default, which
169
-
# in this case means that the parameter server will wait for all trainers
Note that above, ``rpc.shutdown()`` will not immediately shut down the Parameter Server. Instead, it will wait for all workers (trainers in this case) to also call into ``rpc.shutdown()``. This gives us the guarantee that the parameter server will not go offline before all trainers (yet to be define) have completed their training process.
177
177
178
178
Next, we'll define our ``TrainerNet`` class. This will also be a subclass of ``nn.Module``\ , and our ``__init__`` method will use the ``rpc.remote`` API to obtain an RRef, or Remote Reference, to our parameter server. Note that here we are not copying the parameter server to our local process, instead, we can think of ``self.param_server_rref`` as a distributed shared pointer to the parameter server that lives on a separate process.
@@ -232,25 +232,25 @@ As opposed to calling the typical ``loss.backward()`` which would kick off the b
print(f"Rank {rank} training batch {i} loss {loss.item()}")
242
-
dist_autograd.backward(cid, [loss])
243
-
# Ensure that dist autograd ran successfully and gradients were
244
-
# returned.
245
-
assert remote_method(
246
-
ParameterServer.get_dist_gradients,
247
-
net.param_server_rref,
248
-
cid) != {}
249
-
opt.step(cid)
250
-
251
-
print("Training complete!")
252
-
print("Getting accuracy....")
253
-
get_accuracy(test_loader, net)
235
+
for i, (data, target) inenumerate(train_loader):
236
+
with dist_autograd.context() as cid:
237
+
model_output = net(data)
238
+
target = target.to(model_output.device)
239
+
loss = F.nll_loss(model_output, target)
240
+
if i %5==0:
241
+
print(f"Rank {rank} training batch {i} loss {loss.item()}")
242
+
dist_autograd.backward(cid, [loss])
243
+
# Ensure that dist autograd ran successfully and gradients were
244
+
# returned.
245
+
assert remote_method(
246
+
ParameterServer.get_dist_gradients,
247
+
net.param_server_rref,
248
+
cid) != {}
249
+
opt.step(cid)
250
+
251
+
print("Training complete!")
252
+
print("Getting accuracy....")
253
+
get_accuracy(test_loader, net)
254
254
The following simply computes the accuracy of our model after we're done training, much like a traditional local model. However, note that the ``net`` we pass into this function above is an instance of ``TrainerNet`` and therefore the forward pass invokes RPC in a transparent fashion.
255
255
256
256
.. code-block:: python
@@ -353,14 +353,10 @@ Now, we'll create a process corresponding to either a parameter server or traine
353
353
datasets.MNIST(
354
354
'../data',
355
355
train=False,
356
-
transform=transforms.Compose(
357
-
[
356
+
transform=transforms.Compose([
358
357
transforms.ToTensor(),
359
-
transforms.Normalize(
360
-
(0.1307,
361
-
),
362
-
(0.3081,
363
-
))])),
358
+
transforms.Normalize((0.1307,), (0.3081,))
359
+
])),
364
360
batch_size=32,
365
361
shuffle=True,
366
362
)
@@ -379,4 +375,4 @@ Now, we'll create a process corresponding to either a parameter server or traine
379
375
p.join()
380
376
To run the example locally, run the following command worker for the server and each worker you wish to spawn, in separate terminal windows: ``python rpc_parameter_server.py --world_size=WORLD_SIZE --rank=RANK``. For example, for a master node with world size of 2, the command would be ``python rpc_parameter_server.py --world_size=2 --rank=0``. The trainer can then be launched with the command ``python rpc_parameter_server.py --world_size=2 --rank=1`` in a separate window, and this will begin training with one server and a single trainer. Note that this tutorial assumes that training occurs using between 0 and 2 GPUs, and this argument can be configured by passing ``--num_gpus=N`` into the training script.
381
377
382
-
You can pass in the command line arguments ``--master_addr=<address>`` and ``master_port=PORT`` to indicate the address and port that the master worker is listening on, for example, to test functionality where trainers and master nodes run on different machines.
378
+
You can pass in the command line arguments ``--master_addr=ADDRESS`` and ``--master_port=PORT`` to indicate the address and port that the master worker is listening on, for example, to test functionality where trainers and master nodes run on different machines.
0 commit comments