Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

marcosgalleterobbva
Copy link
Contributor

Description

While loading data from Minari's datasets, the download_and_preproc function is used. Inside it, the _patch_info is used for correctly aligning different keys in observations, state or infos that might have one less element.

This patching should not apply to observations that are not a TensorDict, but rather a NonTensorData containing a single observation from the environment. This PR avoids patching these NonTensorData observations, and solves the limitation of not being able to load Atari datasets with Minari.

Motivation and Context

Take this code, for example. It is just a simple loading of the atari/skiing dataset.

from torchrl.data.datasets.minari_data import MinariExperienceReplay

BATCH_SIZE = 1
SAVE_ROOT = None


def download_minari_datasets(dataset_id):
    _ = MinariExperienceReplay(
        dataset_id=dataset_id,
        batch_size=BATCH_SIZE,
        root=SAVE_ROOT,
    )
    print(f"✓ Successfully downloaded {dataset_id}")


if __name__ == "__main__":
    download_minari_datasets("atari/skiing/expert-v0")

Executing this code results in the following error:

Downloading atari/skiing/expert-v0 from Farama servers...
Fetching 2 files: 100%|██████████| 2/2 [00:30<00:00, 15.41s/it]

Dataset atari/skiing/expert-v0 downloaded to /var/folders/jm/mc7b8v9x7w912wk1kmk0z7b00000gp/T/tmpne_qc6vp/atari/skiing/expert-v0
2025-07-24 11:09:27,027 [torchrl][INFO]    first read through data to create data structure... [END]
Traceback (most recent call last):
  File "/Users/O000142/Library/Application Support/JetBrains/PyCharmCE2023.3/scratches/load_atari_dataset.py", line 17, in <module>
    download_minari_datasets("atari/skiing/expert-v0")
  File "/Users/O000142/Library/Application Support/JetBrains/PyCharmCE2023.3/scratches/load_atari_dataset.py", line 8, in download_minari_datasets
    _ = MinariExperienceReplay(
  File "/Users/O000142/Projects/rl/torchrl/data/datasets/minari_data.py", line 187, in __init__
    storage = self._download_and_preproc()
  File "/Users/O000142/Projects/rl/torchrl/data/datasets/minari_data.py", line 268, in _download_and_preproc
    val = _patch_info(val)
  File "/Users/O000142/Projects/rl/torchrl/data/datasets/minari_data.py", line 456, in _patch_info
    raise RuntimeError(
RuntimeError: Unique shapes in a sub-tensordict can only be of length 2, got shapes defaultdict(<class 'list'>, {}).

This PR solves that.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3091

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 24, 2025
Copy link
Collaborator

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks

… feature/fix_atari_minariRB

# Conflicts:
#	torchrl/data/datasets/minari_data.py
@marcosgalleterobbva
Copy link
Contributor Author

Resolved conflicts

@vmoens vmoens added Environments Adds or modifies an environment wrapper Data Data-related PR, will launch data-related jobs bug Something isn't working labels Jul 28, 2025
@vmoens vmoens merged commit 93fcb02 into pytorch:main Jul 28, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Data Data-related PR, will launch data-related jobs Environments Adds or modifies an environment wrapper

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants