Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8783,6 +8783,7 @@ def test_ppo(
value,
loss_critic_type="l2",
functional=functional,
device=device,
)
if composite_action_dist:
loss_fn.set_keys(
Expand Down Expand Up @@ -8883,6 +8884,7 @@ def test_ppo_composite_no_aggregate(
value,
loss_critic_type="l2",
functional=functional,
device=device,
)
loss_fn.set_keys(
action=("action", "action1"),
Expand Down Expand Up @@ -8943,9 +8945,19 @@ def test_ppo_state_dict(
device=device, composite_action_dist=composite_action_dist
)
value = self._create_mock_value(device=device)
loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
device=device,
)
sd = loss_fn.state_dict()
loss_fn2 = loss_class(actor, value, loss_critic_type="l2")
loss_fn2 = loss_class(
actor,
value,
loss_critic_type="l2",
device=device,
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
Expand Down Expand Up @@ -8993,6 +9005,7 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
value,
loss_critic_type="l2",
separate_losses=True,
device=device,
)

if advantage is not None:
Expand Down Expand Up @@ -9100,6 +9113,7 @@ def test_ppo_shared_seq(
loss_critic_type="l2",
separate_losses=separate_losses,
entropy_coef=0.0,
device=device,
)

loss_fn2 = loss_class(
Expand All @@ -9108,6 +9122,7 @@ def test_ppo_shared_seq(
loss_critic_type="l2",
separate_losses=separate_losses,
entropy_coef=0.0,
device=device,
)

if advantage is not None:
Expand Down Expand Up @@ -9202,7 +9217,12 @@ def test_ppo_diff(
else:
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
loss_fn = loss_class(
actor,
value,
loss_critic_type="l2",
device=device,
)

params = TensorDict.from_module(loss_fn, as_module=True)

Expand Down Expand Up @@ -9595,6 +9615,7 @@ def test_ppo_value_clipping(
value,
loss_critic_type="l2",
clip_value=clip_value,
device=device,
)

else:
Expand All @@ -9603,6 +9624,7 @@ def test_ppo_value_clipping(
value,
loss_critic_type="l2",
clip_value=clip_value,
device=device,
)
advantage(td)
if composite_action_dist:
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,9 @@ def __init__(
raise ValueError(
f"clip_value must be a float or a scalar tensor, got {clip_value}."
)
self.register_buffer("clip_value", clip_value)
self.register_buffer("clip_value", clip_value.to(device))
else:
self.clip_value = None
try:
log_prob_keys = self.actor_network.log_prob_keys
action_keys = self.actor_network.dist_sample_keys
Expand Down
Loading