Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sota-implementations/a2c/a2c_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

import hydra
import torch

Expand Down Expand Up @@ -149,6 +153,10 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)
adv_module = CudaGraphModule(adv_module)

Expand Down
8 changes: 8 additions & 0 deletions sota-implementations/a2c/a2c_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

import hydra
import torch

Expand Down Expand Up @@ -145,6 +149,10 @@ def update(batch):
adv_module = torch.compile(adv_module, mode=compile_mode)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=20)
adv_module = CudaGraphModule(adv_module, warmup=20)

Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import numpy as np
import torch.nn
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import numpy as np
import torch.nn
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/bandits/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import argparse

Expand Down
15 changes: 14 additions & 1 deletion sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
The helper functions are coded in the utils.py associated with this script.

"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np

import torch
import tqdm
from tensordict.nn import CudaGraphModule
Expand All @@ -32,6 +36,8 @@
make_offline_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -77,7 +83,9 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_env.start()

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(
cfg.loss, model, device=device
)

# Create Optimizer
(
Expand Down Expand Up @@ -134,6 +142,10 @@ def update(data, policy_eval_start, iteration):
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
Expand All @@ -154,6 +166,7 @@ def update(data, policy_eval_start, iteration):

with timeit("update"):
# compute loss
torch.compiler.cudagraph_mark_step_begin()
i_device = torch.tensor(i, device=device)
loss, loss_vals = update(
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
Expand Down
14 changes: 13 additions & 1 deletion sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
The helper functions are coded in the utils.py associated with this script.

"""
from __future__ import annotations

import warnings

import hydra
import numpy as np
import torch
Expand All @@ -34,6 +38,8 @@
make_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="online_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -103,7 +109,9 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
loss_module, target_net_updater = make_continuous_loss(
cfg.loss, model, device=device
)

# Create optimizer
(
Expand Down Expand Up @@ -140,6 +148,10 @@ def update(sampled_tensordict):
if compile_mode:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
Expand Down
13 changes: 12 additions & 1 deletion sota-implementations/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@

The helper functions are coded in the utils.py associated with this script.
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np

import torch
import torch.cuda
import tqdm
Expand All @@ -33,6 +37,8 @@
make_replay_buffer,
)

torch.set_float32_matmul_precision("high")


@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
def main(cfg: "DictConfig"): # noqa: F821
Expand Down Expand Up @@ -70,7 +76,7 @@ def main(cfg: "DictConfig"): # noqa: F821
model, explore_policy = make_discretecql_model(cfg, train_env, eval_env, device)

# Create loss
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model)
loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device)

compile_mode = None
if cfg.compile.compile:
Expand Down Expand Up @@ -123,6 +129,10 @@ def update(sampled_tensordict):
if compile_mode:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
Expand Down Expand Up @@ -170,6 +180,7 @@ def update(sampled_tensordict):
sampled_tensordict = replay_buffer.sample()
sampled_tensordict = sampled_tensordict.to(device)
with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
loss_dict = update(sampled_tensordict)
tds.append(loss_dict)

Expand Down
14 changes: 8 additions & 6 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools

import torch.nn
Expand Down Expand Up @@ -221,8 +223,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
# distribution_kwargs=TensorDictParams(
# TensorDict(
# {
# "low": action_spec.space.low,
# "high": action_spec.space.high,
# "low": torch.as_tensor(action_spec.space.low, device=device),
# "high": torch.as_tensor(action_spec.space.high, device=device),
# "tanh_loc": NonTensorData(False),
# }
# ),
Expand Down Expand Up @@ -326,7 +328,7 @@ def make_cql_modules_state(model_cfg, proof_environment):
# ---------


def make_continuous_loss(loss_cfg, model):
def make_continuous_loss(loss_cfg, model, device: torch.device | None = None):
loss_module = CQLLoss(
model[0],
model[1],
Expand All @@ -339,19 +341,19 @@ def make_continuous_loss(loss_cfg, model):
with_lagrange=loss_cfg.with_lagrange,
lagrange_thresh=loss_cfg.lagrange_thresh,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater


def make_discrete_loss(loss_cfg, model):
def make_discrete_loss(loss_cfg, model, device: torch.device | None = None):
loss_module = DiscreteCQLLoss(
model,
loss_function=loss_cfg.loss_function,
delay_value=True,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)

return loss_module, target_net_updater
Expand Down
9 changes: 7 additions & 2 deletions sota-implementations/crossq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ collector:
init_random_frames: 25000
frames_per_batch: 1000
init_env_steps: 1000
device: cpu
device:
env_per_collector: 1
reset_at_each_iter: False

Expand Down Expand Up @@ -46,7 +46,12 @@ network:
actor_activation: relu
default_policy_scale: 1.0
scale_lb: 0.1
device: "cuda:0"
device:

compile:
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
Loading
Loading