-
Notifications
You must be signed in to change notification settings - Fork 414
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
#2979 introduced the possibility of passing a shape = () to the Binary spec to model scalar binary values. However the Binary.clone() method has not been updated to be compatible as well.
The same can be said regarding other methods, such as ._unsqeeze()
To Reproduce
import pytest
from torchrl.data import Binary
def test__Binary__can_handle_empty_spec() -> None:
with pytest.raises(IndexError, match="tuple index out of range"):
Binary(shape=()).clone()Reason and Possible fixes
The main problem is that whenever the Binary operates on itself, it returns a self.__class__(n=self.shape[-1], shape=shape, ... ), which fails when shape=() due to self.shape[-1].
I would suggest to replace all n=self.shape[-1] occurrences with n=self.n when the shape is unchanged (eg, in the .clone()). Have not thought what to do for the other cases such as ._unsqueeze(), but I am available for a PR and will think about this later.
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal working example to reproduce the bug
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working