-
Notifications
You must be signed in to change notification settings - Fork 431
[BugFix] make "_reset", "step_count", and other done_based keys follow done_spec #981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
you can do |
|
Init tracker will also create a key that follows done_spec. |
torchrl/envs/common.py
Outdated
| ) | ||
| if not break_when_any_done and done.any(): | ||
| _reset = done.view(tensordict.shape) | ||
| _reset = done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we'd like to clone it to be 100% sure that any downstream modification won't have surprising effects
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, i didn't add it as it was not there before but seems needed
| ) | ||
| input_spec["step_count"] = UnboundedDiscreteTensorSpec( | ||
| shape=input_spec.shape, | ||
| shape=self.parent.done_spec.shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the tests, there is no parent env so
AttributeError: 'NoneType' object has no attribute 'done_spec'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should default to the tensordict shape if there is not parent env.
|
I think we should be ready. All in all I think this also acted as a cleanup as i got to remove some squeezes and views. I hope to have found all occurrences of done-related keys. Let's hope I dodn't miss any |
| # collectors do not support passing other tensors than `"_reset"` | ||
| # to `reset()`. | ||
| if len(self.env.batch_size): | ||
| self._tensordict.masked_fill_(done_or_terminated, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was the utility of ths?
It was assuming that the done_spec is expandable to the other key specs
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
the tests with vmas are broken and it seems to be related to this PR :) |
There are many places where _reset is just intantieated using the batch_size, while in reality it needs to follow the done_spec.
In particular, in transforms, how do we enforce this?