Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[BUG] Binary tensor spec is not compatible with shape=() #3074

@LCarmi

Description

@LCarmi

Describe the bug

#2979 introduced the possibility of passing a shape = () to the Binary spec to model scalar binary values. However the Binary.clone() method has not been updated to be compatible as well.

The same can be said regarding other methods, such as ._unsqeeze()

To Reproduce

import pytest
from torchrl.data import Binary


def test__Binary__can_handle_empty_spec() -> None:
    with pytest.raises(IndexError, match="tuple index out of range"):
        Binary(shape=()).clone()

Reason and Possible fixes

The main problem is that whenever the Binary operates on itself, it returns a self.__class__(n=self.shape[-1], shape=shape, ... ), which fails when shape=() due to self.shape[-1].

I would suggest to replace all n=self.shape[-1] occurrences with n=self.n when the shape is unchanged (eg, in the .clone()). Have not thought what to do for the other cases such as ._unsqueeze(), but I am available for a PR and will think about this later.

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions