From e68ffc89218bccd9352eaaa003ed96f8320c0ae4 Mon Sep 17 00:00:00 2001 From: Default Date: Wed, 24 Aug 2022 14:34:20 +0800 Subject: [PATCH 01/15] Backtest migration --- qlib/backtest/__init__.py | 8 +- qlib/rl/amc4th_migration/backtest.py | 230 ++++++++++++++++++ .../experiments/opds_15_225/opds_15_225.py | 52 ++++ .../experiments/opds_15_225/twap.yml | 25 ++ .../amc4th_migration/naive_config_parser.py | 133 ++++++++++ qlib/rl/amc4th_migration/utils.py | 29 +++ qlib/rl/data/exchange_wrapper.py | 4 +- qlib/rl/data/pickle_styled.py | 40 ++- qlib/rl/interpreter.py | 8 +- qlib/rl/order_execution/integration.py | 2 +- qlib/rl/order_execution/interpreter.py | 42 ++-- qlib/rl/order_execution/network.py | 46 ++++ qlib/rl/order_execution/state.py | 2 +- qlib/rl/order_execution/strategy.py | 196 ++++++++++++++- qlib/rl/trainer/__init__.py | 2 +- qlib/rl/utils/env_wrapper.py | 21 ++ tests/rl/test_saoe_simple.py | 10 +- 17 files changed, 808 insertions(+), 42 deletions(-) create mode 100644 qlib/rl/amc4th_migration/backtest.py create mode 100644 qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py create mode 100644 qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml create mode 100644 qlib/rl/amc4th_migration/naive_config_parser.py create mode 100644 qlib/rl/amc4th_migration/utils.py diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index e8fe73c5a2..81c6437d6d 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -114,7 +114,7 @@ def get_exchange( def create_account_instance( start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], - benchmark: str, + benchmark: Optional[str], account: Union[float, int, dict], pos_type: str = "Position", ) -> Account: @@ -163,7 +163,9 @@ def create_account_instance( init_cash=init_cash, position_dict=position_dict, pos_type=pos_type, - benchmark_config={ + benchmark_config={} + if benchmark is None + else { "benchmark": benchmark, "start_time": start_time, "end_time": end_time, @@ -176,7 +178,7 @@ def get_strategy_executor( end_time: Union[pd.Timestamp, str], strategy: Union[str, dict, object, Path], executor: Union[str, dict, object, Path], - benchmark: str = "SH000300", + benchmark: Optional[str] = "SH000300", account: Union[float, int, dict] = 1e9, exchange_kwargs: dict = {}, pos_type: str = "Position", diff --git a/qlib/rl/amc4th_migration/backtest.py b/qlib/rl/amc4th_migration/backtest.py new file mode 100644 index 0000000000..4e88d71d57 --- /dev/null +++ b/qlib/rl/amc4th_migration/backtest.py @@ -0,0 +1,230 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import copy +import pickle +import sys +from pathlib import Path +from typing import Optional, cast, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from joblib import Parallel, delayed + +from qlib.backtest import collect_data_loop, get_strategy_executor +from qlib.backtest.decision import TradeRangeByTime +from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor +from qlib.backtest.high_performance_ds import BaseOrderIndicator +from qlib.rl.amc4th_migration.naive_config_parser import convert_instance_config, get_backtest_config_fromfile +from qlib.rl.amc4th_migration.utils import read_order_file +from qlib.rl.order_execution.integration import init_qlib +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper + + +def _get_multi_level_executor_config( + strategy_config: dict, + cash_limit: float = None, + generate_report: bool = False, +) -> dict: + strategy_config = cast(dict, convert_instance_config(strategy_config)) + executor_config = { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": "1min", + "verbose": False, + "trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL, + "generate_report": generate_report, + "track_data": True, + }, + } + + freqs = list(strategy_config.keys()) + freqs.sort(key=lambda x: pd.Timedelta(x)) + for freq in freqs: + executor_config = { + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": freq, + "inner_strategy": strategy_config[freq], + "inner_executor": executor_config, + "track_data": True, + }, + } + + return executor_config + + +def _set_env_for_all_strategy(executor: BaseExecutor) -> None: + if isinstance(executor, NestedExecutor): + if hasattr(executor.inner_strategy, "set_env"): + executor.inner_strategy.set_env(CollectDataEnvWrapper()) + _set_env_for_all_strategy(executor.inner_executor) + + +def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: + record_list = [] + for time, value_dict in indicator.items(): + if isinstance(value_dict, BaseOrderIndicator): + # HACK: for qlib v0.8 + value_dict = value_dict.to_series() + try: + value_dict = {k: v for k, v in value_dict.items()} + if value_dict["ffr"].empty: + continue + except Exception: + value_dict = {k: v for k, v in value_dict.items() if k != "pa"} + value_dict = pd.DataFrame(value_dict) + value_dict["datetime"] = time + record_list.append(value_dict) + + if not record_list: + return None + + records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"}) + records = records.set_index(["instrument", "datetime"]) + return records + + +def _generate_report(decisions: list, report_dict: dict) -> dict: + report = {} + decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")]) + for key in ["1minute", "5minute", "30minute", "1day"]: + if key not in report_dict["indicator"]: + continue + report[key] = report_dict["indicator"][key] + report[key + "_obj"] = _convert_indicator_to_dataframe( + report_dict["indicator"][key + "_obj"].order_indicator_his + ) + cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"]) + if len(cur_details) > 0: + cur_details.pop("freq") + report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer") + if "1minute" in report_dict["report"]: + report["simulator"] = report_dict["report"]["1minute"][0] + return report + + +def single( + backtest_config: dict, + orders: pd.DataFrame, + split: str = "stock", + cash_limit: float = None, + generate_report: bool = False, +) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: + if split == "stock": + stock_id = orders.iloc[0].instrument + init_qlib(backtest_config["qlib"], part=stock_id) + else: + day = orders.iloc[0].datetime + init_qlib(backtest_config["qlib"], part=day) + + trade_start_time = orders["datetime"].min() + trade_end_time = orders["datetime"].max() + stocks = orders.instrument.unique().tolist() + + top_strategy_config = { + "class": "FileOrderStrategy", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": { + "file": orders, + "trade_range": TradeRangeByTime( + pd.Timestamp(backtest_config["start_time"]).time(), + pd.Timestamp(backtest_config["end_time"]).time(), + ), + }, + } + + top_executor_config = _get_multi_level_executor_config( + strategy_config=backtest_config["strategies"], + cash_limit=cash_limit, + generate_report=generate_report, + ) + + tmp_backtest_config = copy.deepcopy(backtest_config["exchange"]) + tmp_backtest_config.update( + { + "codes": stocks, + "freq": "1min", + } + ) + + strategy, executor = get_strategy_executor( + start_time=pd.Timestamp(trade_start_time), + end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1), + strategy=top_strategy_config, + executor=top_executor_config, + benchmark=None, + account=cash_limit if cash_limit is not None else int(1e12), + exchange_kwargs=tmp_backtest_config, + pos_type="Position" if cash_limit is not None else "InfPosition", + ) + _set_env_for_all_strategy(executor=executor) + + report_dict: dict = {} + decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict)) + + records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his) + assert records is None or not np.isnan(records["ffr"]).any() + + if generate_report: + report = _generate_report(decisions, report_dict) + if split == "stock": + stock_id = orders.iloc[0].instrument + report = {stock_id: report} + else: + day = orders.iloc[0].datetime + report = {day: report} + return records, report + else: + return records + + +def backtest(backtest_config: dict) -> pd.DataFrame: + order_df = read_order_file(backtest_config["order_file"]) + + cash_limit = backtest_config["exchange"].pop("cash_limit") + generate_report = backtest_config["exchange"].pop("generate_report") + + stock_pool = order_df["instrument"].unique().tolist() + stock_pool.sort() + + mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} + torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 + res = Parallel(**mp_config)( + delayed(single)( + backtest_config=backtest_config, + orders=order_df[order_df["instrument"] == stock].copy(), + split="stock", + cash_limit=cash_limit, + generate_report=generate_report, + ) + for stock in stock_pool + ) + + output_path = Path(backtest_config["output_dir"]) + if generate_report: + with (output_path / "report.pkl").open("wb") as f: + report = {} + for r in res: + report.update(r[1]) + pickle.dump(report, f) + res = pd.concat([r[0] for r in res], 0) + else: + res = pd.concat(res) + + res.to_csv(output_path / "summary.csv") + return res + + +if __name__ == "__main__": + import warnings + + warnings.filterwarnings("ignore", category=DeprecationWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + path = sys.argv[1] + backtest(get_backtest_config_fromfile(path)) diff --git a/qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py b/qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py new file mode 100644 index 0000000000..615749922d --- /dev/null +++ b/qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py @@ -0,0 +1,52 @@ +_base_ = ["./twap.yml"] + +strategies = { + "_delete_": True, + "5min": {"type": "qlib.contrib.strategy.rule_strategy.TWAPStrategy"}, + "30min": {"type": "qlib.rl.order_execution.strategy.MultiplexStrategyOnTradeStep", "strategies": []}, + "1day": { + "type": "qlib.rl.order_execution.strategy.SAOEIntStrategy", + "state_interpreter": { + "type": "qlib.rl.order_execution.interpreter.FullHistoryStateInterpreter", + "max_step": 8, + "data_ticks": 240, + "data_dim": 16, + }, + "action_interpreter": { + "type": "qlib.rl.order_execution.interpreter.CategoricalActionInterpreter", + "values": 4, + "max_step": 8, + }, + "network": { + "type": "qlib.rl.order_execution.network.DualAttentionRNN", + }, + "policy": { + "type": "qlib.rl.order_execution.policy.PPO", + "lr": 1.0e-4, + "weight_file": "data/amc-checkpoints/opds_15_225/opds_15_225_30r_4_80", + }, + }, +} + +import copy + +# for mypy +assert isinstance(strategies["1day"], dict) +assert isinstance(strategies["30min"], dict) + +for step in range(8): + step_start, step_end = max(15, step * 30), min(225, step * 30 + 30) + num_inner_steps = (step_end - step_start + 5 - 1) // 5 + strategy: dict = copy.deepcopy(strategies["1day"]) + strategy["state_interpreter"]["max_step"] = num_inner_steps + action_values = [3, 6, 6, 6, 6, 6, 6, 3] + + strategy["network"] = {"type": "qlib.rl.order_execution.network.DualAttentionRNN"} + strategy["action_interpreter"]["values"] = action_values[step] + strategy["action_interpreter"]["max_step"] = num_inner_steps + strategy["policy"]["weight_file"] = f"data/amc-checkpoints/opds_15_225/opds_{step_start}_{step_end}" + + strategies["30min"]["strategies"].append(strategy) + + +del copy, step, step_start, step_end, num_inner_steps, strategy, action_values diff --git a/qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml b/qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml new file mode 100644 index 0000000000..118cd29964 --- /dev/null +++ b/qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml @@ -0,0 +1,25 @@ +order_file: data/amc-real-order/orders_v4/csi300_nostop.pkl +start_time: "9:45" +end_time: "14:44" +qlib: + provider_uri_day: data/amc-qlib/huaxia_1d_qlib + provider_uri_1min: data/amc-qlib/huaxia_1min_qlib + feature_root_dir: data/amc-qlib-stock + feature_columns_today: [ + "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", + "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" + ] + feature_columns_yesterday: [ + "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", + "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" + ] +exchange: + limit_threshold: ['$ask == 0', '$bid == 0'] + deal_price: ["If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"] + volume_threshold: + all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"] + buy: ["current", "$askV1"] + sell: ["current", "$bidV1"] +strategies: + 1day: + type: neutrader.qlib_integration.strategy.TWAPStrategy diff --git a/qlib/rl/amc4th_migration/naive_config_parser.py b/qlib/rl/amc4th_migration/naive_config_parser.py new file mode 100644 index 0000000000..4add7ac071 --- /dev/null +++ b/qlib/rl/amc4th_migration/naive_config_parser.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import platform +import shutil +import sys +import tempfile +from importlib import import_module + +import yaml + + +def merge_a_into_b(a: dict, b: dict) -> dict: + b = b.copy() + for k, v in a.items(): + if isinstance(v, dict) and k in b: + v.pop("_delete_", False) # TODO: make this more elegant + b[k] = merge_a_into_b(v, b[k]) + else: + b[k] = v + return b + + +def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None: + if not os.path.isfile(filename): + raise FileNotFoundError(msg_tmpl.format(filename)) + + +def parse_backtest_config(path: str) -> dict: + abs_path = os.path.abspath(path) + check_file_exist(abs_path) + + file_ext_name = os.path.splitext(abs_path)[1] + if file_ext_name not in (".py", ".json", ".yaml", ".yml"): + raise IOError("Only py/yml/yaml/json type are supported now!") + + with tempfile.TemporaryDirectory() as tmp_config_dir: + tmp_config_file = tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) + if platform.system() == "Windows": + tmp_config_file.close() + + tmp_config_name = os.path.basename(tmp_config_file.name) + shutil.copyfile(abs_path, tmp_config_file.name) + + if abs_path.endswith(".py"): + tmp_module_name = os.path.splitext(tmp_config_name)[0] + sys.path.insert(0, tmp_config_dir) + module = import_module(tmp_module_name) + sys.path.pop(0) + + config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")} + + del sys.modules[tmp_module_name] + else: + config = yaml.safe_load(open(os.path.join(tmp_config_dir, tmp_config_file.name))) + + if "_base_" in config: + base_file_name = config.pop("_base_") + if not isinstance(base_file_name, list): + base_file_name = [base_file_name] + + for f in base_file_name: + base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f)) + config = merge_a_into_b(a=config, b=base_config) + + return config + + +def _convert_all_list_to_tuple(config: dict) -> dict: + for k, v in config.items(): + if isinstance(v, list): + config[k] = tuple(v) + elif isinstance(v, dict): + config[k] = _convert_all_list_to_tuple(v) + return config + + +def get_backtest_config_fromfile(path: str) -> dict: + backtest_config = parse_backtest_config(path) + + exchange_config_default = { + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5.0, + "trade_unit": 100.0, + "cash_limit": None, + "generate_report": False, + } + backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default) + backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"]) + + backtest_config_default = { + "debug_single_stock": None, + "debug_single_day": None, + "concurrency": -1, + "multiplier": 1.0, + "output_dir": "outputs/", + # "runtime": {}, + } + backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) + + return backtest_config + + +def convert_instance_config(config: object) -> object: + if isinstance(config, dict): + if "type" in config: + type_name = config["type"] + if "." in type_name: + idx = type_name.rindex(".") + module_path, class_name = type_name[:idx], type_name[idx + 1 :] + else: + module_path, class_name = "", type_name + + kwargs = {} + for k, v in config.items(): + if k == "type": + continue + kwargs[k] = convert_instance_config(v) + return { + "class": class_name, + "module_path": module_path, + "kwargs": kwargs, + } + else: + return {k: convert_instance_config(v) for k, v in config.items()} + elif isinstance(config, list): + return [convert_instance_config(item) for item in config] + elif isinstance(config, tuple): + return tuple([convert_instance_config(item) for item in config]) + else: + return config diff --git a/qlib/rl/amc4th_migration/utils.py b/qlib/rl/amc4th_migration/utils.py new file mode 100644 index 0000000000..cad25e0dba --- /dev/null +++ b/qlib/rl/amc4th_migration/utils.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from pathlib import Path + +import pandas as pd + + +def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame: + if isinstance(order_file, pd.DataFrame): + return order_file + + order_file = Path(order_file) + + if order_file.suffix == ".pkl": + order_df = pd.read_pickle(order_file).reset_index() + elif order_file.suffix == ".csv": + order_df = pd.read_csv(order_file) + else: + raise TypeError(f"Unsupported order file type: {order_file}") + + if "date" in order_df.columns: + # legacy dataframe columns + order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"}) + order_df["datetime"] = order_df["datetime"].astype(str) + + return order_df diff --git a/qlib/rl/data/exchange_wrapper.py b/qlib/rl/data/exchange_wrapper.py index 94bb1dcbbd..004074d0b8 100644 --- a/qlib/rl/data/exchange_wrapper.py +++ b/qlib/rl/data/exchange_wrapper.py @@ -5,12 +5,12 @@ import cachetools import pandas as pd - from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime -from qlib.constant import ONE_DAY, EPS_T +from qlib.constant import EPS_T, ONE_DAY from qlib.rl.order_execution.utils import get_ticks_slice from qlib.utils.index_data import IndexData + from .pickle_styled import BaseIntradayBacktestData diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 43fe9dd5ad..7856cb5a16 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -22,7 +22,7 @@ from abc import abstractmethod from functools import lru_cache from pathlib import Path -from typing import List, Sequence, cast +from typing import List, Optional, Sequence, cast import cachetools import numpy as np @@ -30,6 +30,7 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir +from qlib.rl.order_execution.integration import fetch_features from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -178,7 +179,7 @@ def get_time_index(self) -> pd.DatetimeIndex: return cast(pd.DatetimeIndex, self.data.index) -class IntradayProcessedData: +class BaseIntradayProcessedData: """Processed market data after data cleanup and feature engineering. It contains both processed data for "today" and "yesterday", as some algorithms @@ -193,6 +194,10 @@ class IntradayProcessedData: """Processed data for "yesterday". Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + +class IntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle Dataset Handler style data.""" + def __init__( self, data_dir: Path, @@ -233,6 +238,25 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.today}, {self.yesterday})" +class NTIntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle NT style data.""" + + def __init__( + self, + stock_id: str, + date: pd.Timestamp, + ) -> None: + def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: + return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"]) + + self.today = _drop_stock_id(fetch_features(stock_id, date)) + self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) + + def __repr__(self) -> str: + with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + return f"{self.__class__.__name__}({self.today}, {self.yesterday})" + + @lru_cache(maxsize=100) # 100 * 50K = 5MB def load_simple_intraday_backtest_data( data_dir: Path, @@ -249,13 +273,19 @@ def load_simple_intraday_backtest_data( key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), ) def load_intraday_processed_data( - data_dir: Path, + data_dir: Optional[Path], stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, -) -> IntradayProcessedData: - return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) +) -> BaseIntradayProcessedData: + from qlib.rl.order_execution.integration import dataset # pylint: disable=C0415 + + if dataset is None: + assert data_dir is not None + return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) + else: + return NTIntradayProcessedData(stock_id, date) def load_orders( diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 61c9b83819..5ff2780db7 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union import numpy as np @@ -12,7 +12,7 @@ from .simulator import ActType, StateType if TYPE_CHECKING: - from .utils.env_wrapper import EnvWrapper + from .utils.env_wrapper import CollectDataEnvWrapper, EnvWrapper import gym from gym import spaces @@ -40,7 +40,7 @@ class Interpreter: class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - env: Optional[EnvWrapper] = None + env: Union[EnvWrapper, CollectDataEnvWrapper, None] = None @property def observation_space(self) -> gym.Space: @@ -74,7 +74,7 @@ def interpret(self, simulator_state: StateType) -> ObsType: class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - env: Optional[EnvWrapper] = None + env: Union[EnvWrapper, CollectDataEnvWrapper, None] = None @property def action_space(self) -> gym.Space: diff --git a/qlib/rl/order_execution/integration.py b/qlib/rl/order_execution/integration.py index 07ca381613..d32ce49c82 100644 --- a/qlib/rl/order_execution/integration.py +++ b/qlib/rl/order_execution/integration.py @@ -41,7 +41,7 @@ def __init__( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), - key=lambda stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest), + key=lambda _, stock_id, date, backtest: (stock_id, date.replace(hour=0, minute=0, second=0), backtest), ) def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame: start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 089fc553cf..a074322402 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -57,8 +57,6 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): Parameters ---------- - data_dir - Path to load data after feature engineering. max_step Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. data_ticks @@ -66,9 +64,12 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): the total ticks is the length of day in minutes. data_dim Number of dimensions in data. + data_dir + Path to load data after feature engineering. It is optional since in some cases we do not need explicit + path to load data. For example, the data has already been preloaded in `init_qlib()`. """ - def __init__(self, data_dir: Path, max_step: int, data_ticks: int, data_dim: int) -> None: + def __init__(self, max_step: int, data_ticks: int, data_dim: int, data_dir: Path = None) -> None: self.data_dir = data_dir self.max_step = max_step self.data_ticks = data_ticks @@ -96,15 +97,15 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: FullHistoryObs, canonicalize( { - "data_processed": self._mask_future_info(processed.today, state.cur_time), - "data_processed_prev": processed.yesterday, - "acquiring": state.order.direction == state.order.BUY, - "cur_tick": min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1), - "cur_step": min(self.env.status["cur_step"], self.max_step - 1), - "num_step": self.max_step, - "target": state.order.amount, - "position": state.position, - "position_history": position_history[: self.max_step], + "data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)), + "data_processed_prev": np.array(processed.yesterday), + "acquiring": _to_int32(state.order.direction == state.order.BUY), + "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_step": _to_int32(min(self.env.status["cur_step"], self.max_step - 1)), + "num_step": _to_int32(self.max_step), + "target": _to_float32(state.order.amount), + "position": _to_float32(state.position), + "position_history": _to_float32(position_history[: self.max_step]), }, ), ) @@ -186,10 +187,11 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. """ - def __init__(self, values: int | List[float]) -> None: + def __init__(self, values: int | List[float], max_step: int = None) -> None: if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values + self.max_step = max_step @property def action_space(self) -> spaces.Discrete: @@ -197,7 +199,11 @@ def action_space(self) -> spaces.Discrete: def interpret(self, state: SAOEState, action: int) -> float: assert 0 <= action < len(self.action_values) - return min(state.position, state.order.amount * self.action_values[action]) + assert self.env is not None + if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1: + return state.position + else: + return min(state.position, state.order.amount * self.action_values[action]) class TwapRelativeActionInterpreter(ActionInterpreter[SAOEState, float, float]): @@ -218,3 +224,11 @@ def interpret(self, state: SAOEState, action: float) -> float: estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step) twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"]) return min(state.position, twap_volume * action) + + +def _to_int32(val): + return np.array(int(val), dtype=np.int32) + + +def _to_float32(val): + return np.array(val, dtype=np.float32) diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index 3d0279559e..b1467782c0 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -117,3 +117,49 @@ def forward(self, batch: Batch) -> torch.Tensor: out = torch.cat(sources, -1) return self.fc(out) + + +class Attention(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.q_net = nn.Linear(in_dim, out_dim) + self.k_net = nn.Linear(in_dim, out_dim) + self.v_net = nn.Linear(in_dim, out_dim) + + def forward(self, Q, K, V): + q = self.q_net(Q) + k = self.k_net(K) + v = self.v_net(V) + + attn = torch.einsum("ijk,ilk->ijl", q, k) + attn = attn.to(Q.device) + attn_prob = torch.softmax(attn, dim=-1) + + attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v) + + return attn_vec + + +class DualAttentionRNN(Recurrent): + """ + Dual-attention RNN leverages features from yesterday and fuses them into features today. + """ + + def _init_extra_branches(self): + self.attention = Attention(self.hidden_dim, self.hidden_dim) + self.num_sources += 1 + + def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: + sources, data_out = super()._source_features(obs, device) + + data_prev = obs["data_processed_prev"] + cur_time = obs["cur_tick"].long() + bs_indices = torch.arange(cur_time.size(0), device=device) + + data_prev_in = self.raw_fc(data_prev) + data_prev_out, _ = self.prev_rnn(data_prev_in) + att_out = self.attention(data_out, data_prev_out, data_prev_out) + att_out = att_out[bs_indices, cur_time] + sources.insert(1, att_out) + + return sources, data_out diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index d6bbeaea5a..d1fd85502c 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -13,8 +13,8 @@ from qlib.rl.data.exchange_wrapper import IntradayBacktestData from qlib.rl.data.pickle_styled import BaseIntradayBacktestData from qlib.rl.order_execution.utils import dataframe_append, price_advantage +from qlib.typehint import TypedDict from qlib.utils.time import get_day_min_idx_range -from typing_extensions import TypedDict def _get_all_timestamps( diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 4a85bc76ed..0327f65ac5 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -4,18 +4,25 @@ from __future__ import annotations import collections +from abc import ABCMeta from types import GeneratorType -from typing import Any, Optional, Union, cast, Dict, Generator +from typing import Any, cast, Dict, Generator, List, Optional, Union import pandas as pd - -from qlib.backtest import CommonInfrastructure, Order +import torch +from qlib.backtest import CommonInfrastructure, Exchange, Order from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data -from qlib.rl.order_execution.state import SAOEStateAdapter, SAOEState -from qlib.strategy.base import RLStrategy +from qlib.rl.interpreter import ActionInterpreter, StateInterpreter +from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter +from qlib.rl.utils import EnvWrapper +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper +from qlib.strategy.base import BaseStrategy, RLStrategy +from qlib.utils import init_instance_by_config +from tianshou.data import Batch +from tianshou.policy import BasePolicy class SAOEStrategy(RLStrategy): @@ -106,7 +113,10 @@ def generate_trade_decision( return decision - def _generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]: + def _generate_trade_decision( + self, + execute_result: list = None, + ) -> Union[BaseTradeDecision, Generator[Any, Any, BaseTradeDecision]]: raise NotImplementedError @@ -146,3 +156,177 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) - order_list = outer_trade_decision.order_list assert len(order_list) == 1 self._order = order_list[0] + + +class SAOEIntStrategy(SAOEStrategy): + """(SAOE)state based strategy with (Int)preters.""" + + def __init__( + self, + policy: dict | BasePolicy, + state_interpreter: dict | StateInterpreter, + action_interpreter: dict | ActionInterpreter, + network: object = None, # TODO: add accurate typehint later. + outer_trade_decision: BaseTradeDecision = None, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + **kwargs: Any, + ) -> None: + super(SAOEIntStrategy, self).__init__( + policy=policy, + outer_trade_decision=outer_trade_decision, + level_infra=level_infra, + common_infra=common_infra, + **kwargs, + ) + + self._state_interpreter: StateInterpreter = init_instance_by_config( + state_interpreter, + accept_types=StateInterpreter, + ) + self._action_interpreter: ActionInterpreter = init_instance_by_config( + action_interpreter, + accept_types=ActionInterpreter, + ) + + if isinstance(policy, dict): + assert network is not None + + if isinstance(network, dict): + network["kwargs"].update( + { + "obs_space": self._state_interpreter.observation_space, + } + ) + network_inst = init_instance_by_config(network) + else: + network_inst = network + + policy["kwargs"].update( + { + "obs_space": self._state_interpreter.observation_space, + "action_space": self._action_interpreter.action_space, + "network": network_inst, + } + ) + self._policy = init_instance_by_config(policy) + elif isinstance(policy, BasePolicy): + self._policy = policy + else: + raise ValueError(f"Unsupported policy type: {type(policy)}.") + + if self._policy is not None: + self._policy.eval() + + def set_env(self, env: EnvWrapper | CollectDataEnvWrapper) -> None: + self._env = env + self._state_interpreter.env = self._action_interpreter.env = self._env + + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + + if isinstance(self._env, CollectDataEnvWrapper): + self._env.reset() + + def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: + states = [] + obs_batch = [] + for decision in self.outer_trade_decision.get_decision(): + order = cast(Order, decision) + state = self.get_saoe_state_by_order(order) + + states.append(state) + obs_batch.append({"obs": self._state_interpreter.interpret(state)}) + + with torch.no_grad(): + policy_out = self._policy(Batch(obs_batch)) + act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act + exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)] + + if isinstance(self._env, CollectDataEnvWrapper): + self._env.step(None) + + oh = self.trade_exchange.get_order_helper() + order_list = [] + for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols): + if exec_vol != 0: + order = cast(Order, decision) + order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) + + return TradeDecisionWO(order_list=order_list, strategy=self) + + +class MultiplexStrategyBase(BaseStrategy, metaclass=ABCMeta): + def __init__( + self, + strategies: List[BaseStrategy] | List[dict], + outer_trade_decision: BaseTradeDecision = None, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + trade_exchange: Exchange = None, + ) -> None: + super().__init__( + outer_trade_decision=outer_trade_decision, + level_infra=level_infra, + common_infra=common_infra, + trade_exchange=trade_exchange, + ) + + self._strategies = [init_instance_by_config(strategy, accept_types=BaseStrategy) for strategy in strategies] + + def set_env(self, env: EnvWrapper | CollectDataEnvWrapper) -> None: + for strategy in self._strategies: + if hasattr(strategy, "set_env"): + strategy.set_env(env) + + +class MultiplexStrategyOnTradeStep(MultiplexStrategyBase): + """To use different strategy on different step of the outer calendar""" + + def __init__( + self, + strategies: List[BaseStrategy] | List[dict], + outer_trade_decision: BaseTradeDecision = None, + level_infra: LevelInfrastructure = None, + common_infra: CommonInfrastructure = None, + trade_exchange: Exchange = None, + ) -> None: + super(MultiplexStrategyOnTradeStep, self).__init__( + strategies=strategies, + outer_trade_decision=outer_trade_decision, + level_infra=level_infra, + common_infra=common_infra, + trade_exchange=trade_exchange, + ) + + def reset_level_infra(self, level_infra: LevelInfrastructure) -> None: + for strategy in self._strategies: + strategy.reset_level_infra(level_infra) + + def reset_common_infra(self, common_infra: CommonInfrastructure) -> None: + for strategy in self._strategies: + strategy.reset_common_infra(common_infra) + + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: + super().reset(outer_trade_decision=outer_trade_decision, **kwargs) + + if outer_trade_decision is not None: + strategy = self._get_current_strategy() + strategy.reset(outer_trade_decision=outer_trade_decision, **kwargs) + + def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: + if self.outer_trade_decision is not None: + strategy = self._get_current_strategy() + return strategy.generate_trade_decision(execute_result=execute_result) + else: + return TradeDecisionWO([], self) + + def post_exe_step(self, execute_result: list) -> None: + if self.outer_trade_decision is not None: + strategy = self._get_current_strategy() + if isinstance(strategy, RLStrategy): + strategy.post_exe_step(execute_result=execute_result) + + def _get_current_strategy(self) -> BaseStrategy: + outer_calendar = self.outer_trade_decision.strategy.trade_calendar + return self._strategies[outer_calendar.get_trade_step()] diff --git a/qlib/rl/trainer/__init__.py b/qlib/rl/trainer/__init__.py index efce804c41..0a197b3781 100644 --- a/qlib/rl/trainer/__init__.py +++ b/qlib/rl/trainer/__init__.py @@ -4,6 +4,6 @@ """Train, test, inference utilities.""" from .api import backtest, train -from .callbacks import EarlyStopping, Checkpoint +from .callbacks import Checkpoint, EarlyStopping from .trainer import Trainer from .vessel import TrainingVessel, TrainingVesselBase diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index 529bfe5973..84ed2cfbd4 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -249,3 +249,24 @@ def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, fl def render(self, mode: str = "human") -> None: raise NotImplementedError("Render is not implemented in EnvWrapper.") + + +class CollectDataEnvWrapper: + """Dummy EnvWrapper for collect_data_loop. It only has minimium interfaces to support the collect_data_loop.""" + + def __init__(self) -> None: + self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) + self.reset() + + def reset(self, **kwargs: Any) -> None: + self.status = EnvWrapperStatus( + cur_step=0, + done=False, + initial_state=None, + obs_history=[], + action_history=[], + reward_history=[], + ) + + def step(self, policy_action: Any = None, **kwargs: Any) -> None: + self.status["cur_step"] += 1 diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 78df41690a..1abddaf191 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -146,7 +146,7 @@ def test_interpreter(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + interpreter = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) interpreter_step = CurrentStepStateInterpreter(13) interpreter_action = CategoricalActionInterpreter(20) interpreter_action_twap = TwapRelativeActionInterpreter() @@ -225,7 +225,7 @@ def test_network_sanity(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + interpreter = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(13) wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) @@ -253,7 +253,7 @@ def test_twap_strategy(finite_env_type): orders = pickle_styled.load_orders(ORDER_DIR) assert len(orders) == 248 - state_interp = FullHistoryStateInterpreter(FEATURE_DATA_DIR, 13, 390, 5) + state_interp = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) action_interp = TwapRelativeActionInterpreter() policy = AllOne(state_interp.observation_space, action_interp.action_space) csv_writer = CsvWriter(Path(__file__).parent / ".output") @@ -282,7 +282,7 @@ def test_cn_ppo_strategy(): orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6) + state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(4) network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) @@ -313,7 +313,7 @@ def test_ppo_train(): orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(CN_FEATURE_DATA_DIR, 8, 240, 6) + state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(4) network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) From 32c24940f2cd6f2dff4ed1587ee75d3cbe3c05eb Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 24 Aug 2022 16:26:03 +0800 Subject: [PATCH 02/15] Minor bug fix in test --- tests/rl/test_qlib_simulator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index b7d548e9ea..0a42d71394 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -11,6 +11,7 @@ from qlib.backtest.executor import SimulatorExecutor from qlib.rl.order_execution import CategoricalActionInterpreter from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper TOTAL_POSITION = 2100.0 @@ -192,6 +193,7 @@ def test_interpreter() -> None: order = get_order() simulator = get_simulator(order) interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION) + interpreter_action.env = CollectDataEnvWrapper() NUM_STEPS = 7 state = simulator.get_state() From d0cdffbf5ff452d97c3ef136414ad66e4b53acca Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 24 Aug 2022 19:17:15 +0800 Subject: [PATCH 03/15] Reorganize file to avoid loop import --- qlib/rl/amc4th_migration/backtest.py | 2 +- .../{order_execution => data}/integration.py | 0 qlib/rl/data/pickle_styled.py | 4 ++-- qlib/rl/order_execution/simulator_qlib.py | 2 +- qlib/rl/order_execution/simulator_simple.py | 4 ++-- qlib/rl/order_execution/state.py | 7 +++++-- tests/rl/test_saoe_simple.py | 20 +++++++++---------- 7 files changed, 21 insertions(+), 18 deletions(-) rename qlib/rl/{order_execution => data}/integration.py (100%) diff --git a/qlib/rl/amc4th_migration/backtest.py b/qlib/rl/amc4th_migration/backtest.py index 4e88d71d57..0ff565db4a 100644 --- a/qlib/rl/amc4th_migration/backtest.py +++ b/qlib/rl/amc4th_migration/backtest.py @@ -19,7 +19,7 @@ from qlib.backtest.high_performance_ds import BaseOrderIndicator from qlib.rl.amc4th_migration.naive_config_parser import convert_instance_config, get_backtest_config_fromfile from qlib.rl.amc4th_migration.utils import read_order_file -from qlib.rl.order_execution.integration import init_qlib +from qlib.rl.data.integration import init_qlib from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper diff --git a/qlib/rl/order_execution/integration.py b/qlib/rl/data/integration.py similarity index 100% rename from qlib/rl/order_execution/integration.py rename to qlib/rl/data/integration.py diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 7856cb5a16..25db10c32e 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -30,7 +30,7 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir -from qlib.rl.order_execution.integration import fetch_features +from qlib.rl.data.integration import fetch_features from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -279,7 +279,7 @@ def load_intraday_processed_data( feature_dim: int, time_index: pd.Index, ) -> BaseIntradayProcessedData: - from qlib.rl.order_execution.integration import dataset # pylint: disable=C0415 + from qlib.rl.data.integration import dataset # pylint: disable=C0415 if dataset is None: assert data_dir is not None diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 3002fd333e..718c2ba572 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -11,7 +11,7 @@ from qlib.backtest.executor import NestedExecutor from qlib.rl.simulator import Simulator -from .integration import init_qlib +from qlib.rl.data.integration import init_qlib from .state import SAOEState, SAOEStateAdapter from .strategy import SAOEStrategy diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index f95aeebad0..17efb4b093 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -18,10 +18,10 @@ # TODO: Integrating Qlib's native data with simulator_simple -__all__ = ["SingleAssetOrderExecution"] +__all__ = ["SingleAssetOrderExecutionSimple"] -class SingleAssetOrderExecution(Simulator[Order, SAOEState, float]): +class SingleAssetOrderExecutionSimple(Simulator[Order, SAOEState, float]): """Single-asset order execution (SAOE) simulator. As there's no "calendar" in the simple simulator, ticks are used to trade. diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index d1fd85502c..dfa125cafa 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -3,6 +3,7 @@ from __future__ import annotations +import typing from typing import cast, NamedTuple, Optional, Tuple import numpy as np @@ -10,12 +11,14 @@ from qlib.backtest import Exchange, Order from qlib.backtest.executor import BaseExecutor from qlib.constant import EPS, ONE_MIN, REG_CN -from qlib.rl.data.exchange_wrapper import IntradayBacktestData -from qlib.rl.data.pickle_styled import BaseIntradayBacktestData from qlib.rl.order_execution.utils import dataframe_append, price_advantage from qlib.typehint import TypedDict from qlib.utils.time import get_day_min_idx_range +if typing.TYPE_CHECKING: + from qlib.rl.data.exchange_wrapper import IntradayBacktestData + from qlib.rl.data.pickle_styled import BaseIntradayBacktestData + def _get_all_timestamps( start: pd.Timestamp, diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 1abddaf191..8bcc04aaad 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -49,7 +49,7 @@ def test_pickle_data_inspect(): def test_simulator_first_step(): order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) state = simulator.get_state() assert state.cur_time == pd.Timestamp("2013-12-11 09:30:00") assert state.position == 30.0 @@ -83,7 +83,7 @@ def test_simulator_first_step(): def test_simulator_stop_twap(): order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) for _ in range(13): simulator.step(1.0) @@ -106,10 +106,10 @@ def test_simulator_stop_early(): order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) with pytest.raises(ValueError): - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) simulator.step(2.0) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) simulator.step(1.0) with pytest.raises(AssertionError): @@ -119,7 +119,7 @@ def test_simulator_stop_early(): def test_simulator_start_middle(): order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) assert len(simulator.ticks_for_order) == 330 assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") simulator.step(2.0) @@ -138,7 +138,7 @@ def test_simulator_start_middle(): def test_interpreter(): order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) assert len(simulator.ticks_for_order) == 330 assert simulator.cur_time == pd.Timestamp("2013-12-11 10:15:00") @@ -219,7 +219,7 @@ def test_network_sanity(): # we won't check the correctness of networks here order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59")) - simulator = SingleAssetOrderExecution(order, BACKTEST_DATA_DIR) + simulator = SingleAssetOrderExecutionSimple(order, BACKTEST_DATA_DIR) assert len(simulator.ticks_for_order) == 390 class EmulateEnvWrapper(NamedTuple): @@ -259,7 +259,7 @@ def test_twap_strategy(finite_env_type): csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecution, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=BACKTEST_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, @@ -290,7 +290,7 @@ def test_cn_ppo_strategy(): csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, @@ -319,7 +319,7 @@ def test_ppo_train(): policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) train( - partial(SingleAssetOrderExecution, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), + partial(SingleAssetOrderExecutionSimple, data_dir=CN_BACKTEST_DATA_DIR, ticks_per_step=30), state_interp, action_interp, orders, From 28f88cb8726ef5e1b9d84dd738b9759264e0132d Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Wed, 24 Aug 2022 20:18:15 +0800 Subject: [PATCH 04/15] Fix test SAOE bug --- tests/rl/test_saoe_simple.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 8bcc04aaad..95fa7a2958 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -19,6 +19,7 @@ from qlib.rl.order_execution import * from qlib.rl.trainer import backtest, train from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus +from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") @@ -185,6 +186,8 @@ class EmulateEnvWrapper(NamedTuple): assert np.sum(obs["data_processed"][60:]) == 0 # second step: action + interpreter_action.env = CollectDataEnvWrapper() + interpreter_action_twap.env = CollectDataEnvWrapper() action = interpreter_action(simulator.get_state(), 1) assert action == 15 / 20 @@ -255,6 +258,7 @@ def test_twap_strategy(finite_env_type): state_interp = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) action_interp = TwapRelativeActionInterpreter() + action_interp.env = CollectDataEnvWrapper() policy = AllOne(state_interp.observation_space, action_interp.action_space) csv_writer = CsvWriter(Path(__file__).parent / ".output") @@ -284,6 +288,7 @@ def test_cn_ppo_strategy(): state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(4) + action_interp.env = CollectDataEnvWrapper() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) @@ -315,6 +320,7 @@ def test_ppo_train(): state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(4) + action_interp.env = CollectDataEnvWrapper() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) From 9b5db217bd54136a2205b702b5e851cbb62c7a87 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Tue, 30 Aug 2022 15:29:24 +0800 Subject: [PATCH 05/15] Remove unnecessary names --- .../experiments/opds_15_225/opds_15_225.py | 52 ------------------- .../experiments/opds_15_225/twap.yml | 25 --------- .../backtest.py | 4 +- .../naive_config_parser.py | 0 .../utils.py | 0 5 files changed, 2 insertions(+), 79 deletions(-) delete mode 100644 qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py delete mode 100644 qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml rename qlib/rl/{amc4th_migration => neural_trading_migration}/backtest.py (97%) rename qlib/rl/{amc4th_migration => neural_trading_migration}/naive_config_parser.py (100%) rename qlib/rl/{amc4th_migration => neural_trading_migration}/utils.py (100%) diff --git a/qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py b/qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py deleted file mode 100644 index 615749922d..0000000000 --- a/qlib/rl/amc4th_migration/experiments/opds_15_225/opds_15_225.py +++ /dev/null @@ -1,52 +0,0 @@ -_base_ = ["./twap.yml"] - -strategies = { - "_delete_": True, - "5min": {"type": "qlib.contrib.strategy.rule_strategy.TWAPStrategy"}, - "30min": {"type": "qlib.rl.order_execution.strategy.MultiplexStrategyOnTradeStep", "strategies": []}, - "1day": { - "type": "qlib.rl.order_execution.strategy.SAOEIntStrategy", - "state_interpreter": { - "type": "qlib.rl.order_execution.interpreter.FullHistoryStateInterpreter", - "max_step": 8, - "data_ticks": 240, - "data_dim": 16, - }, - "action_interpreter": { - "type": "qlib.rl.order_execution.interpreter.CategoricalActionInterpreter", - "values": 4, - "max_step": 8, - }, - "network": { - "type": "qlib.rl.order_execution.network.DualAttentionRNN", - }, - "policy": { - "type": "qlib.rl.order_execution.policy.PPO", - "lr": 1.0e-4, - "weight_file": "data/amc-checkpoints/opds_15_225/opds_15_225_30r_4_80", - }, - }, -} - -import copy - -# for mypy -assert isinstance(strategies["1day"], dict) -assert isinstance(strategies["30min"], dict) - -for step in range(8): - step_start, step_end = max(15, step * 30), min(225, step * 30 + 30) - num_inner_steps = (step_end - step_start + 5 - 1) // 5 - strategy: dict = copy.deepcopy(strategies["1day"]) - strategy["state_interpreter"]["max_step"] = num_inner_steps - action_values = [3, 6, 6, 6, 6, 6, 6, 3] - - strategy["network"] = {"type": "qlib.rl.order_execution.network.DualAttentionRNN"} - strategy["action_interpreter"]["values"] = action_values[step] - strategy["action_interpreter"]["max_step"] = num_inner_steps - strategy["policy"]["weight_file"] = f"data/amc-checkpoints/opds_15_225/opds_{step_start}_{step_end}" - - strategies["30min"]["strategies"].append(strategy) - - -del copy, step, step_start, step_end, num_inner_steps, strategy, action_values diff --git a/qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml b/qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml deleted file mode 100644 index 118cd29964..0000000000 --- a/qlib/rl/amc4th_migration/experiments/opds_15_225/twap.yml +++ /dev/null @@ -1,25 +0,0 @@ -order_file: data/amc-real-order/orders_v4/csi300_nostop.pkl -start_time: "9:45" -end_time: "14:44" -qlib: - provider_uri_day: data/amc-qlib/huaxia_1d_qlib - provider_uri_1min: data/amc-qlib/huaxia_1min_qlib - feature_root_dir: data/amc-qlib-stock - feature_columns_today: [ - "$open", "$high", "$low", "$close", "$vwap", "$bid", "$ask", "$volume", - "$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5" - ] - feature_columns_yesterday: [ - "$open_1", "$high_1", "$low_1", "$close_1", "$vwap_1", "$bid_1", "$ask_1", "$volume_1", - "$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1" - ] -exchange: - limit_threshold: ['$ask == 0', '$bid == 0'] - deal_price: ["If($ask == 0, $bid, $ask)", "If($bid == 0, $ask, $bid)"] - volume_threshold: - all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"] - buy: ["current", "$askV1"] - sell: ["current", "$bidV1"] -strategies: - 1day: - type: neutrader.qlib_integration.strategy.TWAPStrategy diff --git a/qlib/rl/amc4th_migration/backtest.py b/qlib/rl/neural_trading_migration/backtest.py similarity index 97% rename from qlib/rl/amc4th_migration/backtest.py rename to qlib/rl/neural_trading_migration/backtest.py index 0ff565db4a..e6c309033c 100644 --- a/qlib/rl/amc4th_migration/backtest.py +++ b/qlib/rl/neural_trading_migration/backtest.py @@ -17,8 +17,8 @@ from qlib.backtest.decision import TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator -from qlib.rl.amc4th_migration.naive_config_parser import convert_instance_config, get_backtest_config_fromfile -from qlib.rl.amc4th_migration.utils import read_order_file +from qlib.rl.neural_trading_migration.naive_config_parser import convert_instance_config, get_backtest_config_fromfile +from qlib.rl.neural_trading_migration.utils import read_order_file from qlib.rl.data.integration import init_qlib from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper diff --git a/qlib/rl/amc4th_migration/naive_config_parser.py b/qlib/rl/neural_trading_migration/naive_config_parser.py similarity index 100% rename from qlib/rl/amc4th_migration/naive_config_parser.py rename to qlib/rl/neural_trading_migration/naive_config_parser.py diff --git a/qlib/rl/amc4th_migration/utils.py b/qlib/rl/neural_trading_migration/utils.py similarity index 100% rename from qlib/rl/amc4th_migration/utils.py rename to qlib/rl/neural_trading_migration/utils.py From 6e09470aaa4772b7968f1523fcee3cd67f16f9b0 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 8 Sep 2022 11:30:12 +0800 Subject: [PATCH 06/15] Resolve PR comments; remove private classes; --- qlib/rl/order_execution/interpreter.py | 6 +- qlib/rl/order_execution/network.py | 25 --------- qlib/rl/order_execution/strategy.py | 76 -------------------------- 3 files changed, 4 insertions(+), 103 deletions(-) diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index a074322402..aacfe86fee 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -5,7 +5,7 @@ import math from pathlib import Path -from typing import Any, List, cast +from typing import Any, List, Optional, cast import numpy as np import pandas as pd @@ -185,9 +185,11 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): Then when policy givens decision $x$, $a_x$ times order amount is the output. It can also be an integer $n$, in which case the list of length $n+1$ is auto-generated, i.e., $[0, 1/n, 2/n, \\ldots, n/n]$. + max_step + Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. """ - def __init__(self, values: int | List[float], max_step: int = None) -> None: + def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None: if isinstance(values, int): values = [i / values for i in range(0, values + 1)] self.action_values = values diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index b1467782c0..d6a11189cf 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -138,28 +138,3 @@ def forward(self, Q, K, V): attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v) return attn_vec - - -class DualAttentionRNN(Recurrent): - """ - Dual-attention RNN leverages features from yesterday and fuses them into features today. - """ - - def _init_extra_branches(self): - self.attention = Attention(self.hidden_dim, self.hidden_dim) - self.num_sources += 1 - - def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: - sources, data_out = super()._source_features(obs, device) - - data_prev = obs["data_processed_prev"] - cur_time = obs["cur_tick"].long() - bs_indices = torch.arange(cur_time.size(0), device=device) - - data_prev_in = self.raw_fc(data_prev) - data_prev_out, _ = self.prev_rnn(data_prev_in) - att_out = self.attention(data_out, data_prev_out, data_prev_out) - att_out = att_out[bs_indices, cur_time] - sources.insert(1, att_out) - - return sources, data_out diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 0327f65ac5..308940ec99 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -254,79 +254,3 @@ def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDeci order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) return TradeDecisionWO(order_list=order_list, strategy=self) - - -class MultiplexStrategyBase(BaseStrategy, metaclass=ABCMeta): - def __init__( - self, - strategies: List[BaseStrategy] | List[dict], - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - trade_exchange: Exchange = None, - ) -> None: - super().__init__( - outer_trade_decision=outer_trade_decision, - level_infra=level_infra, - common_infra=common_infra, - trade_exchange=trade_exchange, - ) - - self._strategies = [init_instance_by_config(strategy, accept_types=BaseStrategy) for strategy in strategies] - - def set_env(self, env: EnvWrapper | CollectDataEnvWrapper) -> None: - for strategy in self._strategies: - if hasattr(strategy, "set_env"): - strategy.set_env(env) - - -class MultiplexStrategyOnTradeStep(MultiplexStrategyBase): - """To use different strategy on different step of the outer calendar""" - - def __init__( - self, - strategies: List[BaseStrategy] | List[dict], - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - trade_exchange: Exchange = None, - ) -> None: - super(MultiplexStrategyOnTradeStep, self).__init__( - strategies=strategies, - outer_trade_decision=outer_trade_decision, - level_infra=level_infra, - common_infra=common_infra, - trade_exchange=trade_exchange, - ) - - def reset_level_infra(self, level_infra: LevelInfrastructure) -> None: - for strategy in self._strategies: - strategy.reset_level_infra(level_infra) - - def reset_common_infra(self, common_infra: CommonInfrastructure) -> None: - for strategy in self._strategies: - strategy.reset_common_infra(common_infra) - - def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: - super().reset(outer_trade_decision=outer_trade_decision, **kwargs) - - if outer_trade_decision is not None: - strategy = self._get_current_strategy() - strategy.reset(outer_trade_decision=outer_trade_decision, **kwargs) - - def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: - if self.outer_trade_decision is not None: - strategy = self._get_current_strategy() - return strategy.generate_trade_decision(execute_result=execute_result) - else: - return TradeDecisionWO([], self) - - def post_exe_step(self, execute_result: list) -> None: - if self.outer_trade_decision is not None: - strategy = self._get_current_strategy() - if isinstance(strategy, RLStrategy): - strategy.post_exe_step(execute_result=execute_result) - - def _get_current_strategy(self) -> BaseStrategy: - outer_calendar = self.outer_trade_decision.strategy.trade_calendar - return self._strategies[outer_calendar.get_trade_step()] From 0d84962fe7f23db5df078eef03bc7b8fd168eaea Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 8 Sep 2022 11:39:27 +0800 Subject: [PATCH 07/15] Fix CI error --- qlib/rl/order_execution/strategy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 308940ec99..747d1d7270 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -4,13 +4,15 @@ from __future__ import annotations import collections -from abc import ABCMeta from types import GeneratorType -from typing import Any, cast, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, Optional, Union, cast import pandas as pd import torch -from qlib.backtest import CommonInfrastructure, Exchange, Order +from tianshou.data import Batch +from tianshou.policy import BasePolicy + +from qlib.backtest import CommonInfrastructure, Order from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN @@ -19,10 +21,8 @@ from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter from qlib.rl.utils import EnvWrapper from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper -from qlib.strategy.base import BaseStrategy, RLStrategy +from qlib.strategy.base import RLStrategy from qlib.utils import init_instance_by_config -from tianshou.data import Batch -from tianshou.policy import BasePolicy class SAOEStrategy(RLStrategy): From 523571b5b4658b5df6ee6d6bf0fd2d7330dc74c0 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Thu, 15 Sep 2022 11:21:30 +0800 Subject: [PATCH 08/15] Resolve PR comments --- .../backtest.py | 8 +++-- .../naive_config_parser.py | 0 .../utils.py | 0 qlib/rl/interpreter.py | 9 +++-- qlib/rl/order_execution/interpreter.py | 2 ++ qlib/rl/order_execution/strategy.py | 16 +++++---- qlib/rl/utils/env_wrapper.py | 33 ++++++++++++------- tests/rl/test_qlib_simulator.py | 1 + tests/rl/test_saoe_simple.py | 5 +++ 9 files changed, 48 insertions(+), 26 deletions(-) rename qlib/rl/{neural_trading_migration => contrib}/backtest.py (96%) rename qlib/rl/{neural_trading_migration => contrib}/naive_config_parser.py (100%) rename qlib/rl/{neural_trading_migration => contrib}/utils.py (100%) diff --git a/qlib/rl/neural_trading_migration/backtest.py b/qlib/rl/contrib/backtest.py similarity index 96% rename from qlib/rl/neural_trading_migration/backtest.py rename to qlib/rl/contrib/backtest.py index e6c309033c..1185fd9bbc 100644 --- a/qlib/rl/neural_trading_migration/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -17,8 +17,8 @@ from qlib.backtest.decision import TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator -from qlib.rl.neural_trading_migration.naive_config_parser import convert_instance_config, get_backtest_config_fromfile -from qlib.rl.neural_trading_migration.utils import read_order_file +from qlib.rl.contrib.naive_config_parser import convert_instance_config, get_backtest_config_fromfile +from qlib.rl.contrib.utils import read_order_file from qlib.rl.data.integration import init_qlib from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper @@ -61,7 +61,9 @@ def _get_multi_level_executor_config( def _set_env_for_all_strategy(executor: BaseExecutor) -> None: if isinstance(executor, NestedExecutor): if hasattr(executor.inner_strategy, "set_env"): - executor.inner_strategy.set_env(CollectDataEnvWrapper()) + env = CollectDataEnvWrapper() + env.reset() + executor.inner_strategy.set_env(env) _set_env_for_all_strategy(executor.inner_executor) diff --git a/qlib/rl/neural_trading_migration/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py similarity index 100% rename from qlib/rl/neural_trading_migration/naive_config_parser.py rename to qlib/rl/contrib/naive_config_parser.py diff --git a/qlib/rl/neural_trading_migration/utils.py b/qlib/rl/contrib/utils.py similarity index 100% rename from qlib/rl/neural_trading_migration/utils.py rename to qlib/rl/contrib/utils.py diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 5ff2780db7..d2d81f81cd 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -3,16 +3,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import numpy as np from qlib.typehint import final - from .simulator import ActType, StateType if TYPE_CHECKING: - from .utils.env_wrapper import CollectDataEnvWrapper, EnvWrapper + from .utils.env_wrapper import BaseEnvWrapper import gym from gym import spaces @@ -40,7 +39,7 @@ class Interpreter: class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" - env: Union[EnvWrapper, CollectDataEnvWrapper, None] = None + env: Optional[BaseEnvWrapper] = None @property def observation_space(self) -> gym.Space: @@ -74,7 +73,7 @@ def interpret(self, simulator_state: StateType) -> ObsType: class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter): """Action Interpreter that interpret rl agent action into qlib orders""" - env: Union[EnvWrapper, CollectDataEnvWrapper, None] = None + env: Optional[BaseEnvWrapper] = None @property def action_space(self) -> gym.Space: diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index aacfe86fee..331967e6c9 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -69,6 +69,8 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): path to load data. For example, the data has already been preloaded in `init_qlib()`. """ + # TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case. + # TODO: So it should be redesigned after the data interface is well-designed. def __init__(self, max_step: int, data_ticks: int, data_dim: int, data_dir: Path = None) -> None: self.data_dir = data_dir self.max_step = max_step diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 747d1d7270..136266a5b1 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -5,7 +5,7 @@ import collections from types import GeneratorType -from typing import Any, Dict, Generator, Optional, Union, cast +from typing import Any, cast, Dict, Generator, Optional, Union import pandas as pd import torch @@ -19,8 +19,7 @@ from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter -from qlib.rl.utils import EnvWrapper -from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper +from qlib.rl.utils.env_wrapper import BaseEnvWrapper from qlib.strategy.base import RLStrategy from qlib.utils import init_instance_by_config @@ -170,6 +169,7 @@ def __init__( outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, + backtest: bool = False, **kwargs: Any, ) -> None: super(SAOEIntStrategy, self).__init__( @@ -180,6 +180,8 @@ def __init__( **kwargs, ) + self._backtest = backtest + self._state_interpreter: StateInterpreter = init_instance_by_config( state_interpreter, accept_types=StateInterpreter, @@ -218,14 +220,15 @@ def __init__( if self._policy is not None: self._policy.eval() - def set_env(self, env: EnvWrapper | CollectDataEnvWrapper) -> None: + def set_env(self, env: BaseEnvWrapper) -> None: self._env = env self._state_interpreter.env = self._action_interpreter.env = self._env def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs: Any) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) - if isinstance(self._env, CollectDataEnvWrapper): + # In backtest, env.reset() needs to be manually called since there is no outer trainer to call it + if self._backtest: self._env.reset() def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: @@ -243,7 +246,8 @@ def _generate_trade_decision(self, execute_result: list = None) -> BaseTradeDeci act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)] - if isinstance(self._env, CollectDataEnvWrapper): + # In backtest, env.step() needs to be manually called since there is no outer trainer to call it + if self._backtest: self._env.step(None) oh = self.trade_exchange.get_order_helper() diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index 84ed2cfbd4..f082f3b013 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -4,7 +4,7 @@ from __future__ import annotations import weakref -from typing import Any, Callable, Dict, Generic, Iterable, Iterator, Optional, Tuple, cast +from typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple import gym from gym import Space @@ -14,7 +14,6 @@ from qlib.rl.reward import Reward from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType from qlib.typehint import TypedDict - from .finite_env import generate_nan_observation from .log import LogCollector, LogLevel @@ -49,9 +48,24 @@ class EnvWrapperStatus(TypedDict): reward_history: list -class EnvWrapper( +class BaseEnvWrapper( gym.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType], +): + """Base env wrapper for RL environments. It has two implementations: + - EnvWrapper: Qlib-based RL environment used in training. + - CollectDataEnvWrapper: Dummy environment used in collect_data_loop. + """ + + def __init__(self) -> None: + self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) + + def render(self, mode: str = "human") -> None: + raise NotImplementedError("Render is not implemented in BaseEnvWrapper.") + + +class EnvWrapper( + BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType], ): """Qlib-based RL environment, subclassing ``gym.Env``. A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. @@ -115,6 +129,8 @@ def __init__( # 3. Avoid circular reference. # 4. When the components get serialized, we can throw away the env without any burden. # (though this part is not implemented yet) + super().__init__() + for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]: if obj is not None: obj.env = weakref.proxy(self) # type: ignore @@ -247,16 +263,9 @@ def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, fl info_dict = InfoDict(log=self.logger.logs(), aux_info=aux_info) return obs, rew, done, info_dict - def render(self, mode: str = "human") -> None: - raise NotImplementedError("Render is not implemented in EnvWrapper.") - -class CollectDataEnvWrapper: - """Dummy EnvWrapper for collect_data_loop. It only has minimium interfaces to support the collect_data_loop.""" - - def __init__(self) -> None: - self.status: EnvWrapperStatus = cast(EnvWrapperStatus, None) - self.reset() +class CollectDataEnvWrapper(BaseEnvWrapper[InitialStateType, StateType, ActType, ObsType, PolicyActType]): + """Dummy EnvWrapper for collect_data_loop. It only has minimum interfaces to support the collect_data_loop.""" def reset(self, **kwargs: Any) -> None: self.status = EnvWrapperStatus( diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index 0a42d71394..14bf8b5a11 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -194,6 +194,7 @@ def test_interpreter() -> None: simulator = get_simulator(order) interpreter_action = CategoricalActionInterpreter(values=NUM_EXECUTION) interpreter_action.env = CollectDataEnvWrapper() + interpreter_action.env.reset() NUM_STEPS = 7 state = simulator.get_state() diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 95fa7a2958..9bf86e018f 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -188,6 +188,8 @@ class EmulateEnvWrapper(NamedTuple): # second step: action interpreter_action.env = CollectDataEnvWrapper() interpreter_action_twap.env = CollectDataEnvWrapper() + interpreter_action.env.reset() + interpreter_action_twap.env.reset() action = interpreter_action(simulator.get_state(), 1) assert action == 15 / 20 @@ -259,6 +261,7 @@ def test_twap_strategy(finite_env_type): state_interp = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) action_interp = TwapRelativeActionInterpreter() action_interp.env = CollectDataEnvWrapper() + action_interp.env.reset() policy = AllOne(state_interp.observation_space, action_interp.action_space) csv_writer = CsvWriter(Path(__file__).parent / ".output") @@ -289,6 +292,7 @@ def test_cn_ppo_strategy(): state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(4) action_interp.env = CollectDataEnvWrapper() + action_interp.env.reset() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) @@ -321,6 +325,7 @@ def test_ppo_train(): state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) action_interp = CategoricalActionInterpreter(4) action_interp.env = CollectDataEnvWrapper() + action_interp.env.reset() network = Recurrent(state_interp.observation_space) policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) From c9534ff3010278891db5b00bebedc5099f12cb9e Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 16 Sep 2022 14:07:57 +0800 Subject: [PATCH 09/15] Refactor data interfaces --- qlib/rl/data/base.py | 65 +++++++++++ .../data/{exchange_wrapper.py => native.py} | 43 +++++++- qlib/rl/data/pickle_styled.py | 103 +++++------------- qlib/rl/order_execution/interpreter.py | 34 +++--- qlib/rl/order_execution/state.py | 4 +- qlib/rl/order_execution/strategy.py | 2 +- tests/rl/test_saoe_simple.py | 16 +-- 7 files changed, 166 insertions(+), 101 deletions(-) create mode 100644 qlib/rl/data/base.py rename qlib/rl/data/{exchange_wrapper.py => native.py} (67%) diff --git a/qlib/rl/data/base.py b/qlib/rl/data/base.py new file mode 100644 index 0000000000..e258abe869 --- /dev/null +++ b/qlib/rl/data/base.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from abc import abstractmethod + +import pandas as pd + + +class BaseIntradayBacktestData: + """ + Raw market data that is often used in backtesting (thus called BacktestData). + + Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest + data type. + """ + + @abstractmethod + def __repr__(self) -> str: + raise NotImplementedError + + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_deal_price(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_volume(self) -> pd.Series: + raise NotImplementedError + + @abstractmethod + def get_time_index(self) -> pd.DatetimeIndex: + raise NotImplementedError + + +class BaseIntradayProcessedData: + """Processed market data after data cleanup and feature engineering. + + It contains both processed data for "today" and "yesterday", as some algorithms + might use the market information of the previous day to assist decision making. + """ + + today: pd.DataFrame + """Processed data for "today". + Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + + yesterday: pd.DataFrame + """Processed data for "yesterday". + Number of records must be ``time_length``, and columns must be ``feature_dim``.""" + + +class ProcessedDataProvider: + """Provider of processed data""" + + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + raise NotImplementedError diff --git a/qlib/rl/data/exchange_wrapper.py b/qlib/rl/data/native.py similarity index 67% rename from qlib/rl/data/exchange_wrapper.py rename to qlib/rl/data/native.py index 004074d0b8..ca97a54266 100644 --- a/qlib/rl/data/exchange_wrapper.py +++ b/qlib/rl/data/native.py @@ -1,17 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations from typing import cast import cachetools import pandas as pd +from cachetools.keys import hashkey + from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime from qlib.constant import EPS_T, ONE_DAY from qlib.rl.order_execution.utils import get_ticks_slice from qlib.utils.index_data import IndexData -from .pickle_styled import BaseIntradayBacktestData +from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider +from .integration import fetch_features class IntradayBacktestData(BaseIntradayBacktestData): @@ -108,3 +112,40 @@ def load_qlib_backtest_data( ticks_for_order=ticks_for_order, ) return backtest_data + + +class NTIntradayProcessedData(BaseIntradayProcessedData): + """Subclass of IntradayProcessedData. Used to handle NT style data.""" + + def __init__( + self, + stock_id: str, + date: pd.Timestamp, + ) -> None: + def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: + return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"]) + + self.today = _drop_stock_id(fetch_features(stock_id, date)) + self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) + + def __repr__(self) -> str: + with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + return f"{self.__class__.__name__}({self.today}, {self.yesterday})" + + +@cachetools.cached( # type: ignore + cache=cachetools.LRUCache(100), # 100 * 50K = 5MB +) +def load_nt_intraday_processed_data(stock_id: str, date: pd.Timestamp) -> NTIntradayProcessedData: + return NTIntradayProcessedData(stock_id, date) + + +class NTProcessedDataProvider(ProcessedDataProvider): + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + return load_nt_intraday_processed_data(stock_id, date) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 25db10c32e..ed62a4180d 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -19,10 +19,9 @@ from __future__ import annotations -from abc import abstractmethod from functools import lru_cache from pathlib import Path -from typing import List, Optional, Sequence, cast +from typing import List, Sequence, cast import cachetools import numpy as np @@ -30,7 +29,7 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir -from qlib.rl.data.integration import fetch_features +from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -87,35 +86,6 @@ def _read_pickle(filename_without_suffix: Path) -> pd.DataFrame: return pd.read_pickle(_find_pickle(filename_without_suffix)) -class BaseIntradayBacktestData: - """ - Raw market data that is often used in backtesting (thus called BacktestData). - - Base class for all types of backtest data. Currently, each type of simulator has its corresponding backtest - data type. - """ - - @abstractmethod - def __repr__(self) -> str: - raise NotImplementedError - - @abstractmethod - def __len__(self) -> int: - raise NotImplementedError - - @abstractmethod - def get_deal_price(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_volume(self) -> pd.Series: - raise NotImplementedError - - @abstractmethod - def get_time_index(self) -> pd.DatetimeIndex: - raise NotImplementedError - - class SimpleIntradayBacktestData(BaseIntradayBacktestData): """Backtest data for simple simulator""" @@ -179,22 +149,6 @@ def get_time_index(self) -> pd.DatetimeIndex: return cast(pd.DatetimeIndex, self.data.index) -class BaseIntradayProcessedData: - """Processed market data after data cleanup and feature engineering. - - It contains both processed data for "today" and "yesterday", as some algorithms - might use the market information of the previous day to assist decision making. - """ - - today: pd.DataFrame - """Processed data for "today". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - - yesterday: pd.DataFrame - """Processed data for "yesterday". - Number of records must be ``time_length``, and columns must be ``feature_dim``.""" - - class IntradayProcessedData(BaseIntradayProcessedData): """Subclass of IntradayProcessedData. Used to handle Dataset Handler style data.""" @@ -238,25 +192,6 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.today}, {self.yesterday})" -class NTIntradayProcessedData(BaseIntradayProcessedData): - """Subclass of IntradayProcessedData. Used to handle NT style data.""" - - def __init__( - self, - stock_id: str, - date: pd.Timestamp, - ) -> None: - def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: - return df.reset_index().drop(columns=["instrument"]).set_index(["datetime"]) - - self.today = _drop_stock_id(fetch_features(stock_id, date)) - self.yesterday = _drop_stock_id(fetch_features(stock_id, date, yesterday=True)) - - def __repr__(self) -> str: - with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): - return f"{self.__class__.__name__}({self.today}, {self.yesterday})" - - @lru_cache(maxsize=100) # 100 * 50K = 5MB def load_simple_intraday_backtest_data( data_dir: Path, @@ -270,22 +205,38 @@ def load_simple_intraday_backtest_data( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, _, __: hashkey(data_dir, stock_id, date), + key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), ) -def load_intraday_processed_data( - data_dir: Optional[Path], +def load_pickled_intraday_processed_data( + data_dir: Path, stock_id: str, date: pd.Timestamp, feature_dim: int, time_index: pd.Index, ) -> BaseIntradayProcessedData: - from qlib.rl.data.integration import dataset # pylint: disable=C0415 + return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) - if dataset is None: - assert data_dir is not None - return IntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) - else: - return NTIntradayProcessedData(stock_id, date) + +class PickleProcessedDataProvider(ProcessedDataProvider): + def __init__(self, data_dir: Path) -> None: + super().__init__() + + self._data_dir = data_dir + + def get_data( + self, + stock_id: str, + date: pd.Timestamp, + feature_dim: int, + time_index: pd.Index, + ) -> BaseIntradayProcessedData: + return load_pickled_intraday_processed_data( + data_dir=self._data_dir, + stock_id=stock_id, + date=date, + feature_dim=feature_dim, + time_index=time_index, + ) def load_orders( diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 331967e6c9..123db5af1e 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -4,7 +4,6 @@ from __future__ import annotations import math -from pathlib import Path from typing import Any, List, Optional, cast import numpy as np @@ -12,7 +11,7 @@ from gym import spaces from qlib.constant import EPS -from qlib.rl.data import pickle_styled +from qlib.rl.data.base import ProcessedDataProvider from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState from qlib.typehint import TypedDict @@ -25,6 +24,8 @@ "FullHistoryObs", ] +from qlib.utils import init_instance_by_config + def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: """To 32-bit numeric types. Recursively.""" @@ -64,26 +65,33 @@ class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): the total ticks is the length of day in minutes. data_dim Number of dimensions in data. - data_dir - Path to load data after feature engineering. It is optional since in some cases we do not need explicit - path to load data. For example, the data has already been preloaded in `init_qlib()`. + processed_data_provider + Provider of the processed data. """ # TODO: All implementations related to `data_dir` is coupled with the specific data format for that specific case. # TODO: So it should be redesigned after the data interface is well-designed. - def __init__(self, max_step: int, data_ticks: int, data_dim: int, data_dir: Path = None) -> None: - self.data_dir = data_dir + def __init__( + self, + max_step: int, + data_ticks: int, + data_dim: int, + processed_data_provider: dict | ProcessedDataProvider, + ) -> None: self.max_step = max_step self.data_ticks = data_ticks self.data_dim = data_dim + self.processed_data_provider: ProcessedDataProvider = init_instance_by_config( + processed_data_provider, + accept_types=ProcessedDataProvider, + ) def interpret(self, state: SAOEState) -> FullHistoryObs: - processed = pickle_styled.load_intraday_processed_data( - self.data_dir, - state.order.stock_id, - pd.Timestamp(state.order.start_time.date()), - self.data_dim, - state.ticks_index, + processed = self.processed_data_provider.get_data( + stock_id=state.order.stock_id, + date=pd.Timestamp(state.order.start_time.date()), + feature_dim=self.data_dim, + time_index=state.ticks_index, ) position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) diff --git a/qlib/rl/order_execution/state.py b/qlib/rl/order_execution/state.py index dfa125cafa..a46928ee89 100644 --- a/qlib/rl/order_execution/state.py +++ b/qlib/rl/order_execution/state.py @@ -16,8 +16,8 @@ from qlib.utils.time import get_day_min_idx_range if typing.TYPE_CHECKING: - from qlib.rl.data.exchange_wrapper import IntradayBacktestData - from qlib.rl.data.pickle_styled import BaseIntradayBacktestData + from qlib.rl.data.base import BaseIntradayBacktestData + from qlib.rl.data.native import IntradayBacktestData def _get_all_timestamps( diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 136266a5b1..d09009fc60 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -16,7 +16,7 @@ from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN -from qlib.rl.data.exchange_wrapper import load_qlib_backtest_data +from qlib.rl.data.native import load_qlib_backtest_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter from qlib.rl.utils.env_wrapper import BaseEnvWrapper diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index 9bf86e018f..22bd039096 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -16,6 +16,7 @@ from qlib.config import C from qlib.log import set_log_with_config from qlib.rl.data import pickle_styled +from qlib.rl.data.pickle_styled import PickleProcessedDataProvider from qlib.rl.order_execution import * from qlib.rl.trainer import backtest, train from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus @@ -41,9 +42,8 @@ def test_pickle_data_inspect(): data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) assert len(data) == 390 - data = pickle_styled.load_intraday_processed_data( - DATA_DIR / "processed", "AAL", "2013-12-11", 5, data.get_time_index() - ) + provider = PickleProcessedDataProvider(DATA_DIR / "processed") + data = provider.get_data("AAL", "2013-12-11", 5, data.get_time_index()) assert len(data.today) == len(data.yesterday) == 390 @@ -147,7 +147,7 @@ def test_interpreter(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) + interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) interpreter_step = CurrentStepStateInterpreter(13) interpreter_action = CategoricalActionInterpreter(20) interpreter_action_twap = TwapRelativeActionInterpreter() @@ -230,7 +230,7 @@ def test_network_sanity(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) + interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(13) wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) @@ -258,7 +258,7 @@ def test_twap_strategy(finite_env_type): orders = pickle_styled.load_orders(ORDER_DIR) assert len(orders) == 248 - state_interp = FullHistoryStateInterpreter(13, 390, 5, FEATURE_DATA_DIR) + state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) action_interp = TwapRelativeActionInterpreter() action_interp.env = CollectDataEnvWrapper() action_interp.env.reset() @@ -289,7 +289,7 @@ def test_cn_ppo_strategy(): orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) + state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(4) action_interp.env = CollectDataEnvWrapper() action_interp.env.reset() @@ -322,7 +322,7 @@ def test_ppo_train(): orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(8, 240, 6, CN_FEATURE_DATA_DIR) + state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) action_interp = CategoricalActionInterpreter(4) action_interp.env = CollectDataEnvWrapper() action_interp.env.reset() From 944ea3085934d811574a33467572065bc2d98565 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 16 Sep 2022 14:18:25 +0800 Subject: [PATCH 10/15] Remove convert_instance_config and change config --- qlib/rl/contrib/backtest.py | 3 +-- qlib/rl/contrib/naive_config_parser.py | 30 -------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 1185fd9bbc..993eb20b80 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -17,7 +17,7 @@ from qlib.backtest.decision import TradeRangeByTime from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor from qlib.backtest.high_performance_ds import BaseOrderIndicator -from qlib.rl.contrib.naive_config_parser import convert_instance_config, get_backtest_config_fromfile +from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile from qlib.rl.contrib.utils import read_order_file from qlib.rl.data.integration import init_qlib from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper @@ -28,7 +28,6 @@ def _get_multi_level_executor_config( cash_limit: float = None, generate_report: bool = False, ) -> dict: - strategy_config = cast(dict, convert_instance_config(strategy_config)) executor_config = { "class": "SimulatorExecutor", "module_path": "qlib.backtest.executor", diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 4add7ac071..0432c72dd8 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -101,33 +101,3 @@ def get_backtest_config_fromfile(path: str) -> dict: backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default) return backtest_config - - -def convert_instance_config(config: object) -> object: - if isinstance(config, dict): - if "type" in config: - type_name = config["type"] - if "." in type_name: - idx = type_name.rindex(".") - module_path, class_name = type_name[:idx], type_name[idx + 1 :] - else: - module_path, class_name = "", type_name - - kwargs = {} - for k, v in config.items(): - if k == "type": - continue - kwargs[k] = convert_instance_config(v) - return { - "class": class_name, - "module_path": module_path, - "kwargs": kwargs, - } - else: - return {k: convert_instance_config(v) for k, v in config.items()} - elif isinstance(config, list): - return [convert_instance_config(item) for item in config] - elif isinstance(config, tuple): - return tuple([convert_instance_config(item) for item in config]) - else: - return config From 2f7de4cf76929cd81d8ed6727e994c3860cc39c2 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 16 Sep 2022 14:25:01 +0800 Subject: [PATCH 11/15] Pylint issue --- qlib/rl/data/native.py | 1 - 1 file changed, 1 deletion(-) diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index ca97a54266..025dce3710 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -6,7 +6,6 @@ import cachetools import pandas as pd -from cachetools.keys import hashkey from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime From 289ca9e8d09e83c2c8fa8426405543ad98ca84c4 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 16 Sep 2022 14:33:06 +0800 Subject: [PATCH 12/15] Pylint issue --- qlib/rl/contrib/backtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 993eb20b80..709c050dfb 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -6,7 +6,7 @@ import pickle import sys from pathlib import Path -from typing import Optional, cast, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import pandas as pd From 0716f8588f44bd040728cfe64732042d7b23da92 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Fri, 16 Sep 2022 16:27:18 +0800 Subject: [PATCH 13/15] Fix tempfile warning --- qlib/rl/contrib/naive_config_parser.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 0432c72dd8..eaf62636cc 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -36,24 +36,24 @@ def parse_backtest_config(path: str) -> dict: raise IOError("Only py/yml/yaml/json type are supported now!") with tempfile.TemporaryDirectory() as tmp_config_dir: - tmp_config_file = tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) - if platform.system() == "Windows": - tmp_config_file.close() + with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file: + if platform.system() == "Windows": + tmp_config_file.close() - tmp_config_name = os.path.basename(tmp_config_file.name) - shutil.copyfile(abs_path, tmp_config_file.name) + tmp_config_name = os.path.basename(tmp_config_file.name) + shutil.copyfile(abs_path, tmp_config_file.name) - if abs_path.endswith(".py"): - tmp_module_name = os.path.splitext(tmp_config_name)[0] - sys.path.insert(0, tmp_config_dir) - module = import_module(tmp_module_name) - sys.path.pop(0) + if abs_path.endswith(".py"): + tmp_module_name = os.path.splitext(tmp_config_name)[0] + sys.path.insert(0, tmp_config_dir) + module = import_module(tmp_module_name) + sys.path.pop(0) - config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")} + config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")} - del sys.modules[tmp_module_name] - else: - config = yaml.safe_load(open(os.path.join(tmp_config_dir, tmp_config_file.name))) + del sys.modules[tmp_module_name] + else: + config = yaml.safe_load(open(tmp_config_file.name)) if "_base_" in config: base_file_name = config.pop("_base_") From 6a35bee05d3de17a4115fa94e8e90a1c71a1a924 Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Mon, 19 Sep 2022 10:55:56 +0800 Subject: [PATCH 14/15] Resolve PR comments --- qlib/rl/data/native.py | 2 +- qlib/rl/order_execution/interpreter.py | 16 ++++++++++++++++ qlib/rl/order_execution/strategy.py | 4 ++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index 025dce3710..eb612cf64e 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -77,7 +77,7 @@ def get_time_index(self) -> pd.DatetimeIndex: cache=cachetools.LRUCache(100), key=lambda order, _, __: order.key_by_day, ) -def load_qlib_backtest_data( +def load_backtest_data( order: Order, trade_exchange: Exchange, trade_range: TradeRange, diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 123db5af1e..0b89977491 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -87,6 +87,10 @@ def __init__( ) def interpret(self, state: SAOEState) -> FullHistoryObs: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + processed = self.processed_data_provider.get_data( stock_id=state.order.stock_id, date=pd.Timestamp(state.order.start_time.date()), @@ -173,6 +177,10 @@ def observation_space(self) -> spaces.Dict: return spaces.Dict(space) def interpret(self, state: SAOEState) -> CurrentStateObs: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + assert self.env is not None assert self.env.status["cur_step"] <= self.max_step obs = CurrentStateObs( @@ -210,6 +218,10 @@ def action_space(self) -> spaces.Discrete: return spaces.Discrete(len(self.action_values)) def interpret(self, state: SAOEState, action: int) -> float: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + assert 0 <= action < len(self.action_values) assert self.env is not None if self.max_step is not None and self.env.status["cur_step"] >= self.max_step - 1: @@ -232,6 +244,10 @@ def action_space(self) -> spaces.Box: return spaces.Box(0, np.inf, shape=(), dtype=np.float32) def interpret(self, state: SAOEState, action: float) -> float: + # TODO: This interpreter relies on EnvWrapper.status, so we have to give it a dummy EnvWrapper when running + # backtest. Currently, the dummy EnvWrapper is CollectDataEnvWrapper. We should find a more elegant + # way to decompose interpreter and EnvWrapper in the future. + assert self.env is not None estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step) twap_volume = state.position / (estimated_total_steps - self.env.status["cur_step"]) diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index d09009fc60..5ee9ad403a 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -16,7 +16,7 @@ from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWO, TradeRange from qlib.backtest.utils import LevelInfrastructure from qlib.constant import ONE_MIN -from qlib.rl.data.native import load_qlib_backtest_data +from qlib.rl.data.native import load_backtest_data from qlib.rl.interpreter import ActionInterpreter, StateInterpreter from qlib.rl.order_execution.state import SAOEState, SAOEStateAdapter from qlib.rl.utils.env_wrapper import BaseEnvWrapper @@ -47,7 +47,7 @@ def __init__( self._last_step_range = (0, 0) def _create_qlib_backtest_adapter(self, order: Order, trade_range: TradeRange) -> SAOEStateAdapter: - backtest_data = load_qlib_backtest_data(order, self.trade_exchange, trade_range) + backtest_data = load_backtest_data(order, self.trade_exchange, trade_range) return SAOEStateAdapter( order=order, From 47fe0c0231b9e40e0ceeac6c42186af5b7866ece Mon Sep 17 00:00:00 2001 From: Huoran Li Date: Mon, 19 Sep 2022 11:01:06 +0800 Subject: [PATCH 15/15] Add more comments --- qlib/rl/order_execution/strategy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 5ee9ad403a..ecc879bf51 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -221,6 +221,10 @@ def __init__( self._policy.eval() def set_env(self, env: BaseEnvWrapper) -> None: + # TODO: This method is used to set EnvWrapper for interpreters since they rely on EnvWrapper. + # We should decompose the interpreters with EnvWrapper in the future and we should remove this method + # after that. + self._env = env self._state_interpreter.env = self._action_interpreter.env = self._env