Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Prev Previous commit
Next Next commit
Move to SAOEState
  • Loading branch information
lihuoran committed Nov 8, 2022
commit 1f695d680d1442a6b942ccc06a1049e31ab70896
9 changes: 0 additions & 9 deletions qlib/rl/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@ class Interpreter:
states by calling ``self.env.register_state()``, but it's not planned for first iteration.
"""

def __init__(self) -> None:
self.cur_step = 0

def reset(self) -> None:
self.cur_step = 0

def step(self) -> None:
self.cur_step += 1


class StateInterpreter(Generic[StateType, ObsType], Interpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
Expand Down
10 changes: 5 additions & 5 deletions qlib/rl/order_execution/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def interpret(self, state: SAOEState) -> FullHistoryObs:
"data_processed_prev": np.array(processed.yesterday),
"acquiring": _to_int32(state.order.direction == state.order.BUY),
"cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)),
"cur_step": _to_int32(min(self.cur_step, self.max_step - 1)),
"cur_step": _to_int32(min(state.cur_step, self.max_step - 1)),
"num_step": _to_int32(self.max_step),
"target": _to_float32(state.order.amount),
"position": _to_float32(state.position),
Expand Down Expand Up @@ -173,10 +173,10 @@ def observation_space(self) -> spaces.Dict:
return spaces.Dict(space)

def interpret(self, state: SAOEState) -> CurrentStateObs:
assert self.cur_step <= self.max_step
assert state.cur_step <= self.max_step
obs = CurrentStateObs(
acquiring=state.order.direction == state.order.BUY,
cur_step=self.cur_step,
cur_step=state.cur_step,
num_step=self.max_step,
target=state.order.amount,
position=state.position,
Expand Down Expand Up @@ -212,7 +212,7 @@ def action_space(self) -> spaces.Discrete:

def interpret(self, state: SAOEState, action: int) -> float:
assert 0 <= action < len(self.action_values)
if self.max_step is not None and self.cur_step >= self.max_step - 1:
if self.max_step is not None and state.cur_step >= self.max_step - 1:
return state.position
else:
return min(state.position, state.order.amount * self.action_values[action])
Expand All @@ -233,7 +233,7 @@ def action_space(self) -> spaces.Box:

def interpret(self, state: SAOEState, action: float) -> float:
estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step)
twap_volume = state.position / (estimated_total_steps - self.cur_step)
twap_volume = state.position / (estimated_total_steps - state.cur_step)
return min(state.position, twap_volume * action)


Expand Down
5 changes: 4 additions & 1 deletion qlib/rl/order_execution/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def __init__(
action_space=action_space,
)
if weight_file is not None:
set_weight(self, torch.load(weight_file, map_location="cpu")["vessel"]["policy"])
weight = torch.load(weight_file, map_location="cpu")
if "vessel" in weight:
weight = weight["vessel"]["policy"]
set_weight(self, weight)


# utilities: these should be put in a separate (common) file. #
Expand Down
3 changes: 3 additions & 0 deletions qlib/rl/order_execution/simulator_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time)

self.cur_time = self.ticks_for_order[0]
self.cur_step = 0
# NOTE: astype(float) is necessary in some systems.
# this will align the precision with `.to_numpy()` in `_split_exec_vol`
self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean())
Expand Down Expand Up @@ -192,11 +193,13 @@ def step(self, amount: float) -> None:
self.env.logger.add_any(key, value)

self.cur_time = self._next_time()
self.cur_step += 1

def get_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
cur_step=self.cur_step,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,
Expand Down
2 changes: 2 additions & 0 deletions qlib/rl/order_execution/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class SAOEState(NamedTuple):
"""The order we are dealing with."""
cur_time: pd.Timestamp
"""Current time, e.g., 9:30."""
cur_step: int
"""Current step, e.g., 0."""
position: float
"""Current remaining volume to execute."""
history_exec: pd.DataFrame
Expand Down
23 changes: 14 additions & 9 deletions qlib/rl/order_execution/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange
from qlib.backtest.exchange import Exchange
from qlib.backtest.executor import BaseExecutor
from qlib.backtest.utils import LevelInfrastructure
from qlib.backtest.utils import LevelInfrastructure, get_start_end_idx
from qlib.constant import EPS, ONE_MIN, REG_CN
from qlib.rl.data.native import IntradayBacktestData, load_backtest_data
from qlib.rl.interpreter import ActionInterpreter, StateInterpreter
Expand Down Expand Up @@ -84,6 +84,7 @@ class SAOEStateAdapter:
def __init__(
self,
order: Order,
trade_decision: BaseTradeDecision,
executor: BaseExecutor,
exchange: Exchange,
ticks_per_step: int,
Expand All @@ -94,6 +95,7 @@ def __init__(
self.executor = executor
self.exchange = exchange
self.backtest_data = backtest_data
self.start_idx, _ = get_start_end_idx(self.executor.trade_calendar, trade_decision)

self.twap_price = self.backtest_data.get_deal_price().mean()

Expand Down Expand Up @@ -273,6 +275,7 @@ def saoe_state(self) -> SAOEState:
return SAOEState(
order=self.order,
cur_time=self.cur_time,
cur_step=self.executor.trade_calendar.get_trade_step() - self.start_idx,
position=self.position,
history_exec=self.history_exec,
history_steps=self.history_steps,
Expand Down Expand Up @@ -306,11 +309,17 @@ def __init__(
self.adapter_dict: Dict[tuple, SAOEStateAdapter] = {}
self._last_step_range = (0, 0)

def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter:
def _create_qlib_backtest_adapter(
self,
order: Order,
trade_decision: BaseTradeDecision,
trade_range: TradeRange,
) -> SAOEStateAdapter:
backtest_data = load_backtest_data(order, self.trade_exchange, trade_range)

return SAOEStateAdapter(
order=order,
trade_decision=trade_decision,
executor=self.executor,
exchange=self.trade_exchange,
ticks_per_step=int(pd.Timedelta(self.trade_calendar.get_freq()) / ONE_MIN),
Expand All @@ -330,7 +339,9 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -
self.adapter_dict = {}
for decision in outer_trade_decision.get_decision():
order = cast(Order, decision)
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(order, trade_range)
self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter(
order, outer_trade_decision, trade_range
)

def get_saoe_state_by_order(self, order: Order) -> SAOEState:
return self.adapter_dict[order.key_by_day].saoe_state
Expand Down Expand Up @@ -480,9 +491,6 @@ def __init__(
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None:
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)

self._state_interpreter.reset()
self._action_interpreter.reset()

def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame:
assert hasattr(self.outer_trade_decision, "order_list")

Expand Down Expand Up @@ -515,9 +523,6 @@ def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDeci
act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act
exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)]

self._state_interpreter.step()
self._action_interpreter.step()

oh = self.trade_exchange.get_order_helper()
order_list = []
for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols):
Expand Down
4 changes: 0 additions & 4 deletions qlib/rl/utils/env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ def reset(self, **kwargs: Any) -> ObsType:
)

self.simulator.env = cast(EnvWrapper, weakref.proxy(self))
self.state_interpreter.reset()
self.action_interpreter.reset()

sim_state = self.simulator.get_state()
obs = self.state_interpreter(sim_state)
Expand Down Expand Up @@ -215,8 +213,6 @@ def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, fl

# Use the converted action of update the simulator
self.simulator.step(action)
self.state_interpreter.step()
self.action_interpreter.step()

# Update "done" first, as this status might be used by reward_fn later
done = self.simulator.done()
Expand Down
1 change: 0 additions & 1 deletion tests/rl/test_qlib_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def test_interpreter() -> None:
order = get_order()
simulator = get_simulator(order)
interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION)
interpreter_action.reset()

NUM_STEPS = 7
state = simulator.get_state()
Expand Down
10 changes: 0 additions & 10 deletions tests/rl/test_saoe_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class EmulateEnvWrapper(NamedTuple):

# second step
simulator.step(5.0)
interpreter.step()
interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs))

obs = interpreter(simulator.get_state())
Expand All @@ -186,26 +185,21 @@ class EmulateEnvWrapper(NamedTuple):
assert np.sum(obs["data_processed"][60:]) == 0

# second step: action
interpreter_action.reset()
interpreter_action_twap.reset()
action = interpreter_action(simulator.get_state(), 1)
assert action == 15 / 20

interpreter_action_twap.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)
)
interpreter_action_twap.step()
action = interpreter_action_twap(simulator.get_state(), 1.5)
assert action == 1.5

# fast-forward
for _ in range(10):
simulator.step(0.0)
interpreter.step()

# last step
simulator.step(5.0)
interpreter.step()
interpreter.env = EmulateEnvWrapper(
status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs)
)
Expand Down Expand Up @@ -247,7 +241,6 @@ class EmulateEnvWrapper(NamedTuple):
assert 0 <= output["act"].item() <= 13
if i < 13:
simulator.step(1.0)
interpreter.step()
else:
assert obs["cur_tick"] == 389
assert obs["cur_step"] == 12
Expand All @@ -262,7 +255,6 @@ def test_twap_strategy(finite_env_type):

state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR))
action_interp = TwapRelativeActionInterpreter()
action_interp.reset()
policy = AllOne(state_interp.observation_space, action_interp.action_space)
csv_writer = CsvWriter(Path(__file__).parent / ".output")

Expand Down Expand Up @@ -292,7 +284,6 @@ def test_cn_ppo_strategy():

state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(4)
action_interp.reset()
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)
policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu"))
Expand Down Expand Up @@ -324,7 +315,6 @@ def test_ppo_train():

state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR))
action_interp = CategoricalActionInterpreter(4)
action_interp.reset()
network = Recurrent(state_interp.observation_space)
policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4)

Expand Down