diff --git a/test/test_transforms.py b/test/test_transforms.py index 875debcc9e4..f3fdbc88853 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9674,6 +9674,7 @@ def _test_vecnorm_subproc_auto( def rename_t(self): return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")]) + @retry(AssertionError, tries=10, delay=0) @pytest.mark.parametrize("nprc", [2, 5]) def test_vecnorm_parallel_auto(self, nprc): queues = [] @@ -10619,6 +10620,38 @@ def test_compose(self, keys, batch, device, nchannels=1, N=4): [nchannels * N, 16, 16] ) + def test_compose_pop(self): + t1 = CatFrames(in_keys=["a", "b"], N=2, dim=-1) + t2 = FiniteTensorDictCheck() + t3 = ExcludeTransform() + compose = Compose(t1, t2, t3) + assert len(compose.transforms) == 3 + p = compose.pop() + assert p is t3 + assert len(compose.transforms) == 2 + p = compose.pop(0) + assert p is t1 + assert len(compose.transforms) == 1 + p = compose.pop() + assert p is t2 + assert len(compose.transforms) == 0 + with pytest.raises(IndexError, match="index -1 is out of range"): + compose.pop() + + def test_compose_pop_parent_modification(self): + t1 = CatFrames(in_keys=["a", "b"], N=2, dim=-1) + t2 = FiniteTensorDictCheck() + t3 = ExcludeTransform() + compose = Compose(t1, t2, t3) + env = TransformedEnv(ContinuousActionVecMockEnv(), compose) + p = t2.parent + assert isinstance(p.transform[0], CatFrames) + env.transform.pop(0) + assert env.transform[0] is t2 + new_p = t2.parent + assert new_p is not p + assert len(new_p.transform) == 0 + def test_lambda_functions(self): def trsf(data): if "y" in data.keys(): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b21c1c2c8bd..a00747d7e02 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -738,7 +738,7 @@ def __setstate__(self, state): self.__dict__.update(state) @property - def parent(self) -> EnvBase | None: + def parent(self) -> TransformedEnv | None: """Returns the parent env of the transform. The parent env is the env that contains all the transforms up until the current one. @@ -1249,6 +1249,7 @@ def close(self, *, raise_if_closed: bool = True): def empty_cache(self): self.__dict__["_output_spec"] = None self.__dict__["_input_spec"] = None + self.transform.empty_cache() super().empty_cache() def append_transform( @@ -1429,6 +1430,50 @@ def map_transform(trsf): for t in transforms: t.set_container(self) + def pop(self, index: int | None = None) -> Transform: + """Pop a transform from the chain. + + Args: + index (int, optional): The index of the transform to pop. If None, the last transform is popped. + + Returns: + The popped transform. + """ + if index is None: + index = len(self.transforms) - 1 + result = self.transforms.pop(index) + parent = self.parent + self.empty_cache() + if parent is not None: + parent.empty_cache() + return result + + def __delitem__(self, index: int | slice | list): + """Delete a transform in the chain. + + :class:`~torchrl.envs.transforms.Transform` or callable are accepted. + """ + del self.transforms[index] + parent = self.parent + self.empty_cache() + if parent is not None: + parent.empty_cache() + + def __setitem__( + self, + index: int | slice | list, + value: Transform | Callable[[TensorDictBase], TensorDictBase], + ): + """Set a transform in the chain. + + :class:`~torchrl.envs.transforms.Transform` or callable are accepted. + """ + self.transforms[index] = value + parent = self.parent + self.empty_cache() + if parent is not None: + parent.empty_cache() + def close(self): """Close the transform.""" for t in self.transforms: @@ -1594,6 +1639,9 @@ def append( else: self.transforms.append(transform) transform.set_container(self) + parent = self.parent + if parent is not None: + parent.empty_cache() def set_container(self, container: Transform | EnvBase) -> None: self.reset_parent() @@ -1626,6 +1674,9 @@ def insert( # empty cache of all transforms to reset parents and specs self.empty_cache() + parent = self.parent + if parent is not None: + parent.empty_cache() if index < 0: index = index + len(self.transforms) transform.eval() diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 1e8f33268ed..272093eb9b1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -752,10 +752,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: explained_variance = None if self.log_explained_variance: - with torch.no_grad(): # <‑‑ break grad‐flow - tgt = target_return.detach() - pred = state_value.detach() - eps = torch.finfo(tgt.dtype).eps + with torch.no_grad(): # <‑‑ break grad‐flow + tgt = target_return.detach() + pred = state_value.detach() + eps = torch.finfo(tgt.dtype).eps resid = torch.var(tgt - pred, unbiased=False, dim=0) total = torch.var(tgt, unbiased=False, dim=0) explained_variance = 1.0 - resid / (total + eps) @@ -819,7 +819,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", self._weighted_loss_entropy(entropy)) if self._has_critic: - loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict) + loss_critic, value_clip_fraction, explained_variance = self.loss_critic( + tensordict + ) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: td_out.set("value_clip_fraction", value_clip_fraction) @@ -1189,7 +1191,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", self._weighted_loss_entropy(entropy)) if self._has_critic: - loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict) + loss_critic, value_clip_fraction, explained_variance = self.loss_critic( + tensordict + ) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: td_out.set("value_clip_fraction", value_clip_fraction) @@ -1537,7 +1541,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", self._weighted_loss_entropy(entropy)) if self._has_critic: - loss_critic, value_clip_fraction, explained_variance = self.loss_critic(tensordict_copy) + loss_critic, value_clip_fraction, explained_variance = self.loss_critic( + tensordict_copy + ) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: td_out.set("value_clip_fraction", value_clip_fraction)