Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ api_docs
opponent_pool
!/examples/selfplay/opponent_templates/tictactoe_opponent/info.json
wandb_run
examples/dmc/new.gif
10 changes: 10 additions & 0 deletions examples/dmc/ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
episode_length: 25
lr: 5e-4
critic_lr: 5e-4
gamma: 0.99
ppo_epoch: 5
use_valuenorm: true
entropy_coef: 0.0
hidden_size: 128
layer_N: 4
data_chunk_length: 1
110 changes: 110 additions & 0 deletions examples/dmc/train_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numpy as np
from gymnasium.wrappers import FlattenObservation

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers.base_wrapper import BaseWrapper
from openrl.envs.wrappers.extra_wrappers import GIFWrapper
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent


class FrameSkip(BaseWrapper):
def __init__(self, env, num_frames: int = 8):
super().__init__(env)
self.num_frames = num_frames

def step(self, action):
num_skips = self.num_frames
total_reward = 0.0

for x in range(num_skips):
obs, rew, term, trunc, info = super().step(action)
total_reward += rew
if term or trunc:
break

return obs, total_reward, term, trunc, info


env_name = "dm_control/cartpole-balance-v0"
# env_name = "dm_control/walker-walk-v0"


def train():
# create the neural network
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

# create environment, set environment parallelism to 9
env = make(
env_name,
env_num=64,
cfg=cfg,
asynchronous=True,
env_wrappers=[FrameSkip, FlattenObservation],
)

net = Net(env, cfg=cfg, device="cuda")
# initialize the trainer
agent = Agent(
net,
)
# start training, set total number of training steps to 20000
agent.train(total_time_steps=100000)
agent.save("./ppo_agent")
env.close()
return agent





def evaluation():
cfg_parser = create_config_parser()
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])
# begin to test
# Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human.
render_mode = "group_human"
render_mode = "group_rgb_array"
env = make(
env_name,
render_mode=render_mode,
env_num=4,
asynchronous=True,
env_wrappers=[FrameSkip,FlattenObservation],
cfg=cfg
)
env = GIFWrapper(env, gif_path="./new.gif", fps=5)



net = Net(env, cfg=cfg, device="cuda")
# initialize the trainer
agent = Agent(
net,
)
agent.load("./ppo_agent")

# The trained agent sets up the interactive environment it needs.
agent.set_env(env)
# Initialize the environment and get initial observations and environmental information.
obs, info = env.reset()
done = False
step = 0
total_reward = 0.0
while not np.any(done):
if step > 500:
break
# Based on environmental observation input, predict next action.
action, _ = agent.act(obs, deterministic=True)
obs, r, done, info = env.step(action)
step += 1
total_reward += np.mean(r)
if step % 50 == 0:
print(f"{step}: reward:{np.mean(r)}")
print("total step:", step, total_reward)
env.close()

train()
evaluation()
4 changes: 3 additions & 1 deletion openrl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def prepare_loss(
)

q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)

Expand Down
4 changes: 3 additions & 1 deletion openrl/algorithms/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def prepare_loss(
rewards_batch = rewards_batch.reshape(-1, self.n_agent, 1)
rewards_batch = torch.sum(rewards_batch, dim=1, keepdim=True).view(-1, 1)
q_targets = rewards_batch + self.gamma * max_next_q_values * next_masks_batch
q_loss = torch.mean(F.mse_loss(q_values, q_targets.detach())) # 均方误差损失函数
q_loss = torch.mean(
F.mse_loss(q_values, q_targets.detach())
) # 均方误差损失函数

loss_list.append(q_loss)
return loss_list
Expand Down
9 changes: 8 additions & 1 deletion openrl/envs/common/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ def make(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)
else:
if id in gym.envs.registry.keys():
if id.startswith("dm_control/"):
from openrl.envs.dmc import make_dmc_envs

env_fns = make_dmc_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)
elif id in gym.envs.registry.keys():
from openrl.envs.gymnasium import make_gym_envs

env_fns = make_gym_envs(
Expand All @@ -77,6 +83,7 @@ def make(
env_fns = make_mpe_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)

elif id in openrl.envs.nlp_all_envs:
from openrl.envs.nlp import make_nlp_envs

Expand Down
33 changes: 33 additions & 0 deletions openrl/envs/dmc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import copy
from typing import Callable, List, Optional, Union

import dmc2gym

from openrl.envs.common import build_envs
from openrl.envs.dmc.dmc_env import make


def make_dmc_envs(
id: str,
env_num: int = 1,
render_mode: Optional[Union[str, List[str]]] = None,
**kwargs,
):
from openrl.envs.wrappers import ( # AutoReset,; DictWrapper,
RemoveTruncated,
Single2MultiAgentWrapper,
)
from openrl.envs.wrappers.extra_wrappers import ConvertEmptyBoxWrapper

env_wrappers = copy.copy(kwargs.pop("env_wrappers", []))
env_wrappers += [ConvertEmptyBoxWrapper, RemoveTruncated, Single2MultiAgentWrapper]
env_fns = build_envs(
make=make,
id=id,
env_num=env_num,
render_mode=render_mode,
wrappers=env_wrappers,
**kwargs,
)

return env_fns
45 changes: 45 additions & 0 deletions openrl/envs/dmc/dmc_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Any, Optional

import dmc2gym
import gymnasium as gym
import numpy as np

# class DmcEnv:
# def __init__(self):
# env = dmc2gym.make(
# domain_name='walker',
# task_name='walk',
# seed=42,
# visualize_reward=False,
# from_pixels='features',
# height=224,
# width=224,
# frame_skip=2
# )
# # self.observation_space = spaces.Box(
# # low=np.array([0, 0, 0, 0]),
# # high=np.array([self.nrow - 1, self.ncol - 1, self.nrow - 1, self.ncol - 1]),
# # dtype=int,
# # ) # current position and target position
# # self.action_space = spaces.Discrete(
# # 5
# # )


def make(
id: str,
render_mode: Optional[str] = None,
**kwargs: Any,
):
env = gym.make(id, render_mode=render_mode)
# env = dmc2gym.make(
# domain_name='walker',
# task_name='walk',
# seed=42,
# visualize_reward=False,
# from_pixels='features',
# height=224,
# width=224,
# frame_skip=2
# )
return env
10 changes: 4 additions & 6 deletions openrl/envs/mpe/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@
except ImportError:
print(
"Error occured while running `from pyglet.gl import *`",
(
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\''
),
"HINT: make sure you have OpenGL install. On Ubuntu, you can run 'apt-get"
" install python-opengl'. If you're running on a server, you may need a"
" virtual frame buffer; something like this should work: 'xvfb-run -s"
' "-screen 0 1400x900x24" python <your_script.py>\'',
)

import math
Expand Down
30 changes: 10 additions & 20 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,8 @@ def reset_send(

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete"
),
"Calling `reset_send` while waiting for a pending call to"
f" `{self._state.value}` to complete",
self._state.value,
)

Expand Down Expand Up @@ -328,10 +326,8 @@ def step_send(self, actions: np.ndarray):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete."
),
"Calling `step_send` while waiting for a pending call to"
f" `{self._state.value}` to complete.",
self._state.value,
)

Expand Down Expand Up @@ -575,10 +571,8 @@ def call_send(self, name: str, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `call_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down Expand Up @@ -635,10 +629,8 @@ def exec_func_send(self, func: Callable, indices, *args, **kwargs):
self._assert_is_running()
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `exec_func_send` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down Expand Up @@ -715,10 +707,8 @@ def set_attr(self, name: str, values: Union[List[Any], Tuple[Any], object]):

if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
(
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete."
),
"Calling `set_attr` while waiting "
f"for a pending call to `{self._state.value}` to complete.",
str(self._state.value),
)

Expand Down
Loading