Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Evaluation script double-integrator throws ValueError #8

@luigiberducci

Description

@luigiberducci

Hello,
Thanks for sharing your code!

I tried to run the experiments on the double-integrator and got stuck in the evaluation, due to a ValueError when resuming the checkpoint.

How to reproduce

I used the commands in the README.
Training a policy with

python scripts/train_dbint_inner.py --name run1

The training ends without any issue.

Then, I launch the evaluation script

python scripts/eval_dbint_inner.py runs/DbInt_inner/<run_dir>/ckpts/00099999/default

and throws the following error:

ValueError("Dict key mismatch; expected keys: ['alg', 'alg_cfg', 'collect_cfg']; dict: {'alg': {'update_idx': RestoreArgs(restore_type=None, dtype=None), 'key': RestoreArgs(restore_type=None, dtype=None), 'policy': {'step': RestoreArgs(restore_type=None, dtype=None), 'params': {'DiscretePolicyNet_0': {'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}, 'OutputDense': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'opt_state': {'count': RestoreArgs(restore_type=None, dtype=None), 'hyperparams': {'eps': RestoreArgs(restore_type=None, dtype=None), 'learning_rate': RestoreArgs(restore_type=None, dtype=None), 'wd': RestoreArgs(restore_type=None, dtype=None)}, 'hyperparams_states': {'learning_rate': {'count': RestoreArgs(restore_type=None, dtype=None)}}, 'inner_state': {'notfinite_count': RestoreArgs(restore_type=None, dtype=None), 'last_finite': RestoreArgs(restore_type=None, dtype=None), 'total_notfinite': RestoreArgs(restore_type=None, dtype=None), 'inner_state': [{'count': RestoreArgs(restore_type=None, dtype=None), 'mu': {'DiscretePolicyNet_0': {'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}, 'OutputDense': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'nu': {'DiscretePolicyNet_0': {'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}, 'OutputDense': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}}, {'inner_state': RestoreArgs(restore_type=None, dtype=None)}, RestoreArgs(restore_type=None, dtype=None)]}}}, 'Vl': {'step': RestoreArgs(restore_type=None, dtype=None), 'params': {'CostValueNet_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'opt_state': {'count': RestoreArgs(restore_type=None, dtype=None), 'hyperparams': {'eps': RestoreArgs(restore_type=None, dtype=None), 'learning_rate': RestoreArgs(restore_type=None, dtype=None), 'wd': RestoreArgs(restore_type=None, dtype=None)}, 'hyperparams_states': {'learning_rate': {'count': RestoreArgs(restore_type=None, dtype=None)}}, 'inner_state': {'notfinite_count': RestoreArgs(restore_type=None, dtype=None), 'last_finite': RestoreArgs(restore_type=None, dtype=None), 'total_notfinite': RestoreArgs(restore_type=None, dtype=None), 'inner_state': [{'count': RestoreArgs(restore_type=None, dtype=None), 'mu': {'CostValueNet_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'nu': {'CostValueNet_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}}, {'inner_state': RestoreArgs(restore_type=None, dtype=None)}, RestoreArgs(restore_type=None, dtype=None)]}}}, 'Vh': {'step': RestoreArgs(restore_type=None, dtype=None), 'params': {'ConstrValueNet_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'opt_state': {'count': RestoreArgs(restore_type=None, dtype=None), 'hyperparams': {'eps': RestoreArgs(restore_type=None, dtype=None), 'learning_rate': RestoreArgs(restore_type=None, dtype=None), 'wd': RestoreArgs(restore_type=None, dtype=None)}, 'hyperparams_states': {'learning_rate': {'count': RestoreArgs(restore_type=None, dtype=None)}}, 'inner_state': {'notfinite_count': RestoreArgs(restore_type=None, dtype=None), 'last_finite': RestoreArgs(restore_type=None, dtype=None), 'total_notfinite': RestoreArgs(restore_type=None, dtype=None), 'inner_state': [{'count': RestoreArgs(restore_type=None, dtype=None), 'mu': {'ConstrValueNet_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'nu': {'ConstrValueNet_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'MLP_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}, 'Dense_1': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}, 'ZEncoder_0': {'Dense_0': {'bias': RestoreArgs(restore_type=None, dtype=None), 'kernel': RestoreArgs(restore_type=None, dtype=None)}}}}, {'inner_state': RestoreArgs(restore_type=None, dtype=None)}, RestoreArgs(restore_type=None, dtype=None)]}}}, 'disc_gamma': RestoreArgs(restore_type=None, dtype=None)}}.")

Potential issue

I suspect it is related to some mismatch in dependencies and versions.
Could you please share the setup used in the experiments or any information related to jax, orbax versions?

Local setup

I am using a conda environment with Python 3.11.
The pip packages:

Package                  Version            Editable project location
------------------------ ------------------ -----------------------------
absl-py                  2.1.0
asttokens                3.0.0
attrs                    24.2.0
certifi                  2024.8.30
charset-normalizer       3.4.0
chex                     0.1.87
click                    8.1.7
cloudpickle              3.1.0
contourpy                1.3.1
cycler                   0.12.1
decorator                5.1.1
dm-tree                  0.1.8
docker-pycreds           0.4.0
efppo                    0.0.0              /home/luigi/Development/efppo
einops                   0.8.0
equinox                  0.11.9
etils                    1.11.0
executing                2.1.0
flax                     0.10.2
fonttools                4.55.0
fsspec                   2024.10.0
gast                     0.6.0
gitdb                    4.0.11
GitPython                3.1.43
humanize                 4.11.0
idna                     3.10
importlib_resources      6.4.5
ipdb                     0.13.13
ipython                  8.30.0
jax                      0.4.35
jax-cuda12-pjrt          0.4.35
jax-cuda12-plugin        0.4.35
jax-f16                  0.0.2
jaxlib                   0.4.34
jaxtyping                0.2.36
jedi                     0.19.2
kiwisolver               1.4.7
loguru                   0.7.2
markdown-it-py           3.0.0
matplotlib               3.9.3
matplotlib-inline        0.1.7
mdurl                    0.1.2
ml_dtypes                0.5.0
msgpack                  1.1.0
nest-asyncio             1.6.0
numpy                    2.1.3
nvidia-cublas-cu12       12.6.4.1
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvcc-cu12    12.6.85
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.5.1.17
nvidia-cufft-cu12        11.3.0.4
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-nccl-cu12         2.23.4
nvidia-nvjitlink-cu12    12.6.85
opt_einsum               3.4.0
optax                    0.2.4
orbax-checkpoint         0.10.1
packaging                24.2
pandas                   2.2.3
parso                    0.8.4
pexpect                  4.9.0
pillow                   11.0.0
pip                      24.2
platformdirs             4.3.6
prompt_toolkit           3.0.48
protobuf                 5.29.0
psutil                   6.1.0
ptyprocess               0.7.0
pure_eval                0.2.3
Pygments                 2.18.0
pyparsing                3.2.0
python-dateutil          2.9.0.post0
pytz                     2024.2
PyYAML                   6.0.2
requests                 2.32.3
rich                     13.9.4
scipy                    1.14.1
seaborn                  0.13.2
sentry-sdk               2.19.0
setproctitle             1.3.4
setuptools               75.1.0
shapely                  2.0.6
shellingham              1.5.4
simplejson               3.19.3
six                      1.16.0
smmap                    5.0.1
stack-data               0.6.3
tensorstore              0.1.69
tfp-nightly              0.26.0.dev20241202
toolz                    1.0.0
traitlets                5.14.3
typer                    0.14.0
typing_extensions        4.12.2
tzdata                   2024.2
urllib3                  2.2.3
wandb                    0.18.7
wcwidth                  0.2.13
wheel                    0.44.0
zipp                     3.21.0

Thanks you very much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions