Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions experiments/buildmarines_check_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3.common.env_checker import check_env

from urnai.environments.stablebaselines3.custom_env_buildmarines import (
CustomEnvBuildMarines,
)
from urnai.sc2.actions.buildmarines import BuildMarinesActionSpace
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.rewards.buildmarines import BuildMarinesReward
from urnai.sc2.states.buildmarines import BuildMarinesState

players = [sc2_env.Agent(sc2_env.Race.terran)]
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0.0, high=1.0, shape=(4, ), dtype=float)

env = SC2Env(map_name='BuildMarines', visualize=False,
step_mul=16, players=players)
state = BuildMarinesState()
urnai_action_space = BuildMarinesActionSpace()
reward = BuildMarinesReward()
custom_env = CustomEnvBuildMarines(env, state, urnai_action_space, reward,
observation_space, action_space)

check_env(custom_env, warn=True)
126 changes: 126 additions & 0 deletions experiments/solves/solve_buildmarines_sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from absl import app
from gymnasium import spaces
from pysc2.env import sc2_env
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor

import wandb
from urnai.environments.stablebaselines3.custom_env_buildmarines import (
CustomEnvBuildMarines,
)
from urnai.loggers.wandb_logger import WandbLogger
from urnai.sc2.actions.buildmarines import BuildMarinesActionSpace
from urnai.sc2.environments.sc2environment import SC2Env
from urnai.sc2.rewards.buildmarines import BuildMarinesReward
from urnai.sc2.states.buildmarines import BuildMarinesState
from urnai.trainers.stablebaselines3_trainer import SB3Trainer

EPISODE_MINUTES = 15
STEPS_PER_SECOND = 22 # padrão do StarCraft II (aproximado)
STEPS_PER_MINUTE = int(STEPS_PER_SECOND * 60)
MAX_STEPS = EPISODE_MINUTES * STEPS_PER_MINUTE


def declare_wandb_run(config_dict : dict, run_id : str = None):

wandb_run = wandb.init(
project='solve_buildmarines',
config=config_dict,
name=config_dict['model_save_name'],
sync_tensorboard=True,
resume="must" if run_id else None,
id=run_id
)

return wandb_run

def declare_trainer(config_dict: dict, hyperparameters: dict = None):
players = [sc2_env.Agent(sc2_env.Race.terran)]
action_space = spaces.Discrete(n=4, start=0)
observation_space = spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=float)
step_mult = 32

logger = WandbLogger() # Uma única instância de logger compartilhada

# SC2Env separados para treino e avaliação
train_sc2_env = SC2Env(map_name='BuildMarines', step_mul=step_mult, players=players)
eval_sc2_env = SC2Env(map_name='BuildMarines', step_mul=step_mult, players=players)

# Instâncias separadas dos componentes com estado
train_state = BuildMarinesState()
train_action_space = BuildMarinesActionSpace()
train_reward = BuildMarinesReward(config_dict)

eval_state = BuildMarinesState()
eval_action_space = BuildMarinesActionSpace()
eval_reward = BuildMarinesReward(config_dict)

# CustomEnv separados para treino e avaliação
train_custom_env = CustomEnvBuildMarines(train_sc2_env, train_state,
train_action_space, train_reward,
observation_space, action_space, logger,
step_mult, MAX_STEPS)
eval_custom_env = CustomEnvBuildMarines(eval_sc2_env, eval_state,
eval_action_space, eval_reward,
observation_space, action_space, logger,
step_mult, MAX_STEPS)

# Wrappers Monitor
train_env = Monitor(train_custom_env)
eval_env = Monitor(eval_custom_env)

models_dir = f"/home/mambauser/saves/models/{config_dict['model_save_name']}"
logdir = "/home/mambauser/saves/logs"

model = PPO(config_dict['policy'], train_env, verbose=1,
tensorboard_log=logdir,
**(hyperparameters if hyperparameters is not None else {}))

trainer = SB3Trainer(
train_env, eval_env, models_dir, logdir, model,
config_dict['model_save_name'], logger=logger
)

return trainer

def main(unused_argv):
try:
config_dict = {
"policy":"MlpPolicy",
"model_save_name": "TestMoreBarracks2",
"w_supply": 1.0,
"w_barrack": 30.0,
"w_marine": 1.5,
"penalty_no_supply": 0.0,
"penalty_no_barrack": 0.0
}
wandb_run = declare_wandb_run(config_dict)
trainer = declare_trainer(config_dict)
# trainer.load_most_recent_model(trainer.models_dir)
trainer.alternate_train_test(
iterations=100000,
train_steps= int(50 * MAX_STEPS / 32),
test_episodes=10,
callback=None,
return_episode_rewards=True, wandb_log=True
)
# trainer.test_model(
# episodes=100, deterministic=True, render=False,
# wandb_log=False
# )
# trainer.train_model(
# timesteps=100000, log_interval=1,
# reset_num_timesteps=False, progress_bar=True,
# repeat_times=1, start_from=1, callback=WandbCallback()
# )
wandb_run.finish()
except KeyboardInterrupt:
print("Training interrupted by user")

if __name__ == '__main__':
app.run(main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import unittest
from unittest.mock import ANY, MagicMock, patch

from urnai.environments.stablebaselines3.custom_env_buildmarines import (
CustomEnvBuildMarines,
)


@patch('urnai.environments.stablebaselines3.custom_env_buildmarines.CustomEnv.reset')
class TestCustomEnvBuildMarines(unittest.TestCase):

def setUp(self):
self.mock_env_internal = MagicMock()
self.mock_state = MagicMock()
self.mock_urnai_action_space = MagicMock()
self.mock_reward = MagicMock()
self.mock_observation_space = MagicMock()
self.mock_action_space = MagicMock()
self.mock_logger = MagicMock()
self.env_wrapper = CustomEnvBuildMarines(
env=self.mock_env_internal,
state=self.mock_state,
urnai_action_space=self.mock_urnai_action_space,
reward=self.mock_reward,
observation_space=self.mock_observation_space,
action_space=self.mock_action_space,
logger=self.mock_logger,
step_mul=32,
max_steps=5000
)


def test_initialization(self, mock_super_reset):
# GIVEN / WHEN: setUp()
# THEN
self.assertIs(self.env_wrapper._env, self.mock_env_internal)
self.assertIs(self.env_wrapper._state, self.mock_state)
self.assertIs(self.env_wrapper._action_space, self.mock_urnai_action_space)
self.assertIs(self.env_wrapper._reward, self.mock_reward)
self.assertIs(self.env_wrapper.observation_space, self.mock_observation_space)
self.assertIs(self.env_wrapper.action_space, self.mock_action_space)
self.assertEqual(self.env_wrapper.step_count, 0)
self.assertEqual(self.env_wrapper.max_steps, 5000)
self.assertEqual(self.env_wrapper.action_map_count["BuildMarine"], 0)
self.assertIn(3, self.env_wrapper.actions)
self.assertEqual(self.env_wrapper.actions[3], "BuildMarine")


def test_reset(self, mock_super_reset):
# GIVEN
self.env_wrapper.step_count = 100
self.env_wrapper.action_map_count["Collect"] = 5
self.env_wrapper.action_map_reward["Collect"] = 10.0
# WHEN
self.env_wrapper.reset()
# THEN
self.assertEqual(self.env_wrapper.step_count, 0)
self.assertEqual(self.env_wrapper.action_map_count["Collect"], 0)
self.assertEqual(self.env_wrapper.action_map_reward["Collect"], 0)
mock_super_reset.assert_called_once()


def test_step_basic_flow(self, mock_super_reset):
# GIVEN
action = 3 # BuildMarine
mock_obs = "observation_from_sc2"
updated_obs = "observation_from_state_builder"
calculated_reward = 1.5
self.mock_urnai_action_space.get_action.return_value = "urnai_action"
self.mock_env_internal.step.return_value = (mock_obs, 1.0, False, False)
self.mock_state.update.return_value = updated_obs
self.mock_reward.get.return_value = calculated_reward
# WHEN
obs, reward, terminated, truncated, info = self.env_wrapper.step(action)
# THEN
self.mock_urnai_action_space.get_action.assert_called_once_with(action, ANY)
self.mock_env_internal.step.assert_called_once_with("urnai_action")
self.mock_state.update.assert_called_once_with(mock_obs)
self.mock_reward.get.assert_called_once_with(mock_obs, 1.0,False,False, action)
self.assertEqual(self.env_wrapper.step_count, self.env_wrapper.step_mul)
self.assertEqual(self.env_wrapper.action_map_count["BuildMarine"], 1)
self.assertEqual(self.env_wrapper.action_map_reward["BuildMarine"],
calculated_reward)
self.assertEqual(obs, updated_obs)
self.assertEqual(reward, calculated_reward)
self.assertFalse(terminated)
self.assertFalse(truncated)


def test_step_reaches_max_steps_and_logs(self, mock_super_reset):
# GIVEN
self.env_wrapper.max_steps = 100
self.env_wrapper.step_mul = 32
self.env_wrapper.step_count = 100 - self.env_wrapper.step_mul
self.mock_env_internal.step.return_value = ("some_obs", 0.0, False, False)
self.env_wrapper.log_reward_per_action = MagicMock()
# WHEN
obs, reward, terminated, truncated, info = self.env_wrapper.step(0)
# THEN
self.assertTrue(truncated)
self.env_wrapper.log_reward_per_action.assert_called_once()

@patch('urnai.environments.stablebaselines3.custom_env_buildmarines.sc2aux.get_my_units_amount')
def test_log_reward_per_action(self, mock_get_units, mock_super_reset,
):
# GIVEN
self.env_wrapper.action_map_count = {"Collect": 10,
"BuildSupplyDepot": 2,
"BuildBarrack": 1,
"BuildMarine": 0}
self.env_wrapper.action_map_reward = {"Collect": 5.0,
"BuildSupplyDepot": 4.0,
"BuildBarrack": 1.0,
"BuildMarine": 0.0}
self.env_wrapper._reward.total_reward = 10.0
mock_get_units.side_effect = [
0, # Marines
2, # SupplyDepots
1 # Barracks
]
# WHEN
self.env_wrapper.log_reward_per_action()
# THEN
self.mock_logger.log.assert_called_once()
logged_data = self.mock_logger.log.call_args[0][0]
self.assertEqual(logged_data["action/count/Collect"], 10)
self.assertAlmostEqual(logged_data["action/avg_reward/Collect"], 0.5)
self.assertEqual(logged_data["action/count/BuildSupplyDepot"], 2)
self.assertAlmostEqual(logged_data["action/avg_reward/BuildSupplyDepot"], 2.0)
self.assertEqual(logged_data["total_reward"], 10.0)
self.assertEqual(logged_data["marines_built"], 0)
self.assertEqual(logged_data["supply_depots_built"], 2)
self.assertEqual(logged_data["barracks_built"], 1)

@patch('urnai.environments.stablebaselines3.custom_env_buildmarines.sc2aux.get_my_units_amount')
def test_log_reward_per_action_when_logger_is_none(self, mock_get_units,
mock_super_reset):
# GIVEN
self.env_wrapper.action_map_count = {"Collect": 10,
"BuildSupplyDepot": 2,
"BuildBarrack": 1,
"BuildMarine": 5}
self.env_wrapper.action_map_reward = {"Collect": 5.0,
"BuildSupplyDepot": 4.0,
"BuildBarrack": 3.0,
"BuildMarine": 10.0}
self.env_wrapper._reward.total_reward = 22.0
mock_get_units.side_effect = [20, 2, 1]
self.env_wrapper.logger = None
# WHEN
self.env_wrapper.log_reward_per_action()
# THEN
self.mock_logger.log.assert_not_called()
Loading
Loading