File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -283,7 +283,7 @@ def select_action(state):
283283 # t.max(1) will return the largest column value of each row.
284284 # second column on max result is index of where max element was
285285 # found, so we pick action with the larger expected reward.
286- return policy_net (state ).max (1 )[ 1 ] .view (1 , 1 )
286+ return policy_net (state ).max (1 ). indices .view (1 , 1 )
287287 else :
288288 return torch .tensor ([[env .action_space .sample ()]], device = device , dtype = torch .long )
289289
@@ -360,12 +360,12 @@ def optimize_model():
360360
361361 # Compute V(s_{t+1}) for all next states.
362362 # Expected values of actions for non_final_next_states are computed based
363- # on the "older" target_net; selecting their best reward with max(1)[0].
363+ # on the "older" target_net; selecting their best reward with max(1).values
364364 # This is merged based on the mask, such that we'll have either the expected
365365 # state value or 0 in case the state was final.
366366 next_state_values = torch .zeros (BATCH_SIZE , device = device )
367367 with torch .no_grad ():
368- next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 )[ 0 ]
368+ next_state_values [non_final_mask ] = target_net (non_final_next_states ).max (1 ). values
369369 # Compute the expected Q values
370370 expected_state_action_values = (next_state_values * GAMMA ) + reward_batch
371371
You can’t perform that action at this time.
0 commit comments