I am trying to use the SAC loss with a memory model.
The memory model expects data leading with batch dimension [B,T]
I pass data with this shape to the loss,
but here it gets reshaped
|
tensordict_reshape = tensordict.reshape(-1) |
causing the loss of the dimension T, which I cannot retrieve in my model.
Would it be possible to remove the reshaping of the data from the loss?