@@ -76,39 +76,13 @@ usages.
7676 self .dropout = nn.Dropout(p = 0.6 )
7777 self .affine2 = nn.Linear(128 , 2 )
7878
79- self .saved_log_probs = []
80- self .rewards = []
81-
8279 def forward (self , x ):
8380 x = self .affine1(x)
8481 x = self .dropout(x)
8582 x = F.relu(x)
8683 action_scores = self .affine2(x)
8784 return F.softmax(action_scores, dim = 1 )
8885
89- Let's first prepare a helper to run functions remotely on the owner worker of an
90- ``RRef ``. You will find this function being used in several places this
91- tutorial's examples. Ideally, the `torch.distributed.rpc ` package should provide
92- these helper functions out of box. For example, it will be easier if
93- applications can directly call ``RRef.some_func(*arg) `` which will then
94- translate to RPC to the ``RRef `` owner. The progress on this API is tracked in
95- `pytorch/pytorch#31743 <https://github.com/pytorch/pytorch/issues/31743 >`__.
96-
97- .. code :: python
98-
99- from torch.distributed.rpc import rpc_sync
100-
101- def _call_method (method , rref , * args , ** kwargs ):
102- return method(rref.local_value(), * args, ** kwargs)
103-
104-
105- def _remote_method (method , rref , * args , ** kwargs ):
106- args = [method, rref] + list (args)
107- return rpc_sync(rref.owner(), _call_method, args = args, kwargs = kwargs)
108-
109- # to call a function on an rref, we could do the following
110- # _remote_method(some_func, rref, *args)
111-
11286
11387 We are ready to present the observer. In this example, each observer creates its
11488own environment, and waits for the agent's command to run an episode. In each
@@ -134,10 +108,14 @@ simple and the two steps explicit in this example.
134108 formatter_class = argparse.ArgumentDefaultsHelpFormatter,
135109 )
136110
137- parser.add_argument(' --world_size' , default = 2 , help = ' Number of workers' )
138- parser.add_argument(' --log_interval' , default = 1 , help = ' Log every log_interval episodes' )
139- parser.add_argument(' --gamma' , default = 0.1 , help = ' how much to value future rewards' )
140- parser.add_argument(' --seed' , default = 1 , help = ' random seed for reproducibility' )
111+ parser.add_argument(' --world_size' , default = 2 , type = int , metavar = ' W' ,
112+ help = ' number of workers' )
113+ parser.add_argument(' --log_interval' , type = int , default = 10 , metavar = ' N' ,
114+ help = ' interval between training status logs' )
115+ parser.add_argument(' --gamma' , type = float , default = 0.99 , metavar = ' G' ,
116+ help = ' how much to value future rewards' )
117+ parser.add_argument(' --seed' , type = int , default = 1 , metavar = ' S' ,
118+ help = ' random seed for reproducibility' )
141119 args = parser.parse_args()
142120
143121 class Observer :
@@ -147,18 +125,19 @@ simple and the two steps explicit in this example.
147125 self .env = gym.make(' CartPole-v1' )
148126 self .env.seed(args.seed)
149127
150- def run_episode (self , agent_rref , n_steps ):
128+ def run_episode (self , agent_rref ):
151129 state, ep_reward = self .env.reset(), 0
152- for step in range (n_steps ):
130+ for _ in range (10000 ):
153131 # send the state to the agent to get an action
154- action = _remote_method(Agent .select_action, agent_rref, self .id, state)
132+ action = agent_rref.rpc_sync() .select_action( self .id, state)
155133
156134 # apply the action to the environment, and get the reward
157135 state, reward, done, _ = self .env.step(action)
158136
159137 # report the reward to the agent for training purpose
160- _remote_method(Agent .report_reward, agent_rref, self .id, reward)
138+ agent_rref.rpc_sync() .report_reward( self .id, reward)
161139
140+ # finishes after the number of self.env._max_episode_steps
162141 if done:
163142 break
164143
@@ -242,15 +221,15 @@ contain the recorded action probs and rewards.
242221
243222 class Agent :
244223 ...
245- def run_episode (self , n_steps = 0 ):
224+ def run_episode (self ):
246225 futs = []
247226 for ob_rref in self .ob_rrefs:
248227 # make async RPC to kick off an episode on all observers
249228 futs.append(
250229 rpc_async(
251230 ob_rref.owner(),
252- _call_method ,
253- args = (Observer.run_episode, ob_rref, self .agent_rref, n_steps )
231+ ob_rref.rpc_sync().run_episode ,
232+ args = (self .agent_rref,)
254233 )
255234 )
256235
@@ -324,8 +303,7 @@ available in the `API page <https://pytorch.org/docs/master/rpc.html>`__.
324303 import torch.multiprocessing as mp
325304
326305 AGENT_NAME = " agent"
327- OBSERVER_NAME = " obs"
328- TOTAL_EPISODE_STEP = 100
306+ OBSERVER_NAME = " obs{} "
329307
330308 def run_worker (rank , world_size ):
331309 os.environ[' MASTER_ADDR' ] = ' localhost'
@@ -335,17 +313,17 @@ available in the `API page <https://pytorch.org/docs/master/rpc.html>`__.
335313 rpc.init_rpc(AGENT_NAME , rank = rank, world_size = world_size)
336314
337315 agent = Agent(world_size)
316+ print (f " This will run until reward threshold of { agent.reward_threshold} "
317+ " is reached. Ctrl+C to exit." )
338318 for i_episode in count(1 ):
339- n_steps = int (TOTAL_EPISODE_STEP / (args.world_size - 1 ))
340- agent.run_episode(n_steps = n_steps)
319+ agent.run_episode()
341320 last_reward = agent.finish_episode()
342321
343322 if i_episode % args.log_interval == 0 :
344- print (' Episode {} \t Last reward: {:.2f } \t Average reward: {:.2f } ' .format(
345- i_episode, last_reward, agent.running_reward))
346-
323+ print (f " Episode { i_episode} \t Last reward: { last_reward:.2f } \t Average reward: "
324+ f " { agent.running_reward:.2f } " )
347325 if agent.running_reward > agent.reward_threshold:
348- print (" Solved! Running reward is now {} ! " .format( agent.running_reward) )
326+ print (f " Solved! Running reward is now { agent.running_reward} ! " )
349327 break
350328 else :
351329 # other ranks are the observer
@@ -367,6 +345,7 @@ Below are some sample outputs when training with `world_size=2`.
367345
368346::
369347
348+ This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
370349 Episode 10 Last reward: 26.00 Average reward: 10.01
371350 Episode 20 Last reward: 16.00 Average reward: 11.27
372351 Episode 30 Last reward: 49.00 Average reward: 18.62
0 commit comments