@@ -530,6 +530,11 @@ def select_action(self, state):
530
530
actions , noises = self .act .get_action (states ) # plan to be get_action_a_noise
531
531
return actions [0 ].detach ().cpu ().numpy (), noises [0 ].detach ().cpu ().numpy ()
532
532
533
+ def select_actions (self , states ):
534
+ # states = torch.as_tensor((state,), dtype=torch.float32, device=self.device)
535
+ actions , noises = self .act .get_action (states ) # plan to be get_action_a_noise
536
+ return actions , noises
537
+
533
538
def explore_env (self , env , target_step , reward_scale , gamma ):
534
539
trajectory_list = list ()
535
540
@@ -544,20 +549,97 @@ def explore_env(self, env, target_step, reward_scale, gamma):
544
549
self .state = state
545
550
return trajectory_list
546
551
552
+ def explore_envs (self , env , target_step , reward_scale , gamma ):
553
+ state = self .state
554
+ env_num = env .env_num
555
+
556
+ buf_step = target_step // env_num
557
+ states = torch .empty ((buf_step , env_num , env .state_dim ), dtype = torch .float32 , device = self .device )
558
+ actions = torch .empty ((buf_step , env_num , env .action_dim ), dtype = torch .float32 , device = self .device )
559
+ noises = torch .empty ((buf_step , env_num , env .action_dim ), dtype = torch .float32 , device = self .device )
560
+ rewards = torch .empty ((buf_step , env_num ), dtype = torch .float32 , device = self .device )
561
+ dones = torch .empty ((buf_step , env_num ), dtype = torch .float32 , device = self .device )
562
+ for i in range (buf_step ):
563
+ action , noise = self .select_actions (state )
564
+ next_s , reward , done , _ = env .step (action .tanh ())
565
+ # other = (reward * reward_scale, 0.0 if done else gamma, *action, *noise)
566
+ # trajectory_list.append((state, other))
567
+
568
+ states [i ] = state
569
+ actions [i ] = action
570
+ noises [i ] = noise
571
+ rewards [i ] = reward
572
+ dones [i ] = done
573
+
574
+ # state = env.reset() if done else next_s
575
+ state = next_s
576
+ self .state = state
577
+ rewards = rewards * reward_scale
578
+ masks = (1 - dones ) * gamma
579
+ return states , rewards , masks , actions , noises
580
+
581
+ def prepare_buffer (self , buffer ):
582
+ buffer .update_now_len ()
583
+ buf_len = buffer .now_len
584
+ with torch .no_grad (): # compute reverse reward
585
+ reward , mask , action , a_noise , state = buffer .sample_all ()
586
+
587
+ # print(';', [t.shape for t in (reward, mask, action, a_noise, state)])
588
+ bs = 2 ** 10 # set a smaller 'BatchSize' when out of GPU memory.
589
+ value = torch .cat ([self .cri_target (state [i :i + bs ]) for i in range (0 , state .size (0 ), bs )], dim = 0 ).squeeze (1 )
590
+ logprob = self .act .get_old_logprob (action , a_noise )
591
+
592
+ pre_state = torch .as_tensor ((self .state ,), dtype = torch .float32 , device = self .device )
593
+ pre_r_sum = self .cri_target (pre_state ).detach ()
594
+ r_sum , advantage = self .get_reward_sum (buf_len , reward , mask , value , pre_r_sum )
595
+ buffer .empty_buffer ()
596
+ return state , action , r_sum , logprob , advantage
597
+
598
+ def prepare_buffers (self , buffer ):
599
+ with torch .no_grad (): # compute reverse reward
600
+ states , rewards , masks , actions , noises = buffer
601
+ buf_len = states .size (0 )
602
+ env_num = states .size (1 )
603
+
604
+ values = torch .empty_like (rewards )
605
+ logprobs = torch .empty_like (rewards )
606
+ bs = 2 ** 10 # set a smaller 'BatchSize' when out of GPU memory.
607
+ for j in range (env_num ):
608
+ for i in range (0 , buf_len , bs ):
609
+ values [i :i + bs , j ] = self .cri_target (states [i :i + bs , j ]).squeeze (1 )
610
+ logprobs [:, j ] = self .act .get_old_logprob (actions [:, j ], noises [:, j ]).squeeze (1 )
611
+
612
+ pre_states = torch .as_tensor (self .state , dtype = torch .float32 , device = self .device )
613
+ pre_r_sums = self .cri_target (pre_states ).detach ().squeeze (1 )
614
+
615
+ r_sums , advantages = self .get_reward_sum ((buf_len , env_num ), rewards , masks , values , pre_r_sums )
616
+
617
+ buf_len_vec = buf_len * env_num
618
+
619
+ states = states .view ((buf_len_vec , - 1 ))
620
+ actions = actions .view ((buf_len_vec , - 1 ))
621
+ r_sums = r_sums .view (buf_len_vec )
622
+ logprobs = logprobs .view (buf_len_vec )
623
+ advantages = advantages .view (buf_len_vec )
624
+ return states , actions , r_sums , logprobs , advantages
625
+
547
626
def update_net (self , buffer , batch_size , repeat_times , soft_update_tau ):
548
627
if isinstance (buffer , list ):
549
628
buffer_tuple = list (map (list , zip (* buffer ))) # 2D-list transpose
550
629
(buf_state , buf_action , buf_r_sum , buf_logprob , buf_advantage
551
630
) = [torch .cat (tensor_list , dim = 0 ).to (self .device )
552
631
for tensor_list in buffer_tuple ]
553
-
632
+ elif isinstance (buffer , tuple ):
633
+ (buf_state , buf_action , buf_r_sum , buf_logprob , buf_advantage
634
+ ) = buffer
554
635
else :
555
636
(buf_state , buf_action , buf_r_sum , buf_logprob , buf_advantage
556
- ) = self .prepare_buffer (buffer , self . state )
637
+ ) = self .prepare_buffer (buffer )
557
638
buf_len = buf_state .size (0 )
558
639
559
640
'''PPO: Surrogate objective of Trust Region'''
560
641
obj_critic = obj_actor = old_logprob = None
642
+ r_sum_std = 1 # todo buf_r_sum.std() + 1e-6
561
643
for _ in range (int (buf_len / batch_size * repeat_times )):
562
644
indices = torch .randint (buf_len , size = (batch_size ,), requires_grad = False , device = self .device )
563
645
@@ -576,60 +658,47 @@ def update_net(self, buffer, batch_size, repeat_times, soft_update_tau):
576
658
self .optim_update (self .act_optim , obj_actor )
577
659
578
660
value = self .cri (state ).squeeze (1 ) # critic network predicts the reward_sum (Q value) of state
579
- obj_critic = self .criterion (value , r_sum ) # / (r_sum.std() + 1e-6)
661
+ obj_critic = self .criterion (value , r_sum ) / r_sum_std
580
662
self .optim_update (self .cri_optim , obj_critic )
581
663
self .soft_update (self .cri_target , self .cri , soft_update_tau ) if self .cri_target is not self .cri else None
582
664
583
665
return obj_critic .item (), obj_actor .item (), old_logprob .mean ().item () # logging_tuple
584
666
585
- def prepare_buffer (self , buffer , state_ary ):
586
- buffer .update_now_len ()
587
- buf_len = buffer .now_len
588
- with torch .no_grad (): # compute reverse reward
589
- reward , mask , action , a_noise , state = buffer .sample_all ()
590
-
591
- # print(';', [t.shape for t in (reward, mask, action, a_noise, state)])
592
- bs = 2 ** 10 # set a smaller 'BatchSize' when out of GPU memory.
593
- value = torch .cat ([self .cri_target (state [i :i + bs ]) for i in range (0 , state .size (0 ), bs )], dim = 0 )
594
- logprob = self .act .get_old_logprob (action , a_noise )
595
-
596
- pre_state = torch .as_tensor ((state_ary ,), dtype = torch .float32 , device = self .device )
597
- pre_r_sum = self .cri (pre_state ).detach ()
598
- r_sum , advantage = self .get_reward_sum (buf_len , reward , mask , value , pre_r_sum )
599
- buffer .empty_buffer ()
600
- return state , action , r_sum , logprob , advantage
601
-
602
667
def get_reward_sum_raw (self , buf_len , buf_reward , buf_mask , buf_value , pre_r_sum ) -> (torch .Tensor , torch .Tensor ):
603
668
"""compute the excepted discounted episode return
604
669
605
670
:int buf_len: the length of ReplayBuffer
606
- :torch.Tensor buf_reward: buf_reward.shape==(buf_len, 1)
607
- :torch.Tensor buf_mask: buf_mask.shape ==(buf_len, 1)
608
- :torch.Tensor buf_value: buf_value.shape ==(buf_len, 1)
609
- :return torch.Tensor buf_r_sum: buf_r_sum.shape ==(buf_len, 1)
671
+ :torch.Tensor buf_reward: buf_reward.shape==(buf_len, )
672
+ :torch.Tensor buf_mask: buf_mask.shape ==(buf_len, )
673
+ :torch.Tensor buf_value: buf_value.shape ==(buf_len, )
674
+ :torch.Tensor pre_r_sum: pre_r_sum.shape ==(1, 1)
675
+ :return torch.Tensor buf_r_sum: buf_r_sum.shape ==(buf_len, 1)
610
676
:return torch.Tensor buf_advantage: buf_advantage.shape ==(buf_len, 1)
611
677
"""
612
678
buf_r_sum = torch .empty (buf_len , dtype = torch .float32 , device = self .device ) # reward sum
613
- for i in range (buf_len - 1 , - 1 , - 1 ):
679
+ the_len = buf_len [0 ] if isinstance (buf_len , tuple ) else buf_len
680
+ for i in range (the_len - 1 , - 1 , - 1 ):
614
681
buf_r_sum [i ] = buf_reward [i ] + buf_mask [i ] * pre_r_sum
615
682
pre_r_sum = buf_r_sum [i ]
616
- buf_advantage = buf_r_sum - ( buf_mask * buf_value . squeeze ( 1 ))
617
- buf_advantage = (buf_advantage - buf_advantage .mean ()) / (buf_advantage .std () + 1e-5 )
683
+ buf_advantage = buf_r_sum - buf_mask * buf_value
684
+ buf_advantage = (buf_advantage - buf_advantage .mean ()) # todo / (buf_advantage.std() + 1e-5)
618
685
return buf_r_sum , buf_advantage
619
686
620
687
def get_reward_sum_gae (self , buf_len , buf_reward , buf_mask , buf_value , pre_r_sum ) -> (torch .Tensor , torch .Tensor ):
621
688
buf_r_sum = torch .empty (buf_len , dtype = torch .float32 , device = self .device ) # old policy value
622
689
buf_advantage = torch .empty (buf_len , dtype = torch .float32 , device = self .device ) # advantage value
623
690
624
- pre_advantage = pre_r_sum * (np .exp (self .lambda_gae_adv - 0.4 ) - 1 ) # advantage value of previous step
625
- for i in range (buf_len - 1 , - 1 , - 1 ):
691
+ pre_advantage = pre_r_sum * (np .exp (self .lambda_gae_adv - 0.5 ) - 1 ) # advantage value of previous step
692
+
693
+ the_len = buf_len [0 ] if isinstance (buf_len , tuple ) else buf_len
694
+ for i in range (the_len - 1 , - 1 , - 1 ):
626
695
buf_r_sum [i ] = buf_reward [i ] + buf_mask [i ] * pre_r_sum
627
696
pre_r_sum = buf_r_sum [i ]
628
697
629
698
buf_advantage [i ] = buf_reward [i ] + buf_mask [i ] * (pre_advantage - buf_value [i ]) # fix a bug here
630
699
pre_advantage = buf_value [i ] + buf_advantage [i ] * self .lambda_gae_adv
631
700
632
- buf_advantage = (buf_advantage - buf_advantage .mean ()) / (buf_advantage .std () + 1e-5 )
701
+ buf_advantage = (buf_advantage - buf_advantage .mean ()) # todo / (buf_advantage.std() + 1e-5)
633
702
return buf_r_sum , buf_advantage
634
703
635
704
0 commit comments