Thanks to visit codestin.com
Credit goes to github.com

Skip to content

GAE does not support LSTM-based value network. #2444

@levelrin

Description

@levelrin

Motivation

I got the following error when I used GAE with an LSTM-based value network:

RuntimeError: Batching rule not implemented for aten::lstm.input. We could not generate a fallback.

Here is the code I ran:

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torch import nn
from torchrl.objectives.value import GAE

class ValueNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=2,
            hidden_size=1,
            num_layers=1,
            bidirectional=False,
            batch_first=True
        )

    def forward(self, i):
        output, (hidden_state, cell_state) = self.lstm(i)
        return hidden_state

def main():
    value_network = ValueNetwork()
    value_dict_module = TensorDictModule(value_network, in_keys=["observation"], out_keys=["value"])
    gae = GAE(
        gamma=0.98,
        lmbda=0.95,
        value_network=value_dict_module
    )
    gae.set_keys(
        advantage="advantage",
        value_target="value_target",
        value="value",
    )
    tensor_dict = TensorDict({
        "next": {
            "observation": torch.FloatTensor([
                [[8, 9], [10, 11]],
                [[12, 13], [14, 15]]
            ]),
            "reward": torch.FloatTensor([[1], [-1]]),
            "done": torch.BoolTensor([[1], [1]]),
            "terminated": torch.BoolTensor([[1], [1]])
        },
        "observation": torch.FloatTensor([
            [[0, 1], [2, 3]],
            [[4, 5], [6, 7]]
        ])
    }, batch_size=2)

    output_tensor_dict = gae(tensor_dict)
    print(f"output_tensor_dict: {output_tensor_dict}")
    advantage = output_tensor_dict["advantage"]
    print(f"advantage: {advantage}")

main()

The error was caused by this exact line:

output_tensor_dict = gae(tensor_dict)

I tried using unbatched input and realized that GAE does not support unbatched input.
For example, this is the unbatched input I tried:

tensor_dict = TensorDict({
    "next": {
        "observation": torch.FloatTensor([[4, 5], [6, 7]]),
        "reward": torch.FloatTensor([1]),
        "done": torch.BoolTensor([1]),
        "terminated": torch.BoolTensor([1])
    },
    "observation": torch.FloatTensor([[0, 1], [2, 3]])
}, batch_size=[])

And I got this error from GAE:

RuntimeError: Expected input tensordict to have at least one dimensions, got tensordict.batch_size = torch.Size([])

Therefore, I concluded that GAE does not support an LSTM-based value network.

Solution

GAE should support an LSTM-based value network.

Alternatives

GAE should support unbatched tensor dict as an input.

Additional context

I'm using torchrl version: 0.5.0.

I found ticket #2372, which might be related to this issue, but I was not sure how to make my code work.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions