|
24 | 24 | an action, the environment *transitions* to a new state, and also |
25 | 25 | returns a reward that indicates the consequences of the action. In this |
26 | 26 | task, rewards are +1 for every incremental timestep and the environment |
27 | | -terminates if the pole falls over too far or the crat mover more then 2.4 |
| 27 | +terminates if the pole falls over too far or the cart moves more then 2.4 |
28 | 28 | units away from center. This means better performing scenarios will run |
29 | 29 | for longer duration, accumulating larger return. |
30 | 30 |
|
@@ -249,14 +249,15 @@ def forward(self, x): |
249 | 249 | T.Resize(40, interpolation=Image.CUBIC), |
250 | 250 | T.ToTensor()]) |
251 | 251 |
|
| 252 | + |
252 | 253 | def get_cart_location(screen_width): |
253 | 254 | world_width = env.x_threshold * 2 |
254 | 255 | scale = screen_width / world_width |
255 | 256 | return int(env.state[0] * scale + screen_width / 2.0) # MIDDLE OF CART |
256 | 257 |
|
257 | 258 | def get_screen(): |
258 | | - # Returned requested by gym is 400x600x3, but is sometimes larger such as |
259 | | - # as 800x1200x3. Transpose into torch order (CHW). |
| 259 | + # Returned screen requested by gym is 400x600x3, but is sometimes larger |
| 260 | + # such as 800x1200x3. Transpose it into torch order (CHW). |
260 | 261 | screen = env.render(mode='rgb_array').transpose((2, 0, 1)) |
261 | 262 | # Cart is in the lower half, so strip off the top and bottom of the screen |
262 | 263 | _, screen_height, screen_width = screen.shape |
@@ -310,20 +311,18 @@ def get_screen(): |
310 | 311 | # episode. |
311 | 312 | # |
312 | 313 |
|
313 | | -BATCH_SIZE = 196 #128 |
| 314 | +BATCH_SIZE = 128 |
314 | 315 | GAMMA = 0.999 |
315 | 316 | EPS_START = 0.9 |
316 | | -EPS_END = 0.07 |
317 | | -EPS_DECAY = 300 |
| 317 | +EPS_END = 0.05 |
| 318 | +EPS_DECAY = 200 |
318 | 319 | TARGET_UPDATE = 10 |
319 | 320 |
|
320 | 321 | # Get screen size so that we can initialize layers correctly based on shape |
321 | | -# returned from AI gym. Typical dimentions at this pont are close to 3x40x90 |
322 | | -# which is the result of a clamped and down-scaled buffer in get_screen() |
| 322 | +# returned from AI gym. Typical dimensions at this point are close to 3x40x90 |
| 323 | +# which is the result of a clamped and down-scaled render buffer in get_screen() |
323 | 324 | init_screen = get_screen() |
324 | 325 | _, _, screen_height, screen_width = init_screen.shape |
325 | | -#screen_height = init_screen.shape[2] |
326 | | -#print("Screen size w,h:", screen_width, " ", screen_height) |
327 | 326 |
|
328 | 327 | policy_net = DQN(screen_height, screen_width).to(device) |
329 | 328 | target_net = DQN(screen_height, screen_width).to(device) |
@@ -452,7 +451,7 @@ def optimize_model(): |
452 | 451 | # duration improvements. |
453 | 452 | # |
454 | 453 |
|
455 | | -num_episodes = 500 |
| 454 | +num_episodes = 50 |
456 | 455 | for i_episode in range(num_episodes): |
457 | 456 | # Initialize the environment and state |
458 | 457 | env.reset() |
@@ -496,14 +495,14 @@ def optimize_model(): |
496 | 495 | plt.show() |
497 | 496 |
|
498 | 497 | ###################################################################### |
499 | | -# Here is the diagram that illustrates the overall resulting flow. |
| 498 | +# Here is the diagram that illustrates the overall resulting data flow. |
500 | 499 | # |
501 | 500 | # .. figure:: /_static/img/reinforcement_learning_diagram.jpg |
502 | 501 | # |
503 | 502 | # Actions are chosen either randomly or based on a policy, getting the next |
504 | | -# step sample for the gym environment. We record the results in the |
505 | | -# replay memory and also perform optimization step on every iteration. |
| 503 | +# step sample from the gym environment. We record the results in the |
| 504 | +# replay memory and also run optimization step on every iteration. |
506 | 505 | # Optimization picks a random batch from the replay memory to do training of the |
507 | | -# new policy. "Older" target_net, used in optimization to computed expected |
508 | | -# Q values is updated occasionally to keep it current. |
| 506 | +# new policy. "Older" target_net is also used in optimization to compute the |
| 507 | +# expected Q values; it is updated occasionally to keep it current. |
509 | 508 | # |
0 commit comments