diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d8dfeb6c2d1..624a0f098e1 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2242,3 +2242,29 @@ def _set_seed(self, seed: Optional[int]): random.seed(seed) torch.manual_seed(0) return seed + + +class EnvThatErrorsAfter10Iters(EnvBase): + def __init__(self): + self.action_spec = Composite(action=Unbounded((1,))) + self.reward_spec = Composite(reward=Unbounded((1,))) + self.done_spec = Composite(done=Unbounded((1,))) + self.observation_spec = Composite(observation=Unbounded((1,))) + self.counter = 0 + super().__init__() + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDict: + return self.full_observation_spec.zero().update(self.full_done_spec.zero()) + + def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict: + if self.counter >= 10: + raise RuntimeError("max steps!") + self.counter += 1 + return ( + self.full_observation_spec.zero() + .update(self.full_done_spec.zero()) + .update(self.full_reward_spec.zero()) + ) + + def _set_seed(self, seed: Optional[int]): + ... diff --git a/test/test_collector.py b/test/test_collector.py index d2f1c102416..423049d7add 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -9,6 +9,7 @@ import functools import gc import os +import subprocess import sys from typing import Optional @@ -77,6 +78,7 @@ from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule if os.getenv("PYTORCH_TEST_FBCODE"): + IS_FB = True from pytorch.rl.test._utils_internal import ( CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, @@ -98,6 +100,7 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, + EnvThatErrorsAfter10Iters, EnvWithDynamicSpec, HeterogeneousCountingEnv, HeterogeneousCountingEnvPolicy, @@ -107,6 +110,7 @@ NestedCountingEnv, ) else: + IS_FB = False from _utils_internal import ( CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, @@ -128,6 +132,7 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, + EnvThatErrorsAfter10Iters, EnvWithDynamicSpec, HeterogeneousCountingEnv, HeterogeneousCountingEnvPolicy, @@ -235,1396 +240,1875 @@ def make_policy(env): # return tensordict_device_type == storing_device_type -class TestCollectorDevices: - class DeviceLessEnv(EnvBase): - # receives data on cpu, outputs on gpu -- tensordict has no device - def __init__(self, default_device): - self.default_device = default_device - super().__init__(device=None) - self.observation_spec = Composite( - observation=Unbounded((), device=default_device) - ) - self.reward_spec = Unbounded(1, device=default_device) - self.full_done_spec = Composite( - done=Unbounded(1, dtype=torch.bool, device=self.default_device), - truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), - terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), +class TestCollectorGeneric: + @pytest.mark.parametrize("num_env", [1, 2]) + # 1226: for efficiency, we just test vec, not "conv" + @pytest.mark.parametrize("env_name", ["vec"]) + def test_collector_batch_size( + self, num_env, env_name, seed=100, num_workers=2, frames_per_batch=20 + ): + """Tests that there are 'frames_per_batch' frames in each batch of a collection.""" + if num_env == 3 and IS_WINDOWS: + pytest.skip( + "Test timeout (> 10 min) on CI pipeline Windows machine with GPU" ) - self.action_spec = Unbounded((), device=None) - assert self.device is None - assert self.full_observation_spec is not None - assert self.full_done_spec is not None - assert self.full_state_spec is not None - assert self.full_action_spec is not None - assert self.full_reward_spec is not None + if num_env == 1: - def _step(self, tensordict): - assert tensordict.device is None - with torch.device(self.default_device): - return TensorDict( - { - "observation": torch.zeros(()), - "reward": torch.zeros((1,)), - "done": torch.zeros((1,), dtype=torch.bool), - "terminated": torch.zeros((1,), dtype=torch.bool), - "truncated": torch.zeros((1,), dtype=torch.bool), - }, - batch_size=[], - device=None, - ) + def env_fn(): + env = make_make_env(env_name)() + return env - def _reset(self, tensordict=None): - with torch.device(self.default_device): - return TensorDict( - { - "observation": torch.zeros(()), - "done": torch.zeros((1,), dtype=torch.bool), - "terminated": torch.zeros((1,), dtype=torch.bool), - "truncated": torch.zeros((1,), dtype=torch.bool), - }, - batch_size=[], - device=None, + else: + + def env_fn(): + # 1226: For efficiency, we don't use Parallel but Serial + # env = ParallelEnv( + env = SerialEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) ) + return env - def _set_seed(self, seed: int | None = None): - return seed + policy = make_policy(env_name) - class EnvWithDevice(EnvBase): - def __init__(self, default_device): - self.default_device = default_device - super().__init__(device=self.default_device) - self.observation_spec = Composite( - observation=Unbounded((), device=self.default_device) - ) - self.reward_spec = Unbounded(1, device=self.default_device) - self.full_done_spec = Composite( - done=Unbounded(1, dtype=torch.bool, device=self.default_device), - truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), - terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), - device=self.default_device, - ) - self.action_spec = Unbounded((), device=self.default_device) - assert self.device == _make_ordinal_device( - torch.device(self.default_device) - ) - assert self.full_observation_spec is not None - assert self.full_done_spec is not None - assert self.full_state_spec is not None - assert self.full_action_spec is not None - assert self.full_reward_spec is not None + torch.manual_seed(0) + np.random.seed(0) - def _step(self, tensordict): - assert tensordict.device == _make_ordinal_device( - torch.device(self.default_device) - ) - with torch.device(self.default_device): - return TensorDict( - { - "observation": torch.zeros(()), - "reward": torch.zeros((1,)), - "done": torch.zeros((1,), dtype=torch.bool), - "terminated": torch.zeros((1,), dtype=torch.bool), - "truncated": torch.zeros((1,), dtype=torch.bool), - }, - batch_size=[], - device=self.default_device, - ) + ccollector = MultiaSyncDataCollector( + create_env_fn=[env_fn for _ in range(num_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=1000, + total_frames=frames_per_batch * 100, + ) + try: + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert b.numel() == -(-frames_per_batch // num_env) * num_env + if i == 5: + break + assert b.names[-1] == "time" + finally: + ccollector.shutdown() - def _reset(self, tensordict=None): - with torch.device(self.default_device): - return TensorDict( - { - "observation": torch.zeros(()), - "done": torch.zeros((1,), dtype=torch.bool), - "terminated": torch.zeros((1,), dtype=torch.bool), - "truncated": torch.zeros((1,), dtype=torch.bool), - }, - batch_size=[], - device=self.default_device, + ccollector = MultiSyncDataCollector( + create_env_fn=[env_fn for _ in range(num_workers)], + policy=policy, + frames_per_batch=frames_per_batch, + max_frames_per_traj=1000, + total_frames=frames_per_batch * 100, + cat_results="stack", + ) + try: + ccollector.set_seed(seed) + for i, b in enumerate(ccollector): + assert ( + b.numel() + == -(-frames_per_batch // num_env // num_workers) + * num_env + * num_workers ) + if i == 5: + break + assert b.names[-1] == "time" + finally: + ccollector.shutdown() + del ccollector - def _set_seed(self, seed: int | None = None): - return seed + @pytest.mark.parametrize("num_env", [1, 2]) + @pytest.mark.parametrize("env_name", ["conv", "vec"]) + def test_collector_consistency(self, num_env, env_name, seed=100): + """Tests that a rollout gathered with env.rollout matches one gathered with the collector.""" + if num_env == 1: - class DeviceLessPolicy(TensorDictModuleBase): - in_keys = ["observation"] - out_keys = ["action"] + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env - # receives data on gpu and outputs on cpu - def forward(self, tensordict): - assert tensordict.device is None - return tensordict.set("action", torch.zeros((), device="cpu")) + else: - class PolicyWithDevice(TensorDictModuleBase): - in_keys = ["observation"] - out_keys = ["action"] - # receives and sends data on gpu - default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + def env_fn(seed): + env = ParallelEnv( + num_workers=num_env, + create_env_fn=make_make_env(env_name), + create_env_kwargs=[ + {"seed": s} for s in generate_seeds(seed, num_env) + ], + ) + return env - def forward(self, tensordict): - assert tensordict.device == _make_ordinal_device( - torch.device(self.default_device) - ) - return tensordict.set("action", torch.zeros((), device=self.default_device)) + policy = make_policy(env_name) - @pytest.mark.parametrize("main_device", get_default_devices()) - @pytest.mark.parametrize("storing_device", [None, *get_default_devices()]) - def test_output_device(self, main_device, storing_device): + torch.manual_seed(0) + np.random.seed(0) + + # Get a single rollout with dummypolicy + env = env_fn(seed) + env = TransformedEnv(env, StepCounter(20)) + rollout1a = env.rollout(policy=policy, max_steps=50, auto_reset=True) + env.set_seed(seed) + rollout1b = env.rollout(policy=policy, max_steps=50, auto_reset=True) + rollout2 = env.rollout(policy=policy, max_steps=50, auto_reset=True) + try: + assert_allclose_td(rollout1a, rollout1b) + with pytest.raises(AssertionError): + assert_allclose_td(rollout1a, rollout2) + finally: + env.close() - # env has no device, policy is strictly on GPU - device = None - env_device = None - policy_device = main_device - env = self.DeviceLessEnv(main_device) - policy = self.PolicyWithDevice() collector = SyncDataCollector( - env, - policy, - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - frames_per_batch=1, - total_frames=10, + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20 * num_env, + max_frames_per_traj=20, + total_frames=200, + device="cpu", ) - for data in collector: # noqa: B007 - break + collector_iter = iter(collector) + b1 = next(collector_iter) + b2 = next(collector_iter) - assert data.device == storing_device + # if num_env == 1: + # # rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension + # rollout1a = rollout1a.unsqueeze(0) + try: + with pytest.raises(AssertionError): + assert_allclose_td(b1, b2) + assert ( + rollout1a.batch_size == b1.batch_size + ), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}" + assert_allclose_td(rollout1a, b1.select(*rollout1a.keys(True, True))) + finally: + collector.shutdown() - # env is on cuda, policy has no device - device = None - env_device = main_device - policy_device = None - env = self.EnvWithDevice(main_device) - policy = self.DeviceLessPolicy() - collector = SyncDataCollector( - env, - policy, - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - frames_per_batch=1, - total_frames=10, - ) - for data in collector: # noqa: B007 - break - assert data.device == storing_device + @pytest.mark.skipif(not _has_gym, reason="gym library is not installed") + @pytest.mark.parametrize("parallel", [False, True]) + @pytest.mark.parametrize( + "constr", + [ + functools.partial(split_trajectories, prefix="collector"), + functools.partial(split_trajectories), + functools.partial( + split_trajectories, trajectory_key=("collector", "traj_ids") + ), + ], + ) + def test_collector_env_reset(self, constr, parallel): + torch.manual_seed(0) - # env and policy are on device - device = main_device - env_device = None - policy_device = None - env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() - collector = SyncDataCollector( - env, - policy, - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - frames_per_batch=1, - total_frames=10, + def make_env(): + # This is currently necessary as the methods in GymWrapper may have mismatching backend + # versions. + with set_gym_backend(gym_backend()): + return TransformedEnv( + GymEnv(PONG_VERSIONED(), frame_skip=4), StepCounter() + ) + + if parallel: + env = ParallelEnv(2, make_env) + else: + env = SerialEnv(2, make_env) + try: + # env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4)) + env.set_seed(0) + collector = SyncDataCollector( + env, + policy=None, + total_frames=10001, + frames_per_batch=10000, + split_trajs=False, + ) + for _data in collector: + break + steps = _data["next", "step_count"][..., 1:, :] + done = _data["next", "done"][..., :-1, :] + # we don't want just one done + assert done.sum() > 3 + # check that after a done, the next step count is always 1 + assert (steps[done] == 1).all() + # check that if the env is not done, the next step count is > 1 + assert (steps[~done] > 1).all() + # check that if step is 1, then the env was done before + assert (steps == 1)[done].all() + # check that split traj has a minimum total reward of -21 (for pong only) + _data = constr(_data) + assert _data["next", "reward"].sum(-2).min() == -21 + finally: + env.close() + del env + + @pytest.mark.parametrize( + "break_when_any_done,break_when_all_done", + [[True, False], [False, True], [False, False]], + ) + @pytest.mark.parametrize("n_envs", [1, 4]) + def test_collector_outplace_policy( + self, n_envs, break_when_any_done, break_when_all_done + ): + def policy_inplace(td): + td.set("action", torch.ones(td.shape + (1,))) + return td + + def policy_outplace(td): + return td.empty().set("action", torch.ones(td.shape + (1,))) + + if n_envs == 1: + env = CountingEnv(10) + else: + env = SerialEnv( + n_envs, + [functools.partial(CountingEnv, 10 + i) for i in range(n_envs)], + ) + env.reset() + c_inplace = SyncDataCollector( + env, policy_inplace, frames_per_batch=10, total_frames=100 ) - for data in collector: # noqa: B007 - break - assert data.device == main_device + d_inplace = torch.cat(list(c_inplace), dim=0) + env.reset() + c_outplace = SyncDataCollector( + env, policy_outplace, frames_per_batch=10, total_frames=100 + ) + d_outplace = torch.cat(list(c_outplace), dim=0) + assert_allclose_td(d_inplace, d_outplace) - # same but more specific - device = None - env_device = main_device - policy_device = main_device - env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() - collector = SyncDataCollector( - env, - policy, - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - frames_per_batch=1, - total_frames=10, + @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") + @pytest.mark.parametrize( + "collector_class", + [ + SyncDataCollector, + MultiaSyncDataCollector, + functools.partial(MultiSyncDataCollector, cat_results="stack"), + ], + ) + @pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution + @pytest.mark.parametrize( + "explicit_spec,split_trajs", [[True, True], [False, False]] + ) # 1226: faster execution + def test_collector_output_keys( + self, collector_class, init_random_frames, explicit_spec, split_trajs + ): + from torchrl.envs.libs.gym import GymEnv + + out_features = 1 + hidden_size = 12 + total_frames = 200 + frames_per_batch = 20 + num_envs = 3 + + net = LSTMNet( + out_features, + {"input_size": hidden_size, "hidden_size": hidden_size}, + {"out_features": hidden_size}, + ) + + policy_kwargs = { + "module": net, + "in_keys": ["observation", "hidden1", "hidden2"], + "out_keys": [ + "action", + "hidden1", + "hidden2", + ("next", "hidden1"), + ("next", "hidden2"), + ], + } + if explicit_spec: + hidden_spec = Unbounded((1, hidden_size)) + policy_kwargs["spec"] = Composite( + action=Unbounded(), + hidden1=hidden_spec, + hidden2=hidden_spec, + next=Composite(hidden1=hidden_spec, hidden2=hidden_spec), + ) + + policy = SafeModule(**policy_kwargs) + + env_maker = lambda: GymEnv(PENDULUM_VERSIONED()) + + policy(env_maker().reset()) + + collector_kwargs = { + "create_env_fn": env_maker, + "policy": policy, + "total_frames": total_frames, + "frames_per_batch": frames_per_batch, + "init_random_frames": init_random_frames, + "split_trajs": split_trajs, + } + + if collector_class is not SyncDataCollector: + collector_kwargs["create_env_fn"] = [ + collector_kwargs["create_env_fn"] for _ in range(num_envs) + ] + + collector = collector_class(**collector_kwargs) + + keys = { + "action", + "done", + "collector", + "hidden1", + "hidden2", + ("next", "hidden1"), + ("next", "hidden2"), + ("next", "observation"), + ("next", "done"), + ("next", "reward"), + "next", + "observation", + ("collector", "traj_ids"), + } + if split_trajs: + keys.add(("collector", "mask")) + + keys.add(("next", "terminated")) + keys.add("terminated") + keys.add(("next", "truncated")) + keys.add("truncated") + b = next(iter(collector)) + + assert set(b.keys(True)) == keys + collector.shutdown() + del collector + + @pytest.mark.parametrize( + "collector_class", + [ + functools.partial(MultiSyncDataCollector, cat_results="stack"), + MultiaSyncDataCollector, + SyncDataCollector, + ], + ) + def test_collector_reloading(self, collector_class): + def make_env(): + return ContinuousActionVecMockEnv() + + dummy_env = make_env() + obs_spec = dummy_env.observation_spec["observation"] + policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) + policy = Actor(policy_module, spec=dummy_env.action_spec) + policy_explore = TensorDictSequential( + policy, OrnsteinUhlenbeckProcessModule(spec=policy.spec) ) - for data in collector: # noqa: B007 - break - assert data.device == main_device - # none has a device - device = None - env_device = None - policy_device = None - env = self.DeviceLessEnv(main_device) - policy = self.DeviceLessPolicy() + collector_kwargs = { + "create_env_fn": make_env, + "policy": policy_explore, + "frames_per_batch": 30, + "total_frames": 90, + } + if collector_class is not SyncDataCollector: + collector_kwargs["create_env_fn"] = [ + collector_kwargs["create_env_fn"] for _ in range(3) + ] + + collector = collector_class(**collector_kwargs) + for i, _ in enumerate(collector): + if i == 3: + break + collector_frames = collector._frames + collector_iter = collector._iter + collector_state_dict = collector.state_dict() + collector.shutdown() + + collector = collector_class(**collector_kwargs) + collector.load_state_dict(collector_state_dict) + assert collector._frames == collector_frames + assert collector._iter == collector_iter + for _ in enumerate(collector): + raise AssertionError + collector.shutdown() + del collector + + @pytest.mark.skipif( + sys.version_info >= (3, 11), + reason="Nested spawned multiprocessed is currently failing in python 3.11. " + "See https://github.com/python/cpython/pull/108568 for info and fix.", + ) + @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") + @pytest.mark.parametrize("static_seed", [True, False]) + def test_collector_vecnorm_envcreator(self, static_seed): + """ + High level test of the following pipeline: + (1) Design a function that creates an environment with VecNorm + (2) Wrap that function in an EnvCreator to instantiate the shared tensordict + (3) Create a ParallelEnv that dispatches this env across workers + (4) Run several ParallelEnv synchronously + The function tests that the tensordict gathered from the workers match at certain moments in time, and that they + are modified after the collector is run for more steps. + + """ + from torchrl.envs.libs.gym import GymEnv + + num_envs = 4 + env_make = EnvCreator( + lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm()) + ) + env_make = ParallelEnv(num_envs, env_make) + + policy = RandomPolicy(env_make.action_spec) + num_data_collectors = 2 + c = MultiSyncDataCollector( + [env_make] * num_data_collectors, + policy=policy, + total_frames=int(1e6), + frames_per_batch=200, + cat_results="stack", + ) + + init_seed = 0 + new_seed = c.set_seed(init_seed, static_seed=static_seed) + if static_seed: + assert new_seed == init_seed + else: + assert new_seed != init_seed + + seed = init_seed + for _ in range(num_envs * num_data_collectors): + seed = seed_generator(seed) + if not static_seed: + assert new_seed == seed + else: + assert new_seed != seed + + c_iter = iter(c) + next(c_iter) + next(c_iter) + + s = c.state_dict() + + td1 = ( + TensorDict(s["worker0"]["env_state_dict"]["worker3"]["_extra_state"]) + .unflatten_keys(VecNorm.SEP) + .clone() + ) + td2 = ( + TensorDict(s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]) + .unflatten_keys(VecNorm.SEP) + .clone() + ) + assert (td1 == td2).all() + + next(c_iter) + next(c_iter) + + s = c.state_dict() + + td3 = ( + TensorDict(s["worker0"]["env_state_dict"]["worker3"]["_extra_state"]) + .unflatten_keys(VecNorm.SEP) + .clone() + ) + td4 = ( + TensorDict(s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]) + .unflatten_keys(VecNorm.SEP) + .clone() + ) + assert (td3 == td4).all() + assert (td1 != td4).any() + c.shutdown() + del c + + @pytest.mark.parametrize("num_env", [1, 2]) + @pytest.mark.parametrize("env_name", ["conv", "vec"]) + def test_concurrent_collector_consistency(self, num_env, env_name, seed=40): + if num_env == 1: + + def env_fn(seed): + env = make_make_env(env_name)() + env.set_seed(seed) + return env + + else: + + def env_fn(seed): + env = ParallelEnv( + num_workers=num_env, + create_env_fn=make_make_env(env_name), + create_env_kwargs=[ + {"seed": i} for i in range(seed, seed + num_env) + ], + ) + return env + + policy = make_policy(env_name) + collector = SyncDataCollector( - env, - policy, - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - frames_per_batch=1, - total_frames=10, + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + device="cpu", ) - for data in collector: # noqa: B007 - break - assert data.device == storing_device + try: + assert collector._use_buffers + for i, d in enumerate(collector): + if i == 0: + b1 = d + elif i == 1: + b2 = d + else: + break + assert d.names[-1] == "time" + with pytest.raises(AssertionError): + assert_allclose_td(b1, b2) + finally: + collector.shutdown() - class CudaPolicy(TensorDictSequential): - def __init__(self, n_obs): - module = torch.nn.Linear(n_obs, n_obs, device="cuda") - module.weight.data.copy_(torch.eye(n_obs)) - module.bias.data.fill_(0) - m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"]) - m1 = TensorDictModule( - lambda a: a + 1, in_keys=["hidden"], out_keys=["action"] - ) - super().__init__(m0, m1) + ccollector = aSyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + ) + for i, d in enumerate(ccollector): + if i == 0: + b1c = d + elif i == 1: + b2c = d + else: + break - class GoesThroughEnv(EnvBase): - def __init__(self, n_obs, device): - self.observation_spec = Composite(observation=Unbounded(n_obs)) - self.action_spec = Unbounded(n_obs) - self.reward_spec = Unbounded(1) - self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool)) - super().__init__(device=device) + try: + assert ccollector._use_buffers + assert d.names[-1] == "time" - def _step( - self, - tensordict: TensorDictBase, - ) -> TensorDictBase: - a = tensordict["action"] - if self.device is not None: - assert a.device == self.device - out = tensordict.empty() - out["observation"] = tensordict["observation"] + ( - a - tensordict["observation"] - ) - out["reward"] = torch.zeros((1,), device=self.device) - out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool) - return out + with pytest.raises(AssertionError): + assert_allclose_td(b1c, b2c) - def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - return self.full_done_specs.zeros().update(self.observation_spec.zeros()) + assert_allclose_td(b1c, b1) + assert_allclose_td(b2c, b2) + finally: + ccollector.shutdown() + del ccollector - def _set_seed(self, seed: Optional[int]): - return seed + @pytest.mark.parametrize("num_env", [1, 2]) + @pytest.mark.parametrize("env_name", ["vec", "conv"]) + def test_concurrent_collector_seed(self, num_env, env_name, seed=100): + if num_env == 1: - @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") - @pytest.mark.parametrize("env_device", ["cuda:0", "cpu"]) - @pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"]) - @pytest.mark.parametrize("no_cuda_sync", [True, False]) - def test_no_synchronize(self, env_device, storing_device, no_cuda_sync): - """Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted.""" - should_raise = not no_cuda_sync - should_raise = should_raise & ( - (env_device == "cpu") or (storing_device == "cpu") + def env_fn(): + env = make_make_env(env_name)() + return env + + else: + + def env_fn(): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + return env + + policy = make_policy(env_name) + + torch.manual_seed(0) + np.random.seed(0) + ccollector = aSyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=20, + total_frames=300, ) - with patch("torch.cuda.synchronize") as mock_synchronize, pytest.raises( - AssertionError, match="Expected 'synchronize' to not have been called." - ) if should_raise else contextlib.nullcontext(): + try: + ccollector.set_seed(seed) + for i, data in enumerate(ccollector): + if i == 0: + b1 = data + ccollector.set_seed(seed) + elif i == 1: + b2 = data + elif i == 2: + b3 = data + else: + break + assert_allclose_td(b1, b2) + with pytest.raises(AssertionError): + assert_allclose_td(b1, b3) + finally: + ccollector.shutdown() + + @pytest.mark.parametrize( + "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] + ) + def test_env_that_errors(self, ctype): + make_env = EnvThatErrorsAfter10Iters + policy = RandomPolicy(make_env().action_spec) + if ctype is SyncDataCollector: collector = SyncDataCollector( - create_env_fn=functools.partial( - self.GoesThroughEnv, n_obs=1000, device=None - ), - policy=self.CudaPolicy(n_obs=1000), - frames_per_batch=100, - total_frames=1000, - env_device=env_device, - storing_device=storing_device, - policy_device="cuda:0", - no_cuda_sync=no_cuda_sync, + make_env, policy=policy, frames_per_batch=30, total_frames=60 ) - assert collector.env.device == torch.device(env_device) - i = 0 - for d in collector: - for _d in d.unbind(0): - u = _d["observation"].unique() - assert u.numel() == 1, i - assert u == i, i - i += 1 - u = _d["next", "observation"].unique() - assert u.numel() == 1, i - assert u == i, i - mock_synchronize.assert_not_called() + else: + collector = ctype( + [make_env, make_env], + policy=policy, + frames_per_batch=30, + total_frames=60, + ) + with pytest.raises(RuntimeError): + for _ in collector: + break + @retry(AssertionError, tries=10, delay=0) + @pytest.mark.skipif(IS_FB, reason="Not compatible with fbcode") + @pytest.mark.parametrize("to", [3, 10]) + @pytest.mark.parametrize( + "collector_cls", ["MultiSyncDataCollector", "MultiaSyncDataCollector"] + ) + def test_env_that_waits(self, to, collector_cls): + # Tests that the collector fails if the MAX_IDLE_COUNT TensorDict: + return self.full_observation_spec.zero().update(self.full_done_spec.zero()) + + def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict: + time.sleep(1) + return ( + self.full_observation_spec.zero() + .update(self.full_done_spec.zero()) + .update(self.full_reward_spec.zero()) + ) + + def _set_seed(self, seed: Optional[int]): + ... + +if __name__ == "__main__": + policy = RandomPolicy(EnvThatWaitsFor1Sec().action_spec) + c = {collector_cls}([EnvThatWaitsFor1Sec], policy=policy, total_frames=15, frames_per_batch=5) + for d in c: + break + c.shutdown() +""" + result = subprocess.run( + ["python", "-c", script], capture_output=True, text=True + ) + # This errors if the timeout is 5 secs, not 15 + assert result.returncode == int( + to == 3 + ), f"Test failed with output: {result.stdout}" + + @pytest.mark.parametrize( + "collector_class", + [ + functools.partial(MultiSyncDataCollector, cat_results="stack"), + MultiaSyncDataCollector, + SyncDataCollector, + ], + ) + @pytest.mark.parametrize("exclude", [True, False]) + @pytest.mark.parametrize( + "out_key", ["_dummy", ("out", "_dummy"), ("_out", "dummy")] + ) + def test_excluded_keys(self, collector_class, exclude, out_key): + if not exclude and collector_class is not SyncDataCollector: + pytest.skip("defining _exclude_private_keys is not possible") + + def make_env(): + return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) + + dummy_env = make_env() + obs_spec = dummy_env.observation_spec["observation"] + policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) + policy = TensorDictModule( + policy_module, in_keys=["observation"], out_keys=["action"] + ) + copier = TensorDictModule( + lambda x: x, in_keys=["observation"], out_keys=[out_key] + ) + policy_explore = TensorDictSequential( + policy, + copier, + OrnsteinUhlenbeckProcessModule( + spec=Composite({key: None for key in policy.out_keys}) + ), + ) + + collector_kwargs = { + "create_env_fn": make_env, + "policy": policy_explore, + "frames_per_batch": 30, + "total_frames": -1, + } + if collector_class is not SyncDataCollector: + collector_kwargs["create_env_fn"] = [ + collector_kwargs["create_env_fn"] for _ in range(3) + ] + + collector = collector_class(**collector_kwargs) + collector._exclude_private_keys = exclude + for b in collector: + keys = set(b.keys()) + if exclude: + assert not any(key.startswith("_") for key in keys) + assert out_key not in b.keys(True, True) + else: + assert any(key.startswith("_") for key in keys) + assert out_key in b.keys(True, True) + break + collector.shutdown() + dummy_env.close() + del collector + + @pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv]) + def test_initial_obs_consistency(self, env_class, seed=1): + # non regression test on #938 + torch.manual_seed(seed) + start_val = 4 + if env_class == CountingEnv: + num_envs = 1 + env = CountingEnv(device="cpu", max_steps=8, start_val=start_val) + max_steps = 8 + elif env_class == CountingBatchedEnv: + num_envs = 2 + env = CountingBatchedEnv( + device="cpu", + batch_size=[num_envs], + max_steps=torch.arange(num_envs) + 17, + start_val=torch.ones([num_envs]) * start_val, + ) + max_steps = env.max_steps.max().item() + env.set_seed(seed) + policy = lambda tensordict: tensordict.set( + "action", torch.ones(tensordict.shape, dtype=torch.int) + ) + collector = SyncDataCollector( + create_env_fn=env, + policy=policy, + frames_per_batch=((max_steps - 3) * 2 + 2) + * num_envs, # at least two episodes + split_trajs=False, + total_frames=-1, + ) + for _d in collector: + break + obs = _d["observation"].squeeze() + if env_class == CountingEnv: + arange_0 = start_val + torch.arange(max_steps - 3) + arange = start_val + torch.arange(2) + expected = torch.cat([arange_0, arange_0, arange]) + else: + # the first env has a shorter horizon than the second + arange_0 = start_val + torch.arange(max_steps - 3 - 1) + arange = start_val + torch.arange(start_val) + expected_0 = torch.cat([arange_0, arange_0, arange]) + arange_0 = start_val + torch.arange(max_steps - 3) + arange = start_val + torch.arange(2) + expected_1 = torch.cat([arange_0, arange_0, arange]) + expected = torch.stack([expected_0, expected_1]) + assert torch.allclose(obs, expected.to(obs.dtype)) + collector.shutdown() + del collector + def test_maxframes_error(self): + env = TransformedEnv(CountingEnv(), StepCounter(2)) + _ = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + total_frames=10_000, + frames_per_batch=1000, + ) + with pytest.raises(ValueError): + _ = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + total_frames=10_000, + frames_per_batch=1000, + max_frames_per_traj=2, + ) + + @pytest.mark.filterwarnings( + "error::UserWarning", "ignore:Tensordict is registered in PyTree:UserWarning" + ) + @pytest.mark.parametrize( + "collector_type", + [ + SyncDataCollector, + MultiaSyncDataCollector, + functools.partial(MultiSyncDataCollector, cat_results="stack"), + ], + ) + def test_no_deepcopy_policy(self, collector_type): + # Tests that the collector instantiation does not make a deepcopy of the policy if not necessary. + # + # The only situation where we want to deepcopy the policy is when the policy_device differs from the actual device + # of the policy. This can only be checked if the policy is an nn.Module and any of the params is not on the desired + # device. + # + # If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we + # can trust that the user knows what to do). + + # warnings.warn("Tensordict is registered in PyTree", category=UserWarning) + + shared_device = torch.device("cpu") + if torch.cuda.is_available(): + original_device = torch.device("cuda:0") + elif torch.mps.is_available(): + original_device = torch.device("mps") + else: + pytest.skip("No GPU or MPS device") + + def make_policy(device=None, nn_module=True): + if nn_module: + return TensorDictModule( + nn.Linear(7, 7, device=device), + in_keys=["observation"], + out_keys=["action"], + ) + policy = make_policy(device=device) + return CloudpickleWrapper(policy) -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize("env_name", ["conv", "vec"]) -def test_concurrent_collector_consistency(num_env, env_name, seed=40): - if num_env == 1: + def make_and_test_policy( + policy, + policy_device=None, + env_device=None, + device=None, + trust_policy=None, + ): + # make sure policy errors when copied + + policy.__deepcopy__ = __deepcopy_error__ + envs = ContinuousActionVecMockEnv(device=env_device) + if collector_type is not SyncDataCollector: + envs = [envs, envs] + c = collector_type( + envs, + policy=policy, + total_frames=1000, + frames_per_batch=10, + policy_device=policy_device, + env_device=env_device, + device=device, + trust_policy=trust_policy, + ) + for _ in c: + return - def env_fn(seed): - env = make_make_env(env_name)() - env.set_seed(seed) - return env + # Simplest use cases + policy = make_policy() + make_and_test_policy(policy) - else: + if collector_type is SyncDataCollector or original_device.type != "mps": + # mps cannot be shared + policy = make_policy(device=original_device) + make_and_test_policy(policy, env_device=original_device) - def env_fn(seed): - env = ParallelEnv( - num_workers=num_env, - create_env_fn=make_make_env(env_name), - create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], + if collector_type is SyncDataCollector or original_device.type != "mps": + policy = make_policy(device=original_device) + make_and_test_policy( + policy, policy_device=original_device, env_device=original_device ) - return env - policy = make_policy(env_name) + # a deepcopy must occur when the policy_device differs from the actual device + with pytest.raises(RuntimeError, match="deepcopy not allowed"): + policy = make_policy(device=original_device) + make_and_test_policy( + policy, policy_device=shared_device, env_device=shared_device + ) - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device="cpu", - ) - try: - assert collector._use_buffers - for i, d in enumerate(collector): - if i == 0: - b1 = d - elif i == 1: - b2 = d - else: - break - assert d.names[-1] == "time" - with pytest.raises(AssertionError): - assert_allclose_td(b1, b2) - finally: - collector.shutdown() + # a deepcopy must occur when device differs from the actual device + with pytest.raises(RuntimeError, match="deepcopy not allowed"): + policy = make_policy(device=original_device) + make_and_test_policy(policy, device=shared_device) - ccollector = aSyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - ) - for i, d in enumerate(ccollector): - if i == 0: - b1c = d - elif i == 1: - b2c = d - else: - break + # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device + # is there to inform us + substitute_device = ( + original_device if torch.cuda.is_available() else torch.device("cpu") + ) + policy = make_policy(substitute_device, nn_module=False) + with pytest.warns(UserWarning): + make_and_test_policy( + policy, policy_device=substitute_device, env_device=substitute_device + ) + # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device + with pytest.warns(UserWarning): + make_and_test_policy( + policy, policy_device=substitute_device, env_device=shared_device + ) + make_and_test_policy( + policy, + policy_device=substitute_device, + env_device=shared_device, + trust_policy=True, + ) - try: - assert ccollector._use_buffers - assert d.names[-1] == "time" + # If there is no policy_device, we assume that the user is doing things right too but don't warn + if collector_type is SyncDataCollector or original_device.type != "mps": + policy = make_policy(original_device, nn_module=False) + make_and_test_policy(policy, env_device=original_device) - with pytest.raises(AssertionError): - assert_allclose_td(b1c, b2c) + # If the policy is a CudaGraphModule, we know it's on cuda - no need to warn + if torch.cuda.is_available() and collector_type is SyncDataCollector: + policy = make_policy(original_device) + cudagraph_policy = CudaGraphModule(policy) + make_and_test_policy( + cudagraph_policy, + policy_device=original_device, + env_device=shared_device, + ) - assert_allclose_td(b1c, b1) - assert_allclose_td(b2c, b2) - finally: - ccollector.shutdown() - del ccollector + @pytest.mark.parametrize( + "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] + ) + def test_no_stopiteration(self, ctype): + # Tests that there is no StopIteration raised and that the length of the collector is properly set + if ctype is SyncDataCollector: + envs = SerialEnv(16, CountingEnv) + else: + envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] + collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) + try: + c_iter = iter(collector) + assert len(collector) == 2 + for i in range(len(collector)): # noqa: B007 + c = next(c_iter) + assert c is not None + assert i == 1 + finally: + collector.shutdown() + del collector -@pytest.mark.skipif(not _has_gym, reason="gym library is not installed") -@pytest.mark.parametrize("parallel", [False, True]) -@pytest.mark.parametrize( - "constr", - [ - functools.partial(split_trajectories, prefix="collector"), - functools.partial(split_trajectories), - functools.partial(split_trajectories, trajectory_key=("collector", "traj_ids")), - ], -) -def test_collector_env_reset(constr, parallel): - torch.manual_seed(0) + def test_policy_with_mask(self): + env = CountingBatchedEnv( + start_val=torch.tensor(10), max_steps=torch.tensor(1e5) + ) - def make_env(): - # This is currently necessary as the methods in GymWrapper may have mismatching backend - # versions. - with set_gym_backend(gym_backend()): - return TransformedEnv(GymEnv(PONG_VERSIONED(), frame_skip=4), StepCounter()) + def policy(td): + obs = td.get("observation") + # This policy cannot work with obs all 0s + if not obs.any(): + raise AssertionError + action = obs.clone() + td.set("action", action) + return td - if parallel: - env = ParallelEnv(2, make_env) - else: - env = SerialEnv(2, make_env) - try: - # env = SerialEnv(2, lambda: GymEnv("CartPole-v1", frame_skip=4)) - env.set_seed(0) collector = SyncDataCollector( - env, - policy=None, - total_frames=10001, - frames_per_batch=10000, - split_trajs=False, + env, policy=policy, frames_per_batch=10, total_frames=20 ) - for _data in collector: + for _ in collector: break - steps = _data["next", "step_count"][..., 1:, :] - done = _data["next", "done"][..., :-1, :] - # we don't want just one done - assert done.sum() > 3 - # check that after a done, the next step count is always 1 - assert (steps[done] == 1).all() - # check that if the env is not done, the next step count is > 1 - assert (steps[~done] > 1).all() - # check that if step is 1, then the env was done before - assert (steps == 1)[done].all() - # check that split traj has a minimum total reward of -21 (for pong only) - _data = constr(_data) - assert _data["next", "reward"].sum(-2).min() == -21 - finally: - env.close() - del env + collector.shutdown() + @retry(AssertionError, tries=10, delay=0) + @pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) + @pytest.mark.parametrize("env_device", [None, *get_available_devices()]) + @pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) + @pytest.mark.parametrize("parallel", [False, True]) + @pytest.mark.parametrize("share_individual_td", [False, True]) + def test_reset_heterogeneous_envs( + self, + policy_device: torch.device, + env_device: torch.device, + storing_device: torch.device, + parallel, + share_individual_td, + ): + if ( + policy_device is not None + and policy_device.type == "cuda" + and env_device is None + ): + env_device = torch.device("cpu") # explicit mapping + elif ( + env_device is not None + and env_device.type == "cuda" + and policy_device is None + ): + policy_device = torch.device("cpu") + env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2)) + env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3)) + if parallel: + cls = ParallelEnv + else: + cls = SerialEnv + env = cls( + 2, [env1, env2], device=env_device, share_individual_td=share_individual_td + ) + collector = SyncDataCollector( + env, + RandomPolicy(env.action_spec), + total_frames=10_000, + frames_per_batch=100, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + ) + try: + for data in collector: # noqa: B007 + break + data_device = storing_device if storing_device is not None else env_device + assert ( + data[0]["next", "truncated"].squeeze() + == torch.tensor([False, True], device=data_device).repeat(25)[:50] + ).all(), data[0]["next", "truncated"] + assert ( + data[1]["next", "truncated"].squeeze() + == torch.tensor([False, False, True], device=data_device).repeat(17)[ + :50 + ] + ).all(), data[1]["next", "truncated"][:10] + finally: + collector.shutdown() + del collector -@pytest.mark.parametrize( - "break_when_any_done,break_when_all_done", - [[True, False], [False, True], [False, False]], -) -@pytest.mark.parametrize("n_envs", [1, 4]) -def test_collector_outplace_policy(n_envs, break_when_any_done, break_when_all_done): - def policy_inplace(td): - td.set("action", torch.ones(td.shape + (1,))) - return td + @pytest.mark.parametrize( + "collector_cls", + [SyncDataCollector, MultiSyncDataCollector, MultiaSyncDataCollector], + ) + def test_set_truncated(self, collector_cls): + env_fn = lambda: TransformedEnv( + NestedCountingEnv(), InitTracker() + ).add_truncated_keys() + env = env_fn() + policy = CloudpickleWrapper(env.rand_action) + if collector_cls == SyncDataCollector: + collector = collector_cls( + env, + policy=policy, + frames_per_batch=20, + total_frames=-1, + set_truncated=True, + trust_policy=True, + ) + else: + collector = collector_cls( + [env_fn, env_fn], + policy=policy, + frames_per_batch=20, + total_frames=-1, + cat_results="stack", + set_truncated=True, + trust_policy=True, + ) + try: + for data in collector: + assert data[..., -1]["next", "data", "truncated"].all() + break + finally: + collector.shutdown() + del collector - def policy_outplace(td): - return td.empty().set("action", torch.ones(td.shape + (1,))) + @pytest.mark.parametrize("frames_per_batch", [200, 10]) + @pytest.mark.parametrize("num_env", [1, 2]) + @pytest.mark.parametrize("env_name", ["vec"]) + def test_split_trajs(self, num_env, env_name, frames_per_batch, seed=5): + if num_env == 1: - if n_envs == 1: - env = CountingEnv(10) - else: - env = SerialEnv( - n_envs, - [functools.partial(CountingEnv, 10 + i) for i in range(n_envs)], - ) - env.reset() - c_inplace = SyncDataCollector( - env, policy_inplace, frames_per_batch=10, total_frames=100 - ) - d_inplace = torch.cat(list(c_inplace), dim=0) - env.reset() - c_outplace = SyncDataCollector( - env, policy_outplace, frames_per_batch=10, total_frames=100 - ) - d_outplace = torch.cat(list(c_outplace), dim=0) - assert_allclose_td(d_inplace, d_outplace) + def env_fn(seed): + env = MockSerialEnv(device="cpu") + env.set_seed(seed) + return env + else: -# Deprecated reset_when_done -# @pytest.mark.parametrize("num_env", [1, 2]) -# @pytest.mark.parametrize("env_name", ["vec"]) -# def test_collector_done_persist(num_env, env_name, seed=5): -# if num_env == 1: -# -# def env_fn(seed): -# env = MockSerialEnv(device="cpu") -# env.set_seed(seed) -# return env -# -# else: -# -# def env_fn(seed): -# def make_env(seed): -# env = MockSerialEnv(device="cpu") -# env.set_seed(seed) -# return env -# -# env = ParallelEnv( -# num_workers=num_env, -# create_env_fn=make_env, -# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], -# ) -# env.set_seed(seed) -# return env -# -# policy = make_policy(env_name) -# -# collector = SyncDataCollector( -# create_env_fn=env_fn, -# create_env_kwargs={"seed": seed}, -# policy=policy, -# frames_per_batch=200 * num_env, -# max_frames_per_traj=2000, -# total_frames=20000, -# device="cpu", -# reset_when_done=False, -# ) -# for _, d in enumerate(collector): # noqa -# break -# -# assert (d["done"].sum(-2) >= 1).all() -# assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 -# -# del collector + def env_fn(seed): + def make_env(seed): + env = MockSerialEnv(device="cpu") + env.set_seed(seed) + return env + + env = SerialEnv( + num_workers=num_env, + create_env_fn=make_env, + create_env_kwargs=[ + {"seed": i} for i in range(seed, seed + num_env) + ], + ) + env.set_seed(seed) + return env + policy = make_policy(env_name) -@pytest.mark.parametrize("frames_per_batch", [200, 10]) -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize("env_name", ["vec"]) -def test_split_trajs(num_env, env_name, frames_per_batch, seed=5): - if num_env == 1: + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=frames_per_batch * num_env, + max_frames_per_traj=2000, + total_frames=20000, + device="cpu", + reset_when_done=True, + split_trajs=True, + ) + for _, d in enumerate(collector): # noqa + break - def env_fn(seed): - env = MockSerialEnv(device="cpu") - env.set_seed(seed) - return env + assert d.ndimension() == 2 + assert d["collector", "mask"].shape == d.shape + assert d["next", "step_count"].shape == d["next", "done"].shape + assert d["collector", "traj_ids"].shape == d.shape + for traj in d.unbind(0): + assert traj["collector", "traj_ids"].unique().numel() == 1 + assert ( + traj["next", "step_count"][1:] - traj["next", "step_count"][:-1] == 1 + ).all() - else: + del collector - def env_fn(seed): - def make_env(seed): - env = MockSerialEnv(device="cpu") + @pytest.mark.parametrize("num_env", [1, 2]) + @pytest.mark.parametrize( + "collector_class", + [ + SyncDataCollector, + ], + ) # aSyncDataCollector]) + @pytest.mark.parametrize( + "env_name", ["vec"] + ) # 1226: removing "conv" for efficiency + def test_traj_len_consistency(self, num_env, env_name, collector_class, seed=100): + """Tests that various frames_per_batch lead to the same results.""" + + if num_env == 1: + + def env_fn(seed): + env = make_make_env(env_name)() env.set_seed(seed) return env - env = SerialEnv( - num_workers=num_env, - create_env_fn=make_env, - create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], - ) - env.set_seed(seed) - return env + else: - policy = make_policy(env_name) - - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=frames_per_batch * num_env, - max_frames_per_traj=2000, - total_frames=20000, - device="cpu", - reset_when_done=True, - split_trajs=True, - ) - for _, d in enumerate(collector): # noqa - break + def env_fn(seed): + env = ParallelEnv( + num_workers=num_env, create_env_fn=make_make_env(env_name) + ) + env.set_seed(seed) + return env - assert d.ndimension() == 2 - assert d["collector", "mask"].shape == d.shape - assert d["next", "step_count"].shape == d["next", "done"].shape - assert d["collector", "traj_ids"].shape == d.shape - for traj in d.unbind(0): - assert traj["collector", "traj_ids"].unique().numel() == 1 - assert ( - traj["next", "step_count"][1:] - traj["next", "step_count"][:-1] == 1 - ).all() + max_frames_per_traj = 20 - del collector + policy = make_policy(env_name) + collector1 = collector_class( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=1 * num_env, + max_frames_per_traj=2000, + total_frames=2 * num_env * max_frames_per_traj, + device="cpu", + ) + collector1.set_seed(seed) + count = 0 + data1 = [] + for d in collector1: + data1.append(d) + count += d.shape[-1] + if count > max_frames_per_traj: + break -# TODO: design a test that ensures that collectors are interrupted even if __del__ is not called -# @pytest.mark.parametrize("should_shutdown", [True, False]) -# def test_shutdown_collector(should_shutdown, num_env=3, env_name="vec", seed=40): -# def env_fn(seed): -# env = ParallelEnv( -# num_workers=num_env, -# create_env_fn=make_make_env(env_name), -# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], -# ) -# return env -# -# policy = make_policy(env_name) -# -# ccollector = aSyncDataCollector( -# create_env_fn=env_fn, -# create_env_kwargs={"seed": seed}, -# policy=policy, -# frames_per_batch=20, -# max_frames_per_traj=2000, -# total_frames=20000, -# ) -# for i, d in enumerate(ccollector): -# if i == 0: -# b1c = d -# elif i == 1: -# b2c = d -# else: -# break -# with pytest.raises(AssertionError): -# assert_allclose_td(b1c, b2c) -# -# if should_shutdown: -# ccollector.shutdown() + data1 = torch.cat(data1, d.ndim - 1) + data1 = data1[..., :max_frames_per_traj] + collector1.shutdown() + del collector1 -@pytest.mark.parametrize("num_env", [1, 2]) -# 1226: for efficiency, we just test vec, not "conv" -@pytest.mark.parametrize("env_name", ["vec"]) -def test_collector_batch_size( - num_env, env_name, seed=100, num_workers=2, frames_per_batch=20 -): - """Tests that there are 'frames_per_batch' frames in each batch of a collection.""" - if num_env == 3 and IS_WINDOWS: - pytest.skip("Test timeout (> 10 min) on CI pipeline Windows machine with GPU") - if num_env == 1: + collector10 = collector_class( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=10 * num_env, + max_frames_per_traj=2000, + total_frames=2 * num_env * max_frames_per_traj, + device="cpu", + ) + collector10.set_seed(seed) + count = 0 + data10 = [] + for d in collector10: + data10.append(d) + count += d.shape[-1] + if count > max_frames_per_traj: + break - def env_fn(): - env = make_make_env(env_name)() - return env + data10 = torch.cat(data10, data1.ndim - 1) + data10 = data10[..., :max_frames_per_traj] - else: + collector10.shutdown() + del collector10 - def env_fn(): - # 1226: For efficiency, we don't use Parallel but Serial - # env = ParallelEnv( - env = SerialEnv(num_workers=num_env, create_env_fn=make_make_env(env_name)) - return env + collector20 = collector_class( + create_env_fn=env_fn, + create_env_kwargs={"seed": seed}, + policy=policy, + frames_per_batch=20 * num_env, + max_frames_per_traj=2000, + total_frames=2 * num_env * max_frames_per_traj, + device="cpu", + ) + collector20.set_seed(seed) + count = 0 + data20 = [] + for d in collector20: + data20.append(d) + count += d.shape[-1] + if count > max_frames_per_traj: + break - policy = make_policy(env_name) + collector20.shutdown() + del collector20 - torch.manual_seed(0) - np.random.seed(0) + data20 = torch.cat(data20, data1.ndim - 1) + data20 = data20[..., :max_frames_per_traj] - ccollector = MultiaSyncDataCollector( - create_env_fn=[env_fn for _ in range(num_workers)], - policy=policy, - frames_per_batch=frames_per_batch, - max_frames_per_traj=1000, - total_frames=frames_per_batch * 100, - ) - try: - ccollector.set_seed(seed) - for i, b in enumerate(ccollector): - assert b.numel() == -(-frames_per_batch // num_env) * num_env - if i == 5: - break - assert b.names[-1] == "time" - finally: - ccollector.shutdown() + assert_allclose_td(data1, data20) + assert_allclose_td(data10, data20) - ccollector = MultiSyncDataCollector( - create_env_fn=[env_fn for _ in range(num_workers)], - policy=policy, - frames_per_batch=frames_per_batch, - max_frames_per_traj=1000, - total_frames=frames_per_batch * 100, - cat_results="stack", - ) - try: - ccollector.set_seed(seed) - for i, b in enumerate(ccollector): - assert ( - b.numel() - == -(-frames_per_batch // num_env // num_workers) - * num_env - * num_workers - ) - if i == 5: - break - assert b.names[-1] == "time" - finally: - ccollector.shutdown() - del ccollector + @pytest.mark.parametrize("use_async", [False, True]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") + def test_update_weights(self, use_async): + def create_env(): + return ContinuousActionVecMockEnv() + n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] + policy = SafeModule( + torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] + ) + policy(create_env().reset()) -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize("env_name", ["vec", "conv"]) -def test_concurrent_collector_seed(num_env, env_name, seed=100): - if num_env == 1: + collector_class = ( + MultiSyncDataCollector if not use_async else MultiaSyncDataCollector + ) + collector = collector_class( + [create_env] * 3, + policy=policy, + device=[torch.device("cuda:0")] * 3, + storing_device=[torch.device("cuda:0")] * 3, + frames_per_batch=20, + cat_results="stack", + ) + # collect state_dict + state_dict = collector.state_dict() + policy_state_dict = policy.state_dict() + for worker in range(3): + for k in state_dict[f"worker{worker}"]["policy_state_dict"]: + torch.testing.assert_close( + state_dict[f"worker{worker}"]["policy_state_dict"][k], + policy_state_dict[k].cpu(), + ) - def env_fn(): - env = make_make_env(env_name)() - return env + # change policy weights + for p in policy.parameters(): + p.data += torch.randn_like(p) + + # collect state_dict + state_dict = collector.state_dict() + policy_state_dict = policy.state_dict() + # check they don't match + for worker in range(3): + for k in state_dict[f"worker{worker}"]["policy_state_dict"]: + with pytest.raises(AssertionError): + torch.testing.assert_close( + state_dict[f"worker{worker}"]["policy_state_dict"][k], + policy_state_dict[k].cpu(), + ) - else: + # update weights + collector.update_policy_weights_() - def env_fn(): - env = ParallelEnv( - num_workers=num_env, create_env_fn=make_make_env(env_name) - ) - return env + # collect state_dict + state_dict = collector.state_dict() + policy_state_dict = policy.state_dict() + for worker in range(3): + for k in state_dict[f"worker{worker}"]["policy_state_dict"]: + torch.testing.assert_close( + state_dict[f"worker{worker}"]["policy_state_dict"][k], + policy_state_dict[k].cpu(), + ) - policy = make_policy(env_name) + collector.shutdown() + del collector - torch.manual_seed(0) - np.random.seed(0) - ccollector = aSyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=20, - total_frames=300, - ) - try: - ccollector.set_seed(seed) - for i, data in enumerate(ccollector): - if i == 0: - b1 = data - ccollector.set_seed(seed) - elif i == 1: - b2 = data - elif i == 2: - b3 = data - else: - break - assert_allclose_td(b1, b2) - with pytest.raises(AssertionError): - assert_allclose_td(b1, b3) - finally: - ccollector.shutdown() +class TestCollectorDevices: + class DeviceLessEnv(EnvBase): + # receives data on cpu, outputs on gpu -- tensordict has no device + def __init__(self, default_device): + self.default_device = default_device + super().__init__(device=None) + self.observation_spec = Composite( + observation=Unbounded((), device=default_device) + ) + self.reward_spec = Unbounded(1, device=default_device) + self.full_done_spec = Composite( + done=Unbounded(1, dtype=torch.bool, device=self.default_device), + truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), + terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), + ) + self.action_spec = Unbounded((), device=None) + assert self.device is None + assert self.full_observation_spec is not None + assert self.full_done_spec is not None + assert self.full_state_spec is not None + assert self.full_action_spec is not None + assert self.full_reward_spec is not None -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize("env_name", ["conv", "vec"]) -def test_collector_consistency(num_env, env_name, seed=100): - """Tests that a rollout gathered with env.rollout matches one gathered with the collector.""" - if num_env == 1: + def _step(self, tensordict): + assert tensordict.device is None + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "reward": torch.zeros((1,)), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=None, + ) - def env_fn(seed): - env = make_make_env(env_name)() - env.set_seed(seed) - return env + def _reset(self, tensordict=None): + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=None, + ) - else: + def _set_seed(self, seed: int | None = None): + return seed - def env_fn(seed): - env = ParallelEnv( - num_workers=num_env, - create_env_fn=make_make_env(env_name), - create_env_kwargs=[{"seed": s} for s in generate_seeds(seed, num_env)], + class EnvWithDevice(EnvBase): + def __init__(self, default_device): + self.default_device = default_device + super().__init__(device=self.default_device) + self.observation_spec = Composite( + observation=Unbounded((), device=self.default_device) ) - return env + self.reward_spec = Unbounded(1, device=self.default_device) + self.full_done_spec = Composite( + done=Unbounded(1, dtype=torch.bool, device=self.default_device), + truncated=Unbounded(1, dtype=torch.bool, device=self.default_device), + terminated=Unbounded(1, dtype=torch.bool, device=self.default_device), + device=self.default_device, + ) + self.action_spec = Unbounded((), device=self.default_device) + assert self.device == _make_ordinal_device( + torch.device(self.default_device) + ) + assert self.full_observation_spec is not None + assert self.full_done_spec is not None + assert self.full_state_spec is not None + assert self.full_action_spec is not None + assert self.full_reward_spec is not None - policy = make_policy(env_name) + def _step(self, tensordict): + assert tensordict.device == _make_ordinal_device( + torch.device(self.default_device) + ) + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "reward": torch.zeros((1,)), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=self.default_device, + ) - torch.manual_seed(0) - np.random.seed(0) + def _reset(self, tensordict=None): + with torch.device(self.default_device): + return TensorDict( + { + "observation": torch.zeros(()), + "done": torch.zeros((1,), dtype=torch.bool), + "terminated": torch.zeros((1,), dtype=torch.bool), + "truncated": torch.zeros((1,), dtype=torch.bool), + }, + batch_size=[], + device=self.default_device, + ) - # Get a single rollout with dummypolicy - env = env_fn(seed) - env = TransformedEnv(env, StepCounter(20)) - rollout1a = env.rollout(policy=policy, max_steps=50, auto_reset=True) - env.set_seed(seed) - rollout1b = env.rollout(policy=policy, max_steps=50, auto_reset=True) - rollout2 = env.rollout(policy=policy, max_steps=50, auto_reset=True) - try: - assert_allclose_td(rollout1a, rollout1b) - with pytest.raises(AssertionError): - assert_allclose_td(rollout1a, rollout2) - finally: - env.close() - - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20 * num_env, - max_frames_per_traj=20, - total_frames=200, - device="cpu", - ) - collector_iter = iter(collector) - b1 = next(collector_iter) - b2 = next(collector_iter) - - # if num_env == 1: - # # rollouts collected through DataCollector are padded using pad_sequence, which introduces a first dimension - # rollout1a = rollout1a.unsqueeze(0) - try: - with pytest.raises(AssertionError): - assert_allclose_td(b1, b2) - assert ( - rollout1a.batch_size == b1.batch_size - ), f"got batch_size {rollout1a.batch_size} and {b1.batch_size}" - assert_allclose_td(rollout1a, b1.select(*rollout1a.keys(True, True))) - finally: - collector.shutdown() + def _set_seed(self, seed: int | None = None): + return seed + class DeviceLessPolicy(TensorDictModuleBase): + in_keys = ["observation"] + out_keys = ["action"] -@pytest.mark.parametrize("num_env", [1, 2]) -@pytest.mark.parametrize( - "collector_class", - [ - SyncDataCollector, - ], -) # aSyncDataCollector]) -@pytest.mark.parametrize("env_name", ["vec"]) # 1226: removing "conv" for efficiency -def test_traj_len_consistency(num_env, env_name, collector_class, seed=100): - """Tests that various frames_per_batch lead to the same results.""" + # receives data on gpu and outputs on cpu + def forward(self, tensordict): + assert tensordict.device is None + return tensordict.set("action", torch.zeros((), device="cpu")) - if num_env == 1: + class PolicyWithDevice(TensorDictModuleBase): + in_keys = ["observation"] + out_keys = ["action"] + # receives and sends data on gpu + default_device = "cuda:0" if torch.cuda.device_count() else "cpu" - def env_fn(seed): - env = make_make_env(env_name)() - env.set_seed(seed) - return env + def forward(self, tensordict): + assert tensordict.device == _make_ordinal_device( + torch.device(self.default_device) + ) + return tensordict.set("action", torch.zeros((), device=self.default_device)) - else: + @pytest.mark.parametrize("main_device", get_default_devices()) + @pytest.mark.parametrize("storing_device", [None, *get_default_devices()]) + def test_output_device(self, main_device, storing_device): - def env_fn(seed): - env = ParallelEnv( - num_workers=num_env, create_env_fn=make_make_env(env_name) - ) - env.set_seed(seed) - return env + # env has no device, policy is strictly on GPU + device = None + env_device = None + policy_device = main_device + env = self.DeviceLessEnv(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break - max_frames_per_traj = 20 + assert data.device == storing_device - policy = make_policy(env_name) + # env is on cuda, policy has no device + device = None + env_device = main_device + policy_device = None + env = self.EnvWithDevice(main_device) + policy = self.DeviceLessPolicy() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 + break + assert data.device == storing_device - collector1 = collector_class( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=1 * num_env, - max_frames_per_traj=2000, - total_frames=2 * num_env * max_frames_per_traj, - device="cpu", - ) - collector1.set_seed(seed) - count = 0 - data1 = [] - for d in collector1: - data1.append(d) - count += d.shape[-1] - if count > max_frames_per_traj: + # env and policy are on device + device = main_device + env_device = None + policy_device = None + env = self.EnvWithDevice(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 break + assert data.device == main_device - data1 = torch.cat(data1, d.ndim - 1) - data1 = data1[..., :max_frames_per_traj] - - collector1.shutdown() - del collector1 - - collector10 = collector_class( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=10 * num_env, - max_frames_per_traj=2000, - total_frames=2 * num_env * max_frames_per_traj, - device="cpu", - ) - collector10.set_seed(seed) - count = 0 - data10 = [] - for d in collector10: - data10.append(d) - count += d.shape[-1] - if count > max_frames_per_traj: + # same but more specific + device = None + env_device = main_device + policy_device = main_device + env = self.EnvWithDevice(main_device) + policy = self.PolicyWithDevice() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 break + assert data.device == main_device - data10 = torch.cat(data10, data1.ndim - 1) - data10 = data10[..., :max_frames_per_traj] - - collector10.shutdown() - del collector10 - - collector20 = collector_class( - create_env_fn=env_fn, - create_env_kwargs={"seed": seed}, - policy=policy, - frames_per_batch=20 * num_env, - max_frames_per_traj=2000, - total_frames=2 * num_env * max_frames_per_traj, - device="cpu", - ) - collector20.set_seed(seed) - count = 0 - data20 = [] - for d in collector20: - data20.append(d) - count += d.shape[-1] - if count > max_frames_per_traj: + # none has a device + device = None + env_device = None + policy_device = None + env = self.DeviceLessEnv(main_device) + policy = self.DeviceLessPolicy() + collector = SyncDataCollector( + env, + policy, + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + frames_per_batch=1, + total_frames=10, + ) + for data in collector: # noqa: B007 break + assert data.device == storing_device - collector20.shutdown() - del collector20 - - data20 = torch.cat(data20, data1.ndim - 1) - data20 = data20[..., :max_frames_per_traj] - - assert_allclose_td(data1, data20) - assert_allclose_td(data10, data20) - - -@pytest.mark.skipif( - sys.version_info >= (3, 11), - reason="Nested spawned multiprocessed is currently failing in python 3.11. " - "See https://github.com/python/cpython/pull/108568 for info and fix.", -) -@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") -@pytest.mark.parametrize("static_seed", [True, False]) -def test_collector_vecnorm_envcreator(static_seed): - """ - High level test of the following pipeline: - (1) Design a function that creates an environment with VecNorm - (2) Wrap that function in an EnvCreator to instantiate the shared tensordict - (3) Create a ParallelEnv that dispatches this env across workers - (4) Run several ParallelEnv synchronously - The function tests that the tensordict gathered from the workers match at certain moments in time, and that they - are modified after the collector is run for more steps. - - """ - from torchrl.envs.libs.gym import GymEnv - - num_envs = 4 - env_make = EnvCreator( - lambda: TransformedEnv(GymEnv(PENDULUM_VERSIONED()), VecNorm()) - ) - env_make = ParallelEnv(num_envs, env_make) - - policy = RandomPolicy(env_make.action_spec) - num_data_collectors = 2 - c = MultiSyncDataCollector( - [env_make] * num_data_collectors, - policy=policy, - total_frames=int(1e6), - frames_per_batch=200, - cat_results="stack", - ) - - init_seed = 0 - new_seed = c.set_seed(init_seed, static_seed=static_seed) - if static_seed: - assert new_seed == init_seed - else: - assert new_seed != init_seed - - seed = init_seed - for _ in range(num_envs * num_data_collectors): - seed = seed_generator(seed) - if not static_seed: - assert new_seed == seed - else: - assert new_seed != seed - - c_iter = iter(c) - next(c_iter) - next(c_iter) - - s = c.state_dict() - - td1 = ( - TensorDict(s["worker0"]["env_state_dict"]["worker3"]["_extra_state"]) - .unflatten_keys(VecNorm.SEP) - .clone() - ) - td2 = ( - TensorDict(s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]) - .unflatten_keys(VecNorm.SEP) - .clone() - ) - assert (td1 == td2).all() - - next(c_iter) - next(c_iter) - - s = c.state_dict() - - td3 = ( - TensorDict(s["worker0"]["env_state_dict"]["worker3"]["_extra_state"]) - .unflatten_keys(VecNorm.SEP) - .clone() - ) - td4 = ( - TensorDict(s["worker1"]["env_state_dict"]["worker0"]["_extra_state"]) - .unflatten_keys(VecNorm.SEP) - .clone() - ) - assert (td3 == td4).all() - assert (td1 != td4).any() - c.shutdown() - del c - - -@pytest.mark.parametrize("use_async", [False, True]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") -def test_update_weights(use_async): - def create_env(): - return ContinuousActionVecMockEnv() - - n_actions = ContinuousActionVecMockEnv().action_spec.shape[-1] - policy = SafeModule( - torch.nn.LazyLinear(n_actions), in_keys=["observation"], out_keys=["action"] - ) - policy(create_env().reset()) - - collector_class = ( - MultiSyncDataCollector if not use_async else MultiaSyncDataCollector - ) - collector = collector_class( - [create_env] * 3, - policy=policy, - device=[torch.device("cuda:0")] * 3, - storing_device=[torch.device("cuda:0")] * 3, - frames_per_batch=20, - cat_results="stack", - ) - # collect state_dict - state_dict = collector.state_dict() - policy_state_dict = policy.state_dict() - for worker in range(3): - for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], - policy_state_dict[k].cpu(), + class CudaPolicy(TensorDictSequential): + def __init__(self, n_obs): + module = torch.nn.Linear(n_obs, n_obs, device="cuda") + module.weight.data.copy_(torch.eye(n_obs)) + module.bias.data.fill_(0) + m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"]) + m1 = TensorDictModule( + lambda a: a + 1, in_keys=["hidden"], out_keys=["action"] ) + super().__init__(m0, m1) - # change policy weights - for p in policy.parameters(): - p.data += torch.randn_like(p) - - # collect state_dict - state_dict = collector.state_dict() - policy_state_dict = policy.state_dict() - # check they don't match - for worker in range(3): - for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - with pytest.raises(AssertionError): - torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], - policy_state_dict[k].cpu(), - ) - - # update weights - collector.update_policy_weights_() + class GoesThroughEnv(EnvBase): + def __init__(self, n_obs, device): + self.observation_spec = Composite(observation=Unbounded(n_obs)) + self.action_spec = Unbounded(n_obs) + self.reward_spec = Unbounded(1) + self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool)) + super().__init__(device=device) - # collect state_dict - state_dict = collector.state_dict() - policy_state_dict = policy.state_dict() - for worker in range(3): - for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], - policy_state_dict[k].cpu(), + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + a = tensordict["action"] + if self.device is not None: + assert a.device == self.device + out = tensordict.empty() + out["observation"] = tensordict["observation"] + ( + a - tensordict["observation"] ) + out["reward"] = torch.zeros((1,), device=self.device) + out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool) + return out - collector.shutdown() - del collector + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + return self.full_done_specs.zeros().update(self.observation_spec.zeros()) + def _set_seed(self, seed: Optional[int]): + return seed -@pytest.mark.parametrize( - "collector_class", - [ - functools.partial(MultiSyncDataCollector, cat_results="stack"), - MultiaSyncDataCollector, - SyncDataCollector, - ], -) -@pytest.mark.parametrize("exclude", [True, False]) -@pytest.mark.parametrize("out_key", ["_dummy", ("out", "_dummy"), ("_out", "dummy")]) -def test_excluded_keys(collector_class, exclude, out_key): - if not exclude and collector_class is not SyncDataCollector: - pytest.skip("defining _exclude_private_keys is not possible") - - def make_env(): - return TransformedEnv(ContinuousActionVecMockEnv(), InitTracker()) - - dummy_env = make_env() - obs_spec = dummy_env.observation_spec["observation"] - policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) - policy = TensorDictModule( - policy_module, in_keys=["observation"], out_keys=["action"] - ) - copier = TensorDictModule(lambda x: x, in_keys=["observation"], out_keys=[out_key]) - policy_explore = TensorDictSequential( - policy, - copier, - OrnsteinUhlenbeckProcessModule( - spec=Composite({key: None for key in policy.out_keys}) - ), - ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device") + @pytest.mark.parametrize("env_device", ["cuda:0", "cpu"]) + @pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"]) + @pytest.mark.parametrize("no_cuda_sync", [True, False]) + def test_no_synchronize(self, env_device, storing_device, no_cuda_sync): + """Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted.""" + should_raise = not no_cuda_sync + should_raise = should_raise & ( + (env_device == "cpu") or (storing_device == "cpu") + ) + with patch("torch.cuda.synchronize") as mock_synchronize, pytest.raises( + AssertionError, match="Expected 'synchronize' to not have been called." + ) if should_raise else contextlib.nullcontext(): + collector = SyncDataCollector( + create_env_fn=functools.partial( + self.GoesThroughEnv, n_obs=1000, device=None + ), + policy=self.CudaPolicy(n_obs=1000), + frames_per_batch=100, + total_frames=1000, + env_device=env_device, + storing_device=storing_device, + policy_device="cuda:0", + no_cuda_sync=no_cuda_sync, + ) + assert collector.env.device == torch.device(env_device) + i = 0 + for d in collector: + for _d in d.unbind(0): + u = _d["observation"].unique() + assert u.numel() == 1, i + assert u == i, i + i += 1 + u = _d["next", "observation"].unique() + assert u.numel() == 1, i + assert u == i, i + mock_synchronize.assert_not_called() - collector_kwargs = { - "create_env_fn": make_env, - "policy": policy_explore, - "frames_per_batch": 30, - "total_frames": -1, - } - if collector_class is not SyncDataCollector: - collector_kwargs["create_env_fn"] = [ - collector_kwargs["create_env_fn"] for _ in range(3) - ] - - collector = collector_class(**collector_kwargs) - collector._exclude_private_keys = exclude - for b in collector: - keys = set(b.keys()) - if exclude: - assert not any(key.startswith("_") for key in keys) - assert out_key not in b.keys(True, True) - else: - assert any(key.startswith("_") for key in keys) - assert out_key in b.keys(True, True) - break - collector.shutdown() - dummy_env.close() - del collector + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("storing_device", ["cuda", "cpu"]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") + def test_collector_device_combinations(self, device, storing_device): + if IS_WINDOWS and PYTHON_3_10 and storing_device == "cuda" and device == "cuda": + pytest.skip("Windows fatal exception: access violation in torch.storage") + def env_fn(seed): + env = make_make_env("conv")() + env.set_seed(seed) + return env -@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") -@pytest.mark.parametrize( - "collector_class", - [ - SyncDataCollector, - MultiaSyncDataCollector, - functools.partial(MultiSyncDataCollector, cat_results="stack"), - ], -) -@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution -@pytest.mark.parametrize( - "explicit_spec,split_trajs", [[True, True], [False, False]] -) # 1226: faster execution -def test_collector_output_keys( - collector_class, init_random_frames, explicit_spec, split_trajs -): - from torchrl.envs.libs.gym import GymEnv - - out_features = 1 - hidden_size = 12 - total_frames = 200 - frames_per_batch = 20 - num_envs = 3 - - net = LSTMNet( - out_features, - {"input_size": hidden_size, "hidden_size": hidden_size}, - {"out_features": hidden_size}, - ) + policy = dummypolicy_conv() - policy_kwargs = { - "module": net, - "in_keys": ["observation", "hidden1", "hidden2"], - "out_keys": [ - "action", - "hidden1", - "hidden2", - ("next", "hidden1"), - ("next", "hidden2"), - ], - } - if explicit_spec: - hidden_spec = Unbounded((1, hidden_size)) - policy_kwargs["spec"] = Composite( - action=Unbounded(), - hidden1=hidden_spec, - hidden2=hidden_spec, - next=Composite(hidden1=hidden_spec, hidden2=hidden_spec), + collector = SyncDataCollector( + create_env_fn=env_fn, + create_env_kwargs={"seed": 0}, + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + device=device, + storing_device=storing_device, ) + assert collector._use_buffers + batch = next(collector.iterator()) + assert batch.device == _make_ordinal_device(torch.device(storing_device)) + collector.shutdown() - policy = SafeModule(**policy_kwargs) - - env_maker = lambda: GymEnv(PENDULUM_VERSIONED()) - - policy(env_maker().reset()) + collector = MultiSyncDataCollector( + create_env_fn=[ + env_fn, + ], + create_env_kwargs=[ + {"seed": 0}, + ], + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + device=[ + device, + ], + storing_device=[ + storing_device, + ], + cat_results="stack", + ) + batch = next(collector.iterator()) + assert batch.device == _make_ordinal_device(torch.device(storing_device)) + collector.shutdown() - collector_kwargs = { - "create_env_fn": env_maker, - "policy": policy, - "total_frames": total_frames, - "frames_per_batch": frames_per_batch, - "init_random_frames": init_random_frames, - "split_trajs": split_trajs, - } + collector = MultiaSyncDataCollector( + create_env_fn=[ + env_fn, + ], + create_env_kwargs=[ + {"seed": 0}, + ], + policy=policy, + frames_per_batch=20, + max_frames_per_traj=2000, + total_frames=20000, + device=[ + device, + ], + storing_device=[ + storing_device, + ], + ) + batch = next(collector.iterator()) + assert batch.device == _make_ordinal_device(torch.device(storing_device)) + collector.shutdown() + del collector - if collector_class is not SyncDataCollector: - collector_kwargs["create_env_fn"] = [ - collector_kwargs["create_env_fn"] for _ in range(num_envs) - ] - collector = collector_class(**collector_kwargs) +# @pytest.mark.skipif( +# IS_WINDOWS and PYTHON_3_10, +# reason="Windows Access Violation in torch.multiprocessing / BrokenPipeError in multiprocessing.connection", +# ) +# @pytest.mark.parametrize("num_env", [2]) +# @pytest.mark.parametrize("device", ["cuda", "cpu", None]) +# @pytest.mark.parametrize("policy_device", ["cuda", "cpu", None]) +# @pytest.mark.parametrize("storing_device", ["cuda", "cpu", None]) +# def test_output_device_consistency( +# num_env, device, policy_device, storing_device, seed=40 +# ): +# if ( +# device == "cuda" or policy_device == "cuda" or storing_device == "cuda" +# ) and not torch.cuda.is_available(): +# pytest.skip("cuda is not available") +# +# if IS_WINDOWS and PYTHON_3_7: +# if device == "cuda" and policy_device == "cuda" and device is None: +# pytest.skip( +# "BrokenPipeError in multiprocessing.connection with Python 3.7 on Windows" +# ) +# +# _device = "cuda:0" if device == "cuda" else device +# _policy_device = "cuda:0" if policy_device == "cuda" else policy_device +# _storing_device = "cuda:0" if storing_device == "cuda" else storing_device +# +# if num_env == 1: +# +# def env_fn(seed): +# env = make_make_env("vec")() +# env.set_seed(seed) +# return env +# +# else: +# +# def env_fn(seed): +# # 1226: faster execution +# # env = ParallelEnv( +# env = SerialEnv( +# num_workers=num_env, +# create_env_fn=make_make_env("vec"), +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# return env +# +# if _policy_device is None: +# policy = make_policy("vec") +# else: +# policy = ParametricPolicy().to(torch.device(_policy_device)) +# +# collector = SyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# device=_device, +# storing_device=_storing_device, +# ) +# for _, d in enumerate(collector): +# assert _is_consistent_device_type( +# device, policy_device, storing_device, d.device.type +# ) +# break +# assert d.names[-1] == "time" +# +# collector.shutdown() +# +# ccollector = aSyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# device=_device, +# storing_device=_storing_device, +# ) +# +# for _, d in enumerate(ccollector): +# assert _is_consistent_device_type( +# device, policy_device, storing_device, d.device.type +# ) +# break +# assert d.names[-1] == "time" +# +# ccollector.shutdown() +# del ccollector - keys = { - "action", - "done", - "collector", - "hidden1", - "hidden2", - ("next", "hidden1"), - ("next", "hidden2"), - ("next", "observation"), - ("next", "done"), - ("next", "reward"), - "next", - "observation", - ("collector", "traj_ids"), - } - if split_trajs: - keys.add(("collector", "mask")) - - keys.add(("next", "terminated")) - keys.add("terminated") - keys.add(("next", "truncated")) - keys.add("truncated") - b = next(iter(collector)) - - assert set(b.keys(True)) == keys - collector.shutdown() - del collector - - -@pytest.mark.parametrize("device", ["cuda", "cpu"]) -@pytest.mark.parametrize("storing_device", ["cuda", "cpu"]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") -def test_collector_device_combinations(device, storing_device): - if IS_WINDOWS and PYTHON_3_10 and storing_device == "cuda" and device == "cuda": - pytest.skip("Windows fatal exception: access violation in torch.storage") - - def env_fn(seed): - env = make_make_env("conv")() - env.set_seed(seed) - return env - policy = dummypolicy_conv() +# Deprecated reset_when_done +# @pytest.mark.parametrize("num_env", [1, 2]) +# @pytest.mark.parametrize("env_name", ["vec"]) +# def test_collector_done_persist(num_env, env_name, seed=5): +# if num_env == 1: +# +# def env_fn(seed): +# env = MockSerialEnv(device="cpu") +# env.set_seed(seed) +# return env +# +# else: +# +# def env_fn(seed): +# def make_env(seed): +# env = MockSerialEnv(device="cpu") +# env.set_seed(seed) +# return env +# +# env = ParallelEnv( +# num_workers=num_env, +# create_env_fn=make_env, +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# env.set_seed(seed) +# return env +# +# policy = make_policy(env_name) +# +# collector = SyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=200 * num_env, +# max_frames_per_traj=2000, +# total_frames=20000, +# device="cpu", +# reset_when_done=False, +# ) +# for _, d in enumerate(collector): # noqa +# break +# +# assert (d["done"].sum(-2) >= 1).all() +# assert torch.unique(d["collector", "traj_ids"], dim=-1).shape[-1] == 1 +# +# del collector - collector = SyncDataCollector( - create_env_fn=env_fn, - create_env_kwargs={"seed": 0}, - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=device, - storing_device=storing_device, - ) - assert collector._use_buffers - batch = next(collector.iterator()) - assert batch.device == _make_ordinal_device(torch.device(storing_device)) - collector.shutdown() - - collector = MultiSyncDataCollector( - create_env_fn=[ - env_fn, - ], - create_env_kwargs=[ - {"seed": 0}, - ], - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=[ - device, - ], - storing_device=[ - storing_device, - ], - cat_results="stack", - ) - batch = next(collector.iterator()) - assert batch.device == _make_ordinal_device(torch.device(storing_device)) - collector.shutdown() - collector = MultiaSyncDataCollector( - create_env_fn=[ - env_fn, - ], - create_env_kwargs=[ - {"seed": 0}, - ], - policy=policy, - frames_per_batch=20, - max_frames_per_traj=2000, - total_frames=20000, - device=[ - device, - ], - storing_device=[ - storing_device, - ], - ) - batch = next(collector.iterator()) - assert batch.device == _make_ordinal_device(torch.device(storing_device)) - collector.shutdown() - del collector +# TODO: design a test that ensures that collectors are interrupted even if __del__ is not called +# @pytest.mark.parametrize("should_shutdown", [True, False]) +# def test_shutdown_collector(should_shutdown, num_env=3, env_name="vec", seed=40): +# def env_fn(seed): +# env = ParallelEnv( +# num_workers=num_env, +# create_env_fn=make_make_env(env_name), +# create_env_kwargs=[{"seed": i} for i in range(seed, seed + num_env)], +# ) +# return env +# +# policy = make_policy(env_name) +# +# ccollector = aSyncDataCollector( +# create_env_fn=env_fn, +# create_env_kwargs={"seed": seed}, +# policy=policy, +# frames_per_batch=20, +# max_frames_per_traj=2000, +# total_frames=20000, +# ) +# for i, d in enumerate(ccollector): +# if i == 0: +# b1c = d +# elif i == 1: +# b2c = d +# else: +# break +# with pytest.raises(AssertionError): +# assert_allclose_td(b1c, b2c) +# +# if should_shutdown: +# ccollector.shutdown() @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @@ -1745,56 +2229,6 @@ def test_auto_wrap_error(self, collector_class, env_maker): ) -@pytest.mark.parametrize("env_class", [CountingEnv, CountingBatchedEnv]) -def test_initial_obs_consistency(env_class, seed=1): - # non regression test on #938 - torch.manual_seed(seed) - start_val = 4 - if env_class == CountingEnv: - num_envs = 1 - env = CountingEnv(device="cpu", max_steps=8, start_val=start_val) - max_steps = 8 - elif env_class == CountingBatchedEnv: - num_envs = 2 - env = CountingBatchedEnv( - device="cpu", - batch_size=[num_envs], - max_steps=torch.arange(num_envs) + 17, - start_val=torch.ones([num_envs]) * start_val, - ) - max_steps = env.max_steps.max().item() - env.set_seed(seed) - policy = lambda tensordict: tensordict.set( - "action", torch.ones(tensordict.shape, dtype=torch.int) - ) - collector = SyncDataCollector( - create_env_fn=env, - policy=policy, - frames_per_batch=((max_steps - 3) * 2 + 2) * num_envs, # at least two episodes - split_trajs=False, - total_frames=-1, - ) - for _d in collector: - break - obs = _d["observation"].squeeze() - if env_class == CountingEnv: - arange_0 = start_val + torch.arange(max_steps - 3) - arange = start_val + torch.arange(2) - expected = torch.cat([arange_0, arange_0, arange]) - else: - # the first env has a shorter horizon than the second - arange_0 = start_val + torch.arange(max_steps - 3 - 1) - arange = start_val + torch.arange(start_val) - expected_0 = torch.cat([arange_0, arange_0, arange]) - arange_0 = start_val + torch.arange(max_steps - 3) - arange = start_val + torch.arange(2) - expected_1 = torch.cat([arange_0, arange_0, arange]) - expected = torch.stack([expected_0, expected_1]) - assert torch.allclose(obs, expected.to(obs.dtype)) - collector.shutdown() - del collector - - def weight_reset(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): m.reset_parameters() @@ -1868,135 +2302,6 @@ def env_fn(seed): del collector -def test_maxframes_error(): - env = TransformedEnv(CountingEnv(), StepCounter(2)) - _ = SyncDataCollector( - env, RandomPolicy(env.action_spec), total_frames=10_000, frames_per_batch=1000 - ) - with pytest.raises(ValueError): - _ = SyncDataCollector( - env, - RandomPolicy(env.action_spec), - total_frames=10_000, - frames_per_batch=1000, - max_frames_per_traj=2, - ) - - -@retry(AssertionError, tries=10, delay=0) -@pytest.mark.parametrize("policy_device", [None, *get_available_devices()]) -@pytest.mark.parametrize("env_device", [None, *get_available_devices()]) -@pytest.mark.parametrize("storing_device", [None, *get_available_devices()]) -@pytest.mark.parametrize("parallel", [False, True]) -@pytest.mark.parametrize("share_individual_td", [False, True]) -def test_reset_heterogeneous_envs( - policy_device: torch.device, - env_device: torch.device, - storing_device: torch.device, - parallel, - share_individual_td, -): - if ( - policy_device is not None - and policy_device.type == "cuda" - and env_device is None - ): - env_device = torch.device("cpu") # explicit mapping - elif env_device is not None and env_device.type == "cuda" and policy_device is None: - policy_device = torch.device("cpu") - env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2)) - env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3)) - if parallel: - cls = ParallelEnv - else: - cls = SerialEnv - env = cls( - 2, [env1, env2], device=env_device, share_individual_td=share_individual_td - ) - collector = SyncDataCollector( - env, - RandomPolicy(env.action_spec), - total_frames=10_000, - frames_per_batch=100, - policy_device=policy_device, - env_device=env_device, - storing_device=storing_device, - ) - try: - for data in collector: # noqa: B007 - break - data_device = storing_device if storing_device is not None else env_device - assert ( - data[0]["next", "truncated"].squeeze() - == torch.tensor([False, True], device=data_device).repeat(25)[:50] - ).all(), data[0]["next", "truncated"] - assert ( - data[1]["next", "truncated"].squeeze() - == torch.tensor([False, False, True], device=data_device).repeat(17)[:50] - ).all(), data[1]["next", "truncated"][:10] - finally: - collector.shutdown() - del collector - - -def test_policy_with_mask(): - env = CountingBatchedEnv(start_val=torch.tensor(10), max_steps=torch.tensor(1e5)) - - def policy(td): - obs = td.get("observation") - # This policy cannot work with obs all 0s - if not obs.any(): - raise AssertionError - action = obs.clone() - td.set("action", action) - return td - - collector = SyncDataCollector( - env, policy=policy, frames_per_batch=10, total_frames=20 - ) - for _ in collector: - break - collector.shutdown() - - -@pytest.mark.parametrize( - "collector_cls", - [SyncDataCollector, MultiSyncDataCollector, MultiaSyncDataCollector], -) -def test_set_truncated(collector_cls): - env_fn = lambda: TransformedEnv( - NestedCountingEnv(), InitTracker() - ).add_truncated_keys() - env = env_fn() - policy = CloudpickleWrapper(env.rand_action) - if collector_cls == SyncDataCollector: - collector = collector_cls( - env, - policy=policy, - frames_per_batch=20, - total_frames=-1, - set_truncated=True, - trust_policy=True, - ) - else: - collector = collector_cls( - [env_fn, env_fn], - policy=policy, - frames_per_batch=20, - total_frames=-1, - cat_results="stack", - set_truncated=True, - trust_policy=True, - ) - try: - for data in collector: - assert data[..., -1]["next", "data", "truncated"].all() - break - finally: - collector.shutdown() - del collector - - class TestNestedEnvsCollector: def test_multi_collector_nested_env_consistency(self, seed=1): torch.manual_seed(seed) @@ -2652,56 +2957,6 @@ def test_aggregate_reset_to_root_errors(self): ) -@pytest.mark.parametrize( - "collector_class", - [ - functools.partial(MultiSyncDataCollector, cat_results="stack"), - MultiaSyncDataCollector, - SyncDataCollector, - ], -) -def test_collector_reloading(collector_class): - def make_env(): - return ContinuousActionVecMockEnv() - - dummy_env = make_env() - obs_spec = dummy_env.observation_spec["observation"] - policy_module = nn.Linear(obs_spec.shape[-1], dummy_env.action_spec.shape[-1]) - policy = Actor(policy_module, spec=dummy_env.action_spec) - policy_explore = TensorDictSequential( - policy, OrnsteinUhlenbeckProcessModule(spec=policy.spec) - ) - - collector_kwargs = { - "create_env_fn": make_env, - "policy": policy_explore, - "frames_per_batch": 30, - "total_frames": 90, - } - if collector_class is not SyncDataCollector: - collector_kwargs["create_env_fn"] = [ - collector_kwargs["create_env_fn"] for _ in range(3) - ] - - collector = collector_class(**collector_kwargs) - for i, _ in enumerate(collector): - if i == 3: - break - collector_frames = collector._frames - collector_iter = collector._iter - collector_state_dict = collector.state_dict() - collector.shutdown() - - collector = collector_class(**collector_kwargs) - collector.load_state_dict(collector_state_dict) - assert collector._frames == collector_frames - assert collector._iter == collector_iter - for _ in enumerate(collector): - raise AssertionError - collector.shutdown() - del collector - - class TestLibThreading: @pytest.mark.skipif( IS_OSX, @@ -3214,161 +3469,6 @@ def __deepcopy_error__(*args, **kwargs): raise RuntimeError("deepcopy not allowed") -@pytest.mark.filterwarnings( - "error::UserWarning", "ignore:Tensordict is registered in PyTree:UserWarning" -) -@pytest.mark.parametrize( - "collector_type", - [ - SyncDataCollector, - MultiaSyncDataCollector, - functools.partial(MultiSyncDataCollector, cat_results="stack"), - ], -) -def test_no_deepcopy_policy(collector_type): - # Tests that the collector instantiation does not make a deepcopy of the policy if not necessary. - # - # The only situation where we want to deepcopy the policy is when the policy_device differs from the actual device - # of the policy. This can only be checked if the policy is an nn.Module and any of the params is not on the desired - # device. - # - # If the policy is not a nn.Module or has no parameter, policy_device should warn (we don't know what to do but we - # can trust that the user knows what to do). - - # warnings.warn("Tensordict is registered in PyTree", category=UserWarning) - - shared_device = torch.device("cpu") - if torch.cuda.is_available(): - original_device = torch.device("cuda:0") - elif torch.mps.is_available(): - original_device = torch.device("mps") - else: - pytest.skip("No GPU or MPS device") - - def make_policy(device=None, nn_module=True): - if nn_module: - return TensorDictModule( - nn.Linear(7, 7, device=device), - in_keys=["observation"], - out_keys=["action"], - ) - policy = make_policy(device=device) - return CloudpickleWrapper(policy) - - def make_and_test_policy( - policy, - policy_device=None, - env_device=None, - device=None, - trust_policy=None, - ): - # make sure policy errors when copied - - policy.__deepcopy__ = __deepcopy_error__ - envs = ContinuousActionVecMockEnv(device=env_device) - if collector_type is not SyncDataCollector: - envs = [envs, envs] - c = collector_type( - envs, - policy=policy, - total_frames=1000, - frames_per_batch=10, - policy_device=policy_device, - env_device=env_device, - device=device, - trust_policy=trust_policy, - ) - for _ in c: - return - - # Simplest use cases - policy = make_policy() - make_and_test_policy(policy) - - if collector_type is SyncDataCollector or original_device.type != "mps": - # mps cannot be shared - policy = make_policy(device=original_device) - make_and_test_policy(policy, env_device=original_device) - - if collector_type is SyncDataCollector or original_device.type != "mps": - policy = make_policy(device=original_device) - make_and_test_policy( - policy, policy_device=original_device, env_device=original_device - ) - - # a deepcopy must occur when the policy_device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): - policy = make_policy(device=original_device) - make_and_test_policy( - policy, policy_device=shared_device, env_device=shared_device - ) - - # a deepcopy must occur when device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): - policy = make_policy(device=original_device) - make_and_test_policy(policy, device=shared_device) - - # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device - # is there to inform us - substitute_device = ( - original_device if torch.cuda.is_available() else torch.device("cpu") - ) - policy = make_policy(substitute_device, nn_module=False) - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=substitute_device - ) - # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=shared_device - ) - make_and_test_policy( - policy, - policy_device=substitute_device, - env_device=shared_device, - trust_policy=True, - ) - - # If there is no policy_device, we assume that the user is doing things right too but don't warn - if collector_type is SyncDataCollector or original_device.type != "mps": - policy = make_policy(original_device, nn_module=False) - make_and_test_policy(policy, env_device=original_device) - - # If the policy is a CudaGraphModule, we know it's on cuda - no need to warn - if torch.cuda.is_available() and collector_type is SyncDataCollector: - policy = make_policy(original_device) - cudagraph_policy = CudaGraphModule(policy) - make_and_test_policy( - cudagraph_policy, - policy_device=original_device, - env_device=shared_device, - ) - - -@pytest.mark.parametrize( - "ctype", [SyncDataCollector, MultiaSyncDataCollector, MultiSyncDataCollector] -) -def test_no_stopiteration(ctype): - # Tests that there is no StopIteration raised and that the length of the collector is properly set - if ctype is SyncDataCollector: - envs = SerialEnv(16, CountingEnv) - else: - envs = [SerialEnv(8, CountingEnv), SerialEnv(8, CountingEnv)] - - collector = ctype(create_env_fn=envs, frames_per_batch=173, total_frames=300) - try: - c_iter = iter(collector) - assert len(collector) == 2 - for i in range(len(collector)): # noqa: B007 - c = next(c_iter) - assert c is not None - assert i == 1 - finally: - collector.shutdown() - del collector - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d9d0850f9e3..1b57270bb3e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -6,6 +6,7 @@ import _pickle import abc +import collections import contextlib @@ -20,6 +21,7 @@ from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager +from queue import Empty from textwrap import indent from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Union @@ -258,7 +260,11 @@ def update_policy_weights_( self.policy_weights.data.update_(self.get_weights_fn()) def __iter__(self) -> Iterator[TensorDictBase]: - yield from self.iterator() + try: + yield from self.iterator() + except Exception: + self.shutdown() + raise def next(self): try: @@ -2325,8 +2331,28 @@ def iterator(self) -> Iterator[TensorDictBase]: while self.queue_out.qsize() < int(self.num_workers): continue + recv = collections.deque() + t0 = time.time() + while len(recv) < self.num_workers and ( + (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) + ): + for _ in range(self.num_workers): + try: + new_data, j = self.queue_out.get(timeout=_TIMEOUT) + recv.append((new_data, j)) + except (TimeoutError, Empty): + _check_for_faulty_process(self.procs) + if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): + try: + self.shutdown() + finally: + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + for _ in range(self.num_workers): - new_data, j = self.queue_out.get() + new_data, j = recv.popleft() use_buffers = self._use_buffers if self.replay_buffer is not None: idx = new_data @@ -2659,12 +2685,19 @@ def iterator(self) -> Iterator[TensorDictBase]: workers_frames = [0 for _ in range(self.num_workers)] while self._frames < self.total_frames: self._iter += 1 + counter = 0 while True: try: - idx, j, out = self._get_from_queue(timeout=10.0) + idx, j, out = self._get_from_queue(timeout=_TIMEOUT) break - except TimeoutError: + except (TimeoutError, Empty): + counter += _TIMEOUT _check_for_faulty_process(self.procs) + if counter > (_TIMEOUT * _MAX_IDLE_COUNT): + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) if self.replay_buffer is None: worker_frames = out.numel() if self.split_trajs: