File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 99This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
1010on the CartPole-v1 task from `Gymnasium <https://gymnasium.farama.org>`__.
1111
12+ You might find it helpful to read the original `Deep Q Learning (DQN) <https://arxiv.org/abs/1312.5602>`__ paper
13+
1214**Task**
1315
1416The agent has to decide between two actions - moving the cart left or
8385plt .ion ()
8486
8587# if GPU is to be used
86- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
88+ device = torch .device (
89+ "cuda" if torch .cuda .is_available () else
90+ "mps" if torch .backends .mps .is_available () else
91+ "cpu"
92+ )
8793
8894
8995######################################################################
@@ -397,7 +403,7 @@ def optimize_model():
397403# can produce better results if convergence is not observed.
398404#
399405
400- if torch .cuda .is_available ():
406+ if torch .cuda .is_available () or torch . backends . mps . is_available () :
401407 num_episodes = 600
402408else :
403409 num_episodes = 50
You can’t perform that action at this time.
0 commit comments