From 24048e2b9896d94f38804418b94e6c10cbc1cad4 Mon Sep 17 00:00:00 2001 From: Luca Date: Sun, 22 Mar 2026 16:00:18 +0000 Subject: [PATCH 1/3] AI --- .github/copilot-instructions.md | 5 + dev/quantflow.dockerfile | 2 +- docs/api/options/vol_surface.md | 12 + docs/index.md | 74 ++++- pyproject.toml | 22 +- quantflow/__init__.py | 2 +- quantflow/ai/__init__.py | 1 + quantflow/ai/server.py | 27 ++ quantflow/{cli => ai/tools}/__init__.py | 0 quantflow/ai/tools/base.py | 44 +++ quantflow/ai/tools/charts.py | 40 +++ quantflow/ai/tools/crypto.py | 90 ++++++ quantflow/ai/tools/fred.py | 82 ++++++ quantflow/ai/tools/stocks.py | 104 +++++++ quantflow/ai/tools/vault.py | 35 +++ quantflow/cli/app.py | 100 ------- quantflow/cli/commands/__init__.py | 18 -- quantflow/cli/commands/base.py | 134 --------- quantflow/cli/commands/crypto.py | 149 ---------- quantflow/cli/commands/fred.py | 118 -------- quantflow/cli/commands/stocks.py | 134 --------- quantflow/cli/commands/vault.py | 52 ---- quantflow/cli/script.py | 14 - quantflow/cli/settings.py | 12 - quantflow/data/deribit.py | 13 +- quantflow/options/inputs.py | 87 ++++-- quantflow/options/surface.py | 364 ++++++++++++++---------- readme.md | 67 ++++- taplo.toml | 1 + uv.lock | 92 +----- 30 files changed, 879 insertions(+), 1016 deletions(-) create mode 100644 quantflow/ai/__init__.py create mode 100644 quantflow/ai/server.py rename quantflow/{cli => ai/tools}/__init__.py (100%) create mode 100644 quantflow/ai/tools/base.py create mode 100644 quantflow/ai/tools/charts.py create mode 100644 quantflow/ai/tools/crypto.py create mode 100644 quantflow/ai/tools/fred.py create mode 100644 quantflow/ai/tools/stocks.py create mode 100644 quantflow/ai/tools/vault.py delete mode 100644 quantflow/cli/app.py delete mode 100644 quantflow/cli/commands/__init__.py delete mode 100644 quantflow/cli/commands/base.py delete mode 100644 quantflow/cli/commands/crypto.py delete mode 100644 quantflow/cli/commands/fred.py delete mode 100644 quantflow/cli/commands/stocks.py delete mode 100644 quantflow/cli/commands/vault.py delete mode 100644 quantflow/cli/script.py delete mode 100644 quantflow/cli/settings.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 7ad16aab..0df6a55e 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -8,6 +8,11 @@ applyTo: '/**' # Quantflow Instructions +## Development + +* Always run `make lint` after code changes — runs taplo, isort, black, ruff, and mypy +* Never edit `readme.md` directly — it is generated from `docs/index.md` via `make docs` + ## Docker * The Dockerfile is at `dev/quantflow.dockerfile` diff --git a/dev/quantflow.dockerfile b/dev/quantflow.dockerfile index 9ca19034..f4bcd0af 100644 --- a/dev/quantflow.dockerfile +++ b/dev/quantflow.dockerfile @@ -8,7 +8,7 @@ WORKDIR /build COPY pyproject.toml uv.lock readme.md ./ # Install dependencies (no root package, with needed extras) -RUN uv sync --frozen --no-install-project --extra book --extra docs --extra data +RUN uv sync --frozen --no-install-project --extra ai --extra book --extra docs --extra data # Copy source and build docs COPY mkdocs.yml ./ diff --git a/docs/api/options/vol_surface.md b/docs/api/options/vol_surface.md index b291183c..32bbe1bb 100644 --- a/docs/api/options/vol_surface.md +++ b/docs/api/options/vol_surface.md @@ -9,4 +9,16 @@ ::: quantflow.options.surface.VolSurfaceLoader +::: quantflow.options.surface.OptionPrice + ::: quantflow.options.surface.OptionSelection + +::: quantflow.options.inputs.VolSurfaceInputs + +::: quantflow.options.inputs.VolSurfaceInput + +::: quantflow.options.inputs.SpotInput + +::: quantflow.options.inputs.ForwardInput + +::: quantflow.options.inputs.OptionInput diff --git a/docs/index.md b/docs/index.md index 21715530..ce864123 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,7 +18,7 @@ pip install quantflow ## Modules -* [quantflow.cli](https://github.com/quantmind/quantflow/tree/main/quantflow/cli) command line client (requires `quantflow[cli,data]`) +* [quantflow.ai](https://github.com/quantmind/quantflow/tree/main/quantflow/ai) MCP server for AI clients (requires `quantflow[ai,data]`) * [quantflow.data](https://github.com/quantmind/quantflow/tree/main/quantflow/data) data APIs (requires `quantflow[data]`) * [quantflow.options](https://github.com/quantmind/quantflow/tree/main/quantflow/options) option pricing and calibration * [quantflow.sp](https://github.com/quantmind/quantflow/tree/main/quantflow/sp) stochastic process primitives @@ -27,23 +27,69 @@ pip install quantflow ## Optional dependencies -Quantflow comes with two optional dependencies: +* `data` — data retrieval: `pip install quantflow[data]` +* `ai` — MCP server for AI clients: `pip install quantflow[ai,data]` -* `data` for data retrieval, to install it use - ``` - pip install quantflow[data] - ``` -* `cli` for command line interface, to install it use - ``` - pip install quantflow[data,cli] - ``` +## MCP Server -## Command line tools +Quantflow exposes its data tools as an [MCP](https://modelcontextprotocol.io) server, allowing AI clients such as Claude to query market data, crypto volatility surfaces, and economic indicators directly. -The command line tools are available when installing with the extra `cli` and `data` dependencies. +Install with the `ai` and `data` extras: ```bash -pip install quantflow[cli,data] +pip install quantflow[ai,data] ``` -It is possible to use the command line tool `qf` to download data and run pricing and calibration scripts. +### API keys + +Store your API keys in `~/.quantflow/.vault`: + +``` +fmp=your-fmp-key +fred=your-fred-key +``` + +Or let the AI manage them for you via the `vault_add` tool once connected. + +### Claude Code + +```bash +claude mcp add quantflow -- uv run qf-mcp +``` + +### Claude Desktop + +Add to your Claude Desktop config (`~/Library/Application Support/Claude/claude_desktop_config.json` on macOS): + +```json +{ + "mcpServers": { + "quantflow": { + "command": "uv", + "args": ["run", "qf-mcp"] + } + } +} +``` + +### Available tools + +| Tool | Description | +|---|---| +| `vault_keys` | List stored API keys | +| `vault_add` | Add or update an API key | +| `vault_delete` | Delete an API key | +| `stock_indices` | List stock market indices | +| `stock_search` | Search companies by name or symbol | +| `stock_profile` | Get company profile | +| `stock_prices` | Get OHLC price history | +| `sector_performance` | Sector performance and PE ratios | +| `crypto_instruments` | List Deribit instruments | +| `crypto_historical_volatility` | Historical volatility from Deribit | +| `crypto_term_structure` | Volatility term structure | +| `crypto_implied_volatility` | Implied volatility surface | +| `crypto_prices` | Crypto OHLC price history | +| `ascii_chart` | ASCII chart for any stock or crypto symbol | +| `fred_subcategories` | Browse FRED categories | +| `fred_series` | List series in a FRED category | +| `fred_data` | Fetch FRED observations | diff --git a/pyproject.toml b/pyproject.toml index 5ff8c826..20fd3b4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "quantflow" -version = "0.4.4" +version = "0.5.0" description = "quantitative analysis" authors = [ { name = "Luca Sbardella", email = "luca@quantmind.com" } ] license = "BSD-3-Clause" @@ -21,27 +21,22 @@ Repository = "https://github.com/quantmind/quantflow" Documentation = "https://quantmind.github.io/quantflow/" [project.optional-dependencies] +ai = [ + "google-genai>=1.61.0", + "mcp>=1.26.0", + "openai>=2.16.0", + "pydantic-ai-slim>=1.51.0", + "rich>=13.9.4", +] book = [ "altair>=6.0.0", "autodocsumm>=0.2.14", "duckdb>=1.4.4", "fastapi>=0.129.0", - "google-genai>=1.61.0", "marimo>=0.19.7", - "mcp>=1.26.0", - "openai>=2.16.0", "plotly>=6.2.0", - "pydantic-ai-slim>=1.51.0", "sympy>=1.12", ] -cli = [ - "asciichartpy>=1.5.25", - "async-cache>=1.1.1", - "click>=8.1.7", - "holidays>=0.63", - "prompt-toolkit>=3.0.43", - "rich>=13.9.4", -] data = [ "aio-fluid[http]>=1.2.1" ] dev = [ "black>=26.3.1", @@ -67,6 +62,7 @@ ml = [ [project.scripts] qf = "quantflow.cli.script:main" +qf-mcp = "quantflow.ai.server:main" [build-system] requires = [ "hatchling" ] diff --git a/quantflow/__init__.py b/quantflow/__init__.py index c12a161e..f1b1d162 100644 --- a/quantflow/__init__.py +++ b/quantflow/__init__.py @@ -1,3 +1,3 @@ """Quantitative analysis and pricing""" -__version__ = "0.4.4" +__version__ = "0.5.0" diff --git a/quantflow/ai/__init__.py b/quantflow/ai/__init__.py new file mode 100644 index 00000000..b633b2b2 --- /dev/null +++ b/quantflow/ai/__init__.py @@ -0,0 +1 @@ +"""AI module for quantflow - MCP server exposing quantflow data tools.""" diff --git a/quantflow/ai/server.py b/quantflow/ai/server.py new file mode 100644 index 00000000..c985392f --- /dev/null +++ b/quantflow/ai/server.py @@ -0,0 +1,27 @@ +"""Quantflow MCP server.""" + +from mcp.server.fastmcp import FastMCP + +from quantflow.ai.tools import charts, crypto, fred, stocks, vault + +from .tools.base import McpTool + + +def create_server() -> FastMCP: + mcp = FastMCP("quantflow") + tool = McpTool() + vault.register(mcp, tool) + crypto.register(mcp, tool) + stocks.register(mcp, tool) + fred.register(mcp, tool) + charts.register(mcp, tool) + return mcp + + +def main() -> None: + server = create_server() + server.run() + + +if __name__ == "__main__": + main() diff --git a/quantflow/cli/__init__.py b/quantflow/ai/tools/__init__.py similarity index 100% rename from quantflow/cli/__init__.py rename to quantflow/ai/tools/__init__.py diff --git a/quantflow/ai/tools/base.py b/quantflow/ai/tools/base.py new file mode 100644 index 00000000..ce53ed5f --- /dev/null +++ b/quantflow/ai/tools/base.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass, field +from io import StringIO +from pathlib import Path + +import pandas as pd +from ccy.cli.console import df_to_rich +from mcp.server.fastmcp.exceptions import ToolError +from rich.console import Console + +from quantflow.data.fmp import FMP +from quantflow.data.fred import Fred +from quantflow.data.vault import Vault + +VAULT_PATH = Path.home() / ".quantflow" / ".vault" + + +@dataclass +class McpTool: + vault: Vault = field(default_factory=lambda: Vault(VAULT_PATH)) + + def fmp(self) -> FMP: + key = self.vault.get("fmp") + if not key: + raise ToolError( + "FMP API key not found in vault. " + " Please add it using the vault_add tool." + ) + return FMP(key=key) + + def fred(self) -> Fred: + key = self.vault.get("fred") + if not key: + raise ToolError( + "FRED API key not found in vault. " + " Please add it using the vault_add tool." + ) + return Fred(key=key) + + def rich(self, df: pd.DataFrame) -> str: + table = df_to_rich(df) + buf = StringIO() + console = Console(file=buf, no_color=True) + console.print(table) + return buf.getvalue() diff --git a/quantflow/ai/tools/charts.py b/quantflow/ai/tools/charts.py new file mode 100644 index 00000000..b475228c --- /dev/null +++ b/quantflow/ai/tools/charts.py @@ -0,0 +1,40 @@ +"""Chart tools for the quantflow MCP server.""" + +from mcp.server.fastmcp import FastMCP + +from .base import McpTool + + +def register(mcp: FastMCP, tool: McpTool) -> None: + + @mcp.tool() + async def ascii_chart(symbol: str, frequency: str = "", height: int = 20) -> str: + """Plot an ASCII candlestick chart for a stock or cryptocurrency. + + Args: + symbol: Ticker symbol e.g. AAPL, BTCUSD, ETHUSD + frequency: Data frequency - 1min, 5min, 15min, 30min, 1hour, 4hour, + or empty for daily + height: Chart height in terminal rows (default: 20) + """ + import asciichartpy as ac + + async with tool.fmp() as client: + df = await client.prices(symbol, frequency=frequency) + if df.empty: + return f"No price data for {symbol}" + + df = df.sort_values("date").tail(50) + prices = df["close"].tolist() + first_date = df["date"].iloc[0] + last_date = df["date"].iloc[-1] + low = min(prices) + high = max(prices) + last = prices[-1] + + chart = ac.plot(prices, {"height": height, "format": "{:8,.0f}"}) + return ( + f"{symbol} Close Price ({first_date} → {last_date})\n" + f"High: {high:,.2f} Low: {low:,.2f} Last: {last:,.2f}\n\n" + f"{chart}" + ) diff --git a/quantflow/ai/tools/crypto.py b/quantflow/ai/tools/crypto.py new file mode 100644 index 00000000..2efea5ce --- /dev/null +++ b/quantflow/ai/tools/crypto.py @@ -0,0 +1,90 @@ +"""Crypto tools for the quantflow MCP server.""" + +from mcp.server.fastmcp import FastMCP + +from quantflow.data.deribit import Deribit, InstrumentKind + +from .base import McpTool + + +def register(mcp: FastMCP, tool: McpTool) -> None: + + @mcp.tool() + async def crypto_instruments(currency: str, kind: str = "spot") -> str: + """List available instruments for a cryptocurrency on Deribit. + + Args: + currency: Cryptocurrency symbol e.g. BTC, ETH + kind: Instrument kind - spot, future, option (default: spot) + """ + async with Deribit() as client: + data = await client.get_instruments( + currency=currency, kind=InstrumentKind(kind) + ) + if not data: + return f"No instruments found for {currency} ({kind})" + rows = "\n".join(str(d) for d in data[:20]) + return f"Instruments for {currency} ({kind}):\n{rows}" + + @mcp.tool() + async def crypto_historical_volatility(currency: str) -> str: + """Get historical volatility for a cryptocurrency from Deribit. + + Args: + currency: Cryptocurrency symbol e.g. BTC, ETH + """ + async with Deribit() as client: + df = await client.get_volatility(currency) + if df.empty: + return f"No volatility data for {currency}" + return f"Historical volatility for {currency}:\n{df.to_string(index=False)}" + + @mcp.tool() + async def crypto_term_structure(currency: str) -> str: + """Get the volatility term structure for a cryptocurrency from Deribit. + + Args: + currency: Cryptocurrency symbol e.g. BTC, ETH + """ + from quantflow.options.surface import VolSurface + + async with Deribit() as client: + loader = await client.volatility_surface_loader(currency) + vs: VolSurface = loader.surface() + ts = vs.term_structure().round({"ttm": 4}) + return f"Term structure for {currency}:\n{ts.to_string(index=False)}" + + @mcp.tool() + async def crypto_implied_volatility(currency: str, maturity_index: int = -1) -> str: + """Get the implied volatility surface for a cryptocurrency from Deribit. + + Args: + currency: Cryptocurrency symbol e.g. BTC, ETH + maturity_index: Maturity index (-1 for all maturities) + """ + from quantflow.options.surface import VolSurface + + async with Deribit() as client: + loader = await client.volatility_surface_loader(currency) + vs: VolSurface = loader.surface() + index = None if maturity_index < 0 else maturity_index + vs.bs(index=index) + df = vs.options_df(index=index) + df["implied_vol"] = df["implied_vol"].map("{:.2%}".format) + return f"Implied volatility for {currency}:\n{df.to_string(index=False)}" + + @mcp.tool() + async def crypto_prices(symbol: str, frequency: str = "") -> str: + """Get OHLC price history for a cryptocurrency via FMP. + + Args: + symbol: Cryptocurrency symbol e.g. BTCUSD + frequency: Data frequency - 1min, 5min, 15min, 30min, 1hour, 4hour, + or empty for daily + """ + async with tool.fmp() as client: + df = await client.prices(symbol, frequency=frequency) + if df.empty: + return f"No price data for {symbol}" + df = df[["date", "open", "high", "low", "close", "volume"]].sort_values("date") + return f"Prices for {symbol}:\n{df.tail(50).to_string(index=False)}" diff --git a/quantflow/ai/tools/fred.py b/quantflow/ai/tools/fred.py new file mode 100644 index 00000000..beb19cd5 --- /dev/null +++ b/quantflow/ai/tools/fred.py @@ -0,0 +1,82 @@ +"""FRED tools for the quantflow MCP server.""" + +from mcp.server.fastmcp import FastMCP + +from .base import McpTool + + +def register(mcp: FastMCP, tool: McpTool) -> None: + + @mcp.tool() + async def fred_subcategories(category_id: str | None = None) -> str: + """List FRED categories. Omit category_id to get top-level categories. + + Args: + category_id: FRED category ID (optional, defaults to root) + """ + from fluid.utils.data import compact_dict + + async with tool.fred() as client: + data = await client.subcategories( + params=compact_dict(category_id=category_id) + ) + cats = data.get("categories", []) + if not cats: + return "No categories found" + import pandas as pd + + df = pd.DataFrame(cats, columns=["id", "name"]) + return f"FRED categories:\n{df.to_string(index=False)}" + + @mcp.tool() + async def fred_series(category_id: str) -> str: + """List data series available in a FRED category. + + Args: + category_id: FRED category ID + """ + from fluid.utils.data import compact_dict + + async with tool.fred() as client: + data = await client.series(params=compact_dict(category_id=category_id)) + series = data.get("seriess", []) + if not series: + return f"No series found for category {category_id}" + import pandas as pd + + df = pd.DataFrame( + series, + columns=[ + "id", + "popularity", + "title", + "frequency", + "observation_start", + "observation_end", + ], + ).sort_values("popularity", ascending=False) + return f"FRED series for category {category_id}:\n{df.to_string(index=False)}" + + @mcp.tool() + async def fred_data( + series_id: str, + length: int = 100, + frequency: str = "d", + ) -> str: + """Fetch observations for a FRED data series. + + Args: + series_id: FRED series ID e.g. GDP, UNRATE, DGS10 + length: Number of data points to return (default: 100) + frequency: Frequency - d, w, bw, m, q, sa, a (default: d for daily) + """ + async with tool.fred() as client: + df = await client.serie_data( + params=dict( + series_id=series_id, + limit=length, + frequency=frequency, + sort_order="desc", + ) + ) + return f"FRED data for {series_id}:\n{df.to_string(index=False)}" diff --git a/quantflow/ai/tools/stocks.py b/quantflow/ai/tools/stocks.py new file mode 100644 index 00000000..2b033220 --- /dev/null +++ b/quantflow/ai/tools/stocks.py @@ -0,0 +1,104 @@ +"""Stocks tools for the quantflow MCP server.""" + +from datetime import timedelta + +import pandas as pd +from mcp.server.fastmcp import FastMCP + +from quantflow.utils.dates import utcnow + +from .base import McpTool + + +def register(mcp: FastMCP, tool: McpTool) -> None: + + @mcp.tool() + async def stock_indices() -> str: + """List available stock market indices.""" + async with tool.fmp() as client: + data = await client.indices() + return tool.rich(pd.DataFrame(data)) + + @mcp.tool() + async def stock_search(query: str) -> str: + """Search for stocks by company name or symbol. + + Args: + query: Company name or ticker symbol to search for + """ + async with tool.fmp() as client: + data = await client.search(query) + + df = pd.DataFrame(data, columns=["symbol", "name", "currency", "stockExchange"]) + return f"Search results for '{query}':\n{df.to_string(index=False)}" + + @mcp.tool() + async def stock_profile(symbol: str) -> str: + """Get company profile for a stock symbol. + + Args: + symbol: Stock ticker symbol e.g. AAPL, MSFT + """ + async with tool.fmp() as client: + data = await client.profile(symbol) + if not data: + return f"No profile found for {symbol}" + d = dict(data[0]) + description = d.pop("description", "") or "" + lines = "\n".join(f"{k}: {v}" for k, v in d.items()) + return f"{description}\n\n{lines}".strip() + + @mcp.tool() + async def stock_prices(symbol: str, frequency: str = "") -> str: + """Get OHLC price history for a stock. + + Args: + symbol: Stock ticker symbol e.g. AAPL, MSFT + frequency: Data frequency - 1min, 5min, 15min, 30min, 1hour, 4hour, + or empty for daily + """ + async with tool.fmp() as client: + df = await client.prices(symbol, frequency=frequency) + if df.empty: + return f"No price data for {symbol}" + df = df[["date", "open", "high", "low", "close", "volume"]].sort_values("date") + return f"Prices for {symbol}:\n{df.tail(50).to_string(index=False)}" + + @mcp.tool() + async def sector_performance(period: str = "1d") -> str: + """Get sector performance and PE ratios. + + Args: + period: Time period - 1d, 1w, 1m, 3m, 6m, 1y (default: 1d) + """ + from ccy import period as to_period + from ccy.tradingcentres import prevbizday + from fluid.utils.data import compact_dict + + async with tool.fmp() as client: + to_date = utcnow().date() + if period != "1d": + from_date = to_date - timedelta(days=to_period(period).totaldays) + sp = await client.sector_performance( + from_date=prevbizday(from_date, 0), + to_date=prevbizday(to_date, 0), + summary=True, + ) + else: + sp = await client.sector_performance() + pe = await client.sector_pe( + params=compact_dict(date=prevbizday(to_date, 0).isoformat()) + ) + + from typing import cast + + import pandas as pd + + spd = cast(dict, sp) + pes = {k["sector"]: round(float(k["pe"]), 3) for k in pe if k["sector"] in spd} + rows = [ + {"sector": k, "performance": float(v), "pe": pes.get(k, float("nan"))} + for k, v in spd.items() + ] + df = pd.DataFrame(rows).sort_values("performance", ascending=False) + return f"Sector performance ({period}):\n{df.to_string(index=False)}" diff --git a/quantflow/ai/tools/vault.py b/quantflow/ai/tools/vault.py new file mode 100644 index 00000000..bf359a1c --- /dev/null +++ b/quantflow/ai/tools/vault.py @@ -0,0 +1,35 @@ +"""Vault tools for the quantflow MCP server.""" + +from mcp.server.fastmcp import FastMCP + +from .base import McpTool + + +def register(mcp: FastMCP, tool: McpTool) -> None: + + @mcp.tool() + def vault_keys() -> list[str]: + """List all API keys stored in the vault.""" + return tool.vault.keys() + + @mcp.tool() + def vault_add(key: str, value: str) -> str: + """Add or update an API key in the vault. + + Args: + key: Key name e.g. fmp, fred + value: API key value + """ + tool.vault.add(key, value) + return f"Key '{key}' saved to vault" + + @mcp.tool() + def vault_delete(key: str) -> str: + """Delete an API key from the vault. + + Args: + key: Key name to delete + """ + if tool.vault.delete(key): + return f"Key '{key}' deleted from vault" + return f"Key '{key}' not found in vault" diff --git a/quantflow/cli/app.py b/quantflow/cli/app.py deleted file mode 100644 index 73d541a0..00000000 --- a/quantflow/cli/app.py +++ /dev/null @@ -1,100 +0,0 @@ -import os -from dataclasses import dataclass, field -from functools import partial -from typing import Any - -import click -from fluid.utils.http_client import HttpResponseError -from prompt_toolkit import PromptSession -from prompt_toolkit.completion import NestedCompleter -from prompt_toolkit.formatted_text import HTML -from prompt_toolkit.history import FileHistory -from rich.console import Console -from rich.text import Text - -from quantflow.data.vault import Vault - -from . import settings -from .commands import quantflow -from .commands.base import QuantGroup - - -@dataclass -class QfApp: - console: Console = field(default_factory=Console) - vault: Vault = field(default_factory=partial(Vault, settings.VAULT_FILE_PATH)) - sections: list[QuantGroup] = field(default_factory=lambda: [quantflow]) - - def __call__(self) -> None: - os.makedirs(settings.SETTINGS_DIRECTORY, exist_ok=True) - history = FileHistory(str(settings.HIST_FILE_PATH)) - session: PromptSession = PromptSession(history=history) - - self.print("Welcome to QuantFlow!", style="bold green") - self.handle_command("help") - - try: - while True: - try: - text = session.prompt( - self.prompt_message(), - completer=self.prompt_completer(), - complete_while_typing=True, - bottom_toolbar=self.bottom_toolbar, - ) - except KeyboardInterrupt: - break - else: - self.handle_command(text) - except click.Abort: - self.console.print(Text("Bye!", style="bold magenta")) - - def prompt_message(self) -> str: - name = ":".join([str(section.name) for section in self.sections]) - return f"{name} > " - - def prompt_completer(self) -> NestedCompleter: - return NestedCompleter.from_nested_dict( - {command: None for command in self.sections[-1].commands} - ) - - def set_section(self, section: QuantGroup) -> None: - self.sections.append(section) - - def back(self) -> None: - self.sections.pop() - - def print(self, text_alike: Any, style: str = "") -> None: - if isinstance(text_alike, str): - style = style or "cyan" - text_alike = Text(f"\n{text_alike}\n", style="cyan") - self.console.print(text_alike) - - def error(self, err: str | Exception) -> None: - self.console.print(Text(f"\n{err}\n", style="bold red")) - - def handle_command(self, text: str) -> None: - if not text: - return - command = self.sections[-1] - try: - command.main(text.split(), standalone_mode=False, obj=self) - except ( - click.exceptions.MissingParameter, - click.exceptions.NoSuchOption, - click.exceptions.UsageError, - HttpResponseError, - ) as e: - self.error(e) - - def bottom_toolbar(self) -> HTML: - sections = "/".join([str(section.name) for section in self.sections]) - back = ( - (' ' "to exit the current section,") - if len(self.sections) > 1 - else "" - ) - return HTML( - f"Your are in {sections}, type{back} " - ' to exit' - ) diff --git a/quantflow/cli/commands/__init__.py b/quantflow/cli/commands/__init__.py deleted file mode 100644 index 30c0a55a..00000000 --- a/quantflow/cli/commands/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from .base import QuantContext, quant_group -from .crypto import crypto -from .fred import fred -from .stocks import stocks -from .vault import vault - - -@quant_group() -def quantflow() -> None: - ctx = QuantContext.current() - if ctx.invoked_subcommand is None: - ctx.qf.print(ctx.get_help()) - - -quantflow.add_command(vault) -quantflow.add_command(crypto) -quantflow.add_command(stocks) -quantflow.add_command(fred) diff --git a/quantflow/cli/commands/base.py b/quantflow/cli/commands/base.py deleted file mode 100644 index 17d09f8e..00000000 --- a/quantflow/cli/commands/base.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -import enum -from typing import TYPE_CHECKING, Any, Self, cast - -import click - -from quantflow.data.fmp import FMP -from quantflow.data.fred import Fred - -if TYPE_CHECKING: - from quantflow.cli.app import QfApp - - -FREQUENCIES = tuple(FMP().historical_frequencies()) - - -class HistoricalPeriod(enum.StrEnum): - day = "1d" - week = "1w" - month = "1m" - three_months = "3m" - six_months = "6m" - year = "1y" - - -class QuantContext(click.Context): - - @classmethod - def current(cls) -> Self: - return cast(Self, click.get_current_context()) - - @property - def qf(self) -> QfApp: - return self.obj # type: ignore - - def set_as_section(self) -> None: - group = cast(QuantGroup, self.command) - group.add_command(back) - self.qf.set_section(group) - self.qf.print(self.get_help()) - - def fmp(self) -> FMP: - if key := self.qf.vault.get("fmp"): - return FMP(key=key) - else: - raise click.UsageError("No FMP API key found") - - def fred(self) -> Fred: - if key := self.qf.vault.get("fred"): - return Fred(key=key) - else: - raise click.UsageError("No FRED API key found") - - -class QuantCommand(click.Command): - context_class = QuantContext - - -class QuantGroup(click.Group): - context_class = QuantContext - command_class = QuantCommand - - -@click.command(cls=QuantCommand) -def exit() -> None: - """Exit the program""" - raise click.Abort() - - -@click.command(cls=QuantCommand) -def help() -> None: - """display the commands""" - if ctx := QuantContext.current().parent: - cast(QuantContext, ctx).qf.print(ctx.get_help()) - - -@click.command(cls=QuantCommand) -def back() -> None: - """Exit the current section""" - ctx = QuantContext.current() - ctx.qf.back() - ctx.qf.handle_command("help") - - -def quant_group() -> Any: - return click.group( - cls=QuantGroup, - commands=[exit, help], - invoke_without_command=True, - add_help_option=False, - ) - - -class options: - length = click.option( - "-l", - "--length", - type=int, - default=100, - show_default=True, - help="Number of data points", - ) - height = click.option( - "-h", - "--height", - type=int, - default=20, - show_default=True, - help="Chart height", - ) - chart = click.option("-c", "--chart", is_flag=True, help="Display chart") - period = click.option( - "-p", - "--period", - type=click.Choice(tuple(p.value for p in HistoricalPeriod)), - default="1d", - show_default=True, - help="Historical period", - ) - index = click.option( - "-i", - "--index", - type=int, - default=-1, - help="maturity index", - ) - frequency = click.option( - "-f", - "--frequency", - type=click.Choice(FREQUENCIES), - default="", - help="Frequency of data - if not provided it is daily", - ) diff --git a/quantflow/cli/commands/crypto.py b/quantflow/cli/commands/crypto.py deleted file mode 100644 index 33e35dee..00000000 --- a/quantflow/cli/commands/crypto.py +++ /dev/null @@ -1,149 +0,0 @@ -from __future__ import annotations - -import asyncio - -import click -import pandas as pd -from asciichartpy import plot -from cache import AsyncTTL -from ccy.cli.console import df_to_rich - -from quantflow.data.deribit import Deribit, InstrumentKind -from quantflow.options.surface import VolSurface -from quantflow.utils.numbers import round_to_step - -from .base import QuantContext, options, quant_group -from .stocks import get_prices - - -@quant_group() -def crypto() -> None: - """Crypto currencies commands""" - ctx = QuantContext.current() - if ctx.invoked_subcommand is None: - ctx.set_as_section() - - -@crypto.command() -@click.argument("currency") -@click.option( - "-k", - "--kind", - type=click.Choice(list(InstrumentKind)), - default=InstrumentKind.spot.value, -) -def instruments(currency: str, kind: str) -> None: - """Provides information about instruments - - Instruments for given cryptocurrency from Deribit API""" - ctx = QuantContext.current() - data = asyncio.run(get_instruments(ctx, currency, kind)) - df = pd.DataFrame(data) - ctx.qf.print(df_to_rich(df)) - - -@crypto.command() -@click.argument("currency") -@options.length -@options.height -@options.chart -def volatility(currency: str, length: int, height: int, chart: bool) -> None: - """Provides information about historical volatility - - Historical volatility for given cryptocurrency from Deribit API - """ - ctx = QuantContext.current() - df = asyncio.run(get_volatility(ctx, currency)) - df["volatility"] = df["volatility"].map(lambda p: round_to_step(p, "0.01")) - if chart: - data = df["volatility"].tolist()[:length] - ctx.qf.print(plot(data, {"height": height})) - else: - ctx.qf.print(df_to_rich(df)) - - -@crypto.command() -@click.argument("currency") -def term_structure(currency: str) -> None: - """Provides information about the term structure for given cryptocurrency""" - ctx = QuantContext.current() - vs = asyncio.run(get_vol_surface(currency)) - ts = vs.term_structure().round({"ttm": 4}) - ts["open_interest"] = ts["open_interest"].map("{:,d}".format) - ts["volume"] = ts["volume"].map("{:,d}".format) - ctx.qf.print(df_to_rich(ts)) - - -@crypto.command() -@click.argument("currency") -@options.index -@options.height -@options.chart -def implied_vol(currency: str, index: int, height: int, chart: bool) -> None: - """Display the Volatility Surface for given cryptocurrency - at a given maturity index - """ - ctx = QuantContext.current() - vs = asyncio.run(get_vol_surface(currency)) - index_or_none = None if index < 0 else index - vs.bs(index=index_or_none) - df = vs.options_df(index=index_or_none) - if chart: - data = (df["implied_vol"] * 100).tolist() - ctx.qf.print(plot(data, {"height": height})) - else: - df[["ttm", "moneyness", "moneyness_ttm"]] = df[ - ["ttm", "moneyness", "moneyness_ttm"] - ].map("{:.4f}".format) - df["implied_vol"] = df["implied_vol"].map("{:.2%}".format) - df["price"] = df["price"].map(lambda p: round_to_step(p, vs.tick_size_options)) - df["forward_price"] = df["forward_price"].map( - lambda p: round_to_step(p, vs.tick_size_forwards) - ) - ctx.qf.print(df_to_rich(df)) - - -@crypto.command() -@click.argument("symbol") -@options.height -@options.length -@options.chart -@options.frequency -def prices(symbol: str, height: int, length: int, chart: bool, frequency: str) -> None: - """Fetch OHLC prices for given cryptocurrency""" - ctx = QuantContext.current() - df = asyncio.run(get_prices(ctx, symbol, frequency)) - if df.empty: - raise click.UsageError( - f"No data for {symbol} - are you sure the symbol exists?" - ) - if chart: - data = list(reversed(df["close"].tolist()[:length])) - ctx.qf.print(plot(data, {"height": height})) - else: - ctx.qf.print( - df_to_rich( - df[["date", "open", "high", "low", "close", "volume"]].sort_values( - "date" - ) - ) - ) - - -async def get_instruments(ctx: QuantContext, currency: str, kind: str) -> list[dict]: - async with Deribit() as client: - return await client.get_instruments( - currency=currency, kind=InstrumentKind(kind) - ) - - -async def get_volatility(ctx: QuantContext, currency: str) -> pd.DataFrame: - async with Deribit() as client: - return await client.get_volatility(currency) - - -@AsyncTTL(time_to_live=10) -async def get_vol_surface(currency: str) -> VolSurface: - async with Deribit() as client: - loader = await client.volatility_surface_loader(currency) - return loader.surface() diff --git a/quantflow/cli/commands/fred.py b/quantflow/cli/commands/fred.py deleted file mode 100644 index 7a130ae6..00000000 --- a/quantflow/cli/commands/fred.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -import asyncio - -import click -import pandas as pd -from asciichartpy import plot -from ccy.cli.console import df_to_rich -from fluid.utils.data import compact_dict -from fluid.utils.http_client import HttpResponseError - -from quantflow.data.fred import Fred - -from .base import QuantContext, options, quant_group - -FREQUENCIES = tuple(Fred.freq) - - -@quant_group() -def fred() -> None: - """Federal Reserve of St. Louis data commands""" - ctx = QuantContext.current() - if ctx.invoked_subcommand is None: - ctx.set_as_section() - - -@fred.command() -@click.argument("category-id", required=False) -def subcategories(category_id: str | None = None) -> None: - """List subcategories for a Fred category""" - ctx = QuantContext.current() - try: - data = asyncio.run(get_subcategories(ctx, category_id)) - except HttpResponseError as e: - ctx.qf.error(e) - else: - df = pd.DataFrame(data["categories"], columns=["id", "name"]) - ctx.qf.print(df_to_rich(df)) - - -@fred.command() -@click.argument("category-id") -@click.option("-j", "--json", is_flag=True, help="Output as JSON") -def series(category_id: str, json: bool = False) -> None: - """List series for a Fred category""" - ctx = QuantContext.current() - try: - data = asyncio.run(get_series(ctx, category_id)) - except HttpResponseError as e: - ctx.qf.error(e) - else: - if json: - ctx.qf.print(data) - else: - df = pd.DataFrame( - data["seriess"], - columns=[ - "id", - "popularity", - "title", - "frequency", - "observation_start", - "observation_end", - ], - ).sort_values("popularity", ascending=False) - ctx.qf.print(df_to_rich(df)) - - -@fred.command() -@click.argument("series-id") -@options.length -@options.height -@options.chart -@click.option( - "-f", - "--frequency", - type=click.Choice(FREQUENCIES), - default="d", - show_default=True, - help="Frequency of data", -) -def data(series_id: str, length: int, height: int, chart: bool, frequency: str) -> None: - """Display a series data""" - ctx = QuantContext.current() - try: - df = asyncio.run(get_serie_data(ctx, series_id, length, frequency)) - except HttpResponseError as e: - ctx.qf.error(e) - else: - if chart: - data = list(reversed(df["value"].tolist()[:length])) - ctx.qf.print(plot(data, {"height": height})) - else: - ctx.qf.print(df_to_rich(df)) - - -async def get_subcategories(ctx: QuantContext, category_id: str | None) -> dict: - async with ctx.fred() as cli: - return await cli.subcategories(params=compact_dict(category_id=category_id)) - - -async def get_series(ctx: QuantContext, category_id: str) -> dict: - async with ctx.fred() as cli: - return await cli.series(params=compact_dict(category_id=category_id)) - - -async def get_serie_data( - ctx: QuantContext, series_id: str, length: int, frequency: str -) -> dict: - async with ctx.fred() as cli: - return await cli.serie_data( - params=dict( - series_id=series_id, - limit=length, - frequency=frequency, - sort_order="desc", - ) - ) diff --git a/quantflow/cli/commands/stocks.py b/quantflow/cli/commands/stocks.py deleted file mode 100644 index 3c62d4ce..00000000 --- a/quantflow/cli/commands/stocks.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -import asyncio -from datetime import timedelta -from typing import cast - -import click -import pandas as pd -from asciichartpy import plot -from ccy import period as to_period -from ccy.cli.console import df_to_rich -from ccy.tradingcentres import prevbizday - -from quantflow.utils.dates import utcnow - -from .base import HistoricalPeriod, QuantContext, options, quant_group - - -@quant_group() -def stocks() -> None: - """Stocks commands""" - ctx = QuantContext.current() - if ctx.invoked_subcommand is None: - ctx.set_as_section() - - -@stocks.command() -def indices() -> None: - """Search companies""" - ctx = QuantContext.current() - data = asyncio.run(get_indices(ctx)) - df = pd.DataFrame(data) - ctx.qf.print(df_to_rich(df)) - - -@stocks.command() -@click.argument("symbol") -def profile(symbol: str) -> None: - """Company profile""" - ctx = QuantContext.current() - data = asyncio.run(get_profile(ctx, symbol)) - if not data: - raise click.UsageError(f"Company {symbol} not found - try searching") - else: - d = data[0] - ctx.qf.print(d.pop("description") or "") - df = pd.DataFrame(d.items(), columns=["Key", "Value"]) - ctx.qf.print(df_to_rich(df)) - - -@stocks.command() -@click.argument("text") -def search(text: str) -> None: - """Search companies""" - ctx = QuantContext.current() - data = asyncio.run(search_company(ctx, text)) - df = pd.DataFrame(data, columns=["symbol", "name", "currency", "stockExchange"]) - ctx.qf.print(df_to_rich(df)) - - -@stocks.command() -@click.argument("symbol") -@options.height -@options.length -@options.frequency -def chart(symbol: str, height: int, length: int, frequency: str) -> None: - """Symbol chart""" - ctx = QuantContext.current() - df = asyncio.run(get_prices(ctx, symbol, frequency)) - if df.empty: - raise click.UsageError( - f"No data for {symbol} - are you sure the symbol exists?" - ) - data = list(reversed(df["close"].tolist()[:length])) - print(plot(data, {"height": height})) - - -@stocks.command() -@options.period -def sectors(period: str) -> None: - """Sectors performance and PE ratios""" - ctx = QuantContext.current() - data = asyncio.run(sector_performance(ctx, HistoricalPeriod(period))) - df = pd.DataFrame(data, columns=["sector", "performance", "pe"]).sort_values( - "performance", ascending=False - ) - ctx.qf.print(df_to_rich(df)) - - -async def get_indices(ctx: QuantContext) -> list[dict]: - async with ctx.fmp() as cli: - return await cli.indices() - - -async def get_prices(ctx: QuantContext, symbol: str, frequency: str) -> pd.DataFrame: - async with ctx.fmp() as cli: - return await cli.prices(symbol, frequency=frequency) - - -async def get_profile(ctx: QuantContext, symbol: str) -> list[dict]: - async with ctx.fmp() as cli: - return await cli.profile(symbol) - - -async def search_company(ctx: QuantContext, text: str) -> list[dict]: - async with ctx.fmp() as cli: - return await cli.search(text) - - -async def sector_performance( - ctx: QuantContext, period: HistoricalPeriod -) -> dict | list[dict]: - async with ctx.fmp() as cli: - to_date = utcnow().date() - if period != HistoricalPeriod.day: - from_date = to_date - timedelta(days=to_period(period.value).totaldays) - sp = await cli.sector_performance( - from_date=prevbizday(from_date, 0).isoformat(), # type: ignore - to_date=prevbizday(to_date, 0).isoformat(), # type: ignore - summary=True, - ) - else: - sp = await cli.sector_performance() - spd = cast(dict, sp) - pe = await cli.sector_pe(params=dict(date=prevbizday(to_date, 0).isoformat())) # type: ignore - pes = {} - for k in pe: - sector = k["sector"] - if sector in spd: - pes[sector] = round(float(k["pe"]), 3) - return [ - dict(sector=k, performance=float(v), pe=pes.get(k, float("nan"))) - for k, v in spd.items() - ] diff --git a/quantflow/cli/commands/vault.py b/quantflow/cli/commands/vault.py deleted file mode 100644 index b0644263..00000000 --- a/quantflow/cli/commands/vault.py +++ /dev/null @@ -1,52 +0,0 @@ -import click - -from .base import QuantContext, quant_group - -API_KEYS = ("fmp", "fred") - - -@quant_group() -def vault() -> None: - """Manage vault secrets""" - ctx = QuantContext.current() - if ctx.invoked_subcommand is None: - ctx.set_as_section() - - -@vault.command() -@click.argument("key", type=click.Choice(API_KEYS)) -@click.argument("value") -def add(key: str, value: str) -> None: - """Add an API key to the vault""" - app = QuantContext.current().qf - app.vault.add(key, value) - - -@vault.command() -@click.argument("key") -def delete(key: str) -> None: - """Delete an API key from the vault""" - app = QuantContext.current().qf - if app.vault.delete(key): - app.print(f"Deleted key {key}") - else: - app.error(f"Key {key} not found") - - -@vault.command() -@click.argument("key") -def show(key: str) -> None: - """Show the value of an API key""" - app = QuantContext.current().qf - if value := app.vault.get(key): - app.print(value) - else: - app.error(f"Key {key} not found") - - -@vault.command() -def keys() -> None: - """Show the keys in the vault""" - app = QuantContext.current().qf - for key in app.vault.keys(): - app.print(key) diff --git a/quantflow/cli/script.py b/quantflow/cli/script.py deleted file mode 100644 index c88aa1fe..00000000 --- a/quantflow/cli/script.py +++ /dev/null @@ -1,14 +0,0 @@ -import dotenv - -dotenv.load_dotenv() - -try: - from .app import QfApp -except ImportError as ex: - raise ImportError( - "Cannot run qf command line, " - "quantflow needs to be installed with cli & data extras, " - "pip install quantflow[cli, data]" - ) from ex - -main = QfApp() diff --git a/quantflow/cli/settings.py b/quantflow/cli/settings.py deleted file mode 100644 index 68bffa8b..00000000 --- a/quantflow/cli/settings.py +++ /dev/null @@ -1,12 +0,0 @@ -# IMPORTATION STANDARD -from pathlib import Path - -# Installation related paths -HOME_DIRECTORY = Path.home() -PACKAGE_DIRECTORY = Path(__file__).parent.parent.parent -REPOSITORY_DIRECTORY = PACKAGE_DIRECTORY.parent - -SETTINGS_DIRECTORY = HOME_DIRECTORY / ".quantflow" -SETTINGS_ENV_FILE = SETTINGS_DIRECTORY / ".env" -HIST_FILE_PATH = SETTINGS_DIRECTORY / ".quantflow.his" -VAULT_FILE_PATH = SETTINGS_DIRECTORY / ".vault" diff --git a/quantflow/data/deribit.py b/quantflow/data/deribit.py index 10f6f600..93df60e8 100644 --- a/quantflow/data/deribit.py +++ b/quantflow/data/deribit.py @@ -12,8 +12,8 @@ from fluid.utils.http_client import AioHttpClient, HttpResponse, HttpResponseError from typing_extensions import Annotated, Doc -from quantflow.options.inputs import OptionType -from quantflow.options.surface import VolSecurityType, VolSurfaceLoader +from quantflow.options.inputs import DefaultVolSecurity, OptionType +from quantflow.options.surface import VolSurfaceLoader from quantflow.utils.numbers import ( Number, round_to_step, @@ -112,7 +112,8 @@ async def volatility_surface_loader( Number | None, Doc("Exclude options with volume below this threshold") ] = None, ) -> VolSurfaceLoader: - """Create a :class:`.VolSurfaceLoader` for a given crypto-currency""" + """Create a [VolSurfaceLoader][quantflow.options.surface.VolSurfaceLoader] + for a given crypto-currency""" loader = VolSurfaceLoader( asset=currency, exclude_open_interest=to_decimal_or_none(exclude_open_interest), @@ -137,7 +138,7 @@ async def volatility_surface_loader( ask = round_to_step(ask_, tick_size) if meta["settlement_period"] == "perpetual": loader.add_spot( - VolSecurityType.spot, + DefaultVolSecurity.spot(), bid=bid, ask=ask, open_interest=to_decimal(entry["open_interest"]), @@ -150,7 +151,7 @@ async def volatility_surface_loader( utc=True, ).to_pydatetime() loader.add_forward( - VolSecurityType.forward, + DefaultVolSecurity.forward(), maturity=maturity, bid=bid, ask=ask, @@ -167,7 +168,7 @@ async def volatility_surface_loader( tick_size = to_decimal(meta["tick_size"]) min_tick_size = min(min_tick_size, tick_size) loader.add_option( - VolSecurityType.option, + DefaultVolSecurity.option(), strike=round_to_step(meta["strike"], tick_size), maturity=pd.to_datetime( meta["expiration_timestamp"], diff --git a/quantflow/options/inputs.py b/quantflow/options/inputs.py index da2461d8..f60b176b 100644 --- a/quantflow/options/inputs.py +++ b/quantflow/options/inputs.py @@ -3,9 +3,9 @@ import enum from datetime import datetime from decimal import Decimal -from typing import TypeVar +from typing import Self, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel, Field from quantflow.utils.numbers import ZERO @@ -39,34 +39,87 @@ class VolSecurityType(enum.StrEnum): forward = enum.auto() option = enum.auto() + +class VolSurfaceSecurity(BaseModel): + def vol_surface_type(self) -> VolSecurityType: + raise NotImplementedError("vol_surface_type must be implemented by subclasses") + + +class DefaultVolSecurity(VolSurfaceSecurity): + security_type: VolSecurityType = Field( + default=VolSecurityType.spot, + description="Type of security for the volatility surface", + ) + def vol_surface_type(self) -> VolSecurityType: - return self + return self.security_type + + @classmethod + def spot(cls) -> Self: + return cls(security_type=VolSecurityType.spot) + + @classmethod + def forward(cls) -> Self: + return cls(security_type=VolSecurityType.forward) + + @classmethod + def option(cls) -> Self: + return cls(security_type=VolSecurityType.option) class VolSurfaceInput(BaseModel): - bid: Decimal - ask: Decimal - open_interest: Decimal = ZERO - volume: Decimal = ZERO + """Base class for volatility surface inputs""" + + bid: Decimal = Field(description="Bid price of the security") + ask: Decimal = Field(description="Ask price of the security") + open_interest: Decimal = Field( + default=ZERO, description="Open interest of the security" + ) + volume: Decimal = Field(default=ZERO, description="Volume of the security") class SpotInput(VolSurfaceInput): - security_type: VolSecurityType = VolSecurityType.spot + """Input data for a spot contract in the volatility surface""" + + security_type: VolSecurityType = Field( + default=VolSecurityType.spot, + description="Type of security for the volatility surface", + ) class ForwardInput(VolSurfaceInput): - maturity: datetime - security_type: VolSecurityType = VolSecurityType.forward + """Input data for a forward contract in the volatility surface""" + + maturity: datetime = Field(description="Expiry date of the forward contract") + security_type: VolSecurityType = Field( + default=VolSecurityType.forward, + description="Type of security for the volatility surface", + ) class OptionInput(VolSurfaceInput): - strike: Decimal - maturity: datetime - option_type: OptionType - security_type: VolSecurityType = VolSecurityType.option + """Input data for an option in the volatility surface""" + + strike: Decimal = Field(description="Strike price of the option") + maturity: datetime = Field(description="Expiry date of the option") + option_type: OptionType = Field(description="Type of the option - call or put") + security_type: VolSecurityType = Field( + default=VolSecurityType.option, + description="Type of security for the volatility surface", + ) + iv_bid: Decimal | None = Field( + default=None, description="Implied volatility based on the bid price" + ) + iv_ask: Decimal | None = Field( + default=None, description="Implied volatility based on the ask price" + ) class VolSurfaceInputs(BaseModel): - asset: str - ref_date: datetime - inputs: list[ForwardInput | SpotInput | OptionInput] + """Class representing the inputs for a volatility surface""" + + asset: str = Field(description="Underlying asset of the volatility surface") + ref_date: datetime = Field(description="Reference date for the volatility surface") + inputs: list[ForwardInput | SpotInput | OptionInput] = Field( + description="List of inputs for the volatility surface" + ) diff --git a/quantflow/options/surface.py b/quantflow/options/surface.py index f8e19fa8..ccf45225 100644 --- a/quantflow/options/surface.py +++ b/quantflow/options/surface.py @@ -2,15 +2,15 @@ import enum import warnings -from dataclasses import dataclass, field, replace from datetime import datetime, timedelta from decimal import Decimal -from typing import Any, Generic, Iterator, NamedTuple, Protocol, Self, TypeVar +from typing import Any, Generic, Iterator, NamedTuple, Self, TypeVar import numpy as np import pandas as pd from ccy.core.daycounter import ActAct, DayCounter -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Doc from quantflow.utils import plot from quantflow.utils.dates import utcnow @@ -19,6 +19,7 @@ from .bs import black_price, implied_black_volatility from .inputs import ( + DefaultVolSecurity, ForwardInput, OptionInput, OptionType, @@ -27,16 +28,13 @@ VolSecurityType, VolSurfaceInput, VolSurfaceInputs, + VolSurfaceSecurity, ) INITIAL_VOL = 0.5 default_day_counter = ActAct() -class VolSurfaceSecurity(Protocol): - def vol_surface_type(self) -> VolSecurityType: ... - - S = TypeVar("S", bound=VolSurfaceSecurity) @@ -58,21 +56,21 @@ class OptionSelection(enum.Enum): """Select the put options only""" -@dataclass -class Price(Generic[S]): - security: S - bid: Decimal - ask: Decimal +class Price(BaseModel, Generic[S]): + security: S = Field(description="The underlying security of the price") + bid: Decimal = Field(description="Bid price") + ask: Decimal = Field(description="Ask price") @property def mid(self) -> Decimal: return (self.bid + self.ask) / 2 -@dataclass class SpotPrice(Price[S]): - open_interest: Decimal = ZERO - volume: Decimal = ZERO + open_interest: Decimal = Field( + default=ZERO, description="Open interest of the spot price" + ) + volume: Decimal = Field(default=ZERO, description="Volume of the spot price") def inputs(self) -> SpotInput: return SpotInput( @@ -83,11 +81,12 @@ def inputs(self) -> SpotInput: ) -@dataclass class FwdPrice(Price[S]): - maturity: datetime - open_interest: Decimal = ZERO - volume: Decimal = ZERO + maturity: datetime = Field(description="Maturity date of the forward price") + open_interest: Decimal = Field( + default=ZERO, description="Open interest of the forward price" + ) + volume: Decimal = Field(default=ZERO, description="Volume of the forward price") def inputs(self) -> ForwardInput: return ForwardInput( @@ -100,33 +99,36 @@ def inputs(self) -> ForwardInput: class OptionMetadata(BaseModel): - strike: Decimal - """Strike price""" - option_type: OptionType - """Type of the option""" - maturity: datetime - """Maturity date""" - forward: Decimal = ZERO - """Forward price of the underlying""" - ttm: float = 0 - """Time to maturity in years""" - open_interest: Decimal = ZERO - """Open interest of the option""" - volume: Decimal = ZERO - """Volume of the option in USD""" + strike: Decimal = Field(description="Strike price of the option") + option_type: OptionType = Field(description="Type of the option, call or put") + maturity: datetime = Field(description="Maturity date of the option") + forward: Decimal = Field( + default=ZERO, description="Forward price of the underlying" + ) + ttm: float = Field(default=0, description="Time to maturity in years") + open_interest: Decimal = Field( + default=ZERO, description="Open interest of the option" + ) + volume: Decimal = Field(default=ZERO, description="Volume of the option") class OptionPrice(BaseModel): - price: Decimal - """Price of the option divided by the forward price""" - meta: OptionMetadata - """Metadata of the option price""" - implied_vol: float = 0 - """Implied Black volatility""" - side: Side = Side.bid - """Side of the market""" - converged: bool = True - """Flag indicating if implied vol calculation converged""" + """Represents the price of an option quoted in the market along with + its metadata and implied volatility information.""" + + price: Decimal = Field( + description="Price of the option as a percentage of the forward price" + ) + meta: OptionMetadata = Field(description="Metadata of the option price") + implied_vol: float = Field( + default=0, description="Implied volatility of the option" + ) + side: Side = Field( + default=Side.bid, description="Side of the market for the option price" + ) + converged: bool = Field( + default=True, description="Flag indicating if implied vol calculation converged" + ) @classmethod def create( @@ -238,6 +240,7 @@ def put_price(self) -> Decimal: return self.price def can_price(self, converged: bool, select: OptionSelection) -> bool: + """Check if the option price can be used for implied volatility calculation""" if self.price_time > ZERO and not np.isnan(self.implied_vol): if not self.converged and converged is True: return False @@ -288,29 +291,26 @@ class OptionArrays(NamedTuple): call_put: np.ndarray -@dataclass -class OptionPrices(Generic[S]): - security: S - meta: OptionMetadata - bid: OptionPrice - ask: OptionPrice +class OptionPrices(BaseModel, Generic[S]): + security: S = Field(description="The underlying security of the option prices") + meta: OptionMetadata = Field(description="Metadata for the option prices") + bid: OptionPrice = Field(description="Bid option price") + ask: OptionPrice = Field(description="Ask option price") def prices( self, - forward: Decimal, - ttm: float, + forward: Annotated[Decimal, Doc("Forward price of the underlying asset")], + ttm: Annotated[float, Doc("Time to maturity in years")], *, - select: OptionSelection = OptionSelection.best, - initial_vol: float = INITIAL_VOL, - converged: bool = True, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, + converged: Annotated[bool, Doc("Whether the calculation has converged")] = True, ) -> Iterator[OptionPrice]: - """Iterator over bid/ask option prices - - :param forward: Forward price of the underlying asset - :param ttm: Time to maturity in years - :param select: the :class:`.OptionSelection` method - :param initial_vol: Initial volatility for the root finding algorithm - """ + """Iterator over bid/ask option prices""" self.meta.forward = forward self.meta.ttm = ttm for o in (self.bid, self.ask): @@ -322,6 +322,7 @@ def prices( yield o def inputs(self) -> OptionInput: + """Convert the option prices to an OptionInput instance""" return OptionInput( bid=self.bid.price, ask=self.ask.price, @@ -330,25 +331,34 @@ def inputs(self) -> OptionInput: strike=self.meta.strike, maturity=self.meta.maturity, option_type=self.meta.option_type, + iv_bid=to_decimal(self.bid.implied_vol), + iv_ask=to_decimal(self.ask.implied_vol), ) -@dataclass -class Strike(Generic[S]): +class Strike(BaseModel, Generic[S]): """Option prices for a single strike""" - strike: Decimal - call: OptionPrices[S] | None = None - put: OptionPrices[S] | None = None + strike: Decimal = Field(description="Strike price of the options") + call: OptionPrices[S] | None = Field( + default=None, description="Call option prices for the strike" + ) + put: OptionPrices[S] | None = Field( + default=None, description="Put option prices for the strike" + ) def option_prices( self, - forward: Decimal, - ttm: float, + forward: Annotated[Decimal, Doc("Forward price of the underlying asset")], + ttm: Annotated[float, Doc("Time to maturity in years")], *, - select: OptionSelection = OptionSelection.best, - initial_vol: float = INITIAL_VOL, - converged: bool = True, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, + converged: Annotated[bool, Doc("Whether the calculation has converged")] = True, ) -> Iterator[OptionPrice]: if select != OptionSelection.put and self.call: yield from self.call.prices( @@ -368,8 +378,7 @@ def option_prices( ) -@dataclass -class VolCrossSection(Generic[S]): +class VolCrossSection(BaseModel, Generic[S], arbitrary_types_allowed=True): """Represents a cross section of a volatility surface at a specific maturity.""" maturity: datetime @@ -401,11 +410,17 @@ def info_dict(self, ref_date: datetime, spot: SpotPrice[S]) -> dict: def option_prices( self, - ref_date: datetime, + ref_date: Annotated[ + datetime, Doc("Reference date for time to maturity calculation") + ], *, - select: OptionSelection = OptionSelection.best, - initial_vol: float = INITIAL_VOL, - converged: bool = True, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, + converged: Annotated[bool, Doc("Whether the calculation has converged")] = True, ) -> Iterator[OptionPrice]: """Iterator over option prices in the cross section""" for s in self.strikes: @@ -427,8 +442,7 @@ def securities(self) -> Iterator[FwdPrice[S] | OptionPrices[S]]: yield strike.put -@dataclass -class VolSurface(Generic[S]): +class VolSurface(BaseModel, Generic[S], arbitrary_types_allowed=True): """Represents a volatility surface, which captures the implied volatility of an option for different strikes and maturities. @@ -476,13 +490,15 @@ def securities(self) -> Iterator[SpotPrice[S] | FwdPrice[S] | OptionPrices[S]]: yield from maturity.securities() def inputs(self) -> VolSurfaceInputs: + """Convert the volatility surface to a + [VolSurfaceInputs][quantflow.options.inputs.VolSurfaceInputs] instance""" return VolSurfaceInputs( asset=self.asset, ref_date=self.ref_date, inputs=list(s.inputs() for s in self.securities()), ) - def term_structure(self, frequency: float = 0) -> pd.DataFrame: + def term_structure(self) -> pd.DataFrame: """Return the term structure of the volatility surface""" return pd.DataFrame( cross.info_dict(self.ref_date, self.spot) for cross in self.maturities @@ -490,22 +506,31 @@ def term_structure(self, frequency: float = 0) -> pd.DataFrame: def trim(self, num_maturities: int) -> Self: """Create a new volatility surface with the last `num_maturities` maturities""" - return replace(self, maturities=self.maturities[-num_maturities:]) + return self.model_copy( + update=dict(maturities=self.maturities[-num_maturities:]) + ) def option_prices( self, *, - select: OptionSelection = OptionSelection.best, - index: int | None = None, - initial_vol: float = INITIAL_VOL, - converged: bool = True, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + index: Annotated[ + int | None, Doc("Index of the cross section to use, if None use all") + ] = None, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, + converged: Annotated[ + bool, + Doc( + "Returns options with converged implied volatility " + "calculation only if True" + ), + ] = True, ) -> Iterator[OptionPrice]: - """Iterator over selected option prices in the surface - - :param select: the :class:`.OptionSelection` method - :param index: Index of the cross section to use, if None use all - :param initial_vol: Initial volatility for the root finding algorithm - """ + """Iterator over selected option prices in the surface""" if index is not None: yield from self.maturities[index].option_prices( self.ref_date, @@ -535,18 +560,20 @@ def option_list( def bs( self, *, - select: OptionSelection = OptionSelection.best, - index: int | None = None, - initial_vol: float = INITIAL_VOL, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + index: Annotated[ + int | None, Doc("Index of the cross section to use, if None use all") + ] = None, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, ) -> list[OptionPrice]: - """calculate Black-Scholes implied volatility for all options - in the surface - - :param select: the :class:`.OptionSelection` method - :param index: Index of the cross section to use, if None use all - :param initial_vol: Initial volatility for the root finding algorithm - - Some options may not converge, in this case the implied volatility is not + """Calculate Black-Scholes implied volatility for all options + in the surface. + For some options, the implied volatility calculation may not converge, + in this case the implied volatility is not calculated correctly and the option is marked as not converged. """ d = self.as_array( @@ -574,8 +601,12 @@ def bs( def calc_bs_prices( self, *, - select: OptionSelection = OptionSelection.best, - index: int | None = None, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + index: Annotated[ + int | None, Doc("Index of the cross section to use, if None use all") + ] = None, ) -> np.ndarray: """calculate Black-Scholes prices for all options in the surface""" d = self.as_array(select=select, index=index) @@ -584,10 +615,16 @@ def calc_bs_prices( def options_df( self, *, - select: OptionSelection = OptionSelection.best, - index: int | None = None, - initial_vol: float = INITIAL_VOL, - converged: bool = True, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + index: Annotated[ + int | None, Doc("Index of the cross section to use, if None use all") + ] = None, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, + converged: Annotated[bool, Doc("Whether the calculation has converged")] = True, ) -> pd.DataFrame: """Time frame of Black-Scholes call input data""" data = self.option_prices( @@ -601,17 +638,18 @@ def options_df( def as_array( self, *, - select: OptionSelection = OptionSelection.best, - index: int | None = None, - initial_vol: float = INITIAL_VOL, - converged: bool = True, + select: Annotated[ + OptionSelection, Doc("Option selection method") + ] = OptionSelection.best, + index: Annotated[ + int | None, Doc("Index of the cross section to use, if None use all") + ] = None, + initial_vol: Annotated[ + float, Doc("Initial volatility for the root finding algorithm") + ] = INITIAL_VOL, + converged: Annotated[bool, Doc("Whether the calculation has converged")] = True, ) -> OptionArrays: - """Organize option prices in a numpy arrays for black volatility calculation - - :param select: the :class:`.OptionSelection` method - :param index: Index of the cross section to use, if None use all - :param initial_vol: Initial volatility for the root finding algorithm - """ + """Organize option prices in a numpy arrays for black volatility calculation""" options = list( self.option_prices( select=select, @@ -672,14 +710,25 @@ def plot3d( return plot.plot_vol_surface_3d(df, **kwargs) -@dataclass -class VolCrossSectionLoader(Generic[S]): - maturity: datetime - forward: FwdPrice[S] | None = None - """Forward price of the underlying asset at the time of the cross section""" - strikes: dict[Decimal, Strike[S]] = field(default_factory=dict) - """List of strikes and their corresponding option prices""" - day_counter: DayCounter = default_day_counter +class VolCrossSectionLoader(BaseModel, Generic[S], arbitrary_types_allowed=True): + maturity: datetime = Field(description="Maturity date of the cross section") + forward: FwdPrice[S] | None = Field( + default=None, + description=( + "Forward price of the underlying asset at the time of the cross section" + ), + ) + strikes: dict[Decimal, Strike[S]] = Field( + default_factory=dict, + description="Dictionary of strikes and their corresponding option prices", + ) + day_counter: DayCounter = Field( + default=default_day_counter, + description=( + "Day counter for time to maturity calculations " + "- by default it uses Act/Act" + ), + ) def add_option( self, @@ -732,26 +781,41 @@ def cross_section(self) -> VolCrossSection[S] | None: ) -@dataclass -class GenericVolSurfaceLoader(Generic[S]): +class GenericVolSurfaceLoader(BaseModel, Generic[S], arbitrary_types_allowed=True): """Helper class to build a volatility surface from a list of securities""" - asset: str = "" - """Name of the underlying asset""" - spot: SpotPrice[S] | None = None - """Spot price of the underlying asset""" - maturities: dict[datetime, VolCrossSectionLoader[S]] = field(default_factory=dict) - """Dictionary of maturities and their corresponding cross section loaders""" - day_counter: DayCounter = default_day_counter + asset: str = Field(default="", description="Name of the underlying asset") + spot: SpotPrice[S] | None = Field( + default=None, description="Spot price of the underlying asset" + ) + maturities: dict[datetime, VolCrossSectionLoader[S]] = Field( + default_factory=dict, + description=( + "Dictionary of maturities and their corresponding cross section loaders" + ), + ) + day_counter: DayCounter = Field( + default=default_day_counter, + description=( + "Day counter for time to maturity calculations " + "by default it uses Act/Act" + ), + ) """Day counter for time to maturity calculations - by default it uses Act/Act""" - tick_size_forwards: Decimal | None = None - """Tick size for rounding forward and spot prices - optional""" - tick_size_options: Decimal | None = None - """Tick size for rounding option prices - optional""" - exclude_open_interest: Decimal | None = None - """Exclude options with open interest at or below this value""" - exclude_volume: Decimal | None = None - """Exclude options with volume at or below this value""" + tick_size_forwards: Decimal | None = Field( + default=None, + description="Tick size for rounding forward and spot prices - optional", + ) + tick_size_options: Decimal | None = Field( + default=None, description="Tick size for rounding option prices - optional" + ) + exclude_open_interest: Decimal | None = Field( + default=None, + description="Exclude options with open interest at or below this value", + ) + exclude_volume: Decimal | None = Field( + default=None, description="Exclude options with volume at or below this value" + ) def get_or_create_maturity(self, maturity: datetime) -> VolCrossSectionLoader[S]: if maturity not in self.maturities: @@ -773,7 +837,7 @@ def add_spot( if security.vol_surface_type() != VolSecurityType.spot: raise ValueError("Security is not a spot") self.spot = SpotPrice( - security, + security=security, bid=bid, ask=ask, open_interest=open_interest, @@ -793,7 +857,7 @@ def add_forward( if security.vol_surface_type() != VolSecurityType.forward: raise ValueError("Security is not a forward") self.get_or_create_maturity(maturity=maturity).forward = FwdPrice( - security, + security=security, bid=bid, ask=ask, maturity=maturity, @@ -851,7 +915,7 @@ def surface(self, ref_date: datetime | None = None) -> VolSurface[S]: ) -class VolSurfaceLoader(GenericVolSurfaceLoader[VolSecurityType]): +class VolSurfaceLoader(GenericVolSurfaceLoader[DefaultVolSecurity]): """A volatility surface loader""" def add(self, input: VolSurfaceInput) -> None: @@ -862,7 +926,7 @@ def add(self, input: VolSurfaceInput) -> None: """ if isinstance(input, SpotInput): self.add_spot( - VolSecurityType.spot, + DefaultVolSecurity.spot(), bid=input.bid, ask=input.ask, open_interest=input.open_interest, @@ -870,7 +934,7 @@ def add(self, input: VolSurfaceInput) -> None: ) elif isinstance(input, ForwardInput): self.add_forward( - VolSecurityType.forward, + DefaultVolSecurity.forward(), maturity=input.maturity, bid=input.bid, ask=input.ask, @@ -879,7 +943,7 @@ def add(self, input: VolSurfaceInput) -> None: ) elif isinstance(input, OptionInput): self.add_option( - VolSecurityType.option, + DefaultVolSecurity.option(), strike=input.strike, option_type=input.option_type, maturity=input.maturity, @@ -892,7 +956,7 @@ def add(self, input: VolSurfaceInput) -> None: raise ValueError(f"Unknown input type {type(input)}") -def surface_from_inputs(inputs: VolSurfaceInputs) -> VolSurface[VolSecurityType]: +def surface_from_inputs(inputs: VolSurfaceInputs) -> VolSurface[DefaultVolSecurity]: loader = VolSurfaceLoader() for input in inputs.inputs: loader.add(input) diff --git a/readme.md b/readme.md index 9fd27ad9..ce864123 100644 --- a/readme.md +++ b/readme.md @@ -18,21 +18,78 @@ pip install quantflow ## Modules -* [quantflow.cli](https://github.com/quantmind/quantflow/tree/main/quantflow/cli) command line client (requires `quantflow[cli,data]`) +* [quantflow.ai](https://github.com/quantmind/quantflow/tree/main/quantflow/ai) MCP server for AI clients (requires `quantflow[ai,data]`) * [quantflow.data](https://github.com/quantmind/quantflow/tree/main/quantflow/data) data APIs (requires `quantflow[data]`) * [quantflow.options](https://github.com/quantmind/quantflow/tree/main/quantflow/options) option pricing and calibration * [quantflow.sp](https://github.com/quantmind/quantflow/tree/main/quantflow/sp) stochastic process primitives * [quantflow.ta](https://github.com/quantmind/quantflow/tree/main/quantflow/ta) timeseries analysis tools * [quantflow.utils](https://github.com/quantmind/quantflow/tree/main/quantflow/utils) utilities and helpers +## Optional dependencies +* `data` — data retrieval: `pip install quantflow[data]` +* `ai` — MCP server for AI clients: `pip install quantflow[ai,data]` -## Command line tools +## MCP Server -The command line tools are available when installing with the extra `cli` and `data` dependencies. +Quantflow exposes its data tools as an [MCP](https://modelcontextprotocol.io) server, allowing AI clients such as Claude to query market data, crypto volatility surfaces, and economic indicators directly. + +Install with the `ai` and `data` extras: + +```bash +pip install quantflow[ai,data] +``` + +### API keys + +Store your API keys in `~/.quantflow/.vault`: + +``` +fmp=your-fmp-key +fred=your-fred-key +``` + +Or let the AI manage them for you via the `vault_add` tool once connected. + +### Claude Code ```bash -pip install quantflow[cli,data] +claude mcp add quantflow -- uv run qf-mcp +``` + +### Claude Desktop + +Add to your Claude Desktop config (`~/Library/Application Support/Claude/claude_desktop_config.json` on macOS): + +```json +{ + "mcpServers": { + "quantflow": { + "command": "uv", + "args": ["run", "qf-mcp"] + } + } +} ``` -It is possible to use the command line tool `qf` to download data and run pricing and calibration scripts. +### Available tools + +| Tool | Description | +|---|---| +| `vault_keys` | List stored API keys | +| `vault_add` | Add or update an API key | +| `vault_delete` | Delete an API key | +| `stock_indices` | List stock market indices | +| `stock_search` | Search companies by name or symbol | +| `stock_profile` | Get company profile | +| `stock_prices` | Get OHLC price history | +| `sector_performance` | Sector performance and PE ratios | +| `crypto_instruments` | List Deribit instruments | +| `crypto_historical_volatility` | Historical volatility from Deribit | +| `crypto_term_structure` | Volatility term structure | +| `crypto_implied_volatility` | Implied volatility surface | +| `crypto_prices` | Crypto OHLC price history | +| `ascii_chart` | ASCII chart for any stock or crypto symbol | +| `fred_subcategories` | Browse FRED categories | +| `fred_series` | List series in a FRED category | +| `fred_data` | Fetch FRED observations | diff --git a/taplo.toml b/taplo.toml index fb011331..cbb5a42a 100644 --- a/taplo.toml +++ b/taplo.toml @@ -13,6 +13,7 @@ formatting = { reorder_arrays = true, reorder_keys = true } include = [ "pyproject.toml" ] keys = [ "project.optional-dependencies", + "project.optional-dependencies.ai", "project.optional-dependencies.book", "project.optional-dependencies.data", "project.optional-dependencies.dev", diff --git a/uv.lock b/uv.lock index 70328205..d36c208a 100644 --- a/uv.lock +++ b/uv.lock @@ -219,27 +219,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] -[[package]] -name = "asciichartpy" -version = "1.5.25" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "setuptools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/41/3a/b01436be647f881515ec2f253616bf4a40c1d27d02a69e7f038e27fcdf81/asciichartpy-1.5.25.tar.gz", hash = "sha256:63a305302b2aad51da288b58226009b7b0313eba7d8e2452d5a21a13fcf44d74", size = 8201, upload-time = "2020-08-17T02:07:18.292Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/d0/7b958df957e4827837b590944008f0b28078f552b451f7407b4b3d54f574/asciichartpy-1.5.25-py2.py3-none-any.whl", hash = "sha256:33c417a3c8ef7d0a11b98eb9ea6dd9b2c1b17559e539b207a17d26d4302d0258", size = 7228, upload-time = "2020-08-17T02:07:16.386Z" }, -] - -[[package]] -name = "async-cache" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d9/fc/2c1f3ad6eeb791547512d220fd704dafb60550300aab0cf6d7e2fd726603/async_cache-2.0.0.tar.gz", hash = "sha256:62b17c216b0b437dcfdf1890b4770b8b11f3b42ce4be1c30949575c2c2194911", size = 17046, upload-time = "2026-02-28T11:23:08.875Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/de/67f89784b5f332e68cd666c998fc993abe6a6ac01a63a1912b0eeacf3396/async_cache-2.0.0-py3-none-any.whl", hash = "sha256:620b632adc7f26efdcae3f371d7627bbc4649815f42ec661ccc05953d5e93b76", size = 9755, upload-time = "2026-02-28T11:23:07.492Z" }, -] - [[package]] name = "async-timeout" version = "5.0.1" @@ -1091,18 +1070,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/7f/13cd798d180af4bf4c0ceddeefba2b864a63c71645abc0308b768d67bb81/hjson-3.1.0-py3-none-any.whl", hash = "sha256:65713cdcf13214fb554eb8b4ef803419733f4f5e551047c9b711098ab7186b89", size = 54018, upload-time = "2022-08-13T02:52:59.899Z" }, ] -[[package]] -name = "holidays" -version = "0.92" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "python-dateutil" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a4/9a/e397b5c64a17f198b7b9b719244b1ffb823ac685656e608b70de7a5b59da/holidays-0.92.tar.gz", hash = "sha256:5d716ececf94e0d354ccee255541f6ba702078d7ed17b693262f6446214904a5", size = 844925, upload-time = "2026-03-02T19:33:17.152Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/85/00/8ee09c2c671cc7e95c6212d1c15b2b67c2011468f352c21200e18c08e6c0/holidays-0.92-py3-none-any.whl", hash = "sha256:92c192a20d80cd2ddbdf3166d73a9692c59701ded34f6754115b3c849ac60857", size = 1385981, upload-time = "2026-03-02T19:33:15.627Z" }, -] - [[package]] name = "httpcore" version = "1.0.9" @@ -2531,18 +2498,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/c3/24a2f845e3917201628ecaba4f18bab4d18a337834c1df2a159ee9d22a42/prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055", size = 64057, upload-time = "2026-01-14T15:26:24.42Z" }, ] -[[package]] -name = "prompt-toolkit" -version = "3.0.52" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "wcwidth" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, -] - [[package]] name = "propcache" version = "0.4.1" @@ -3161,7 +3116,7 @@ wheels = [ [[package]] name = "quantflow" -version = "0.4.4" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "ccy" }, @@ -3173,27 +3128,22 @@ dependencies = [ ] [package.optional-dependencies] +ai = [ + { name = "google-genai" }, + { name = "mcp" }, + { name = "openai" }, + { name = "pydantic-ai-slim" }, + { name = "rich" }, +] book = [ { name = "altair" }, { name = "autodocsumm" }, { name = "duckdb" }, { name = "fastapi" }, - { name = "google-genai" }, { name = "marimo" }, - { name = "mcp" }, - { name = "openai" }, { name = "plotly" }, - { name = "pydantic-ai-slim" }, { name = "sympy" }, ] -cli = [ - { name = "asciichartpy" }, - { name = "async-cache" }, - { name = "click" }, - { name = "holidays" }, - { name = "prompt-toolkit" }, - { name = "rich" }, -] data = [ { name = "aio-fluid", extra = ["http"] }, ] @@ -3223,37 +3173,32 @@ ml = [ requires-dist = [ { name = "aio-fluid", extras = ["http"], marker = "extra == 'data'", specifier = ">=1.2.1" }, { name = "altair", marker = "extra == 'book'", specifier = ">=6.0.0" }, - { name = "asciichartpy", marker = "extra == 'cli'", specifier = ">=1.5.25" }, - { name = "async-cache", marker = "extra == 'cli'", specifier = ">=1.1.1" }, { name = "autodocsumm", marker = "extra == 'book'", specifier = ">=0.2.14" }, { name = "black", marker = "extra == 'dev'", specifier = ">=26.3.1" }, { name = "ccy", specifier = ">=1.7.1" }, - { name = "click", marker = "extra == 'cli'", specifier = ">=8.1.7" }, { name = "duckdb", marker = "extra == 'book'", specifier = ">=1.4.4" }, { name = "fastapi", marker = "extra == 'book'", specifier = ">=0.129.0" }, { name = "ghp-import", marker = "extra == 'dev'", specifier = ">=2.0.2" }, - { name = "google-genai", marker = "extra == 'book'", specifier = ">=1.61.0" }, + { name = "google-genai", marker = "extra == 'ai'", specifier = ">=1.61.0" }, { name = "griffe-pydantic", marker = "extra == 'docs'", specifier = ">=1.1.0" }, { name = "griffe-typingdoc", marker = "extra == 'docs'", specifier = ">=0.2.7" }, - { name = "holidays", marker = "extra == 'cli'", specifier = ">=0.63" }, { name = "isort", marker = "extra == 'dev'", specifier = ">=8.0.0" }, { name = "marimo", marker = "extra == 'book'", specifier = ">=0.19.7" }, - { name = "mcp", marker = "extra == 'book'", specifier = ">=1.26.0" }, + { name = "mcp", marker = "extra == 'ai'", specifier = ">=1.26.0" }, { name = "mkdocs-macros-plugin", marker = "extra == 'docs'", specifier = ">=1.3.7" }, { name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.7.0" }, { name = "mkdocs-redirects", marker = "extra == 'docs'", specifier = ">=1.2.1" }, { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = "==1.0.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.14.1" }, - { name = "openai", marker = "extra == 'book'", specifier = ">=2.16.0" }, + { name = "openai", marker = "extra == 'ai'", specifier = ">=2.16.0" }, { name = "plotly", marker = "extra == 'book'", specifier = ">=6.2.0" }, { name = "polars", extras = ["pandas", "pyarrow"], specifier = ">=1.11.0" }, - { name = "prompt-toolkit", marker = "extra == 'cli'", specifier = ">=3.0.43" }, { name = "pydantic", specifier = ">=2.0.2" }, - { name = "pydantic-ai-slim", marker = "extra == 'book'", specifier = ">=1.51.0" }, + { name = "pydantic-ai-slim", marker = "extra == 'ai'", specifier = ">=1.51.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=1.0.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, - { name = "rich", marker = "extra == 'cli'", specifier = ">=13.9.4" }, + { name = "rich", marker = "extra == 'ai'", specifier = ">=13.9.4" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.15.4" }, { name = "scipy", specifier = ">=1.14.1" }, { name = "statsmodels", specifier = ">=0.14.6,<0.15.0" }, @@ -3261,7 +3206,7 @@ requires-dist = [ { name = "torch", marker = "extra == 'ml'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu126" }, { name = "types-python-dateutil", marker = "extra == 'dev'", specifier = ">=2.9.0.20251115" }, ] -provides-extras = ["book", "cli", "data", "dev", "docs", "ml"] +provides-extras = ["ai", "book", "data", "dev", "docs", "ml"] [[package]] name = "redis" @@ -4026,15 +3971,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] -[[package]] -name = "wcwidth" -version = "0.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, -] - [[package]] name = "websockets" version = "16.0" From 44da7c80fb814fc2b087c36ba40153bcea9f98f5 Mon Sep 17 00:00:00 2001 From: Luca Date: Sun, 22 Mar 2026 17:24:20 +0000 Subject: [PATCH 2/3] Test ai --- pyproject.toml | 2 + quantflow/ai/tools/base.py | 11 - quantflow/ai/tools/crypto.py | 8 +- quantflow/ai/tools/fred.py | 6 +- quantflow/ai/tools/stocks.py | 21 +- quantflow_tests/test_ai.py | 469 +++++++++++++++++++++++++++++++++++ uv.lock | 33 +++ 7 files changed, 519 insertions(+), 31 deletions(-) create mode 100644 quantflow_tests/test_ai.py diff --git a/pyproject.toml b/pyproject.toml index 20fd3b4e..fbd1e7a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ Documentation = "https://quantmind.github.io/quantflow/" [project.optional-dependencies] ai = [ + "asciichartpy>=1.5.25", + "ccy[holidays]>=1.7.1", "google-genai>=1.61.0", "mcp>=1.26.0", "openai>=2.16.0", diff --git a/quantflow/ai/tools/base.py b/quantflow/ai/tools/base.py index ce53ed5f..e58b8a9f 100644 --- a/quantflow/ai/tools/base.py +++ b/quantflow/ai/tools/base.py @@ -1,11 +1,7 @@ from dataclasses import dataclass, field -from io import StringIO from pathlib import Path -import pandas as pd -from ccy.cli.console import df_to_rich from mcp.server.fastmcp.exceptions import ToolError -from rich.console import Console from quantflow.data.fmp import FMP from quantflow.data.fred import Fred @@ -35,10 +31,3 @@ def fred(self) -> Fred: " Please add it using the vault_add tool." ) return Fred(key=key) - - def rich(self, df: pd.DataFrame) -> str: - table = df_to_rich(df) - buf = StringIO() - console = Console(file=buf, no_color=True) - console.print(table) - return buf.getvalue() diff --git a/quantflow/ai/tools/crypto.py b/quantflow/ai/tools/crypto.py index 2efea5ce..24a51b31 100644 --- a/quantflow/ai/tools/crypto.py +++ b/quantflow/ai/tools/crypto.py @@ -37,7 +37,7 @@ async def crypto_historical_volatility(currency: str) -> str: df = await client.get_volatility(currency) if df.empty: return f"No volatility data for {currency}" - return f"Historical volatility for {currency}:\n{df.to_string(index=False)}" + return df.to_csv(index=False) @mcp.tool() async def crypto_term_structure(currency: str) -> str: @@ -52,7 +52,7 @@ async def crypto_term_structure(currency: str) -> str: loader = await client.volatility_surface_loader(currency) vs: VolSurface = loader.surface() ts = vs.term_structure().round({"ttm": 4}) - return f"Term structure for {currency}:\n{ts.to_string(index=False)}" + return ts.to_csv(index=False) @mcp.tool() async def crypto_implied_volatility(currency: str, maturity_index: int = -1) -> str: @@ -71,7 +71,7 @@ async def crypto_implied_volatility(currency: str, maturity_index: int = -1) -> vs.bs(index=index) df = vs.options_df(index=index) df["implied_vol"] = df["implied_vol"].map("{:.2%}".format) - return f"Implied volatility for {currency}:\n{df.to_string(index=False)}" + return df.to_csv(index=False) @mcp.tool() async def crypto_prices(symbol: str, frequency: str = "") -> str: @@ -87,4 +87,4 @@ async def crypto_prices(symbol: str, frequency: str = "") -> str: if df.empty: return f"No price data for {symbol}" df = df[["date", "open", "high", "low", "close", "volume"]].sort_values("date") - return f"Prices for {symbol}:\n{df.tail(50).to_string(index=False)}" + return df.tail(50).to_csv(index=False) diff --git a/quantflow/ai/tools/fred.py b/quantflow/ai/tools/fred.py index beb19cd5..dc363efc 100644 --- a/quantflow/ai/tools/fred.py +++ b/quantflow/ai/tools/fred.py @@ -26,7 +26,7 @@ async def fred_subcategories(category_id: str | None = None) -> str: import pandas as pd df = pd.DataFrame(cats, columns=["id", "name"]) - return f"FRED categories:\n{df.to_string(index=False)}" + return df.to_csv(index=False) @mcp.tool() async def fred_series(category_id: str) -> str: @@ -55,7 +55,7 @@ async def fred_series(category_id: str) -> str: "observation_end", ], ).sort_values("popularity", ascending=False) - return f"FRED series for category {category_id}:\n{df.to_string(index=False)}" + return df.to_csv(index=False) @mcp.tool() async def fred_data( @@ -79,4 +79,4 @@ async def fred_data( sort_order="desc", ) ) - return f"FRED data for {series_id}:\n{df.to_string(index=False)}" + return df.to_csv(index=False) diff --git a/quantflow/ai/tools/stocks.py b/quantflow/ai/tools/stocks.py index 2b033220..bc9cd595 100644 --- a/quantflow/ai/tools/stocks.py +++ b/quantflow/ai/tools/stocks.py @@ -1,8 +1,12 @@ """Stocks tools for the quantflow MCP server.""" from datetime import timedelta +from typing import cast import pandas as pd +from ccy import period as to_period +from ccy.tradingcentres import prevbizday +from fluid.utils.data import compact_dict from mcp.server.fastmcp import FastMCP from quantflow.utils.dates import utcnow @@ -17,7 +21,7 @@ async def stock_indices() -> str: """List available stock market indices.""" async with tool.fmp() as client: data = await client.indices() - return tool.rich(pd.DataFrame(data)) + return pd.DataFrame(data).to_csv(index=False) @mcp.tool() async def stock_search(query: str) -> str: @@ -30,7 +34,7 @@ async def stock_search(query: str) -> str: data = await client.search(query) df = pd.DataFrame(data, columns=["symbol", "name", "currency", "stockExchange"]) - return f"Search results for '{query}':\n{df.to_string(index=False)}" + return df.to_csv(index=False) @mcp.tool() async def stock_profile(symbol: str) -> str: @@ -62,7 +66,7 @@ async def stock_prices(symbol: str, frequency: str = "") -> str: if df.empty: return f"No price data for {symbol}" df = df[["date", "open", "high", "low", "close", "volume"]].sort_values("date") - return f"Prices for {symbol}:\n{df.tail(50).to_string(index=False)}" + return df.to_csv(index=False) @mcp.tool() async def sector_performance(period: str = "1d") -> str: @@ -71,10 +75,6 @@ async def sector_performance(period: str = "1d") -> str: Args: period: Time period - 1d, 1w, 1m, 3m, 6m, 1y (default: 1d) """ - from ccy import period as to_period - from ccy.tradingcentres import prevbizday - from fluid.utils.data import compact_dict - async with tool.fmp() as client: to_date = utcnow().date() if period != "1d": @@ -89,11 +89,6 @@ async def sector_performance(period: str = "1d") -> str: pe = await client.sector_pe( params=compact_dict(date=prevbizday(to_date, 0).isoformat()) ) - - from typing import cast - - import pandas as pd - spd = cast(dict, sp) pes = {k["sector"]: round(float(k["pe"]), 3) for k in pe if k["sector"] in spd} rows = [ @@ -101,4 +96,4 @@ async def sector_performance(period: str = "1d") -> str: for k, v in spd.items() ] df = pd.DataFrame(rows).sort_values("performance", ascending=False) - return f"Sector performance ({period}):\n{df.to_string(index=False)}" + return df.to_csv(index=False) diff --git a/quantflow_tests/test_ai.py b/quantflow_tests/test_ai.py new file mode 100644 index 00000000..0c0ddcdd --- /dev/null +++ b/quantflow_tests/test_ai.py @@ -0,0 +1,469 @@ +"""Unit tests for the quantflow MCP server tools.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pandas as pd +import pytest +from mcp.server.fastmcp import FastMCP + +from quantflow.ai.tools import charts, crypto, fred, stocks, vault +from quantflow.ai.tools.base import McpTool +from quantflow.data.vault import Vault +from quantflow.options.surface import VolSurfaceInputs, surface_from_inputs + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def text(result: Any) -> str: + """Extract text from a call_tool result (tuple of (blocks, metadata)).""" + blocks = result[0] if isinstance(result, tuple) else result + if blocks and hasattr(blocks[0], "text"): + return blocks[0].text + return str(result) + + +def raw(result: Any) -> Any: + """Get the raw return value from call_tool result.""" + if isinstance(result, tuple) and len(result) > 1: + return result[1].get("result") + return result + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def vault_path(tmp_path: Path) -> Path: + return tmp_path / ".vault" + + +@pytest.fixture +def mcp_tool(vault_path: Path) -> McpTool: + return McpTool(vault=Vault(vault_path)) + + +@pytest.fixture +def mock_fmp() -> AsyncMock: + mock = AsyncMock() + mock.__aenter__ = AsyncMock(return_value=mock) + mock.__aexit__ = AsyncMock(return_value=False) + return mock + + +@pytest.fixture +def mock_fred() -> AsyncMock: + mock = AsyncMock() + mock.__aenter__ = AsyncMock(return_value=mock) + mock.__aexit__ = AsyncMock(return_value=False) + return mock + + +@pytest.fixture +def vol_surface(): + with open("quantflow_tests/volsurface.json") as fp: + return surface_from_inputs(VolSurfaceInputs(**json.load(fp))) + + +@pytest.fixture +def vault_server(mcp_tool: McpTool) -> FastMCP: + mcp = FastMCP("test-vault") + vault.register(mcp, mcp_tool) + return mcp + + +@pytest.fixture +def stocks_server(mcp_tool: McpTool) -> FastMCP: + mcp = FastMCP("test-stocks") + stocks.register(mcp, mcp_tool) + return mcp + + +@pytest.fixture +def crypto_server(mcp_tool: McpTool) -> FastMCP: + mcp = FastMCP("test-crypto") + crypto.register(mcp, mcp_tool) + return mcp + + +@pytest.fixture +def fred_server(mcp_tool: McpTool) -> FastMCP: + mcp = FastMCP("test-fred") + fred.register(mcp, mcp_tool) + return mcp + + +@pytest.fixture +def charts_server(mcp_tool: McpTool) -> FastMCP: + mcp = FastMCP("test-charts") + charts.register(mcp, mcp_tool) + return mcp + + +# --------------------------------------------------------------------------- +# Vault tools +# --------------------------------------------------------------------------- + + +async def test_vault_keys_empty(vault_server: FastMCP) -> None: + result = await vault_server.call_tool("vault_keys", {}) + assert raw(result) == [] + + +async def test_vault_add(vault_server: FastMCP, mcp_tool: McpTool) -> None: + result = await vault_server.call_tool("vault_add", {"key": "fmp", "value": "abc"}) + assert "fmp" in text(result) + assert mcp_tool.vault.get("fmp") == "abc" + + +async def test_vault_keys_after_add(vault_server: FastMCP) -> None: + await vault_server.call_tool("vault_add", {"key": "fred", "value": "xyz"}) + result = await vault_server.call_tool("vault_keys", {}) + assert "fred" in raw(result) + + +async def test_vault_delete_existing(vault_server: FastMCP) -> None: + await vault_server.call_tool("vault_add", {"key": "fmp", "value": "abc"}) + result = await vault_server.call_tool("vault_delete", {"key": "fmp"}) + assert "deleted" in text(result) + + +async def test_vault_delete_missing(vault_server: FastMCP) -> None: + result = await vault_server.call_tool("vault_delete", {"key": "nope"}) + assert "not found" in text(result) + + +# --------------------------------------------------------------------------- +# Stock tools +# --------------------------------------------------------------------------- + + +async def test_stock_indices( + stocks_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.indices.return_value = [ + {"symbol": "^GSPC", "name": "S&P 500"}, + {"symbol": "^IXIC", "name": "NASDAQ Composite"}, + ] + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await stocks_server.call_tool("stock_indices", {}) + assert "^GSPC" in text(result) + assert "S&P 500" in text(result) + + +async def test_stock_search( + stocks_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.search.return_value = [ + { + "symbol": "AAPL", + "name": "Apple Inc.", + "currency": "USD", + "stockExchange": "NASDAQ", + }, + ] + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await stocks_server.call_tool("stock_search", {"query": "Apple"}) + assert "AAPL" in text(result) + + +async def test_stock_profile_found( + stocks_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.profile.return_value = [ + { + "symbol": "AAPL", + "companyName": "Apple Inc.", + "description": "Tech company.", + "price": 200.0, + } + ] + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await stocks_server.call_tool("stock_profile", {"symbol": "AAPL"}) + assert "Tech company" in text(result) + assert "AAPL" in text(result) + + +async def test_stock_profile_not_found( + stocks_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.profile.return_value = [] + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await stocks_server.call_tool("stock_profile", {"symbol": "FAKE"}) + assert "No profile" in text(result) + + +async def test_stock_prices( + stocks_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.prices.return_value = pd.DataFrame( + [ + { + "date": "2025-01-01", + "open": 100, + "high": 110, + "low": 90, + "close": 105, + "volume": 1000, + } + ] + ) + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await stocks_server.call_tool("stock_prices", {"symbol": "AAPL"}) + assert "2025-01-01" in text(result) + + +async def test_stock_prices_empty( + stocks_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.prices.return_value = pd.DataFrame() + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await stocks_server.call_tool("stock_prices", {"symbol": "FAKE"}) + assert "No price data" in text(result) + + +# --------------------------------------------------------------------------- +# Crypto tools +# --------------------------------------------------------------------------- + + +async def test_crypto_instruments(crypto_server: FastMCP) -> None: + mock_client = AsyncMock() + mock_client.get_instruments.return_value = [ + MagicMock(__str__=lambda self: "BTC-SPOT") + ] + + with patch("quantflow.ai.tools.crypto.Deribit") as MockDeribit: + MockDeribit.return_value.__aenter__ = AsyncMock(return_value=mock_client) + MockDeribit.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await crypto_server.call_tool( + "crypto_instruments", {"currency": "BTC"} + ) + assert "BTC" in text(result) + + +async def test_crypto_instruments_empty(crypto_server: FastMCP) -> None: + mock_client = AsyncMock() + mock_client.get_instruments.return_value = [] + + with patch("quantflow.ai.tools.crypto.Deribit") as MockDeribit: + MockDeribit.return_value.__aenter__ = AsyncMock(return_value=mock_client) + MockDeribit.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await crypto_server.call_tool( + "crypto_instruments", {"currency": "BTC"} + ) + assert "No instruments" in text(result) + + +async def test_crypto_historical_volatility(crypto_server: FastMCP) -> None: + mock_client = AsyncMock() + mock_client.get_volatility.return_value = pd.DataFrame( + [{"date": "2025-01-01", "volatility": 0.8}] + ) + + with patch("quantflow.ai.tools.crypto.Deribit") as MockDeribit: + MockDeribit.return_value.__aenter__ = AsyncMock(return_value=mock_client) + MockDeribit.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await crypto_server.call_tool( + "crypto_historical_volatility", {"currency": "BTC"} + ) + assert "volatility" in text(result) + assert "2025-01-01" in text(result) + + +async def test_crypto_historical_volatility_empty(crypto_server: FastMCP) -> None: + mock_client = AsyncMock() + mock_client.get_volatility.return_value = pd.DataFrame() + + with patch("quantflow.ai.tools.crypto.Deribit") as MockDeribit: + MockDeribit.return_value.__aenter__ = AsyncMock(return_value=mock_client) + MockDeribit.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await crypto_server.call_tool( + "crypto_historical_volatility", {"currency": "BTC"} + ) + assert "No volatility data" in text(result) + + +async def test_crypto_term_structure(crypto_server: FastMCP, vol_surface: Any) -> None: + mock_loader = MagicMock() + mock_loader.surface.return_value = vol_surface + mock_client = AsyncMock() + mock_client.volatility_surface_loader.return_value = mock_loader + + with patch("quantflow.ai.tools.crypto.Deribit") as MockDeribit: + MockDeribit.return_value.__aenter__ = AsyncMock(return_value=mock_client) + MockDeribit.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await crypto_server.call_tool( + "crypto_term_structure", {"currency": "ETH"} + ) + assert "ttm" in text(result) + + +async def test_crypto_implied_volatility( + crypto_server: FastMCP, vol_surface: Any +) -> None: + mock_loader = MagicMock() + mock_loader.surface.return_value = vol_surface + mock_client = AsyncMock() + mock_client.volatility_surface_loader.return_value = mock_loader + + with patch("quantflow.ai.tools.crypto.Deribit") as MockDeribit: + MockDeribit.return_value.__aenter__ = AsyncMock(return_value=mock_client) + MockDeribit.return_value.__aexit__ = AsyncMock(return_value=False) + + result = await crypto_server.call_tool( + "crypto_implied_volatility", {"currency": "ETH"} + ) + assert "implied_vol" in text(result) + + +async def test_crypto_prices( + crypto_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.prices.return_value = pd.DataFrame( + [ + { + "date": "2025-01-01", + "open": 90000, + "high": 95000, + "low": 88000, + "close": 92000, + "volume": 500, + } + ] + ) + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await crypto_server.call_tool("crypto_prices", {"symbol": "BTCUSD"}) + assert "close" in text(result) + assert "2025-01-01" in text(result) + + +# --------------------------------------------------------------------------- +# FRED tools +# --------------------------------------------------------------------------- + + +async def test_fred_subcategories( + fred_server: FastMCP, mcp_tool: McpTool, mock_fred: AsyncMock +) -> None: + mcp_tool.vault.add("fred", "test-key") + mock_fred.subcategories.return_value = { + "categories": [{"id": "32991", "name": "Money, Banking, & Finance"}] + } + with patch("quantflow.ai.tools.base.Fred", return_value=mock_fred): + result = await fred_server.call_tool("fred_subcategories", {}) + assert "Money" in text(result) + + +async def test_fred_subcategories_empty( + fred_server: FastMCP, mcp_tool: McpTool, mock_fred: AsyncMock +) -> None: + mcp_tool.vault.add("fred", "test-key") + mock_fred.subcategories.return_value = {"categories": []} + with patch("quantflow.ai.tools.base.Fred", return_value=mock_fred): + result = await fred_server.call_tool("fred_subcategories", {}) + assert "No categories" in text(result) + + +async def test_fred_series( + fred_server: FastMCP, mcp_tool: McpTool, mock_fred: AsyncMock +) -> None: + mcp_tool.vault.add("fred", "test-key") + mock_fred.series.return_value = { + "seriess": [ + { + "id": "GDP", + "popularity": 90, + "title": "Gross Domestic Product", + "frequency": "Quarterly", + "observation_start": "1947-01-01", + "observation_end": "2025-01-01", + } + ] + } + with patch("quantflow.ai.tools.base.Fred", return_value=mock_fred): + result = await fred_server.call_tool("fred_series", {"category_id": "106"}) + assert "GDP" in text(result) + + +async def test_fred_series_empty( + fred_server: FastMCP, mcp_tool: McpTool, mock_fred: AsyncMock +) -> None: + mcp_tool.vault.add("fred", "test-key") + mock_fred.series.return_value = {"seriess": []} + with patch("quantflow.ai.tools.base.Fred", return_value=mock_fred): + result = await fred_server.call_tool("fred_series", {"category_id": "999"}) + assert "No series" in text(result) + + +async def test_fred_data( + fred_server: FastMCP, mcp_tool: McpTool, mock_fred: AsyncMock +) -> None: + mcp_tool.vault.add("fred", "test-key") + mock_fred.serie_data.return_value = pd.DataFrame( + [{"date": "2025-01-01", "value": 27000.0}] + ) + with patch("quantflow.ai.tools.base.Fred", return_value=mock_fred): + result = await fred_server.call_tool("fred_data", {"series_id": "GDP"}) + assert "value" in text(result) + assert "2025-01-01" in text(result) + + +# --------------------------------------------------------------------------- +# Charts tools +# --------------------------------------------------------------------------- + + +async def test_ascii_chart( + charts_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.prices.return_value = pd.DataFrame( + [ + { + "date": f"2025-01-{i:02d}", + "open": 100 + i, + "high": 110 + i, + "low": 90 + i, + "close": 105 + i, + "volume": 1000, + } + for i in range(1, 11) + ] + ) + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await charts_server.call_tool("ascii_chart", {"symbol": "AAPL"}) + t = text(result) + assert "AAPL" in t + assert "High" in t + assert "Low" in t + + +async def test_ascii_chart_empty( + charts_server: FastMCP, mcp_tool: McpTool, mock_fmp: AsyncMock +) -> None: + mcp_tool.vault.add("fmp", "test-key") + mock_fmp.prices.return_value = pd.DataFrame() + with patch("quantflow.ai.tools.base.FMP", return_value=mock_fmp): + result = await charts_server.call_tool("ascii_chart", {"symbol": "FAKE"}) + assert "No price data" in text(result) diff --git a/uv.lock b/uv.lock index d36c208a..c239657e 100644 --- a/uv.lock +++ b/uv.lock @@ -219,6 +219,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "asciichartpy" +version = "1.5.25" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/3a/b01436be647f881515ec2f253616bf4a40c1d27d02a69e7f038e27fcdf81/asciichartpy-1.5.25.tar.gz", hash = "sha256:63a305302b2aad51da288b58226009b7b0313eba7d8e2452d5a21a13fcf44d74", size = 8201, upload-time = "2020-08-17T02:07:18.292Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/d0/7b958df957e4827837b590944008f0b28078f552b451f7407b4b3d54f574/asciichartpy-1.5.25-py2.py3-none-any.whl", hash = "sha256:33c417a3c8ef7d0a11b98eb9ea6dd9b2c1b17559e539b207a17d26d4302d0258", size = 7228, upload-time = "2020-08-17T02:07:16.386Z" }, +] + [[package]] name = "async-timeout" version = "5.0.1" @@ -322,6 +334,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/60/fd/4129e0b64de10b40fedcf8d0587a1f7e5e5e4513051554476a2d6d00c25a/ccy-1.7.2-py3-none-any.whl", hash = "sha256:dd5fae95005e7b9918543e508c41af75107682ce81b540dea8825a6bf1b56402", size = 15341, upload-time = "2025-12-28T20:38:43.605Z" }, ] +[package.optional-dependencies] +holidays = [ + { name = "holidays" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -1070,6 +1087,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/7f/13cd798d180af4bf4c0ceddeefba2b864a63c71645abc0308b768d67bb81/hjson-3.1.0-py3-none-any.whl", hash = "sha256:65713cdcf13214fb554eb8b4ef803419733f4f5e551047c9b711098ab7186b89", size = 54018, upload-time = "2022-08-13T02:52:59.899Z" }, ] +[[package]] +name = "holidays" +version = "0.63" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b0/a9/5f62d56a59bbd3647872175829d881fbecb3f08fccd9bf751523da61e2e1/holidays-0.63.tar.gz", hash = "sha256:0e0fe872c9c4c18bbdf0ddf34990d99f077484ba21d28e14e7d1ad1643b72544", size = 603749, upload-time = "2024-12-16T21:01:02.811Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/9c/5235772fc9d2399f41401e6a054a26b4a993bd8a38e4ff849a6097a912a9/holidays-0.63-py3-none-any.whl", hash = "sha256:f834a86635e4593eb3d8c76c9b4207ca11b200767b53cdef3468b9fe71401412", size = 1181014, upload-time = "2024-12-16T21:00:58.575Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -3129,6 +3158,8 @@ dependencies = [ [package.optional-dependencies] ai = [ + { name = "asciichartpy" }, + { name = "ccy", extra = ["holidays"] }, { name = "google-genai" }, { name = "mcp" }, { name = "openai" }, @@ -3173,9 +3204,11 @@ ml = [ requires-dist = [ { name = "aio-fluid", extras = ["http"], marker = "extra == 'data'", specifier = ">=1.2.1" }, { name = "altair", marker = "extra == 'book'", specifier = ">=6.0.0" }, + { name = "asciichartpy", marker = "extra == 'ai'", specifier = ">=1.5.25" }, { name = "autodocsumm", marker = "extra == 'book'", specifier = ">=0.2.14" }, { name = "black", marker = "extra == 'dev'", specifier = ">=26.3.1" }, { name = "ccy", specifier = ">=1.7.1" }, + { name = "ccy", extras = ["holidays"], marker = "extra == 'ai'", specifier = ">=1.7.1" }, { name = "duckdb", marker = "extra == 'book'", specifier = ">=1.4.4" }, { name = "fastapi", marker = "extra == 'book'", specifier = ">=0.129.0" }, { name = "ghp-import", marker = "extra == 'dev'", specifier = ">=2.0.2" }, From d2b584a0f775b46c6f3e04d4deef6e764e5f1bf4 Mon Sep 17 00:00:00 2001 From: Luca Date: Sun, 22 Mar 2026 17:36:56 +0000 Subject: [PATCH 3/3] ai tests --- quantflow/options/surface.py | 16 +++++------- quantflow_tests/test_options.py | 46 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/quantflow/options/surface.py b/quantflow/options/surface.py index ccf45225..36542d03 100644 --- a/quantflow/options/surface.py +++ b/quantflow/options/surface.py @@ -15,7 +15,7 @@ from quantflow.utils import plot from quantflow.utils.dates import utcnow from quantflow.utils.interest_rates import rate_from_spot_and_forward -from quantflow.utils.numbers import ZERO, Number, sigfig, to_decimal +from quantflow.utils.numbers import ZERO, Number, sigfig, to_decimal, to_decimal_or_none from .bs import black_price, implied_black_volatility from .inputs import ( @@ -331,8 +331,12 @@ def inputs(self) -> OptionInput: strike=self.meta.strike, maturity=self.meta.maturity, option_type=self.meta.option_type, - iv_bid=to_decimal(self.bid.implied_vol), - iv_ask=to_decimal(self.ask.implied_vol), + iv_bid=to_decimal_or_none( + None if np.isnan(self.bid.implied_vol) else self.bid.implied_vol + ), + iv_ask=to_decimal_or_none( + None if np.isnan(self.ask.implied_vol) else self.ask.implied_vol + ), ) @@ -961,9 +965,3 @@ def surface_from_inputs(inputs: VolSurfaceInputs) -> VolSurface[DefaultVolSecuri for input in inputs.inputs: loader.add(input) return loader.surface(ref_date=inputs.ref_date) - - -def assert_same(a: Any, b: Any) -> Any: - if a != b: - raise ValueError(f"Values are not the same: {a} != {b}") - return a diff --git a/quantflow_tests/test_options.py b/quantflow_tests/test_options.py index 334a580e..8b354a26 100644 --- a/quantflow_tests/test_options.py +++ b/quantflow_tests/test_options.py @@ -6,6 +6,7 @@ from quantflow.options import bs from quantflow.options.calibration import HestonCalibration +from quantflow.options.inputs import OptionInput from quantflow.options.pricer import OptionPricer from quantflow.options.surface import ( OptionPrice, @@ -73,6 +74,51 @@ def test_vol_surface(vol_surface: VolSurface): assert len(options) == sum(len(cross) for cross in crosses) +def test_term_structure(vol_surface: VolSurface) -> None: + ts = vol_surface.term_structure() + assert len(ts) == len(vol_surface.maturities) + assert list(ts.columns) == [ + "maturity", + "ttm", + "forward", + "basis", + "rate_percent", + "open_interest", + "volume", + ] + assert (ts["ttm"] > 0).all() + assert ts["ttm"].is_monotonic_increasing + + +def test_trim(vol_surface: VolSurface) -> None: + n = len(vol_surface.maturities) + assert n > 2 + + trimmed = vol_surface.trim(2) + assert len(trimmed.maturities) == 2 + assert trimmed.maturities == vol_surface.maturities[-2:] + assert trimmed.spot == vol_surface.spot + assert trimmed.ref_date == vol_surface.ref_date + + +def test_trim_full(vol_surface: VolSurface) -> None: + n = len(vol_surface.maturities) + trimmed = vol_surface.trim(n) + assert trimmed == vol_surface + + +def test_inputs_implied_vols(vol_surface: VolSurface) -> None: + vol_surface.bs() + inputs = vol_surface.inputs() + option_inputs = [i for i in inputs.inputs if isinstance(i, OptionInput)] + assert option_inputs + assert all(i.iv_bid is not None or i.iv_ask is not None for i in option_inputs) + converged = [ + i for i in option_inputs if i.iv_bid is not None and i.iv_ask is not None + ] + assert converged + + def test_same_vol_surface(vol_surface: VolSurface): inputs = vol_surface.inputs() vol_surface2 = surface_from_inputs(inputs)