66
77
88This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9- on the CartPole-v0 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
9+ on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/>`__.
1010
1111**Task**
1212
3030
3131The CartPole task is designed so that the inputs to the agent are 4 real
3232values representing the environment state (position, velocity, etc.).
33- However, neural networks can solve the task purely by looking at the
34- scene, so we'll use a patch of the screen centered on the cart as an
35- input. Because of this, our results aren't directly comparable to the
36- ones from the official leaderboard - our task is much harder.
37- Unfortunately this does slow down the training, because we have to
38- render all the frames.
33+ We take these 4 inputs without any scaling and pass them through a
34+ small fully-connected network with 2 outputs, one for each action.
35+ The network is trained to predict the expected value for each action,
36+ given the input state. The action with the highest expected value is
37+ then chosen.
3938
40- Strictly speaking, we will present the state as the difference between
41- the current screen patch and the previous one. This will allow the agent
42- to take the velocity of the pole into account from one image.
4339
4440**Packages**
4541
4642
4743First, let's import needed packages. Firstly, we need
4844`gym <https://github.com/openai/gym>`__ for the environment
45+ Install by using `pip`. If you are running this in Google colab, run:
4946
5047.. code-block:: bash
5148
5754- neural networks (``torch.nn``)
5855- optimization (``torch.optim``)
5956- automatic differentiation (``torch.autograd``)
60- - utilities for vision tasks (``torchvision`` - `a separate
61- package <https://github.com/pytorch/vision>`__).
6257
6358"""
6459
7065import matplotlib .pyplot as plt
7166from collections import namedtuple , deque
7267from itertools import count
73- from PIL import Image
7468
7569import torch
7670import torch .nn as nn
7771import torch .optim as optim
7872import torch .nn .functional as F
79- import torchvision .transforms as T
8073
81-
82- if gym .__version__ < '0.26' :
83- env = gym .make ('CartPole-v0' , new_step_api = True , render_mode = 'single_rgb_array' ).unwrapped
74+ if gym .__version__ [:4 ] == '0.26' :
75+ env = gym .make ('CartPole-v1' )
76+ elif gym .__version__ [:4 ] == '0.25' :
77+ env = gym .make ('CartPole-v1' , new_step_api = True )
8478else :
85- env = gym . make ( 'CartPole-v0' , render_mode = 'rgb_array' ). unwrapped
79+ raise ImportError ( f"Requires gym v25 or v26, actual version: { gym . __version__ } " )
8680
8781# set up matplotlib
8882is_ipython = 'inline' in matplotlib .get_backend ()
@@ -152,9 +146,11 @@ def __len__(self):
152146# :math:`R_{t_0} = \sum_{t=t_0}^{\infty} \gamma^{t - t_0} r_t`, where
153147# :math:`R_{t_0}` is also known as the *return*. The discount,
154148# :math:`\gamma`, should be a constant between :math:`0` and :math:`1`
155- # that ensures the sum converges. It makes rewards from the uncertain far
156- # future less important for our agent than the ones in the near future
157- # that it can be fairly confident about.
149+ # that ensures the sum converges. A lower :math:`\gamma` makes
150+ # rewards from the uncertain far future less important for our agent
151+ # than the ones in the near future that it can be fairly confident
152+ # about. It also encourages agents to collect reward closer in time
153+ # than equivalent rewards temporally future away.
158154#
159155# The main idea behind Q-learning is that if we had a function
160156# :math:`Q^*: State \times Action \rightarrow \mathbb{R}`, that could tell
@@ -177,7 +173,7 @@ def __len__(self):
177173# The difference between the two sides of the equality is known as the
178174# temporal difference error, :math:`\delta`:
179175#
180- # .. math:: \delta = Q(s, a) - (r + \gamma \max_a Q(s', a))
176+ # .. math:: \delta = Q(s, a) - (r + \gamma \max_a' Q(s', a))
181177#
182178# To minimise this error, we will use the `Huber
183179# loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts
@@ -211,86 +207,18 @@ def __len__(self):
211207
212208class DQN (nn .Module ):
213209
214- def __init__ (self , h , w , outputs ):
210+ def __init__ (self , n_observations , n_actions ):
215211 super (DQN , self ).__init__ ()
216- self .conv1 = nn .Conv2d (3 , 16 , kernel_size = 5 , stride = 2 )
217- self .bn1 = nn .BatchNorm2d (16 )
218- self .conv2 = nn .Conv2d (16 , 32 , kernel_size = 5 , stride = 2 )
219- self .bn2 = nn .BatchNorm2d (32 )
220- self .conv3 = nn .Conv2d (32 , 32 , kernel_size = 5 , stride = 2 )
221- self .bn3 = nn .BatchNorm2d (32 )
222-
223- # Number of Linear input connections depends on output of conv2d layers
224- # and therefore the input image size, so compute it.
225- def conv2d_size_out (size , kernel_size = 5 , stride = 2 ):
226- return (size - (kernel_size - 1 ) - 1 ) // stride + 1
227- convw = conv2d_size_out (conv2d_size_out (conv2d_size_out (w )))
228- convh = conv2d_size_out (conv2d_size_out (conv2d_size_out (h )))
229- linear_input_size = convw * convh * 32
230- self .head = nn .Linear (linear_input_size , outputs )
212+ self .layer1 = nn .Linear (n_observations , 128 )
213+ self .layer2 = nn .Linear (128 , 128 )
214+ self .layer3 = nn .Linear (128 , n_actions )
231215
232216 # Called with either one element to determine next action, or a batch
233217 # during optimization. Returns tensor([[left0exp,right0exp]...]).
234218 def forward (self , x ):
235- x = x .to (device )
236- x = F .relu (self .bn1 (self .conv1 (x )))
237- x = F .relu (self .bn2 (self .conv2 (x )))
238- x = F .relu (self .bn3 (self .conv3 (x )))
239- return self .head (x .view (x .size (0 ), - 1 ))
240-
241-
242- ######################################################################
243- # Input extraction
244- # ^^^^^^^^^^^^^^^^
245- #
246- # The code below are utilities for extracting and processing rendered
247- # images from the environment. It uses the ``torchvision`` package, which
248- # makes it easy to compose image transforms. Once you run the cell it will
249- # display an example patch that it extracted.
250- #
251-
252- resize = T .Compose ([T .ToPILImage (),
253- T .Resize (40 , interpolation = Image .CUBIC ),
254- T .ToTensor ()])
255-
256-
257- def get_cart_location (screen_width ):
258- world_width = env .x_threshold * 2
259- scale = screen_width / world_width
260- return int (env .state [0 ] * scale + screen_width / 2.0 ) # MIDDLE OF CART
261-
262- def get_screen ():
263- # Returned screen requested by gym is 400x600x3, but is sometimes larger
264- # such as 800x1200x3. Transpose it into torch order (CHW).
265- screen = env .render ().transpose ((2 , 0 , 1 ))
266- # Cart is in the lower half, so strip off the top and bottom of the screen
267- _ , screen_height , screen_width = screen .shape
268- screen = screen [:, int (screen_height * 0.4 ):int (screen_height * 0.8 )]
269- view_width = int (screen_width * 0.6 )
270- cart_location = get_cart_location (screen_width )
271- if cart_location < view_width // 2 :
272- slice_range = slice (view_width )
273- elif cart_location > (screen_width - view_width // 2 ):
274- slice_range = slice (- view_width , None )
275- else :
276- slice_range = slice (cart_location - view_width // 2 ,
277- cart_location + view_width // 2 )
278- # Strip off the edges, so that we have a square image centered on a cart
279- screen = screen [:, :, slice_range ]
280- # Convert to float, rescale, convert to torch tensor
281- # (this doesn't require a copy)
282- screen = np .ascontiguousarray (screen , dtype = np .float32 ) / 255
283- screen = torch .from_numpy (screen )
284- # Resize, and add a batch dimension (BCHW)
285- return resize (screen ).unsqueeze (0 )
286-
287-
288- env .reset ()
289- plt .figure ()
290- plt .imshow (get_screen ().cpu ().squeeze (0 ).permute (1 , 2 , 0 ).numpy (),
291- interpolation = 'none' )
292- plt .title ('Example extracted screen' )
293- plt .show ()
219+ x = F .relu (self .layer1 (x ))
220+ x = F .relu (self .layer2 (x ))
221+ return self .layer3 (x )
294222
295223
296224######################################################################
@@ -315,28 +243,35 @@ def get_screen():
315243# episode.
316244#
317245
246+ # BATCH_SIZE is the number of transitions sampled from the replay buffer
247+ # GAMMA is the discount factor as mentioned in the previous section
248+ # EPS_START is the starting value of epsilon
249+ # EPS_END is the final value of epsilon
250+ # EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
251+ # TAU is the update rate of the target network
252+ # LR is the learning rate of the AdamW optimizer
318253BATCH_SIZE = 128
319- GAMMA = 0.999
254+ GAMMA = 0.99
320255EPS_START = 0.9
321256EPS_END = 0.05
322- EPS_DECAY = 200
323- TARGET_UPDATE = 10
324-
325- # Get screen size so that we can initialize layers correctly based on shape
326- # returned from AI gym. Typical dimensions at this point are close to 3x40x90
327- # which is the result of a clamped and down-scaled render buffer in get_screen()
328- init_screen = get_screen ()
329- _ , _ , screen_height , screen_width = init_screen .shape
257+ EPS_DECAY = 1000
258+ TAU = 0.005
259+ LR = 1e-4
330260
331261# Get number of actions from gym action space
332262n_actions = env .action_space .n
333-
334- policy_net = DQN (screen_height , screen_width , n_actions ).to (device )
335- target_net = DQN (screen_height , screen_width , n_actions ).to (device )
263+ # Get the number of state observations
264+ if gym .__version__ [:4 ] == '0.26' :
265+ state , _ = env .reset ()
266+ elif gym .__version__ [:4 ] == '0.25' :
267+ state , _ = env .reset (return_info = True )
268+ n_observations = len (state )
269+
270+ policy_net = DQN (n_observations , n_actions ).to (device )
271+ target_net = DQN (n_observations , n_actions ).to (device )
336272target_net .load_state_dict (policy_net .state_dict ())
337- target_net .eval ()
338273
339- optimizer = optim .RMSprop (policy_net .parameters ())
274+ optimizer = optim .AdamW (policy_net .parameters (), lr = LR , amsgrad = True )
340275memory = ReplayMemory (10000 )
341276
342277
@@ -356,14 +291,14 @@ def select_action(state):
356291 # found, so we pick action with the larger expected reward.
357292 return policy_net (state ).max (1 )[1 ].view (1 , 1 )
358293 else :
359- return torch .tensor ([[random . randrange ( n_actions )]], device = device , dtype = torch .long )
294+ return torch .tensor ([[env . action_space . sample ( )]], device = device , dtype = torch .long )
360295
361296
362297episode_durations = []
363298
364299
365300def plot_durations ():
366- plt .figure (2 )
301+ plt .figure (1 )
367302 plt .clf ()
368303 durations_t = torch .tensor (episode_durations , dtype = torch .float )
369304 plt .title ('Training...' )
@@ -394,10 +329,9 @@ def plot_durations():
394329# :math:`V(s_{t+1}) = \max_a Q(s_{t+1}, a)`, and combines them into our
395330# loss. By definition we set :math:`V(s) = 0` if :math:`s` is a terminal
396331# state. We also use a target network to compute :math:`V(s_{t+1})` for
397- # added stability. The target network has its weights kept frozen most of
398- # the time, but is updated with the policy network's weights every so often.
399- # This is usually a set number of steps but we shall use episodes for
400- # simplicity.
332+ # added stability. The target network is updated at every step with a
333+ # `soft update <https://arxiv.org/pdf/1509.02971.pdf>`__ controlled by
334+ # the hyperparameter ``TAU``, which was previously defined.
401335#
402336
403337def optimize_model ():
@@ -430,7 +364,8 @@ def optimize_model():
430364 # This is merged based on the mask, such that we'll have either the expected
431365 # state value or 0 in case the state was final.
432366 next_state_values = torch .zeros (BATCH_SIZE , device = device )
433- next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ].detach ()
367+ with torch .no_grad ():
368+ next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[0 ]
434369 # Compute the expected Q values
435370 expected_state_action_values = (next_state_values * GAMMA ) + reward_batch
436371
@@ -441,44 +376,49 @@ def optimize_model():
441376 # Optimize the model
442377 optimizer .zero_grad ()
443378 loss .backward ()
444- for param in policy_net . parameters ():
445- param . grad . data . clamp_ ( - 1 , 1 )
379+ # In-place gradient clipping
380+ torch . nn . utils . clip_grad_value_ ( policy_net . parameters (), 100 )
446381 optimizer .step ()
447382
448383
449384######################################################################
450385#
451386# Below, you can find the main training loop. At the beginning we reset
452- # the environment and initialize the ``state`` Tensor. Then, we sample
453- # an action, execute it, observe the next screen and the reward (always
387+ # the environment and obtain the initial ``state`` Tensor. Then, we sample
388+ # an action, execute it, observe the next state and the reward (always
454389# 1), and optimize our model once. When the episode ends (our model
455390# fails), we restart the loop.
456391#
457- # Below, `num_episodes` is set small. You should download
458- # the notebook and run lot more epsiodes, such as 300+ for meaningful
459- # duration improvements.
392+ # Below, `num_episodes` is set to 600 if a GPU is available, otherwise 50
393+ # episodes are scheduled so training does not take too long. However, 50
394+ # episodes is insufficient for to observe good performance on cartpole.
395+ # You should see the model constantly achieve 500 steps within 600 training
396+ # episodes. Training RL agents can be a noisy process, so restarting training
397+ # can produce better results if convergence is not observed.
460398#
461399
462- num_episodes = 50
400+ if torch .cuda .is_available ():
401+ num_episodes = 600
402+ else :
403+ num_episodes = 50
404+
463405for i_episode in range (num_episodes ):
464- # Initialize the environment and state
465- env .reset ()
466- last_screen = get_screen ()
467- current_screen = get_screen ()
468- state = current_screen - last_screen
406+ # Initialize the environment and get it's state
407+ if gym .__version__ [:4 ] == '0.26' :
408+ state , _ = env .reset ()
409+ elif gym .__version__ [:4 ] == '0.25' :
410+ state , _ = env .reset (return_info = True )
411+ state = torch .tensor (state , dtype = torch .float32 , device = device ).unsqueeze (0 )
469412 for t in count ():
470- # Select and perform an action
471413 action = select_action (state )
472- _ , reward , done , _ , _ = env .step (action .item ())
414+ observation , reward , terminated , truncated , _ = env .step (action .item ())
473415 reward = torch .tensor ([reward ], device = device )
416+ done = terminated or truncated
474417
475- # Observe new state
476- last_screen = current_screen
477- current_screen = get_screen ()
478- if not done :
479- next_state = current_screen - last_screen
480- else :
418+ if terminated :
481419 next_state = None
420+ else :
421+ next_state = torch .tensor (observation , dtype = torch .float32 , device = device ).unsqueeze (0 )
482422
483423 # Store the transition in memory
484424 memory .push (state , action , next_state , reward )
@@ -488,18 +428,21 @@ def optimize_model():
488428
489429 # Perform one step of the optimization (on the policy network)
490430 optimize_model ()
431+
432+ # Soft update of the target network's weights
433+ # θ′ ← τ θ + (1 −τ )θ′
434+ target_net_state_dict = target_net .state_dict ()
435+ policy_net_state_dict = policy_net .state_dict ()
436+ for key in policy_net_state_dict :
437+ target_net_state_dict [key ] = policy_net_state_dict [key ]* TAU + target_net_state_dict [key ]* (1 - TAU )
438+ target_net .load_state_dict (target_net_state_dict )
439+
491440 if done :
492441 episode_durations .append (t + 1 )
493442 plot_durations ()
494443 break
495444
496- # Update the target network, copying all weights and biases in DQN
497- if t % TARGET_UPDATE == 0 :
498- target_net .load_state_dict (policy_net .state_dict ())
499-
500445print ('Complete' )
501- env .render ()
502- env .close ()
503446plt .ioff ()
504447plt .show ()
505448
@@ -512,6 +455,6 @@ def optimize_model():
512455# step sample from the gym environment. We record the results in the
513456# replay memory and also run optimization step on every iteration.
514457# Optimization picks a random batch from the replay memory to do training of the
515- # new policy. "Older " target_net is also used in optimization to compute the
516- # expected Q values; it is updated occasionally to keep it current .
458+ # new policy. The "older " target_net is also used in optimization to compute the
459+ # expected Q values. A soft update of its weights are performed at every step .
517460#
0 commit comments