Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9674,6 +9674,7 @@ def _test_vecnorm_subproc_auto(
def rename_t(self):
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])

@retry(AssertionError, tries=10, delay=0)
@pytest.mark.parametrize("nprc", [2, 5])
def test_vecnorm_parallel_auto(self, nprc):
queues = []
Expand Down Expand Up @@ -10619,6 +10620,38 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4):
[nchannels * N, 16, 16]
)

def test_compose_pop(self):
t1 = CatFrames(in_keys=["a", "b"], N=2, dim=-1)
t2 = FiniteTensorDictCheck()
t3 = ExcludeTransform()
compose = Compose(t1, t2, t3)
assert len(compose.transforms) == 3
p = compose.pop()
assert p is t3
assert len(compose.transforms) == 2
p = compose.pop(0)
assert p is t1
assert len(compose.transforms) == 1
p = compose.pop()
assert p is t2
assert len(compose.transforms) == 0
with pytest.raises(IndexError, match="index -1 is out of range"):
compose.pop()

def test_compose_pop_parent_modification(self):
t1 = CatFrames(in_keys=["a", "b"], N=2, dim=-1)
t2 = FiniteTensorDictCheck()
t3 = ExcludeTransform()
compose = Compose(t1, t2, t3)
env = TransformedEnv(ContinuousActionVecMockEnv(), compose)
p = t2.parent
assert isinstance(p.transform[0], CatFrames)
env.transform.pop(0)
assert env.transform[0] is t2
new_p = t2.parent
assert new_p is not p
assert len(new_p.transform) == 0

def test_lambda_functions(self):
def trsf(data):
if "y" in data.keys():
Expand Down
53 changes: 52 additions & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def __setstate__(self, state):
self.__dict__.update(state)

@property
def parent(self) -> EnvBase | None:
def parent(self) -> TransformedEnv | None:
"""Returns the parent env of the transform.

The parent env is the env that contains all the transforms up until the current one.
Expand Down Expand Up @@ -1249,6 +1249,7 @@ def close(self, *, raise_if_closed: bool = True):
def empty_cache(self):
self.__dict__["_output_spec"] = None
self.__dict__["_input_spec"] = None
self.transform.empty_cache()
super().empty_cache()

def append_transform(
Expand Down Expand Up @@ -1429,6 +1430,50 @@ def map_transform(trsf):
for t in transforms:
t.set_container(self)

def pop(self, index: int | None = None) -> Transform:
"""Pop a transform from the chain.

Args:
index (int, optional): The index of the transform to pop. If None, the last transform is popped.

Returns:
The popped transform.
"""
if index is None:
index = len(self.transforms) - 1
result = self.transforms.pop(index)
parent = self.parent
self.empty_cache()
if parent is not None:
parent.empty_cache()
return result

def __delitem__(self, index: int | slice | list):
"""Delete a transform in the chain.

:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
"""
del self.transforms[index]
parent = self.parent
self.empty_cache()
if parent is not None:
parent.empty_cache()

def __setitem__(
self,
index: int | slice | list,
value: Transform | Callable[[TensorDictBase], TensorDictBase],
):
"""Set a transform in the chain.

:class:`~torchrl.envs.transforms.Transform` or callable are accepted.
"""
self.transforms[index] = value
parent = self.parent
self.empty_cache()
if parent is not None:
parent.empty_cache()

def close(self):
"""Close the transform."""
for t in self.transforms:
Expand Down Expand Up @@ -1594,6 +1639,9 @@ def append(
else:
self.transforms.append(transform)
transform.set_container(self)
parent = self.parent
if parent is not None:
parent.empty_cache()

def set_container(self, container: Transform | EnvBase) -> None:
self.reset_parent()
Expand Down Expand Up @@ -1626,6 +1674,9 @@ def insert(

# empty cache of all transforms to reset parents and specs
self.empty_cache()
parent = self.parent
if parent is not None:
parent.empty_cache()
if index < 0:
index = index + len(self.transforms)
transform.eval()
Expand Down
20 changes: 13 additions & 7 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,10 +752,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:

explained_variance = None
if self.log_explained_variance:
with torch.no_grad(): # <‑‑ break grad‐flow
tgt = target_return.detach()
pred = state_value.detach()
eps = torch.finfo(tgt.dtype).eps
with torch.no_grad(): # <‑‑ break grad‐flow
tgt = target_return.detach()
pred = state_value.detach()
eps = torch.finfo(tgt.dtype).eps
resid = torch.var(tgt - pred, unbiased=False, dim=0)
total = torch.var(tgt, unbiased=False, dim=0)
explained_variance = 1.0 - resid / (total + eps)
Expand Down Expand Up @@ -819,7 +819,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
tensordict
)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
Expand Down Expand Up @@ -1189,7 +1191,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict)
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
tensordict
)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
Expand Down Expand Up @@ -1537,7 +1541,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", self._weighted_loss_entropy(entropy))
if self._has_critic:
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy)
loss_critic, value_clip_fraction, explained_variance = self.loss_critic(
tensordict_copy
)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
Expand Down
Loading