Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions slm_lab/agent/algorithm/policy_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,35 @@ def __init__(self, probs=None, logits=None, validate_args=None):

@property
def logits(self):
return torch.tensor([cat.logits for cat in self.categoricals])
return [cat.logits for cat in self.categoricals]

@property
def probs(self):
return torch.tensor([cat.probs for cat in self.categoricals])
return [cat.probs for cat in self.categoricals]

@property
def param_shape(self):
return torch.tensor([cat.param_shape for cat in self.categoricals])
return [cat.param_shape for cat in self.categoricals]

@property
def mean(self):
return torch.tensor([cat.mean for cat in self.categoricals])
return torch.stack([cat.mean for cat in self.categoricals])

@property
def variance(self):
return torch.tensor([cat.variance for cat in self.categoricals])
return torch.stack([cat.variance for cat in self.categoricals])

def sample(self, sample_shape=torch.Size()):
return torch.tensor([cat.sample(sample_shape=sample_shape) for cat in self.categoricals])
return torch.stack([cat.sample(sample_shape=sample_shape) for cat in self.categoricals])

def log_prob(self, value):
return torch.tensor([cat.log_prob(value[idx]) for idx, cat in enumerate(self.categoricals)])
return torch.stack([cat.log_prob(value[idx]) for idx, cat in enumerate(self.categoricals)])

def entropy(self):
return torch.tensor([cat.entropy() for cat in self.categoricals])
return torch.stack([cat.entropy() for cat in self.categoricals])

def enumerate_support(self):
return torch.tensor([cat.enumerate_support() for cat in self.categoricals])
return [cat.enumerate_support() for cat in self.categoricals]


setattr(distributions, 'Argmax', Argmax)
Expand Down Expand Up @@ -383,11 +383,8 @@ def calc_log_probs(algorithm, net, body, batch):
if not is_multi_action: # already cloned for multi_action above
pdparam = pdparam.clone() # clone for grad safety
_action, action_pd = sample_action_pd(ActionPD, pdparam, body)
log_probs.append(action_pd.log_prob(actions[idx]))
log_probs.append(action_pd.log_prob(actions[idx].float()).sum(dim=0))
log_probs = torch.stack(log_probs)
if is_multi_action:
log_probs = log_probs.mean(dim=1)
log_probs = torch.tensor(log_probs, requires_grad=True)
assert not torch.isnan(log_probs).any(), f'log_probs: {log_probs}, \npdparams: {pdparams} \nactions: {actions}'
logger.debug(f'log_probs: {log_probs}')
return log_probs
5 changes: 3 additions & 2 deletions slm_lab/agent/algorithm/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def calc_pdparam(self, x, evaluate=True, net=None):
@lab_api
def body_act(self, body, state):
action, action_pd = self.action_policy(state, self, body)
body.entropies.append(action_pd.entropy())
body.log_probs.append(action_pd.log_prob(action.float()))
# sum for single and multi-action
body.entropies.append(action_pd.entropy().sum(dim=0))
body.log_probs.append(action_pd.log_prob(action.float()).sum(dim=0))
assert not torch.isnan(body.log_probs[-1])
if len(action.shape) == 0: # scalar
return action.cpu().numpy().astype(body.action_space.dtype).item()
Expand Down
5 changes: 3 additions & 2 deletions slm_lab/agent/algorithm/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ def calc_pdparam(self, x, evaluate=True, net=None):
def body_act(self, body, state):
'''Note, SARSA is discrete-only'''
action, action_pd = self.action_policy(state, self, body)
body.entropies.append(action_pd.entropy())
body.log_probs.append(action_pd.log_prob(action.float()))
# sum for single and multi-action
body.entropies.append(action_pd.entropy().sum(dim=0))
body.log_probs.append(action_pd.log_prob(action.float()).sum(dim=0))
assert not torch.isnan(body.log_probs[-1])
if len(action.shape) == 0: # scalar
return action.cpu().numpy().astype(body.action_space.dtype).item()
Expand Down
2 changes: 1 addition & 1 deletion slm_lab/agent/net/convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def training_step(self, x=None, y=None, loss=None, retain_graph=False):
loss.backward(retain_graph=retain_graph)
if self.clip_grad:
logger.debug(f'Clipping gradient: {self.clip_grad_val}')
torch.nn.utils.clip_grad_norm(self.parameters(), self.clip_grad_val)
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
self.optim.step()
if net_util.to_assert_trained():
assert_trained(self.conv_model)
Expand Down
4 changes: 2 additions & 2 deletions slm_lab/agent/net/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def training_step(self, x=None, y=None, loss=None, retain_graph=False):
loss.backward(retain_graph=retain_graph)
if self.clip_grad:
logger.debug(f'Clipping gradient: {self.clip_grad_val}')
torch.nn.utils.clip_grad_norm(self.parameters(), self.clip_grad_val)
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
self.optim.step()
if net_util.to_assert_trained():
model = getattr(self, 'model', None) or getattr(self, 'model_body')
Expand Down Expand Up @@ -409,7 +409,7 @@ def training_step(self, xs=None, ys=None, loss=None, retain_graph=False):
loss.backward(retain_graph=retain_graph)
if self.clip_grad:
logger.debug(f'Clipping gradient: {self.clip_grad_val}')
torch.nn.utils.clip_grad_norm(self.parameters(), self.clip_grad_val)
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
self.optim.step()
if net_util.to_assert_trained():
assert_trained(self.model_body)
Expand Down
2 changes: 2 additions & 0 deletions slm_lab/agent/net/net_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def gen_assert_trained(pre_model):
def assert_trained(post_model):
post_weights = [param.clone() for param in post_model.parameters()]
assert not all(torch.equal(w1, w2) for w1, w2 in zip(pre_weights, post_weights)), 'Model parameter is not updated in training_step(), check if your tensor is detached from graph.'
assert all(param.grad.norm() < 100.0 for param in post_model.parameters()), 'Gradient norm is > 100, which is bad. Consider using the "clip_grad" and "clip_grad_val" net parameter'
logger.info('Passed network weight update assertation in dev lab_mode.')
return assert_trained


Expand Down