2323As the agent observes the current state of the environment and chooses
2424an action, the environment *transitions* to a new state, and also
2525returns a reward that indicates the consequences of the action. In this
26- task, the environment terminates if the pole falls over too far.
26+ task, rewards are +1 for every incremental timestep and the environment
27+ terminates if the pole falls over too far or the crat mover more then 2.4
28+ units away from center. This means better performing scenarios will run
29+ for longer duration, accumulating larger return.
2730
2831The CartPole task is designed so that the inputs to the agent are 4 real
2932values representing the environment state (position, velocity, etc.).
97100# For this, we're going to need two classses:
98101#
99102# - ``Transition`` - a named tuple representing a single transition in
100- # our environment
103+ # our environment. It maps essentially maps (state, action) pairs
104+ # to their (next_state, reward) result, with the state being the
105+ # screen difference image as described later on.
101106# - ``ReplayMemory`` - a cyclic buffer of bounded size that holds the
102107# transitions observed recently. It also implements a ``.sample()``
103108# method for selecting a random batch of transitions for training.
@@ -197,22 +202,32 @@ def __len__(self):
197202# difference between the current and previous screen patches. It has two
198203# outputs, representing :math:`Q(s, \mathrm{left})` and
199204# :math:`Q(s, \mathrm{right})` (where :math:`s` is the input to the
200- # network). In effect, the network is trying to predict the *quality * of
205+ # network). In effect, the network is trying to predict the *expected return * of
201206# taking each action given the current input.
202207#
203208
204209class DQN (nn .Module ):
205210
206- def __init__ (self ):
211+ def __init__ (self , h , w ):
207212 super (DQN , self ).__init__ ()
208213 self .conv1 = nn .Conv2d (3 , 16 , kernel_size = 5 , stride = 2 )
209214 self .bn1 = nn .BatchNorm2d (16 )
210215 self .conv2 = nn .Conv2d (16 , 32 , kernel_size = 5 , stride = 2 )
211216 self .bn2 = nn .BatchNorm2d (32 )
212217 self .conv3 = nn .Conv2d (32 , 32 , kernel_size = 5 , stride = 2 )
213218 self .bn3 = nn .BatchNorm2d (32 )
214- self .head = nn .Linear (448 , 2 )
215219
220+ # Number of Linear input connections depends on output of conv2d layers
221+ # and therefore the input image size, so compute it.
222+ def conv2d_size_out (size , kernel_size = 5 , stride = 2 ):
223+ return (size - (kernel_size - 1 ) - 1 ) // stride + 1
224+ convw = conv2d_size_out (conv2d_size_out (conv2d_size_out (w )))
225+ convh = conv2d_size_out (conv2d_size_out (conv2d_size_out (h )))
226+ linear_input_size = convw * convh * 32
227+ self .head = nn .Linear (linear_input_size , 2 ) # 448 or 512
228+
229+ # Called with either one element to determine next action, or a batch
230+ # during optimization. Returns tensor([[left0exp,right0exp]...]).
216231 def forward (self , x ):
217232 x = F .relu (self .bn1 (self .conv1 (x )))
218233 x = F .relu (self .bn2 (self .conv2 (x )))
@@ -234,23 +249,20 @@ def forward(self, x):
234249 T .Resize (40 , interpolation = Image .CUBIC ),
235250 T .ToTensor ()])
236251
237- # This is based on the code from gym.
238- screen_width = 600
239-
240-
241- def get_cart_location ():
252+ def get_cart_location (screen_width ):
242253 world_width = env .x_threshold * 2
243254 scale = screen_width / world_width
244255 return int (env .state [0 ] * scale + screen_width / 2.0 ) # MIDDLE OF CART
245256
246-
247257def get_screen ():
248- screen = env .render (mode = 'rgb_array' ).transpose (
249- (2 , 0 , 1 )) # transpose into torch order (CHW)
250- # Strip off the top and bottom of the screen
251- screen = screen [:, 160 :320 ]
252- view_width = 320
253- cart_location = get_cart_location ()
258+ # Returned requested by gym is 400x600x3, but is sometimes larger such as
259+ # as 800x1200x3. Transpose into torch order (CHW).
260+ screen = env .render (mode = 'rgb_array' ).transpose ((2 , 0 , 1 ))
261+ # Cart is in the lower half, so strip off the top and bottom of the screen
262+ _ , screen_height , screen_width = screen .shape
263+ screen = screen [:, int (screen_height * 0.4 ):int (screen_height * 0.8 )]
264+ view_width = int (screen_width * 0.6 )
265+ cart_location = get_cart_location (screen_width )
254266 if cart_location < view_width // 2 :
255267 slice_range = slice (view_width )
256268 elif cart_location > (screen_width - view_width // 2 ):
@@ -298,15 +310,23 @@ def get_screen():
298310# episode.
299311#
300312
301- BATCH_SIZE = 128
313+ BATCH_SIZE = 196 # 128
302314GAMMA = 0.999
303315EPS_START = 0.9
304- EPS_END = 0.05
305- EPS_DECAY = 200
316+ EPS_END = 0.07
317+ EPS_DECAY = 300
306318TARGET_UPDATE = 10
307319
308- policy_net = DQN ().to (device )
309- target_net = DQN ().to (device )
320+ # Get screen size so that we can initialize layers correctly based on shape
321+ # returned from AI gym. Typical dimentions at this pont are close to 3x40x90
322+ # which is the result of a clamped and down-scaled buffer in get_screen()
323+ init_screen = get_screen ()
324+ _ , _ , screen_height , screen_width = init_screen .shape
325+ #screen_height = init_screen.shape[2]
326+ #print("Screen size w,h:", screen_width, " ", screen_height)
327+
328+ policy_net = DQN (screen_height , screen_width ).to (device )
329+ target_net = DQN (screen_height , screen_width ).to (device )
310330target_net .load_state_dict (policy_net .state_dict ())
311331target_net .eval ()
312332
@@ -325,6 +345,9 @@ def select_action(state):
325345 steps_done += 1
326346 if sample > eps_threshold :
327347 with torch .no_grad ():
348+ # t.max(1) will return largest value for column of each row.
349+ # second column on max result is index of where max element was
350+ # found, so we pick action with the larger expected reward.
328351 return policy_net (state ).max (1 )[1 ].view (1 , 1 )
329352 else :
330353 return torch .tensor ([[random .randrange (2 )]], device = device , dtype = torch .long )
@@ -376,10 +399,12 @@ def optimize_model():
376399 return
377400 transitions = memory .sample (BATCH_SIZE )
378401 # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
379- # detailed explanation).
402+ # detailed explanation). This converts batch-array of Transitions
403+ # to Transition of batch-arrays.
380404 batch = Transition (* zip (* transitions ))
381405
382406 # Compute a mask of non-final states and concatenate the batch elements
407+ # (a final state would've been the one after which simulation ended)
383408 non_final_mask = torch .tensor (tuple (map (lambda s : s is not None ,
384409 batch .next_state )), device = device , dtype = torch .uint8 )
385410 non_final_next_states = torch .cat ([s for s in batch .next_state
@@ -389,10 +414,15 @@ def optimize_model():
389414 reward_batch = torch .cat (batch .reward )
390415
391416 # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
392- # columns of actions taken
417+ # columns of actions taken. These are the actions which would've been taken
418+ # for each batch state according to policy_net
393419 state_action_values = policy_net (state_batch ).gather (1 , action_batch )
394420
395421 # Compute V(s_{t+1}) for all next states.
422+ # Expected values of actions for non_final_next_states are computed based
423+ # on the "older" target_net; selecting their best reward with max(1)[0].
424+ # This is merged based on the mask, such that we'll have either the expected
425+ # state value or 0 in case the state was final.
396426 next_state_values = torch .zeros (BATCH_SIZE , device = device )
397427 next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ].detach ()
398428 # Compute the expected Q values
@@ -418,10 +448,11 @@ def optimize_model():
418448# fails), we restart the loop.
419449#
420450# Below, `num_episodes` is set small. You should download
421- # the notebook and run lot more epsiodes.
451+ # the notebook and run lot more epsiodes, such as 300+ for meaningful
452+ # duration improvements.
422453#
423454
424- num_episodes = 50
455+ num_episodes = 500
425456for i_episode in range (num_episodes ):
426457 # Initialize the environment and state
427458 env .reset ()
@@ -454,7 +485,7 @@ def optimize_model():
454485 episode_durations .append (t + 1 )
455486 plot_durations ()
456487 break
457- # Update the target network
488+ # Update the target network, copying all weights and biases in DQN
458489 if i_episode % TARGET_UPDATE == 0 :
459490 target_net .load_state_dict (policy_net .state_dict ())
460491
@@ -463,3 +494,16 @@ def optimize_model():
463494env .close ()
464495plt .ioff ()
465496plt .show ()
497+
498+ ######################################################################
499+ # Here is the diagram that illustrates the overall resulting flow.
500+ #
501+ # .. figure:: /_static/img/reinforcement_learning_diagram.jpg
502+ #
503+ # Actions are chosen either randomly or based on a policy, getting the next
504+ # step sample for the gym environment. We record the results in the
505+ # replay memory and also perform optimization step on every iteration.
506+ # Optimization picks a random batch from the replay memory to do training of the
507+ # new policy. "Older" target_net, used in optimization to computed expected
508+ # Q values is updated occasionally to keep it current.
509+ #
0 commit comments