@@ -442,10 +442,10 @@ local and remote parameters. The helper function is pretty simple, just call
442442.. code :: python
443443
444444 def _parameter_rrefs (module ):
445- param_rrefs = []
446- for param in module.parameters():
447- param_rrefs.append(RRef(param))
448- return param_rrefs
445+ param_rrefs = []
446+ for param in module.parameters():
447+ param_rrefs.append(RRef(param))
448+ return param_rrefs
449449
450450
451451 Then, as the ``RNNModel `` contains three sub-modules, we need to call
@@ -535,24 +535,21 @@ processes.
535535
536536.. code :: python
537537
538- def run_ps ():
539- pass
540-
541- def run_worker (name , rank , func , world_size ):
538+ def run_worker (rank , world_size ):
542539 os.environ[' MASTER_ADDR' ] = ' localhost'
543540 os.environ[' MASTER_PORT' ] = ' 29500'
544- rpc.init_rpc(name, rank = rank, world_size = world_size)
545-
546- func()
541+ if rank == 1 :
542+ rpc.init_rpc(" trainer" , rank = rank, world_size = world_size)
543+ _run_trainer()
544+ else :
545+ rpc.init_rpc(" ps" , rank = rank, world_size = world_size)
546+ # parameter server do nothing
547+ pass
547548
548- # block until all rpcs finish, and shutdown the RPC instance
549+ # block until all rpcs finish
549550 rpc.shutdown()
550551
551- mp.set_start_method(' spawn' )
552- ps = mp.Process(target = run_worker, args = (" ps" , 0 , run_ps, 2 ))
553- ps.start()
554552
555- trainer = mp.Process(target = run_worker, args = (" trainer" , 1 , run_trainer, 2 ))
556- trainer.start()
557- ps.join()
558- trainer.join()
553+ if __name__ == " __main__" :
554+ world_size = 2
555+ mp.spawn(run_worker, args = (world_size, ), nprocs = world_size, join = True )
0 commit comments