-
Notifications
You must be signed in to change notification settings - Fork 414
Description
Without the primer, the collector does not feed any hidden state to the policy
in the RNN tutorial it is stated that the primer is optional and it is used just to store the hidden states in the buffer.
This is not true in practice. Not adding the primer will result in the collector not feeding the hidden states to the policy during execution. Which will silently cause the rnn to loose any recurrency.
To reproduce, comment out this line
rl/tutorials/sphinx-tutorials/dqn_with_rnn.py
Line 269 in 0063741
| env.append_transform(lstm.make_tensordict_primer()) |
and print the policy input at this line
rl/torchrl/collectors/collectors.py
Line 733 in 0063741
| policy_output = self.policy(policy_input) |
you will see that no hidden state is fed to the rnn during execution and no errors or warnings are thrown
The primer overwrites any nested spec
Consider an env with nested specs
env = VmasEnv(
scenario="balance,
num_envs=5,
)add to it a primer for a nested hidden state
env = TransformedEnv(
env,
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
)the primer code in
rl/torchrl/envs/transforms/transforms.py
Line 4649 in 0063741
| observation_spec[key] = self.primers[key] = spec.to(device) |
The same result is obtained with
env = TransformedEnv(
env,
TensorDictPrimer(
{
("agents","h"): UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
}
),
)here, updating the spec instead of overwriting it should do the job
The order of the primer in the transforms seems to have an impact
In the same vmas environemnt as above, if i put the primer and then the reward sum
env = TransformedEnv(
env,
Compose(
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
RewardSum(
in_keys=[env.reward_key],
out_keys=[("agents", "episode_reward")],
),
),
)all works well
but the opposite
env = TransformedEnv(
env,
Compose(
RewardSum(
in_keys=[env.reward_key],
out_keys=[("agents", "episode_reward")],
),
TensorDictPrimer(
{
"agents": CompositeSpec(
{
"h": UnboundedContinuousTensorSpec(
shape=(*env.shape, env.n_agents, 2, 128)
)
},
shape=(*env.shape, env.n_agents),
)
}
),
),
)causes
Traceback (most recent call last):
File "/Users/Matteo/PycharmProjects/torchrl/sota-implementations/multiagent/mappo_ippo.py", line 302, in train
collector = SyncDataCollector(
^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 644, in __init__
self._make_shuttle()
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/collectors/collectors.py", line 661, in _make_shuttle
self._shuttle = self.env.reset()
^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/common.py", line 2143, in reset
tensordict_reset = self._reset(tensordict, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 814, in _reset
tensordict_reset = self.transform._reset(tensordict, tensordict_reset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 1129, in _reset
tensordict_reset = t._reset(tensordict, tensordict_reset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/envs/transforms/transforms.py", line 4722, in _reset
value = self.default_value[key]
~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('agents', 'episode_reward')