Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions qlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def init(default_conf="client", **kwargs):

if "flask_server" in C:
logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}")
logger.info("qlib successfully initialized based on %s settings." % default_conf)
logger.info(f"qlib successfully initialized based on {default_conf} settings.")
data_path = {_freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys()}
logger.info(f"data_path={data_path}")

Expand Down Expand Up @@ -119,10 +119,9 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
LOG.warning(f"{provider_uri} already mounted at {mount_path}")
elif e.returncode == 53:
raise OSError("Network path not found") from e
elif "error" in error_output.lower() or "错误" in error_output:
if "error" in error_output.lower() or "错误" in error_output:
raise OSError("Invalid mount path") from e
else:
raise OSError(f"Unknown mount error: {error_output.strip()}") from e
raise OSError(f"Unknown mount error: {error_output.strip()}") from e
else:
# system: linux/Unix/Mac
# check mount
Expand Down Expand Up @@ -177,10 +176,9 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
except subprocess.CalledProcessError as e:
if e.returncode == 256:
raise OSError("Mount failed: requires sudo or permission denied") from e
elif e.returncode == 32512:
if e.returncode == 32512:
raise OSError(f"mount {provider_uri} on {mount_path} error! Command error") from e
else:
raise OSError(f"Mount failed: {e.stderr}") from e
raise OSError(f"Mount failed: {e.stderr}") from e
else:
LOG.warning(f"{_remote_uri} on {_mount_path} is already mounted")

Expand Down Expand Up @@ -308,7 +306,7 @@ def auto_init(**kwargs):

# merge the arguments
qlib_conf_update = conf.get("qlib_cfg_update", {})
for k, v in kwargs.items():
for k, _ in kwargs.items():
if k in qlib_conf_update:
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
qlib_conf_update.update(kwargs)
Expand Down
19 changes: 13 additions & 6 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ def get_exchange(
start_time: Union[pd.Timestamp, str] = None,
end_time: Union[pd.Timestamp, str] = None,
codes: Union[list, str] = "all",
subscribe_fields: list = [],
subscribe_fields: list = None,
open_cost: float = 0.0015,
close_cost: float = 0.0025,
min_cost: float = 5.0,
limit_threshold: Union[Tuple[str, str], float, None] | None = None,
deal_price: Union[str, Tuple[str, str], List[str]] | None = None,
**kwargs: Any,
) -> Exchange:
if subscribe_fields is None:
subscribe_fields = []
"""get_exchange

Parameters
Expand Down Expand Up @@ -106,8 +108,7 @@ def get_exchange(
**kwargs,
)
return exchange
else:
return init_instance_by_config(exchange, accept_types=Exchange)
return init_instance_by_config(exchange, accept_types=Exchange)


def create_account_instance(
Expand Down Expand Up @@ -181,12 +182,14 @@ def get_strategy_executor(
executor: Union[str, dict, object, Path],
benchmark: Optional[str] = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: dict = None,
pos_type: str = "Position",
) -> Tuple[BaseStrategy, BaseExecutor]:
# NOTE:
# - for avoiding recursive import
# - typing annotations is not reliable
if exchange_kwargs is None:
exchange_kwargs = {}
from ..strategy.base import BaseStrategy # pylint: disable=C0415
from .executor import BaseExecutor # pylint: disable=C0415

Expand Down Expand Up @@ -221,9 +224,11 @@ def backtest(
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: dict = None,
pos_type: str = "Position",
) -> Tuple[PORT_METRIC, INDICATOR_METRIC]:
if exchange_kwargs is None:
exchange_kwargs = {}
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
executor in the nested decision execution

Expand Down Expand Up @@ -283,10 +288,12 @@ def collect_data(
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
exchange_kwargs: dict = None,
pos_type: str = "Position",
return_value: dict | None = None,
) -> Generator[object, None, None]:
if exchange_kwargs is None:
exchange_kwargs = {}
"""initialize the strategy and executor, then collect the trade decision data for rl training

please refer to the docs of the backtest for the explanation of the parameters
Expand Down
45 changes: 32 additions & 13 deletions qlib/backtest/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,16 @@ class Account:
def __init__(
self,
init_cash: float = 1e9,
position_dict: dict = {},
position_dict: dict = None,
freq: str = "day",
benchmark_config: dict = {},
benchmark_config: dict = None,
pos_type: str = "Position",
port_metr_enabled: bool = True,
) -> None:
if position_dict is None:
position_dict = {}
if benchmark_config is None:
benchmark_config = {}
"""the trade account of backtest.

Parameters
Expand Down Expand Up @@ -306,11 +310,19 @@ def update_indicator(
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
trade_info: list = None,
inner_order_indicators: List[BaseOrderIndicator] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
indicator_config: dict = None,
) -> None:
if trade_info is None:
trade_info = []
if inner_order_indicators is None:
inner_order_indicators = []
if decision_list is None:
decision_list = []
if indicator_config is None:
indicator_config = {}
"""update trade indicators and order indicators in each bar end"""
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`

Expand Down Expand Up @@ -342,11 +354,19 @@ def update_bar_end(
trade_exchange: Exchange,
atomic: bool,
outer_trade_decision: BaseTradeDecision,
trade_info: list = [],
inner_order_indicators: List[BaseOrderIndicator] = [],
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = [],
indicator_config: dict = {},
trade_info: list = None,
inner_order_indicators: List[BaseOrderIndicator] = None,
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
indicator_config: dict = None,
) -> None:
if trade_info is None:
trade_info = []
if inner_order_indicators is None:
inner_order_indicators = []
if decision_list is None:
decision_list = []
if indicator_config is None:
indicator_config = {}
"""update account at each trading bar step

Parameters
Expand Down Expand Up @@ -379,7 +399,7 @@ def update_bar_end(
"""
if atomic is True and trade_info is None:
raise ValueError("trade_info is necessary in atomic executor")
elif atomic is False and inner_order_indicators is None:
if atomic is False and inner_order_indicators is None:
raise ValueError("inner_order_indicators is necessary in un-atomic executor")

# update current position and hold bar count in each bar end
Expand Down Expand Up @@ -409,8 +429,7 @@ def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
_positions = self.get_hist_positions()
return _portfolio_metrics, _positions
else:
raise ValueError("generate_portfolio_metrics should be True if you want to generate portfolio_metrics")
raise ValueError("generate_portfolio_metrics should be True if you want to generate portfolio_metrics")

def get_trade_indicator(self) -> Indicator:
"""get the trade indicator instance, which has pa/pos/ffr info."""
Expand Down
3 changes: 2 additions & 1 deletion qlib/backtest/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def collect_data_loop(
indicator_dict: INDICATOR_METRIC = {}

for executor in all_executors:
key = "{}{}".format(*Freq.parse(executor.time_per_step))
_freq_parsed = Freq.parse(executor.time_per_step)
key = f"{_freq_parsed[0]}{_freq_parsed[1]}"
if executor.trade_account.is_port_metr_enabled():
portfolio_dict[key] = executor.trade_account.get_portfolio_metrics()

Expand Down
38 changes: 16 additions & 22 deletions qlib/backtest/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,21 @@ def sign(self) -> int:
def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> Union[OrderDir, np.ndarray]:
if isinstance(direction, OrderDir):
return direction
elif isinstance(direction, (int, float, np.integer, np.floating)):
if isinstance(direction, (int, float, np.integer, np.floating)):
return Order.BUY if direction > 0 else Order.SELL
elif isinstance(direction, str):
if isinstance(direction, str):
dl = direction.lower().strip()
if dl == "sell":
return OrderDir.SELL
elif dl == "buy":
if dl == "buy":
return OrderDir.BUY
else:
raise NotImplementedError(f"This type of input is not supported")
raise NotImplementedError(f"This type of input is not supported")
elif isinstance(direction, np.ndarray):
direction_array = direction.copy()
direction_array[direction_array > 0] = Order.BUY
direction_array[direction_array <= 0] = Order.SELL
return direction_array
else:
raise NotImplementedError(f"This type of input is not supported")
raise NotImplementedError(f"This type of input is not supported")

@property
def key_by_day(self) -> tuple:
Expand Down Expand Up @@ -385,8 +383,7 @@ def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDeci
def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
if self.trade_range is not None:
return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")))
else:
raise NotImplementedError("The decision didn't provide an index range")
raise NotImplementedError("The decision didn't provide an index range")

def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
"""
Expand Down Expand Up @@ -432,9 +429,8 @@ def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]:
except NotImplementedError as e:
if "default_value" in kwargs:
return kwargs["default_value"]
else:
# Default to get full index
raise NotImplementedError(f"The decision didn't provide an index range") from e
# Default to get full index
raise NotImplementedError(f"The decision didn't provide an index range") from e

# clip index
if getattr(self, "total_step", None) is not None:
Expand Down Expand Up @@ -493,17 +489,15 @@ def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = Fals
if self.trade_range is None:
if raise_error:
raise NotImplementedError(f"There is no trade_range in this case")
else:
return 0, day_end_idx - day_start_idx
return 0, day_end_idx - day_start_idx
if rtype == "full":
val_start, val_end = self.trade_range.clip_time_range(day_start, day_end)
elif rtype == "step":
val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time)
else:
if rtype == "full":
val_start, val_end = self.trade_range.clip_time_range(day_start, day_end)
elif rtype == "step":
val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time)
else:
raise ValueError(f"This type of input {rtype} is not supported")
_, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq)
return start_idx - day_start_idx, end_index - day_start_idx
raise ValueError(f"This type of input {rtype} is not supported")
_, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq)
return start_idx - day_start_idx, end_index - day_start_idx

def empty(self) -> bool:
for obj in self.get_decision():
Expand Down
Loading