From c0971258ea5bccd90123887cd341740b382878c0 Mon Sep 17 00:00:00 2001 From: AlexandreBrown Date: Tue, 4 Mar 2025 10:43:35 -0500 Subject: [PATCH] Fixed VideoRecorder crash when passing fps --- test/test_transforms.py | 9 +++++++++ torchrl/record/recorder.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 38ecc81a589..fd3d6676f65 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -131,6 +131,7 @@ from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal from torchrl.modules.utils import get_primers_from_module +from torchrl.record.recorder import VideoRecorder if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa @@ -13978,6 +13979,14 @@ def test_transform_inverse(self): raise pytest.skip("Tested elsewhere") +class TestVideoRecorder: + # TODO: add more tests + def test_can_init_with_fps(self): + recorder = VideoRecorder(None, None, fps=30) + + assert recorder is not None + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index c2cf93dd119..37d905c9a35 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -121,7 +121,7 @@ def __init__( video_kwargs = {} video_kwargs.update(kwargs) if fps is not None: - self.video_kwargs["fps"] = fps + video_kwargs["fps"] = fps self.video_kwargs = video_kwargs self.iter = 0 self.skip = skip