33Reinforcement Learning (DQN) Tutorial
44=====================================
55**Author**: `Adam Paszke <https://github.com/apaszke>`_
6+ `Mark Towers <https://github.com/pseudo-rnd-thoughts>`_
67
78
89This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
9- on the CartPole-v1 task from the `OpenAI Gym <https://www.gymlibrary.dev/ >`__.
10+ on the CartPole-v1 task from `Gymnasium <https://www.gymnasium.farama.org >`__.
1011
1112**Task**
1213
1314The agent has to decide between two actions - moving the cart left or
14- right - so that the pole attached to it stays upright. You can find an
15- official leaderboard with various algorithms and visualizations at the
16- `Gym website <https://www.gymlibrary.dev /environments/classic_control/cart_pole>`__.
15+ right - so that the pole attached to it stays upright. You can find more
16+ information about the environment and other more challenging environments at
17+ `Gymnasium's website <https://gymnasium.farama.org /environments/classic_control/cart_pole/ >`__.
1718
1819.. figure:: /_static/img/cartpole.gif
1920 :alt: cartpole
2425an action, the environment *transitions* to a new state, and also
2526returns a reward that indicates the consequences of the action. In this
2627task, rewards are +1 for every incremental timestep and the environment
27- terminates if the pole falls over too far or the cart moves more then 2.4
28+ terminates if the pole falls over too far or the cart moves more than 2.4
2829units away from center. This means better performing scenarios will run
2930for longer duration, accumulating larger return.
3031
4142
4243
4344First, let's import needed packages. Firstly, we need
44- `gym <https://github.com/openai/gym>`__ for the environment
45- Install by using `pip`. If you are running this in Google colab, run:
45+ `gymnasium <https://gymnasium.farama.org/>`__ for the environment,
46+ installed by using `pip`. This is a fork of the original OpenAI
47+ Gym project and maintained by the same team since Gym v0.19.
48+ If you are running this in Google colab, run:
4649
4750.. code-block:: bash
4851
4952 %%bash
50- pip3 install gym [classic_control]
53+ pip3 install gymnasium [classic_control]
5154
5255We'll also use the following from PyTorch:
5356
5760
5861"""
5962
60- import gym
63+ import gymnasium as gym
6164import math
6265import random
63- import numpy as np
6466import matplotlib
6567import matplotlib .pyplot as plt
6668from collections import namedtuple , deque
7173import torch .optim as optim
7274import torch .nn .functional as F
7375
74- if gym .__version__ [:4 ] == '0.26' :
75- env = gym .make ('CartPole-v1' )
76- elif gym .__version__ [:4 ] == '0.25' :
77- env = gym .make ('CartPole-v1' , new_step_api = True )
78- else :
79- raise ImportError (f"Requires gym v25 or v26, actual version: { gym .__version__ } " )
76+ env = gym .make ("CartPole-v1" )
8077
8178# set up matplotlib
8279is_ipython = 'inline' in matplotlib .get_backend ()
117114class ReplayMemory (object ):
118115
119116 def __init__ (self , capacity ):
120- self .memory = deque ([],maxlen = capacity )
117+ self .memory = deque ([], maxlen = capacity )
121118
122119 def push (self , * args ):
123120 """Save a transition"""
@@ -261,10 +258,7 @@ def forward(self, x):
261258# Get number of actions from gym action space
262259n_actions = env .action_space .n
263260# Get the number of state observations
264- if gym .__version__ [:4 ] == '0.26' :
265- state , _ = env .reset ()
266- elif gym .__version__ [:4 ] == '0.25' :
267- state , _ = env .reset (return_info = True )
261+ state , info = env .reset ()
268262n_observations = len (state )
269263
270264policy_net = DQN (n_observations , n_actions ).to (device )
@@ -286,7 +280,7 @@ def select_action(state):
286280 steps_done += 1
287281 if sample > eps_threshold :
288282 with torch .no_grad ():
289- # t.max(1) will return largest column value of each row.
283+ # t.max(1) will return the largest column value of each row.
290284 # second column on max result is index of where max element was
291285 # found, so we pick action with the larger expected reward.
292286 return policy_net (state ).max (1 )[1 ].view (1 , 1 )
@@ -410,10 +404,7 @@ def optimize_model():
410404
411405for i_episode in range (num_episodes ):
412406 # Initialize the environment and get it's state
413- if gym .__version__ [:4 ] == '0.26' :
414- state , _ = env .reset ()
415- elif gym .__version__ [:4 ] == '0.25' :
416- state , _ = env .reset (return_info = True )
407+ state , info = env .reset ()
417408 state = torch .tensor (state , dtype = torch .float32 , device = device ).unsqueeze (0 )
418409 for t in count ():
419410 action = select_action (state )
0 commit comments