-
Notifications
You must be signed in to change notification settings - Fork 414
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Motivation
I got the following error when I used GAE with an LSTM-based value network:
RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.
Here is the code I ran:
import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.objectives.value import GAE
class ValueNetwork(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(
input_size=2,
hidden_size=1,
num_layers=1,
bidirectional=False,
batch_first=True
)
def forward(self, i):
output, (hidden_state, cell_state) = self.lstm(i)
return hidden_state
def main():
value_network = ValueNetwork()
value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"])
gae = GAE(
gamma=0.98,
lmbda=0.95,
value_network=value_dict_module
)
gae.set_keys(
advantage="advantage",
value_target="value_target",
value="value",
)
tensor_dict = TensorDict({
"next": {
"observation": torch.FloatTensor([
[[8, 9], [10, 11]],
[[12, 13], [14, 15]]
]),
"reward": torch.FloatTensor([[1], [-1]]),
"done": torch.BoolTensor([[1], [1]]),
"terminated": torch.BoolTensor([[1], [1]])
},
"observation": torch.FloatTensor([
[[0, 1], [2, 3]],
[[4, 5], [6, 7]]
])
}, batch_size=2)
output_tensor_dict = gae(tensor_dict)
print(f"output_tensor_dict: {output_tensor_dict}")
advantage = output_tensor_dict["advantage"]
print(f"advantage: {advantage}")
main()The error was caused by this exact line:
output_tensor_dict = gae(tensor_dict)I tried using unbatched input and realized that GAE does not support unbatched input.
For example, this is the unbatched input I tried:
tensor_dict = TensorDict({
"next": {
"observation": torch.FloatTensor([[4, 5], [6, 7]]),
"reward": torch.FloatTensor([1]),
"done": torch.BoolTensor([1]),
"terminated": torch.BoolTensor([1])
},
"observation": torch.FloatTensor([[0, 1], [2, 3]])
}, batch_size=[])And I got this error from GAE:
RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])
Therefore, I concluded that GAE does not support an LSTM-based value network.
Solution
GAE should support an LSTM-based value network.
Alternatives
GAE should support unbatched tensor dict as an input.
Additional context
I'm using torchrl version: 0.5.0.
I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.
Checklist
- I have checked that there is no similar issue in the repo (required)
eryawww
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request