Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(self, observation):
output = self.linear(observation)
if self.multiple_outputs:
return output, output.sum(), output.min(), output.max()
return self.linear(observation)
return output


class UnwrappablePolicy(nn.Module):
Expand Down Expand Up @@ -1512,6 +1512,7 @@ def create_env():
cudagraph_policy=cudagraph,
weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()},
)
assert "policy" in collector._weight_senders, collector._weight_senders.keys()
try:
# collect state_dict
state_dict = collector.state_dict()
Expand Down
2 changes: 2 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3836,6 +3836,8 @@ def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv):
finally:
env.close(raise_if_closed=False)
del env
time.sleep(0.1)
gc.collect()

class AddString(Transform):
def __init__(self):
Expand Down
28 changes: 28 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,19 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any:
else None
)

# If no weights were provided and a sync scheme exists, extract the latest
# weights from the current model using the scheme strategy (state_dict or tensordict).
# This ensures we don't return stale cached weights.
if weights is None and scheme is not None:
from torchrl.weight_update.weight_sync_schemes import (
_resolve_model,
WeightStrategy,
)

strategy = WeightStrategy(extract_as=scheme.strategy)
model = _resolve_model(self, model_id)
return strategy.extract_weights(model)

if weights is None:
if model_id == "policy" and hasattr(self, "policy_weights"):
return self.policy_weights
Expand Down Expand Up @@ -462,6 +475,21 @@ def update_policy_weights_(
# Apply to local policy
if hasattr(self, "policy") and isinstance(self.policy, nn.Module):
strategy.apply_weights(self.policy, weights)
elif (
hasattr(self, "_original_policy")
and isinstance(self._original_policy, nn.Module)
and hasattr(self, "policy")
and isinstance(self.policy, nn.Module)
):
# If no weights were provided, mirror weights from the original (trainer) policy
from torchrl.weight_update.weight_sync_schemes import WeightStrategy

strategy = WeightStrategy(extract_as="tensordict")
weights = strategy.extract_weights(self._original_policy)
# Cast weights to the policy device before applying
if self.policy_device is not None:
weights = weights.to(self.policy_device)
strategy.apply_weights(self.policy, weights)
# Otherwise, no action needed - policy is local and changes are immediately visible

def __iter__(self) -> Iterator[TensorDictBase]:
Expand Down
14 changes: 8 additions & 6 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,14 +2489,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
# Make sure the root is updated
root_shared_tensordict.update_(env._step_mdp(input))

# Set event before sending non-tensor data so parent knows worker is done
# The recv() call itself will provide synchronization for the pipe
mp_event.set()

if _non_tensor_keys:
child_pipe.send(
("non_tensor", next_td.select(*_non_tensor_keys, strict=False))
)

# Set event only after non-tensor data is sent to avoid race condition
mp_event.set()

del next_td

elif cmd == "step_and_maybe_reset":
Expand Down Expand Up @@ -2530,14 +2531,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
event.record()
event.synchronize()

# Set event before sending non-tensor data so parent knows worker is done
# The recv() call itself will provide synchronization for the pipe
mp_event.set()

if _non_tensor_keys:
ntd = root_next_td.select(*_non_tensor_keys)
ntd.set("next", td_next.select(*_non_tensor_keys))
child_pipe.send(("non_tensor", ntd))

# Set event only after non-tensor data is sent to avoid race condition
mp_event.set()

del td, root_next_td

elif cmd == "close":
Expand Down
Loading