diff --git a/.gitignore b/.gitignore index 41cfe11c..a4132301 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ TODO.md .cursor .taskmaster/ .cursorignore +docs/ diff --git a/.justfile b/.justfile index b41bfaa9..c80bac5c 100644 --- a/.justfile +++ b/.justfile @@ -4,7 +4,4 @@ lint: fmt: - cargo fmt --check - -clippy: - cargo clippy + cargo fmt diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..7ed07507 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,55 @@ +# rullm + +## Build Commands + +```bash +# Build all crates +cargo build --all + +# Lint (format + clippy) +just lint + +# Format only +just fmt + +# Check examples compile +cargo check --examples + +# Run tests +cargo test +``` + +## Project Structure + +This is a Rust workspace with two crates: + +- **rullm-core** (`crates/rullm-core/`) - Core library for LLM provider interactions +- **rullm-cli** (`crates/rullm-cli/`) - CLI binary for querying LLMs + +### Core Library Architecture + +The core library uses a trait-based provider system with two API levels: + +1. **Simple API** - String-based, minimal configuration +2. **Advanced API** - Full control with `ChatRequestBuilder` + +Key modules: +- `providers/` - Provider implementations (OpenAI, Anthropic, Google, OpenAI-compatible) +- `compat_types.rs` - OpenAI-compatible message/response types used across providers +- `config.rs` - Provider configuration traits and builders +- `error.rs` - `LlmError` enum with comprehensive error variants +- `utils/sse.rs` - Server-sent event parsing for streaming + +### CLI Architecture + +The CLI is organized by commands in `commands/`: +- `auth.rs` - OAuth and API key management +- `chat.rs` - Interactive chat mode with reedline +- `models.rs` - Model listing and updates +- `alias.rs` - User-defined model aliases +- `templates.rs` - TOML template management + +OAuth implementation in `oauth/`: +- `openai.rs`, `anthropic.rs` - Provider-specific OAuth flows +- `server.rs` - Local callback server for OAuth redirects +- `pkce.rs` - PKCE challenge generation diff --git a/CLAUDE.md b/CLAUDE.md index 763940cd..43c994c2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,165 +1 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -`rullm` is a Rust library and CLI for interacting with multiple LLM providers (OpenAI, Anthropic, Google AI, Groq, OpenRouter). The project uses a workspace structure with two main crates: - -- **rullm-core**: Core library implementing provider integrations, middleware (Tower-based), and streaming support -- **rullm-cli**: Command-line interface built on top of rullm-core - -## Architecture - -### Provider System - -All LLM providers implement two core traits defined in `crates/rullm-core/src/types.rs`: -- `LlmProvider`: Base trait with provider metadata (name, aliases, env_key, default_base_url, available_models, health_check) -- `ChatCompletion`: Extends LlmProvider with chat completion methods (blocking and streaming) - -Provider implementations are in `crates/rullm-core/src/providers/`: -- `openai.rs`: OpenAI GPT models -- `anthropic.rs`: Anthropic Claude models -- `google.rs`: Google Gemini models -- `openai_compatible.rs`: Generic provider for OpenAI-compatible APIs -- `groq.rs`: Groq provider (uses `openai_compatible`) -- `openrouter.rs`: OpenRouter provider (uses `openai_compatible`) - -The `openai_compatible` provider is a generic implementation that other providers like Groq and OpenRouter extend. It uses a `ProviderIdentity` struct to define provider-specific metadata. - -### Middleware Stack - -The library uses Tower middleware (see `crates/rullm-core/src/middleware.rs`): -- Rate limiting -- Timeouts -- Connection pooling -- Logging and metrics - -Configuration is done via `MiddlewareConfig` and `LlmServiceBuilder`. - -### Simple API - -`crates/rullm-core/src/simple.rs` provides a simplified string-based API (`SimpleLlmClient`, `SimpleLlmBuilder`) that wraps the advanced provider APIs for ease of use. - -### CLI Architecture - -The CLI entry point is `crates/rullm-cli/src/main.rs`, which: -1. Parses arguments using clap (see `args.rs`) -2. Loads configuration from `~/.config/rullm/` (see `config.rs`) -3. Dispatches to commands in `crates/rullm-cli/src/commands/` - -Key CLI modules: -- `client.rs`: Creates provider clients from model strings (format: `provider:model`) -- `provider.rs`: Resolves provider names and aliases -- `config.rs`: Manages CLI configuration (models list, aliases, default model) -- `api_keys.rs`: Manages API key storage in system keychain -- `templates.rs`: TOML-based prompt templates with `{{input}}` placeholders -- `commands/chat/`: Interactive chat mode using reedline for advanced REPL features - -### Model Format - -Models are specified using the format `provider:model`: -- Example: `openai:gpt-4`, `anthropic:claude-3-opus-20240229`, `groq:llama-3-8b` -- The CLI resolves this via `client::from_model()` which creates the appropriate provider client - -## Common Development Tasks - -### Building and Running - -```bash -# Build everything -cargo build --all - -# Build release binary -cargo build --release - -# Run the CLI (from workspace root) -cargo run -p rullm-cli -- "your query" - -# Or after building -./target/debug/rullm "your query" -./target/release/rullm "your query" -``` - -### Testing - -```bash -# Run all tests (note: some require API keys) -cargo test --all - -# Run tests for specific crate -cargo test -p rullm-core -cargo test -p rullm-cli - -# Run a specific test -cargo test test_name - -# Check examples compile -cargo check --examples -``` - -### Code Quality - -```bash -# Format code -cargo fmt - -# Check formatting -cargo fmt -- --check - -# Run clippy (linter) -cargo clippy --all-targets --all-features -- -D warnings - -# Fix clippy suggestions automatically -cargo clippy --fix --all-targets --all-features -``` - -### Running Examples - -```bash -# Run examples from rullm-core (requires API keys) -cargo run --example openai_simple -cargo run --example anthropic_simple -cargo run --example google_simple -cargo run --example openai_stream # Streaming example -cargo run --example test_all_providers # Test all providers at once -``` - -### Adding a New Provider - -When adding a new provider: - -1. **OpenAI-compatible providers**: Use `OpenAICompatibleProvider` with a `ProviderIdentity` in `providers/openai_compatible.rs`. See `groq.rs` or `openrouter.rs` for examples. - -2. **Non-compatible providers**: Create a new file in `crates/rullm-core/src/providers/`: - - Implement `LlmProvider` and `ChatCompletion` traits - - Add provider config struct in `crates/rullm-core/src/config.rs` - - Export from `providers/mod.rs` and `lib.rs` - - Add client creation logic in `crates/rullm-cli/src/client.rs` - - Update `crates/rullm-cli/src/provider.rs` for CLI support - -3. Update `DEFAULT_MODELS` in `crates/rullm-core/src/simple.rs` if adding default model mappings - -### Streaming Implementation - -All providers should implement `chat_completion_stream()` returning `StreamResult`. The stream emits: -- `ChatStreamEvent::Token(String)`: Each token/chunk -- `ChatStreamEvent::Done`: Completion marker -- `ChatStreamEvent::Error(String)`: Errors during streaming - -See provider implementations for SSE parsing patterns using `utils::sse::sse_lines()`. - -## Configuration Files - -- **User config**: `~/.config/rullm/config.toml` (or system equivalent) - - Stores: default model, model aliases, cached models list -- **Templates**: `~/.config/rullm/templates/*.toml` -- **API keys**: Stored in system keychain via `api_keys.rs` - -## Important Notes - -- The project uses Rust edition 2024 (rust-version 1.85+) -- Model separator changed from `/` to `:` (e.g., `openai:gpt-4` not `openai/gpt-4`) -- Chat history is persisted in `~/.config/rullm/chat_history/` -- The CLI uses `reedline` for advanced REPL features (syntax highlighting, history, multiline editing) -- In chat mode: Alt+Enter for multiline, Ctrl+O for buffer editing, `/edit` to open $EDITOR +@AGENTS.md diff --git a/Cargo.lock b/Cargo.lock index c91c6f95..ff8613cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,12 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "atty" version = "0.2.14" @@ -180,6 +186,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -195,6 +207,15 @@ dependencies = [ "serde", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -216,6 +237,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.1" @@ -234,7 +261,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -295,6 +322,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -305,12 +342,31 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crossterm" version = "0.28.1" @@ -337,6 +393,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -526,6 +602,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -566,7 +652,26 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.4.0", "indexmap", "slab", "tokio", @@ -595,6 +700,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "home" version = "0.5.11" @@ -615,6 +726,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -622,7 +743,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.4.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -648,9 +792,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -662,6 +806,44 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http 1.4.0", + "hyper 1.8.1", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -669,12 +851,54 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.32", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.8.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "hyper 1.8.1", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2 0.6.0", + "system-configuration", + "tokio", + "tower-service", + "tracing", + "windows-registry", +] + [[package]] name = "iana-time-zone" version = "0.1.63" @@ -833,6 +1057,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f867b9d1d896b67beb18518eda36fdb77a32ea590de864f1325b294a6d14397" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_executable" version = "1.0.4" @@ -863,6 +1097,28 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "js-sys" version = "0.3.77" @@ -988,6 +1244,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1016,6 +1278,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c2599ce0ec54857b29ce62166b0ed9b4f6f1a70ccc9a71165b6154caca8c05" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" +dependencies = [ + "bitflags 2.9.1", + "objc2", +] + [[package]] name = "object" version = "0.36.7" @@ -1215,8 +1502,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", ] [[package]] @@ -1226,7 +1523,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -1238,6 +1545,15 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.3", +] + [[package]] name = "redox_syscall" version = "0.5.17" @@ -1317,16 +1633,16 @@ version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", - "hyper-tls", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-tls 0.5.0", "ipnet", "js-sys", "log", @@ -1339,7 +1655,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "system-configuration", "tokio", "tokio-native-tls", @@ -1353,22 +1669,81 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.8.1", + "hyper-rustls", + "hyper-tls 0.6.0", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tokio-native-tls", + "tower 0.5.2", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rullm-cli" version = "0.1.0" dependencies = [ "anyhow", "atty", + "base64 0.22.1", "chrono", "clap", "clap_complete", "etcetera", "futures", + "hex", "owo-colors", + "rand 0.9.2", "reedline", + "reqwest 0.12.24", "rullm-core", "serde", "serde_json", + "sha2", "strum 0.27.2", "strum_macros 0.27.2", "tempfile", @@ -1376,6 +1751,8 @@ dependencies = [ "toml", "tracing", "tracing-subscriber", + "urlencoding", + "webbrowser", ] [[package]] @@ -1390,8 +1767,8 @@ dependencies = [ "log", "metrics", "once_cell", - "rand", - "reqwest", + "rand 0.8.5", + "reqwest 0.11.27", "serde", "serde_json", "strum 0.27.2", @@ -1401,7 +1778,7 @@ dependencies = [ "tokio", "tokio-test", "toml", - "tower", + "tower 0.4.13", "tower-service", "tracing-subscriber", ] @@ -1438,13 +1815,46 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "rustls" +version = "0.23.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", +] + +[[package]] +name = "rustls-pki-types" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", ] [[package]] @@ -1459,6 +1869,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.27" @@ -1481,7 +1900,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.9.1", - "core-foundation", + "core-foundation 0.9.4", "core-foundation-sys", "libc", "security-framework-sys", @@ -1550,6 +1969,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1685,6 +2115,12 @@ dependencies = [ "syn", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.104" @@ -1702,6 +2138,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + [[package]] name = "synstructure" version = "0.13.2" @@ -1720,7 +2165,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -1847,6 +2292,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -1942,6 +2397,39 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.2", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" +dependencies = [ + "bitflags 2.9.1", + "bytes", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "iri-string", + "pin-project-lite", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2023,6 +2511,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -2041,6 +2535,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.4" @@ -2052,6 +2552,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -2091,6 +2597,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2209,6 +2725,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f1243ef785213e3a32fa0396093424a3a6ea566f9948497e5a2309261a4c97" +dependencies = [ + "core-foundation 0.10.1", + "jni", + "log", + "ndk-context", + "objc2", + "objc2-foundation", + "url", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2225,6 +2757,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2239,9 +2780,9 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", - "windows-result", - "windows-strings", + "windows-link 0.1.3", + "windows-result 0.3.4", + "windows-strings 0.4.2", ] [[package]] @@ -2272,13 +2813,39 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link 0.2.1", + "windows-result 0.4.1", + "windows-strings 0.5.1", +] + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link 0.2.1", ] [[package]] @@ -2287,7 +2854,25 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link 0.2.1", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", ] [[package]] @@ -2326,6 +2911,21 @@ dependencies = [ "windows-targets 0.53.3", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -2363,7 +2963,7 @@ version = "0.53.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" dependencies = [ - "windows-link", + "windows-link 0.1.3", "windows_aarch64_gnullvm 0.53.0", "windows_aarch64_msvc 0.53.0", "windows_i686_gnu 0.53.0", @@ -2374,6 +2974,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -2392,6 +2998,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -2410,6 +3022,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -2440,6 +3058,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -2458,6 +3082,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -2476,6 +3106,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -2494,6 +3130,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -2611,6 +3253,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.2" diff --git a/crates/rullm-cli/Cargo.toml b/crates/rullm-cli/Cargo.toml index a3606882..96fdf084 100644 --- a/crates/rullm-cli/Cargo.toml +++ b/crates/rullm-cli/Cargo.toml @@ -32,6 +32,15 @@ chrono.workspace = true reedline.workspace = true tempfile.workspace = true +# OAuth dependencies +sha2 = "0.10" +rand = "0.9" +base64 = "0.22" +webbrowser = "1.0" +reqwest = { version = "0.12", features = ["json"] } +urlencoding = "2.1" +hex = "0.4" + [dev-dependencies] tempfile.workspace = true diff --git a/crates/rullm-cli/src/api_keys.rs b/crates/rullm-cli/src/api_keys.rs deleted file mode 100644 index f368bdeb..00000000 --- a/crates/rullm-cli/src/api_keys.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::provider::Provider; -use rullm_core::error::LlmError; -use serde::{Deserialize, Serialize}; -use std::path::Path; - -#[derive(Default, Deserialize, Serialize, Debug, Clone)] -pub struct ApiKeys { - pub openai_api_key: Option, - pub groq_api_key: Option, - pub openrouter_api_key: Option, - pub anthropic_api_key: Option, - pub google_ai_api_key: Option, -} - -impl ApiKeys { - /// Load API keys from a TOML file - pub fn load_from_file>(path: P) -> Result { - let path = path.as_ref(); - - if !path.exists() { - return Ok(Self::default()); - } - - let content = std::fs::read_to_string(path) - .map_err(|e| LlmError::validation(format!("Failed to read API keys config: {e}")))?; - - // Handle empty files gracefully - if content.trim().is_empty() { - return Ok(Self::default()); - } - - toml::from_str(&content) - .map_err(|e| LlmError::validation(format!("Failed to parse API keys config: {e}"))) - } - - /// Save API keys to a TOML file - pub fn save_to_file>(&self, path: P) -> Result<(), LlmError> { - let path = path.as_ref(); - - // Create directory if it doesn't exist - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - LlmError::validation(format!("Failed to create data directory: {e}")) - })?; - } - - let content = toml::to_string_pretty(self).map_err(|e| { - LlmError::validation(format!("Failed to serialize API keys config: {e}")) - })?; - - std::fs::write(path, content) - .map_err(|e| LlmError::validation(format!("Failed to write API keys config: {e}"))) - } - - pub fn get_api_key(provider: &Provider, api_keys: &ApiKeys) -> Option { - let key = match provider { - Provider::OpenAI => api_keys.openai_api_key.as_ref(), - Provider::Groq => api_keys.groq_api_key.as_ref(), - Provider::OpenRouter => api_keys.openrouter_api_key.as_ref(), - Provider::Anthropic => api_keys.anthropic_api_key.as_ref(), - Provider::Google => api_keys.google_ai_api_key.as_ref(), - }; - - key.cloned() - .or_else(|| std::env::var(provider.env_key()).ok()) - } - - pub fn set_api_key_for_provider(provider: &Provider, api_keys: &mut ApiKeys, key: &str) { - match provider { - Provider::OpenAI => api_keys.openai_api_key = Some(key.to_string()), - Provider::Groq => api_keys.groq_api_key = Some(key.to_string()), - Provider::OpenRouter => api_keys.openrouter_api_key = Some(key.to_string()), - Provider::Anthropic => api_keys.anthropic_api_key = Some(key.to_string()), - Provider::Google => api_keys.google_ai_api_key = Some(key.to_string()), - } - } - - pub fn delete_api_key_for_provider(provider: &Provider, api_keys: &mut ApiKeys) { - match provider { - Provider::OpenAI => api_keys.openai_api_key = None, - Provider::Groq => api_keys.groq_api_key = None, - Provider::OpenRouter => api_keys.openrouter_api_key = None, - Provider::Anthropic => api_keys.anthropic_api_key = None, - Provider::Google => api_keys.google_ai_api_key = None, - } - } -} diff --git a/crates/rullm-cli/src/args.rs b/crates/rullm-cli/src/args.rs index 039ab818..87dfa3c4 100644 --- a/crates/rullm-cli/src/args.rs +++ b/crates/rullm-cli/src/args.rs @@ -6,11 +6,11 @@ use clap_complete::CompletionCandidate; use clap_complete::engine::ArgValueCompleter; use std::ffi::OsStr; -use crate::api_keys::ApiKeys; +use crate::auth::AuthConfig; use crate::commands::models::load_models_cache; use crate::commands::{Commands, ModelsCache}; use crate::config::{self, Config}; -use crate::constants::{BINARY_NAME, KEYS_CONFIG_FILE}; +use crate::constants::BINARY_NAME; use crate::templates::TemplateStore; // Example strings for after_long_help @@ -85,7 +85,7 @@ pub struct CliConfig { pub data_base_path: PathBuf, pub config: Config, pub models: Models, - pub api_keys: ApiKeys, + pub auth_config: AuthConfig, } impl CliConfig { @@ -97,22 +97,16 @@ impl CliConfig { let config = config::Config::load(&config_base_path).unwrap(); let models = Models::load(&data_base_path).unwrap(); - let api_keys = - ApiKeys::load_from_file(config_base_path.join(KEYS_CONFIG_FILE)).unwrap_or_default(); + let auth_config = AuthConfig::load(&config_base_path).unwrap_or_default(); Self { config_base_path, data_base_path, config, models, - api_keys, + auth_config, } } - - pub fn save_api_keys(&self) -> Result<(), rullm_core::error::LlmError> { - let keys_path = self.config_base_path.join(KEYS_CONFIG_FILE); - self.api_keys.save_to_file(&keys_path) - } } #[derive(Parser)] diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs new file mode 100644 index 00000000..fe78c7a1 --- /dev/null +++ b/crates/rullm-cli/src/auth.rs @@ -0,0 +1,407 @@ +//! Authentication credential management for rullm. +//! +//! Supports multiple authentication methods per provider: +//! - OAuth (for Claude Max/Pro subscriptions) +//! - API keys (traditional method) + +use crate::provider::Provider; +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// File name for auth credentials +pub const AUTH_CONFIG_FILE: &str = "auth.toml"; + +/// Buffer time (in ms) before token expiration to trigger refresh +const TOKEN_EXPIRY_BUFFER_MS: u64 = 5 * 60 * 1000; // 5 minutes + +/// A credential for authenticating with an LLM provider. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum Credential { + /// OAuth credential with access/refresh tokens + OAuth { + access_token: String, + refresh_token: String, + /// Unix timestamp in milliseconds when the access token expires + expires_at: u64, + }, + /// API key credential + Api { api_key: String }, +} + +impl Credential { + /// Create a new OAuth credential + pub fn oauth(access_token: String, refresh_token: String, expires_at: u64) -> Self { + Self::OAuth { + access_token, + refresh_token, + expires_at, + } + } + + /// Create a new API key credential + pub fn api(api_key: String) -> Self { + Self::Api { api_key } + } + + /// Get the access token or API key for use in requests + pub fn get_token(&self) -> &str { + match self { + Self::OAuth { access_token, .. } => access_token, + Self::Api { api_key } => api_key, + } + } + + /// Check if an OAuth token is expired or about to expire + pub fn is_expired(&self) -> bool { + match self { + Self::OAuth { expires_at, .. } => { + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + *expires_at <= now_ms + TOKEN_EXPIRY_BUFFER_MS + } + Self::Api { .. } => false, // API keys don't expire + } + } + + /// Get the refresh token if this is an OAuth credential + pub fn refresh_token(&self) -> Option<&str> { + match self { + Self::OAuth { refresh_token, .. } => Some(refresh_token), + Self::Api { .. } => None, + } + } + + /// Return a display string for the credential type + pub fn type_display(&self) -> &'static str { + match self { + Self::OAuth { .. } => "oauth", + Self::Api { .. } => "api", + } + } +} + +/// Authentication configuration containing credentials for all providers. +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct AuthConfig { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub anthropic: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub openai: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub groq: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub openrouter: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub google: Option, +} + +impl AuthConfig { + /// Load auth config from the default location + pub fn load(config_base_path: &Path) -> Result { + let path = config_base_path.join(AUTH_CONFIG_FILE); + Self::load_from_file(&path) + } + + /// Load auth config from a specific file path + pub fn load_from_file(path: &Path) -> Result { + if !path.exists() { + return Ok(Self::default()); + } + + let content = fs::read_to_string(path) + .with_context(|| format!("Failed to read auth config from {}", path.display()))?; + + if content.trim().is_empty() { + return Ok(Self::default()); + } + + toml::from_str(&content) + .with_context(|| format!("Failed to parse auth config from {}", path.display())) + } + + /// Save auth config to the default location + pub fn save(&self, config_base_path: &Path) -> Result<()> { + let path = config_base_path.join(AUTH_CONFIG_FILE); + self.save_to_file(&path) + } + + /// Save auth config to a specific file path + pub fn save_to_file(&self, path: &Path) -> Result<()> { + // Create parent directory if needed + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + let content = + toml::to_string_pretty(self).with_context(|| "Failed to serialize auth config")?; + + fs::write(path, &content) + .with_context(|| format!("Failed to write auth config to {}", path.display()))?; + + // Set file permissions to 0600 (owner read/write only) on Unix + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + fs::set_permissions(path, perms) + .with_context(|| format!("Failed to set permissions on {}", path.display()))?; + } + + Ok(()) + } + + /// Get credential for a provider from config + pub fn get(&self, provider: &Provider) -> Option<&Credential> { + match provider { + Provider::Anthropic => self.anthropic.as_ref(), + Provider::OpenAI => self.openai.as_ref(), + Provider::Groq => self.groq.as_ref(), + Provider::OpenRouter => self.openrouter.as_ref(), + Provider::Google => self.google.as_ref(), + } + } + + /// Get mutable credential for a provider + pub fn get_mut(&mut self, provider: &Provider) -> &mut Option { + match provider { + Provider::Anthropic => &mut self.anthropic, + Provider::OpenAI => &mut self.openai, + Provider::Groq => &mut self.groq, + Provider::OpenRouter => &mut self.openrouter, + Provider::Google => &mut self.google, + } + } + + /// Set credential for a provider + pub fn set(&mut self, provider: &Provider, credential: Credential) { + *self.get_mut(provider) = Some(credential); + } + + /// Remove credential for a provider + pub fn remove(&mut self, provider: &Provider) { + *self.get_mut(provider) = None; + } +} + +/// Get the auth config file path +pub fn auth_config_path(config_base_path: &Path) -> PathBuf { + config_base_path.join(AUTH_CONFIG_FILE) +} + +/// Source of a credential (file or environment variable) +#[derive(Debug, Clone, PartialEq)] +pub enum CredentialSource { + File, + Environment(String), +} + +/// Result of credential lookup including the source +#[derive(Debug)] +pub struct CredentialInfo { + pub credential: Credential, + pub source: CredentialSource, +} + +/// Get credential for a provider, checking file first, then environment variable. +/// +/// File credentials take precedence over environment variables. +pub fn get_credential(provider: &Provider, auth_config: &AuthConfig) -> Option { + // Check file credentials first (higher precedence) + if let Some(cred) = auth_config.get(provider) { + return Some(CredentialInfo { + credential: cred.clone(), + source: CredentialSource::File, + }); + } + + // Fall back to environment variable + let env_key = provider.env_key(); + if let Ok(api_key) = std::env::var(env_key) { + return Some(CredentialInfo { + credential: Credential::api(api_key), + source: CredentialSource::Environment(env_key.to_string()), + }); + } + + None +} + +/// Get token and credential type for a provider. +/// +/// Returns (token, is_oauth) where is_oauth is true if the credential is OAuth. +/// This is useful when the caller needs to know the credential type to configure +/// different authentication headers. +pub async fn get_token_with_type( + provider: &Provider, + auth_config: &mut AuthConfig, + config_base_path: &Path, +) -> Result<(String, bool)> { + // Get credential info + let info = get_credential(provider, auth_config) + .ok_or_else(|| anyhow::anyhow!("No credential found for {}", provider))?; + + // If from environment, it's always an API key (not OAuth) + if matches!(info.source, CredentialSource::Environment(_)) { + return Ok((info.credential.get_token().to_string(), false)); + } + + // Check if OAuth token is expired and needs refresh + let credential = if info.credential.is_expired() { + if let Some(refresh_tok) = info.credential.refresh_token() { + eprintln!("OAuth token expired, refreshing..."); + match refresh_oauth_token(provider, refresh_tok).await { + Ok(new_credential) => { + auth_config.set(provider, new_credential.clone()); + auth_config.save(config_base_path)?; + eprintln!("Token refreshed successfully."); + new_credential + } + Err(e) => { + return Err(anyhow::anyhow!( + "OAuth token expired and refresh failed: {}. Please run 'rullm auth login {}'", + e, + provider + )); + } + } + } else { + info.credential + } + } else { + info.credential + }; + + let is_oauth = matches!(credential, Credential::OAuth { .. }); + Ok((credential.get_token().to_string(), is_oauth)) +} + +/// Refresh an OAuth token for a specific provider. +async fn refresh_oauth_token(provider: &Provider, refresh_token: &str) -> Result { + use crate::oauth::anthropic::AnthropicOAuth; + + match provider { + Provider::Anthropic => { + let oauth = AnthropicOAuth::new(); + oauth.refresh_token(refresh_token).await + } + _ => Err(anyhow::anyhow!( + "Provider {} does not support OAuth token refresh", + provider + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_credential_oauth() { + let cred = Credential::oauth( + "access".to_string(), + "refresh".to_string(), + u64::MAX, // Far future + ); + assert!(matches!(cred, Credential::OAuth { .. })); + assert_eq!(cred.get_token(), "access"); + assert_eq!(cred.refresh_token(), Some("refresh")); + assert!(!cred.is_expired()); + } + + #[test] + fn test_credential_api() { + let cred = Credential::api("sk-test-key".to_string()); + assert!(matches!(cred, Credential::Api { .. })); + assert_eq!(cred.get_token(), "sk-test-key"); + assert_eq!(cred.refresh_token(), None); + assert!(!cred.is_expired()); + } + + #[test] + fn test_credential_expired() { + let cred = Credential::oauth( + "access".to_string(), + "refresh".to_string(), + 0, // Long expired + ); + assert!(cred.is_expired()); + } + + #[test] + fn test_auth_config_serialization() { + let config = AuthConfig { + anthropic: Some(Credential::oauth( + "sk-ant-oat01-test".to_string(), + "sk-ant-ort01-test".to_string(), + 1764813330304, + )), + openai: Some(Credential::api("sk-proj-test".to_string())), + ..Default::default() + }; + + let toml_str = toml::to_string_pretty(&config).unwrap(); + + // Verify it can be deserialized back + let parsed: AuthConfig = toml::from_str(&toml_str).unwrap(); + assert!(parsed.anthropic.is_some()); + assert!(parsed.openai.is_some()); + assert!(parsed.groq.is_none()); + } + + #[test] + fn test_auth_config_save_load() { + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path(); + + let config = AuthConfig { + groq: Some(Credential::api("test-groq-key".to_string())), + ..Default::default() + }; + + config.save(config_path).unwrap(); + + let loaded = AuthConfig::load(config_path).unwrap(); + assert_eq!( + loaded.groq.as_ref().map(|c| c.get_token()), + Some("test-groq-key") + ); + } + + #[test] + fn test_get_credential_file_precedence() { + let config = AuthConfig { + anthropic: Some(Credential::api("file-key".to_string())), + ..Default::default() + }; + + // File credential should be returned + let info = get_credential(&Provider::Anthropic, &config).unwrap(); + assert_eq!(info.source, CredentialSource::File); + assert_eq!(info.credential.get_token(), "file-key"); + } + + #[tokio::test] + async fn test_get_token_with_type_api_is_not_oauth() { + let temp_dir = TempDir::new().unwrap(); + let mut config = AuthConfig { + openai: Some(Credential::api("sk-test".to_string())), + ..Default::default() + }; + + let (token, is_oauth) = + get_token_with_type(&Provider::OpenAI, &mut config, temp_dir.path()) + .await + .unwrap(); + + assert_eq!(token, "sk-test"); + assert!(!is_oauth); + } +} diff --git a/crates/rullm-cli/src/cli_client.rs b/crates/rullm-cli/src/cli_client.rs index d221be52..743105fb 100644 --- a/crates/rullm-cli/src/cli_client.rs +++ b/crates/rullm-cli/src/cli_client.rs @@ -4,12 +4,38 @@ //! basic chat operations without exposing the full complexity of each provider's API. use futures::StreamExt; -use rullm_core::config::{AnthropicConfig, GoogleAiConfig, OpenAICompatibleConfig, OpenAIConfig}; use rullm_core::error::LlmError; -use rullm_core::providers::openai_compatible::{OpenAICompatibleProvider, identities}; +use rullm_core::providers::anthropic::AnthropicConfig; +use rullm_core::providers::google::GoogleAiConfig; +use rullm_core::providers::openai_compatible::{ + OpenAICompatibleConfig, OpenAICompatibleProvider, OpenAIConfig, identities, +}; use rullm_core::providers::{AnthropicClient, GoogleClient, OpenAIClient}; use std::pin::Pin; +/// Claude Code identification text for OAuth requests +const CLAUDE_CODE_SPOOF_TEXT: &str = "You are Claude Code, Anthropic's official CLI for Claude."; + +/// Prepend Claude Code system block to an existing system prompt (for OAuth requests) +fn prepend_claude_code_system( + existing: Option, +) -> rullm_core::providers::anthropic::SystemPrompt { + use rullm_core::providers::anthropic::{SystemBlock, SystemPrompt}; + + let spoof_block = SystemBlock::text_with_cache(CLAUDE_CODE_SPOOF_TEXT); + + match existing { + None => SystemPrompt::Blocks(vec![spoof_block]), + Some(SystemPrompt::Text(text)) => { + SystemPrompt::Blocks(vec![spoof_block, SystemBlock::text(text)]) + } + Some(SystemPrompt::Blocks(mut blocks)) => { + blocks.insert(0, spoof_block); + SystemPrompt::Blocks(blocks) + } + } +} + /// Simple configuration for CLI adapter #[derive(Debug, Clone, Default)] pub struct CliConfig { @@ -28,6 +54,7 @@ pub enum CliClient { client: AnthropicClient, model: String, config: CliConfig, + is_oauth: bool, }, Google { client: GoogleClient, @@ -67,13 +94,15 @@ impl CliClient { api_key: impl Into, model: impl Into, config: CliConfig, + use_oauth: bool, ) -> Result { - let client_config = AnthropicConfig::new(api_key); + let client_config = AnthropicConfig::new(api_key).with_oauth(use_oauth); let client = AnthropicClient::new(client_config)?; Ok(Self::Anthropic { client, model: model.into(), config, + is_oauth: use_oauth, }) } @@ -159,6 +188,7 @@ impl CliClient { client, model, config, + is_oauth, } => { use rullm_core::providers::anthropic::{Message, MessagesRequest}; @@ -170,6 +200,10 @@ impl CliClient { request.temperature = Some(temp); } + if *is_oauth { + request.system = Some(prepend_claude_code_system(request.system.take())); + } + let response = client.messages(request).await?; let content = response .content @@ -315,6 +349,7 @@ impl CliClient { client, model, config, + is_oauth, } => { use rullm_core::providers::anthropic::{Message, MessagesRequest}; @@ -335,6 +370,10 @@ impl CliClient { request.temperature = Some(temp); } + if *is_oauth { + request.system = Some(prepend_claude_code_system(request.system.take())); + } + let stream = client.messages_stream(request).await?; Ok(Box::pin(stream.filter_map(|event_result| async move { match event_result { @@ -447,6 +486,18 @@ impl CliClient { } } + /// Get available models for the provider + pub async fn available_models(&self) -> Result, LlmError> { + match self { + Self::OpenAI { client, .. } => client.list_models().await, + Self::Anthropic { client, .. } => client.list_models().await, + Self::Google { client, .. } => client.list_models().await, + Self::Groq { client, .. } | Self::OpenRouter { client, .. } => { + client.available_models().await + } + } + } + /// Get provider name pub fn provider_name(&self) -> &'static str { match self { diff --git a/crates/rullm-cli/src/client.rs b/crates/rullm-cli/src/client.rs index 69ce9504..98c7e2a5 100644 --- a/crates/rullm-cli/src/client.rs +++ b/crates/rullm-cli/src/client.rs @@ -1,8 +1,7 @@ use super::provider::Provider; -use crate::api_keys::ApiKeys; use crate::args::{Cli, CliConfig}; +use crate::auth; use crate::cli_client::{CliClient, CliConfig as CoreCliConfig}; -use crate::constants; use anyhow::{Context, Result}; use rullm_core::LlmError; @@ -13,6 +12,7 @@ pub fn create_client( _base_url: Option<&str>, cli: &Cli, model_name: &str, + is_oauth: bool, ) -> Result { // Build CoreCliConfig based on CLI args let mut config = CoreCliConfig::default(); @@ -40,30 +40,47 @@ pub fn create_client( Provider::OpenAI => CliClient::openai(api_key, model_name, config), Provider::Groq => CliClient::groq(api_key, model_name, config), Provider::OpenRouter => CliClient::openrouter(api_key, model_name, config), - Provider::Anthropic => CliClient::anthropic(api_key, model_name, config), + Provider::Anthropic => CliClient::anthropic(api_key, model_name, config, is_oauth), Provider::Google => CliClient::google(api_key, model_name, config), } } -/// Create a CliClient from a model string, CLI arguments, and configuration -/// This is the promoted version of the create_client_from_model closure from lib.rs -pub fn from_model(model_str: &str, cli: &Cli, cli_config: &CliConfig) -> Result { +/// Create a CliClient from a model string, CLI arguments, and configuration. +/// +/// This function handles OAuth token refresh automatically if the token is expired. +/// The `cli_config` is mutable because refreshing a token requires saving the new credential. +pub async fn from_model( + model_str: &str, + cli: &Cli, + cli_config: &mut CliConfig, +) -> Result { // Use the global alias resolver for CLI functionality - let resolver = crate::aliases::get_global_alias_resolver(&cli_config.config_base_path); - let resolver = resolver - .read() - .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on global resolver"))?; - let (provider, model_name) = resolver - .resolve(model_str) - .context("Invalid model format")?; + // Resolve provider and model inside a block so the lock is dropped before the await + let (provider, model_name) = { + let resolver = crate::aliases::get_global_alias_resolver(&cli_config.config_base_path); + let resolver = resolver + .read() + .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on global resolver"))?; + resolver + .resolve(model_str) + .context("Invalid model format")? + }; - let api_key = ApiKeys::get_api_key(&provider, &cli_config.api_keys).ok_or_else(|| { + // Get token with automatic refresh for OAuth, including credential type + let (token, is_oauth) = auth::get_token_with_type( + &provider, + &mut cli_config.auth_config, + &cli_config.config_base_path, + ) + .await + .map_err(|e| { anyhow::anyhow!( - "API key required. Set {} environment variable or add it to {} in config directory", - provider.env_key(), - constants::CONFIG_FILE_NAME + "{}. Run 'rullm auth login {}' or set {} environment variable", + e, + provider, + provider.env_key() ) })?; - create_client(&provider, &api_key, None, cli, &model_name).map_err(anyhow::Error::from) + create_client(&provider, &token, None, cli, &model_name, is_oauth).map_err(anyhow::Error::from) } diff --git a/crates/rullm-cli/src/commands/auth.rs b/crates/rullm-cli/src/commands/auth.rs new file mode 100644 index 00000000..377385e1 --- /dev/null +++ b/crates/rullm-cli/src/commands/auth.rs @@ -0,0 +1,313 @@ +//! Auth command handlers for managing credentials. + +use anyhow::Result; +use clap::{Args, Subcommand}; +use etcetera::BaseStrategy; +use strum::IntoEnumIterator; + +use crate::auth::{self, AuthConfig, Credential}; +use crate::oauth::anthropic::AnthropicOAuth; +use crate::output::OutputLevel; +use crate::provider::Provider; + +#[derive(Args)] +pub struct AuthArgs { + #[command(subcommand)] + pub action: AuthAction, +} + +#[derive(Subcommand)] +pub enum AuthAction { + /// Login to a provider (OAuth or API key) + Login { + /// Provider name (anthropic, openai, groq, openrouter, google) + provider: Option, + }, + /// Logout from a provider (remove stored credentials) + Logout { + /// Provider name (anthropic, openai, groq, openrouter, google) + provider: Option, + }, + /// List all credentials and environment variables + #[command(alias = "ls")] + List, +} + +/// Authentication method selection. +#[derive(Debug, Clone, Copy)] +pub enum AuthMethod { + OAuth, + ApiKey, +} + +impl AuthArgs { + pub async fn run( + &self, + output_level: OutputLevel, + config_base_path: &std::path::Path, + ) -> Result<()> { + match &self.action { + AuthAction::Login { provider } => { + let provider = match provider { + Some(p) => p.clone(), + None => select_provider()?, + }; + + // Determine available auth methods for the provider + let method = select_auth_method(&provider)?; + + match method { + AuthMethod::OAuth => { + let credential = match provider { + Provider::Anthropic => { + let oauth = AnthropicOAuth::new(); + oauth.login().await? + } + Provider::OpenAI => { + anyhow::bail!( + "OpenAI OAuth login is not implemented yet. Use API key instead." + ); + } + _ => { + anyhow::bail!( + "OAuth is not supported for {}. Use API key instead.", + provider + ); + } + }; + + // Save the credential + let mut auth_config = AuthConfig::load(config_base_path)?; + auth_config.set(&provider, credential); + auth_config.save(config_base_path)?; + + crate::output::success( + &format!("Successfully logged in to {provider}"), + output_level, + ); + } + AuthMethod::ApiKey => { + let api_key = prompt_api_key(&provider)?; + + if api_key.is_empty() { + anyhow::bail!("API key cannot be empty"); + } + + let mut auth_config = AuthConfig::load(config_base_path)?; + auth_config.set(&provider, Credential::api(api_key)); + auth_config.save(config_base_path)?; + + crate::output::success( + &format!("API key for {provider} has been saved"), + output_level, + ); + } + } + } + + AuthAction::Logout { provider } => { + let provider = match provider { + Some(p) => p.clone(), + None => select_provider()?, + }; + + let mut auth_config = AuthConfig::load(config_base_path)?; + auth_config.remove(&provider); + auth_config.save(config_base_path)?; + + crate::output::success(&format!("Logged out from {provider}"), output_level); + } + + AuthAction::List => { + let auth_config = AuthConfig::load(config_base_path)?; + print_credentials_list(&auth_config, output_level); + } + } + + Ok(()) + } +} + +/// Select a provider interactively. +fn select_provider() -> Result { + use std::io::{self, Write}; + + println!("\n? Select provider"); + let providers: Vec = Provider::iter().collect(); + + for (i, provider) in providers.iter().enumerate() { + println!(" {}) {}", i + 1, format_provider_display(provider)); + } + + print!("\nEnter number (1-{}): ", providers.len()); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + let choice: usize = input + .trim() + .parse() + .map_err(|_| anyhow::anyhow!("Invalid selection"))?; + + if choice == 0 || choice > providers.len() { + anyhow::bail!("Invalid selection"); + } + + Ok(providers[choice - 1].clone()) +} + +/// Select authentication method for a provider. +fn select_auth_method(provider: &Provider) -> Result { + use std::io::{self, Write}; + + match provider { + Provider::Anthropic => { + println!("\n? Select authentication method"); + println!(" 1) OAuth (subscription-based access)"); + println!(" 2) API Key"); + + print!("\nEnter number (1-2): "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + match input.trim() { + "1" => Ok(AuthMethod::OAuth), + "2" => Ok(AuthMethod::ApiKey), + _ => anyhow::bail!("Invalid selection"), + } + } + Provider::OpenAI => { + println!("\n? Select authentication method"); + println!(" 1) OAuth (not implemented)"); + println!(" 2) API Key"); + + print!("\nEnter number (1-2): "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + match input.trim() { + "1" => { + anyhow::bail!("OpenAI OAuth login is not implemented yet. Use API key instead.") + } + "2" => Ok(AuthMethod::ApiKey), + _ => anyhow::bail!("Invalid selection"), + } + } + _ => Ok(AuthMethod::ApiKey), + } +} + +/// Prompt for API key input. +fn prompt_api_key(provider: &Provider) -> Result { + use std::io::{self, Write}; + + print!("Enter API key for {provider}: "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + Ok(input.trim().to_string()) +} + +/// Format provider name for display. +fn format_provider_display(provider: &Provider) -> &'static str { + match provider { + Provider::Anthropic => "Anthropic", + Provider::OpenAI => "OpenAI", + Provider::Groq => "Groq", + Provider::OpenRouter => "OpenRouter", + Provider::Google => "Google", + } +} + +/// Print the credentials list in a nice format. +fn print_credentials_list(auth_config: &AuthConfig, output_level: OutputLevel) { + let mut file_creds: Vec<(Provider, String)> = Vec::new(); + let mut env_creds: Vec<(Provider, String)> = Vec::new(); + + for provider in Provider::iter() { + // Check file credentials + if let Some(cred) = auth_config.get(&provider) { + file_creds.push((provider, cred.type_display().to_string())); + } else { + // Check environment variable + let env_key = provider.env_key(); + if std::env::var(env_key).is_ok() { + env_creds.push((provider, env_key.to_string())); + } + } + } + + // Print file credentials section + if !file_creds.is_empty() { + let base_strategy = etcetera::choose_base_strategy().unwrap(); + let config_path = auth::auth_config_path( + &base_strategy + .config_dir() + .join(crate::constants::BINARY_NAME), + ); + + // Display path with ~ for home directory + let display_path = config_path + .strip_prefix(base_strategy.home_dir()) + .map(|p| format!("~/{}", p.display())) + .unwrap_or_else(|_| config_path.display().to_string()); + + crate::output::heading(&format!("Credentials ({display_path}):"), output_level); + + for (provider, cred_type) in &file_creds { + crate::output::note( + &format!( + " {}: {}", + crate::output::format_provider(format_provider_display(provider)), + cred_type + ), + output_level, + ); + } + } + + // Print environment variables section + if !env_creds.is_empty() { + if !file_creds.is_empty() { + crate::output::note("", output_level); + } + crate::output::heading("Environment variables:", output_level); + + for (provider, env_key) in &env_creds { + crate::output::note( + &format!( + " {}: {}", + crate::output::format_provider(format_provider_display(provider)), + env_key + ), + output_level, + ); + } + } + + // Print summary + let total = file_creds.len() + env_creds.len(); + if total > 0 { + crate::output::note("", output_level); + crate::output::note( + &format!("{} credential(s) configured.", total), + output_level, + ); + } else { + crate::output::note("No credentials configured.", output_level); + crate::output::hint( + &format!( + "Run {} to add credentials", + crate::output::format_command("rullm auth login") + ), + output_level, + ); + } +} diff --git a/crates/rullm-cli/src/commands/chat/mod.rs b/crates/rullm-cli/src/commands/chat/mod.rs index 897b0115..3c4af395 100644 --- a/crates/rullm-cli/src/commands/chat/mod.rs +++ b/crates/rullm-cli/src/commands/chat/mod.rs @@ -35,11 +35,11 @@ impl ChatArgs { pub async fn run( &self, _output_level: OutputLevel, - cli_config: &CliConfig, + cli_config: &mut CliConfig, cli: &Cli, ) -> Result<()> { let model_str = resolve_model(&cli.model, &self.model, &cli_config.config.default_model)?; - let client = client::from_model(&model_str, cli, cli_config)?; + let client = client::from_model(&model_str, cli, cli_config).await?; run_interactive_chat(&client, None, cli_config, !cli.no_streaming).await?; Ok(()) } diff --git a/crates/rullm-cli/src/commands/info.rs b/crates/rullm-cli/src/commands/info.rs index a1cb4dcb..cb96f42d 100644 --- a/crates/rullm-cli/src/commands/info.rs +++ b/crates/rullm-cli/src/commands/info.rs @@ -3,6 +3,7 @@ use clap::Args; use crate::{ args::{Cli, CliConfig}, + auth, commands::env_var_status, constants::*, output::OutputLevel, @@ -22,7 +23,7 @@ impl InfoArgs { ) -> Result<()> { let config_path = cli_config.config_base_path.join(CONFIG_FILE_NAME); let models_path = cli_config.data_base_path.join(MODEL_FILE_NAME); - let keys_path = cli_config.config_base_path.join(KEYS_CONFIG_FILE); + let auth_path = auth::auth_config_path(&cli_config.config_base_path); let templates_path = cli_config.config_base_path.join(TEMPLATES_DIR_NAME); // crate::output::heading("Config files:", output_level); @@ -30,7 +31,7 @@ impl InfoArgs { &format!("config file: {}", config_path.display()), output_level, ); - crate::output::note(&format!("keys file: {}", keys_path.display()), output_level); + crate::output::note(&format!("auth file: {}", auth_path.display()), output_level); crate::output::note( &format!("models cache file: {}", models_path.display()), output_level, diff --git a/crates/rullm-cli/src/commands/keys.rs b/crates/rullm-cli/src/commands/keys.rs deleted file mode 100644 index 49696e2f..00000000 --- a/crates/rullm-cli/src/commands/keys.rs +++ /dev/null @@ -1,118 +0,0 @@ -use anyhow::Result; -use clap::{Args, Subcommand}; -use strum::IntoEnumIterator; - -use crate::provider::Provider; -use crate::{ - api_keys::ApiKeys, - args::{Cli, CliConfig}, - output::OutputLevel, -}; - -#[derive(Args)] -pub struct KeysArgs { - #[command(subcommand)] - pub action: KeysAction, -} - -#[derive(Subcommand)] -pub enum KeysAction { - /// Set an API key for a provider - Set { - /// Provider name (openai, anthropic, google) - provider: Provider, - /// API key (if not provided, will read from stdin) - #[arg(short, long)] - key: Option, - }, - /// Delete an API key for a provider - Delete { - /// Provider name (openai, anthropic, google) - provider: Provider, - }, - /// List which providers have API keys set - List, -} - -impl KeysArgs { - pub async fn run( - &self, - output_level: OutputLevel, - cli_config: &mut CliConfig, - _cli: &Cli, - ) -> Result<()> { - match &self.action { - KeysAction::Set { provider, key } => { - let api_key = if let Some(key) = key { - key.clone() - } else { - use std::io::{self, Write}; - print!("Enter API key for {provider}: "); - io::stdout().flush()?; - - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - input.trim().to_string() - }; - - if api_key.is_empty() { - return Err(anyhow::anyhow!("API key cannot be empty")); - } - - let api_keys = &mut cli_config.api_keys; - ApiKeys::set_api_key_for_provider(provider, api_keys, &api_key); - cli_config.save_api_keys()?; - - crate::output::success( - &format!("API key for {provider} has been saved"), - output_level, - ); - } - KeysAction::Delete { provider } => { - let api_keys = &mut cli_config.api_keys; - ApiKeys::delete_api_key_for_provider(provider, api_keys); - cli_config.save_api_keys()?; - - crate::output::success( - &format!("API key for {provider} has been deleted"), - output_level, - ); - } - KeysAction::List => { - let api_keys = cli_config.api_keys.clone(); - - for provider in Provider::iter() { - let has_cli_key = match provider { - Provider::OpenAI => api_keys.openai_api_key.is_some(), - Provider::Groq => api_keys.groq_api_key.is_some(), - Provider::OpenRouter => api_keys.openrouter_api_key.is_some(), - Provider::Anthropic => api_keys.anthropic_api_key.is_some(), - Provider::Google => api_keys.google_ai_api_key.is_some(), - }; - - let has_env_key = std::env::var(provider.env_key()).is_ok(); - - let source_info = if has_cli_key { - Some("cli".to_string()) - } else if has_env_key { - Some(format!("env ({})", provider.env_key())) - } else { - None - }; - - if let Some(source) = source_info { - crate::output::note( - &format!( - "{}: {}", - crate::output::format_provider(&provider.to_string()), - source - ), - output_level, - ); - } - } - } - } - Ok(()) - } -} diff --git a/crates/rullm-cli/src/commands/mod.rs b/crates/rullm-cli/src/commands/mod.rs index 5439a214..718f5453 100644 --- a/crates/rullm-cli/src/commands/mod.rs +++ b/crates/rullm-cli/src/commands/mod.rs @@ -12,20 +12,20 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; pub mod alias; +pub mod auth; pub mod chat; pub mod completions; pub mod info; pub mod templates; -pub mod keys; pub mod models; // Re-export the command args structs pub use alias::AliasArgs; +pub use auth::AuthArgs; pub use chat::ChatArgs; pub use completions::CompletionsArgs; pub use info::InfoArgs; -pub use keys::KeysArgs; pub use models::ModelsArgs; // Example strings for after_long_help @@ -41,11 +41,11 @@ const MODELS_EXAMPLES: &str = r#"EXAMPLES: rullm models default openai/gpt-4o # Set default model rullm models clear # Clear model cache"#; -const KEYS_EXAMPLES: &str = r#"EXAMPLES: - rullm keys set openai # Set OpenAI API key (prompted) - rullm keys set anthropic -k sk-ant-... # Set Anthropic key directly - rullm keys list # Show which providers have keys - rullm keys delete google # Remove Google API key"#; +const AUTH_EXAMPLES: &str = r#"EXAMPLES: + rullm auth login # Login interactively + rullm auth login anthropic # Login to Anthropic (OAuth or API key) + rullm auth logout openai # Logout from OpenAI + rullm auth list # Show all credentials"#; const ALIAS_EXAMPLES: &str = r#"EXAMPLES: rullm alias list # Show all available aliases @@ -72,9 +72,9 @@ pub enum Commands { /// Show configuration and system information #[command(after_long_help = INFO_EXAMPLES)] Info(InfoArgs), - /// Manage API keys - #[command(after_long_help = KEYS_EXAMPLES)] - Keys(KeysArgs), + /// Manage authentication credentials + #[command(after_long_help = AUTH_EXAMPLES)] + Auth(AuthArgs), /// Manage model aliases #[command(after_long_help = ALIAS_EXAMPLES)] Alias(AliasArgs), diff --git a/crates/rullm-cli/src/commands/models.rs b/crates/rullm-cli/src/commands/models.rs index cb0f0cba..f66439ab 100644 --- a/crates/rullm-cli/src/commands/models.rs +++ b/crates/rullm-cli/src/commands/models.rs @@ -76,7 +76,7 @@ impl ModelsArgs { let provider = format!("{provider}"); // Try to create a client for this provider let model_hint = format!("{provider}:dummy"); // dummy model name, just to get the client - let client = match client::from_model(&model_hint, cli, cli_config) { + let client = match client::from_model(&model_hint, cli, cli_config).await { Ok(c) => c, Err(_) => { skipped.push(provider); @@ -217,7 +217,7 @@ pub fn clear_models_cache(cli_config: &CliConfig, output_level: OutputLevel) -> } pub async fn update_models( - _cli_config: &mut CliConfig, + cli_config: &mut CliConfig, client: &CliClient, output_level: OutputLevel, ) -> Result<(), LlmError> { @@ -229,15 +229,36 @@ pub async fn update_models( output_level, ); - // TODO: Implement models() method on CliClient - // For now, just return an error - crate::output::error( - "Fetching models not yet implemented for new client architecture", + let mut models = client.available_models().await.map_err(|e| { + crate::output::error(&format!("Failed to fetch models: {e}"), output_level); + e + })?; + + if models.is_empty() { + crate::output::error("No models returned by provider", output_level); + return Err(LlmError::model( + "No models returned by provider".to_string(), + )); + } + + models.sort(); + models.dedup(); + + _cache_models(cli_config, client.provider_name(), &models).map_err(|e| { + crate::output::error(&format!("Failed to update models cache: {e}"), output_level); + LlmError::unknown(e.to_string()) + })?; + + crate::output::success( + &format!( + "Updated {} models for {}", + models.len(), + client.provider_name() + ), output_level, ); - Err(LlmError::unknown( - "Models fetching not yet implemented".to_string(), - )) + + Ok(()) } fn _cache_models(cli_config: &CliConfig, provider_name: &str, models: &[String]) -> Result<()> { diff --git a/crates/rullm-cli/src/constants.rs b/crates/rullm-cli/src/constants.rs index b0f35475..3b55f925 100644 --- a/crates/rullm-cli/src/constants.rs +++ b/crates/rullm-cli/src/constants.rs @@ -1,6 +1,5 @@ pub const CONFIG_FILE_NAME: &str = "config.toml"; pub const MODEL_FILE_NAME: &str = "models.json"; pub const ALIASES_CONFIG_FILE: &str = "aliases.toml"; -pub const KEYS_CONFIG_FILE: &str = "keys.toml"; pub const TEMPLATES_DIR_NAME: &str = "templates"; pub const BINARY_NAME: &str = env!("CARGO_BIN_NAME"); diff --git a/crates/rullm-cli/src/main.rs b/crates/rullm-cli/src/main.rs index e9d3f72e..3758ce4f 100644 --- a/crates/rullm-cli/src/main.rs +++ b/crates/rullm-cli/src/main.rs @@ -1,14 +1,15 @@ // Binary entry point for rullm-cli mod aliases; -mod api_keys; mod args; +mod auth; mod cli_client; mod cli_helpers; mod client; mod commands; mod config; mod constants; +mod oauth; mod output; mod provider; mod spinner; @@ -51,7 +52,7 @@ pub async fn run() -> Result<()> { if cli.model.is_some() { match &cli.command { Some(Commands::Info(_)) - | Some(Commands::Keys(_)) + | Some(Commands::Auth(_)) | Some(Commands::Alias(_)) | Some(Commands::Completions(_)) => { use clap::error::ErrorKind; @@ -81,10 +82,10 @@ pub async fn run() -> Result<()> { // Handle commands match &cli.command { - Some(Commands::Chat(args)) => args.run(output_level, &cli_config, &cli).await?, + Some(Commands::Chat(args)) => args.run(output_level, &mut cli_config, &cli).await?, Some(Commands::Models(args)) => args.run(output_level, &mut cli_config, &cli).await?, Some(Commands::Info(args)) => args.run(output_level, &cli_config, &cli).await?, - Some(Commands::Keys(args)) => args.run(output_level, &mut cli_config, &cli).await?, + Some(Commands::Auth(args)) => args.run(output_level, &cli_config.config_base_path).await?, Some(Commands::Alias(args)) => args.run(output_level, &cli_config, &cli).await?, Some(Commands::Completions(args)) => args.run(output_level, &cli_config, &cli).await?, Some(Commands::Templates(args)) => args.run(output_level, &cli_config, &cli).await?, @@ -92,7 +93,7 @@ pub async fn run() -> Result<()> { if let Some(query) = &cli.query { let model_str = resolve_direct_query_model(&cli.model, &cli_config.config.default_model)?; - let client = client::from_model(&model_str, &cli, &cli_config)?; + let client = client::from_model(&model_str, &cli, &mut cli_config).await?; // Handle template if provided let (system_prompt, final_query) = if let Some(template_name) = &cli.template { diff --git a/crates/rullm-cli/src/oauth/anthropic.rs b/crates/rullm-cli/src/oauth/anthropic.rs new file mode 100644 index 00000000..3f03d75f --- /dev/null +++ b/crates/rullm-cli/src/oauth/anthropic.rs @@ -0,0 +1,259 @@ +//! Anthropic OAuth flow implementation. +//! +//! Supports Claude Max/Pro subscription authentication. + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +use super::PkceChallenge; +use super::server::CallbackServer; +use crate::auth::Credential; + +/// Anthropic OAuth configuration. +pub struct AnthropicOAuth { + /// Authorization URL + pub authorization_url: &'static str, + /// Token URL + pub token_url: &'static str, + /// Client ID (Claude Code's public ID) + pub client_id: &'static str, + /// Required scopes + pub scopes: &'static [&'static str], + /// Local callback port + pub callback_port: u16, +} + +impl Default for AnthropicOAuth { + fn default() -> Self { + Self { + authorization_url: "https://claude.ai/oauth/authorize", + // The token endpoint lives on the console domain (not the public API) + // and requires the `/v1` prefix; posting to the API host returns 404. + token_url: "https://console.anthropic.com/v1/oauth/token", + client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", + scopes: &["org:create_api_key", "user:profile", "user:inference"], + callback_port: 8765, + } + } +} + +/// Token response from Anthropic OAuth. +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: String, + expires_in: u64, + #[allow(dead_code)] + token_type: String, +} + +/// Token refresh request body. +#[derive(Debug, Serialize)] +struct RefreshRequest<'a> { + grant_type: &'static str, + client_id: &'a str, + refresh_token: &'a str, +} + +/// Token exchange request body. +#[derive(Debug, Serialize)] +struct TokenRequest<'a> { + grant_type: &'static str, + client_id: &'a str, + code: &'a str, + redirect_uri: &'a str, + code_verifier: &'a str, + state: &'a str, +} + +impl AnthropicOAuth { + /// Create a new Anthropic OAuth handler with default configuration. + pub fn new() -> Self { + Self::default() + } + + /// Build the authorization URL for the OAuth flow. + fn build_authorization_url(&self, redirect_uri: &str, pkce: &PkceChallenge) -> String { + let scope = self.scopes.join(" "); + + format!( + "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}", + self.authorization_url, + urlencoding::encode(self.client_id), + urlencoding::encode(redirect_uri), + urlencoding::encode(&scope), + urlencoding::encode(&pkce.challenge), + pkce.method(), + urlencoding::encode(&pkce.verifier) // Use verifier as state (like OpenCode) + ) + } + + /// Start the OAuth flow and return the credential on success. + /// + /// This opens a browser for the user to authorize, then captures + /// the callback on a local server. + pub async fn login(&self) -> Result { + // Start local callback server + let server = + CallbackServer::new(self.callback_port).context("Failed to start callback server")?; + + let redirect_uri = server.redirect_uri(); + + // Generate PKCE challenge + let pkce = PkceChallenge::generate(); + + // Build authorization URL + let auth_url = self.build_authorization_url(&redirect_uri, &pkce); + + println!("Opening browser for Anthropic authentication..."); + + // Try to open browser, but don't fail if it doesn't work + if let Err(e) = webbrowser::open(&auth_url) { + println!("Could not open browser automatically: {}", e); + println!("Please open this URL manually:"); + println!("{}", auth_url); + } + + // Wait for the callback + println!("Waiting for authorization callback..."); + let callback = server + .wait_for_callback(Duration::from_secs(300)) + .context("Failed to receive OAuth callback")?; + + // Verify state matches (we use verifier as state) + if let Some(state) = &callback.state { + if state != &pkce.verifier { + anyhow::bail!("State mismatch in OAuth callback"); + } + } + + // Exchange code for tokens (state is required by Anthropic's token endpoint) + let credential = self + .exchange_code( + &callback.code, + &redirect_uri, + &pkce.verifier, + &pkce.verifier, + ) + .await?; + + println!("Authentication successful!"); + Ok(credential) + } + + /// Exchange authorization code for tokens. + async fn exchange_code( + &self, + code: &str, + redirect_uri: &str, + code_verifier: &str, + state: &str, + ) -> Result { + let request_body = TokenRequest { + grant_type: "authorization_code", + client_id: self.client_id, + code, + redirect_uri, + code_verifier, + state, + }; + + let client = reqwest::Client::new(); + let response = client + .post(self.token_url) + .json(&request_body) + .send() + .await + .context("Failed to send token request")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("Token exchange failed: {} - {}", status, body); + } + + let token_response: TokenResponse = response + .json() + .await + .context("Failed to parse token response")?; + + // Calculate expiration timestamp + let expires_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) + .unwrap_or(0); + + Ok(Credential::oauth( + token_response.access_token, + token_response.refresh_token, + expires_at, + )) + } + + /// Refresh an expired OAuth token. + pub async fn refresh_token(&self, refresh_token: &str) -> Result { + let request_body = RefreshRequest { + grant_type: "refresh_token", + client_id: self.client_id, + refresh_token, + }; + + let client = reqwest::Client::new(); + let response = client + .post(self.token_url) + .json(&request_body) + .send() + .await + .context("Failed to send refresh request")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("Token refresh failed: {} - {}", status, body); + } + + let token_response: TokenResponse = response + .json() + .await + .context("Failed to parse refresh response")?; + + let expires_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) + .unwrap_or(0); + + Ok(Credential::oauth( + token_response.access_token, + token_response.refresh_token, + expires_at, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_authorization_url() { + let oauth = AnthropicOAuth::new(); + let pkce = PkceChallenge::generate(); + + let url = oauth.build_authorization_url("http://localhost:8765/callback", &pkce); + + assert!(url.starts_with("https://claude.ai/oauth/authorize")); + assert!(url.contains("response_type=code")); + assert!(url.contains("client_id=")); + assert!(url.contains("redirect_uri=")); + assert!(url.contains("code_challenge=")); + assert!(url.contains("code_challenge_method=S256")); + assert!(url.contains("state=")); + } + + #[test] + fn test_default_config() { + let oauth = AnthropicOAuth::new(); + assert!(oauth.scopes.contains(&"user:inference")); + } +} diff --git a/crates/rullm-cli/src/oauth/mod.rs b/crates/rullm-cli/src/oauth/mod.rs new file mode 100644 index 00000000..e4115303 --- /dev/null +++ b/crates/rullm-cli/src/oauth/mod.rs @@ -0,0 +1,10 @@ +//! OAuth authentication module for rullm. +//! +//! Provides OAuth 2.0 authentication flows for supported providers. + +mod pkce; +mod server; + +pub mod anthropic; + +pub use pkce::PkceChallenge; diff --git a/crates/rullm-cli/src/oauth/pkce.rs b/crates/rullm-cli/src/oauth/pkce.rs new file mode 100644 index 00000000..7086212d --- /dev/null +++ b/crates/rullm-cli/src/oauth/pkce.rs @@ -0,0 +1,89 @@ +//! PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0. +//! +//! Implements RFC 7636 for secure authorization code flow. + +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use rand::RngCore; +use sha2::{Digest, Sha256}; + +/// PKCE challenge pair consisting of verifier and challenge. +#[derive(Debug, Clone)] +pub struct PkceChallenge { + /// The code verifier (sent with token request) + pub verifier: String, + /// The code challenge (sent with authorization request) + pub challenge: String, +} + +impl PkceChallenge { + /// Generate a new PKCE challenge pair. + /// + /// Creates a 64-byte random code verifier and derives the S256 challenge from it. + pub fn generate() -> Self { + // Generate 64 random bytes for the code verifier + let mut verifier_bytes = [0u8; 64]; + rand::rng().fill_bytes(&mut verifier_bytes); + + // Base64url encode the verifier (no padding) + let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); + + // Create code challenge: base64url(sha256(verifier)) + let mut hasher = Sha256::new(); + hasher.update(verifier.as_bytes()); + let hash = hasher.finalize(); + let challenge = URL_SAFE_NO_PAD.encode(hash); + + Self { + verifier, + challenge, + } + } + + /// Get the challenge method (always "S256"). + pub fn method(&self) -> &'static str { + "S256" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pkce_generation() { + let pkce = PkceChallenge::generate(); + + // Verifier should be base64url encoded 64 bytes = 86 chars + assert_eq!(pkce.verifier.len(), 86); + + // Challenge should be base64url encoded SHA256 = 43 chars + assert_eq!(pkce.challenge.len(), 43); + + // Method should be S256 + assert_eq!(pkce.method(), "S256"); + } + + #[test] + fn test_pkce_uniqueness() { + let pkce1 = PkceChallenge::generate(); + let pkce2 = PkceChallenge::generate(); + + // Each generation should produce unique values + assert_ne!(pkce1.verifier, pkce2.verifier); + assert_ne!(pkce1.challenge, pkce2.challenge); + } + + #[test] + fn test_challenge_derivation() { + // Verify that the challenge is correctly derived from verifier + let pkce = PkceChallenge::generate(); + + // Manually compute the expected challenge + let mut hasher = Sha256::new(); + hasher.update(pkce.verifier.as_bytes()); + let hash = hasher.finalize(); + let expected_challenge = URL_SAFE_NO_PAD.encode(hash); + + assert_eq!(pkce.challenge, expected_challenge); + } +} diff --git a/crates/rullm-cli/src/oauth/server.rs b/crates/rullm-cli/src/oauth/server.rs new file mode 100644 index 00000000..e93a6360 --- /dev/null +++ b/crates/rullm-cli/src/oauth/server.rs @@ -0,0 +1,221 @@ +//! Local HTTP server for OAuth callback handling. +//! +//! Starts a temporary server to receive the OAuth authorization code. + +use anyhow::{Context, Result, anyhow}; +use std::io::{Read, Write}; +use std::net::TcpListener; +use std::time::Duration; + +/// Result of waiting for an OAuth callback. +#[derive(Debug)] +pub struct CallbackResult { + /// The authorization code received + pub code: String, + /// The state parameter (for CSRF verification) + pub state: Option, +} + +/// Local callback server for OAuth flows. +pub struct CallbackServer { + listener: TcpListener, + port: u16, +} + +impl CallbackServer { + /// Create a new callback server on the specified port. + pub fn new(port: u16) -> Result { + let addr = format!("127.0.0.1:{port}"); + let listener = + TcpListener::bind(&addr).with_context(|| format!("Failed to bind to {addr}"))?; + + // Get the actual port if 0 was specified + let actual_port = listener.local_addr()?.port(); + + Ok(Self { + listener, + port: actual_port, + }) + } + + /// Get the redirect URI for this callback server. + pub fn redirect_uri(&self) -> String { + format!("http://localhost:{}/callback", self.port) + } + + /// Wait for the OAuth callback and extract the authorization code. + /// + /// This blocks until a request is received or the timeout is reached. + pub fn wait_for_callback(&self, timeout: Duration) -> Result { + // Set non-blocking mode on the listener and poll with timeout + self.listener + .set_nonblocking(true) + .context("Failed to set non-blocking mode")?; + + let start = std::time::Instant::now(); + let poll_interval = Duration::from_millis(100); + + loop { + match self.listener.accept() { + Ok((mut stream, _addr)) => { + // Set the stream back to blocking for read/write + stream.set_nonblocking(false).ok(); + stream.set_read_timeout(Some(Duration::from_secs(5))).ok(); + + // Read the HTTP request + let mut buffer = [0u8; 4096]; + let n = stream + .read(&mut buffer) + .context("Failed to read from connection")?; + + let request = String::from_utf8_lossy(&buffer[..n]); + + // Parse the request to extract code and state + let result = Self::parse_callback_request(&request)?; + + // Send a success response + let response_body = r#" + +Authentication Successful + +
+

Authentication successful!

+

You can close this window and return to the terminal.

+
+ +"#; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + response_body.len(), + response_body + ); + + stream + .write_all(response.as_bytes()) + .context("Failed to send response")?; + + stream.flush().ok(); + + return Ok(result); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // No connection yet, check timeout + if start.elapsed() >= timeout { + return Err(anyhow!("Timeout waiting for OAuth callback")); + } + std::thread::sleep(poll_interval); + } + Err(e) => { + return Err(e).context("Failed to accept connection"); + } + } + } + } + + /// Parse the callback request to extract code and state parameters. + fn parse_callback_request(request: &str) -> Result { + // Extract the request line (GET /callback?code=xxx&state=yyy HTTP/1.1) + let first_line = request + .lines() + .next() + .ok_or_else(|| anyhow!("Empty request"))?; + + // Extract the path with query string + let parts: Vec<&str> = first_line.split_whitespace().collect(); + if parts.len() < 2 { + return Err(anyhow!("Invalid HTTP request line")); + } + + let path = parts[1]; + + // Check for error response + if let Some(error) = Self::extract_query_param(path, "error") { + let description = Self::extract_query_param(path, "error_description") + .unwrap_or_else(|| "Unknown error".to_string()); + return Err(anyhow!("OAuth error: {} - {}", error, description)); + } + + // Extract the code + let code = Self::extract_query_param(path, "code") + .ok_or_else(|| anyhow!("No authorization code in callback"))?; + + // Extract state (optional) + let state = Self::extract_query_param(path, "state"); + + Ok(CallbackResult { code, state }) + } + + /// Extract a query parameter value from a URL path. + fn extract_query_param(path: &str, param: &str) -> Option { + let query = path.split('?').nth(1)?; + for pair in query.split('&') { + let (key, value) = pair.split_once('=')?; + if key == param { + // URL decode the value + return Some(urlencoding::decode(value).ok()?.into_owned()); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_callback_success() { + let request = "GET /callback?code=abc123&state=xyz789 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let result = CallbackServer::parse_callback_request(request).unwrap(); + assert_eq!(result.code, "abc123"); + assert_eq!(result.state, Some("xyz789".to_string())); + } + + #[test] + fn test_parse_callback_no_state() { + let request = "GET /callback?code=abc123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let result = CallbackServer::parse_callback_request(request).unwrap(); + assert_eq!(result.code, "abc123"); + assert_eq!(result.state, None); + } + + #[test] + fn test_parse_callback_error() { + let request = "GET /callback?error=access_denied&error_description=User%20denied%20access HTTP/1.1\r\n"; + let result = CallbackServer::parse_callback_request(request); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("access_denied")); + } + + #[test] + fn test_parse_callback_no_code() { + let request = "GET /callback?state=xyz HTTP/1.1\r\n"; + let result = CallbackServer::parse_callback_request(request); + assert!(result.is_err()); + } + + #[test] + fn test_extract_query_param() { + let path = "/callback?code=abc&state=def&other=ghi"; + assert_eq!( + CallbackServer::extract_query_param(path, "code"), + Some("abc".to_string()) + ); + assert_eq!( + CallbackServer::extract_query_param(path, "state"), + Some("def".to_string()) + ); + assert_eq!(CallbackServer::extract_query_param(path, "missing"), None); + } + + #[test] + fn test_redirect_uri() { + // Note: This test requires an available port + if let Ok(server) = CallbackServer::new(0) { + let uri = server.redirect_uri(); + assert!(uri.starts_with("http://localhost:")); + assert!(uri.ends_with("/callback")); + } + } +} diff --git a/crates/rullm-core/examples/anthropic_stream.rs b/crates/rullm-core/examples/anthropic_stream.rs index 736781c8..cd74fda8 100644 --- a/crates/rullm-core/examples/anthropic_stream.rs +++ b/crates/rullm-core/examples/anthropic_stream.rs @@ -28,11 +28,12 @@ async fn main() -> Result<(), Box> { while let Some(event_result) = stream.next().await { match event_result { Ok(event) => match event { - StreamEvent::ContentBlockDelta { delta, .. } => { - if let Delta::TextDelta { text } = delta { - print!("{text}"); - std::io::Write::flush(&mut std::io::stdout())?; - } + StreamEvent::ContentBlockDelta { + delta: Delta::TextDelta { text }, + .. + } => { + print!("{text}"); + std::io::Write::flush(&mut std::io::stdout())?; } StreamEvent::MessageStop => { println!("\n✅ Stream completed successfully"); @@ -73,11 +74,12 @@ async fn main() -> Result<(), Box> { while let Some(event_result) = conversation_stream.next().await { match event_result { Ok(event) => match event { - StreamEvent::ContentBlockDelta { delta, .. } => { - if let Delta::TextDelta { text } = delta { - print!("{text}"); - std::io::Write::flush(&mut std::io::stdout())?; - } + StreamEvent::ContentBlockDelta { + delta: Delta::TextDelta { text }, + .. + } => { + print!("{text}"); + std::io::Write::flush(&mut std::io::stdout())?; } StreamEvent::MessageStop => { println!("\n✅ Philosophical stream completed"); @@ -116,12 +118,13 @@ async fn main() -> Result<(), Box> { while let Some(event_result) = creative_stream.next().await { match event_result { Ok(event) => match event { - StreamEvent::ContentBlockDelta { delta, .. } => { - if let Delta::TextDelta { text } = delta { - print!("{text}"); - std::io::Write::flush(&mut std::io::stdout())?; - char_count += text.len(); - } + StreamEvent::ContentBlockDelta { + delta: Delta::TextDelta { text }, + .. + } => { + print!("{text}"); + std::io::Write::flush(&mut std::io::stdout())?; + char_count += text.len(); } StreamEvent::MessageStop => { println!("\n✅ Story completed (~{char_count} characters)"); @@ -158,11 +161,12 @@ async fn main() -> Result<(), Box> { while let Some(event_result) = code_stream.next().await { match event_result { Ok(event) => match event { - StreamEvent::ContentBlockDelta { delta, .. } => { - if let Delta::TextDelta { text } = delta { - print!("{text}"); - std::io::Write::flush(&mut std::io::stdout())?; - } + StreamEvent::ContentBlockDelta { + delta: Delta::TextDelta { text }, + .. + } => { + print!("{text}"); + std::io::Write::flush(&mut std::io::stdout())?; } StreamEvent::MessageStop => { println!("\n✅ Code explanation completed"); @@ -195,10 +199,11 @@ async fn main() -> Result<(), Box> { while let Some(event_result) = error_stream.next().await { match event_result { Ok(event) => match event { - StreamEvent::ContentBlockDelta { delta, .. } => { - if let Delta::TextDelta { text } = delta { - print!("{text}"); - } + StreamEvent::ContentBlockDelta { + delta: Delta::TextDelta { text }, + .. + } => { + print!("{text}"); } StreamEvent::Error { error } => { println!("📡 Stream error event (as expected): {}", error.message); diff --git a/crates/rullm-core/examples/openai_config.rs b/crates/rullm-core/examples/openai_config.rs index 19279a9f..2ba628d7 100644 --- a/crates/rullm-core/examples/openai_config.rs +++ b/crates/rullm-core/examples/openai_config.rs @@ -1,7 +1,8 @@ -use rullm_core::config::{OpenAIConfig, ProviderConfig}; +use rullm_core::config::ProviderConfig; use rullm_core::providers::openai::{ ChatCompletionRequest, ChatMessage, ContentPart, MessageContent, OpenAIClient, }; +use rullm_core::providers::openai_compatible::OpenAIConfig; // Helper to extract text from MessageContent fn extract_text(content: &MessageContent) -> String { diff --git a/crates/rullm-core/examples/test_all_providers.rs b/crates/rullm-core/examples/test_all_providers.rs index 87041eb6..67a948a3 100644 --- a/crates/rullm-core/examples/test_all_providers.rs +++ b/crates/rullm-core/examples/test_all_providers.rs @@ -1,7 +1,7 @@ -use rullm_core::config::{AnthropicConfig, GoogleAiConfig, OpenAIConfig}; -use rullm_core::providers::anthropic::AnthropicClient; -use rullm_core::providers::google::GoogleClient; +use rullm_core::providers::anthropic::{AnthropicClient, AnthropicConfig}; +use rullm_core::providers::google::{GoogleAiConfig, GoogleClient}; use rullm_core::providers::openai::OpenAIClient; +use rullm_core::providers::openai_compatible::OpenAIConfig; use std::env; #[tokio::main] @@ -37,12 +37,18 @@ async fn main() -> Result<(), Box> { // 2. Test Anthropic Provider println!("🔍 Testing Anthropic Provider..."); match test_anthropic_provider().await { - Ok(model_count) => { + Ok(models) => { + println!("✅ Anthropic: Found {} models", models.len()); println!( - "✅ Anthropic: API is working ({} models available)", - model_count + " Models (first 5): {}", + models + .iter() + .take(5) + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); - results.push(("Anthropic", true, model_count)); + results.push(("Anthropic", true, models.len())); } Err(e) => { println!("❌ Anthropic: Failed - {e}"); @@ -133,7 +139,7 @@ async fn test_openai_provider() -> Result, Box Result> { +async fn test_anthropic_provider() -> Result, Box> { let api_key = env::var("ANTHROPIC_API_KEY") .map_err(|_| "ANTHROPIC_API_KEY environment variable not set")?; @@ -146,19 +152,10 @@ async fn test_anthropic_provider() -> Result> Err(e) => println!(" Health check: ⚠️ Warning - {e}"), } - // Anthropic doesn't have a list models endpoint, so we'll just return known models - let known_models = vec![ - "claude-3-5-sonnet-20241022", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - ]; - - for model in &known_models { - println!(" Known model: {model}"); - } + // Get available models + let models = client.list_models().await?; - Ok(known_models.len()) + Ok(models) } async fn test_google_provider() -> Result, Box> { diff --git a/crates/rullm-core/src/config.rs b/crates/rullm-core/src/config.rs index be8a3488..8528f4ba 100644 --- a/crates/rullm-core/src/config.rs +++ b/crates/rullm-core/src/config.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; +use crate::providers::{AnthropicConfig, GoogleAiConfig, OpenAICompatibleConfig, OpenAIConfig}; + /// Configuration trait for LLM providers pub trait ProviderConfig: Send + Sync { /// Get the API key for this provider @@ -88,224 +90,6 @@ impl ProviderConfig for HttpProviderConfig { } } -/// OpenAI-compatible configuration (supports OpenAI, Groq, OpenRouter, etc.) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenAICompatibleConfig { - pub api_key: String, - pub organization: Option, - pub project: Option, - pub base_url: Option, - pub timeout_seconds: u64, -} - -/// Type alias for backwards compatibility -pub type OpenAIConfig = OpenAICompatibleConfig; - -impl OpenAICompatibleConfig { - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - organization: None, - project: None, - base_url: None, - timeout_seconds: 30, - } - } - - pub fn groq(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - organization: None, - project: None, - base_url: Some("https://api.groq.com/openai/v1".to_string()), - timeout_seconds: 30, - } - } - - pub fn openrouter(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - organization: None, - project: None, - base_url: Some("https://openrouter.ai/api/v1".to_string()), - timeout_seconds: 30, - } - } - - pub fn with_organization(mut self, org: impl Into) -> Self { - self.organization = Some(org.into()); - self - } - - pub fn with_project(mut self, project: impl Into) -> Self { - self.project = Some(project.into()); - self - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = Some(base_url.into()); - self - } -} - -impl ProviderConfig for OpenAICompatibleConfig { - fn api_key(&self) -> &str { - &self.api_key - } - - fn base_url(&self) -> &str { - self.base_url - .as_deref() - .unwrap_or("https://api.openai.com/v1") - } - - fn timeout(&self) -> Duration { - Duration::from_secs(self.timeout_seconds) - } - - fn headers(&self) -> HashMap { - let mut headers = HashMap::new(); - headers.insert( - "Authorization".to_string(), - format!("Bearer {}", self.api_key), - ); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - - if let Some(org) = &self.organization { - headers.insert("OpenAI-Organization".to_string(), org.clone()); - } - - if let Some(project) = &self.project { - headers.insert("OpenAI-Project".to_string(), project.clone()); - } - - headers - } - - fn validate(&self) -> Result<(), crate::error::LlmError> { - if self.api_key.is_empty() { - return Err(crate::error::LlmError::configuration("API key is required")); - } - - // Relaxed validation: don't require 'sk-' prefix since Groq and OpenRouter use different formats - // OpenAI keys start with 'sk-', Groq uses 'gsk_', OpenRouter uses different format - - Ok(()) - } -} - -/// Anthropic-specific configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AnthropicConfig { - pub api_key: String, - pub base_url: Option, - pub timeout_seconds: u64, -} - -impl AnthropicConfig { - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - base_url: None, - timeout_seconds: 30, - } - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = Some(base_url.into()); - self - } -} - -impl ProviderConfig for AnthropicConfig { - fn api_key(&self) -> &str { - &self.api_key - } - - fn base_url(&self) -> &str { - self.base_url - .as_deref() - .unwrap_or("https://api.anthropic.com") - } - - fn timeout(&self) -> Duration { - Duration::from_secs(self.timeout_seconds) - } - - fn headers(&self) -> HashMap { - let mut headers = HashMap::new(); - headers.insert("x-api-key".to_string(), self.api_key.clone()); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); - headers - } - - fn validate(&self) -> Result<(), crate::error::LlmError> { - if self.api_key.is_empty() { - return Err(crate::error::LlmError::configuration( - "Anthropic API key is required", - )); - } - - Ok(()) - } -} - -/// Google AI configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GoogleAiConfig { - pub api_key: String, - pub base_url: Option, - pub timeout_seconds: u64, -} - -impl GoogleAiConfig { - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - base_url: None, - timeout_seconds: 30, - } - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = Some(base_url.into()); - self - } -} - -impl ProviderConfig for GoogleAiConfig { - fn api_key(&self) -> &str { - &self.api_key - } - - fn base_url(&self) -> &str { - self.base_url - .as_deref() - .unwrap_or("https://generativelanguage.googleapis.com/v1beta") - } - - fn timeout(&self) -> Duration { - Duration::from_secs(self.timeout_seconds) - } - - fn headers(&self) -> HashMap { - let mut headers = HashMap::new(); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - headers - } - - fn validate(&self) -> Result<(), crate::error::LlmError> { - if self.api_key.is_empty() { - return Err(crate::error::LlmError::configuration( - "Google AI API key is required", - )); - } - - Ok(()) - } -} - /// Configuration builder for creating provider configs from environment variables pub struct ConfigBuilder; diff --git a/crates/rullm-core/src/lib.rs b/crates/rullm-core/src/lib.rs index 5aeb4864..0f7700ab 100644 --- a/crates/rullm-core/src/lib.rs +++ b/crates/rullm-core/src/lib.rs @@ -154,16 +154,10 @@ pub mod error; pub mod providers; pub mod utils; -#[cfg(test)] -mod tests; - // Concrete client exports pub use providers::{AnthropicClient, GoogleClient, OpenAIClient, OpenAICompatibleProvider}; -pub use config::{ - AnthropicConfig, ConfigBuilder, GoogleAiConfig, OpenAICompatibleConfig, OpenAIConfig, - ProviderConfig, -}; +pub use config::{ConfigBuilder, HttpProviderConfig, ProviderConfig}; pub use error::LlmError; pub use utils::sse::sse_lines; diff --git a/crates/rullm-core/src/providers/anthropic/client.rs b/crates/rullm-core/src/providers/anthropic/client.rs index 2da1cf5d..bd08bc21 100644 --- a/crates/rullm-core/src/providers/anthropic/client.rs +++ b/crates/rullm-core/src/providers/anthropic/client.rs @@ -1,5 +1,6 @@ +use super::config::AnthropicConfig; use super::types::*; -use crate::config::{AnthropicConfig, ProviderConfig}; +use crate::config::ProviderConfig; use crate::error::LlmError; use crate::utils::sse::sse_lines; use futures::Stream; @@ -185,6 +186,68 @@ impl AnthropicClient { Ok(tokens) } + /// List available models + pub async fn list_models(&self) -> Result, LlmError> { + let url = format!("{}/v1/models", self.base_url); + + let mut req = self.client.get(&url); + for (key, value) in self.config.headers() { + req = req.header(key, value); + } + + let response = req.send().await?; + + if !response.status().is_success() { + return Err(LlmError::api( + "anthropic", + "Failed to fetch available models", + Some(response.status().to_string()), + None, + )); + } + + let json: serde_json::Value = response + .json() + .await + .map_err(|e| LlmError::serialization("Failed to parse models response", Box::new(e)))?; + + let models_array = json + .get("data") + .and_then(|d| d.as_array()) + .or_else(|| json.get("models").and_then(|m| m.as_array())) + .ok_or_else(|| { + LlmError::serialization( + "Invalid models response format", + Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Missing data array", + )), + ) + })?; + + let models: Vec = models_array + .iter() + .filter_map(|m| { + m.get("id") + .and_then(|id| id.as_str()) + .or_else(|| m.get("name").and_then(|name| name.as_str())) + .or_else(|| m.get("model").and_then(|model| model.as_str())) + .map(|s| s.to_string()) + }) + .collect(); + + if models.is_empty() { + return Err(LlmError::api( + "anthropic", + "No models found in response", + None, + None, + )); + } + + Ok(models) + } + /// Health check pub async fn health_check(&self) -> Result<(), LlmError> { // Anthropic doesn't have a dedicated health endpoint diff --git a/crates/rullm-core/src/providers/anthropic/config.rs b/crates/rullm-core/src/providers/anthropic/config.rs new file mode 100644 index 00000000..baf54655 --- /dev/null +++ b/crates/rullm-core/src/providers/anthropic/config.rs @@ -0,0 +1,85 @@ +use crate::config::ProviderConfig; +use crate::error::LlmError; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// Anthropic-specific configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnthropicConfig { + pub api_key: String, + pub base_url: Option, + pub timeout_seconds: u64, + /// Whether to use OAuth authentication (Bearer token) instead of API key (x-api-key) + #[serde(default)] + pub use_oauth: bool, +} + +impl AnthropicConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + base_url: None, + timeout_seconds: 30, + use_oauth: false, + } + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } + + pub fn with_oauth(mut self, use_oauth: bool) -> Self { + self.use_oauth = use_oauth; + self + } +} + +impl ProviderConfig for AnthropicConfig { + fn api_key(&self) -> &str { + &self.api_key + } + + fn base_url(&self) -> &str { + self.base_url + .as_deref() + .unwrap_or("https://api.anthropic.com") + } + + fn timeout(&self) -> Duration { + Duration::from_secs(self.timeout_seconds) + } + + fn headers(&self) -> HashMap { + let mut headers = HashMap::new(); + + if self.use_oauth { + // OAuth: use Bearer token + required beta headers + headers.insert( + "Authorization".to_string(), + format!("Bearer {}", self.api_key), + ); + headers.insert( + "anthropic-beta".to_string(), + "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14".to_string(), + ); + headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); + } else { + // API key: use x-api-key header + headers.insert("x-api-key".to_string(), self.api_key.clone()); + headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); + } + + headers.insert("Content-Type".to_string(), "application/json".to_string()); + headers + } + + fn validate(&self) -> Result<(), LlmError> { + if self.api_key.is_empty() { + return Err(LlmError::configuration("Anthropic API key is required")); + } + + Ok(()) + } +} diff --git a/crates/rullm-core/src/providers/anthropic/mod.rs b/crates/rullm-core/src/providers/anthropic/mod.rs index 794d2e67..081ca01a 100644 --- a/crates/rullm-core/src/providers/anthropic/mod.rs +++ b/crates/rullm-core/src/providers/anthropic/mod.rs @@ -24,7 +24,9 @@ //! ``` pub mod client; +pub mod config; pub mod types; pub use client::AnthropicClient; +pub use config::AnthropicConfig; pub use types::*; diff --git a/crates/rullm-core/src/providers/anthropic/types.rs b/crates/rullm-core/src/providers/anthropic/types.rs index e152315a..f0a9c54f 100644 --- a/crates/rullm-core/src/providers/anthropic/types.rs +++ b/crates/rullm-core/src/providers/anthropic/types.rs @@ -143,6 +143,35 @@ pub struct CacheControl { pub cache_type: String, // "ephemeral" } +impl CacheControl { + /// Create an ephemeral cache control + pub fn ephemeral() -> Self { + Self { + cache_type: "ephemeral".to_string(), + } + } +} + +impl SystemBlock { + /// Create a text system block + pub fn text(text: impl Into) -> Self { + Self { + block_type: "text".to_string(), + text: text.into(), + cache_control: None, + } + } + + /// Create a text system block with ephemeral cache control + pub fn text_with_cache(text: impl Into) -> Self { + Self { + block_type: "text".to_string(), + text: text.into(), + cache_control: Some(CacheControl::ephemeral()), + } + } +} + /// Request metadata #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Metadata { diff --git a/crates/rullm-core/src/providers/google/client.rs b/crates/rullm-core/src/providers/google/client.rs index 7c1fe8f6..adbdde1f 100644 --- a/crates/rullm-core/src/providers/google/client.rs +++ b/crates/rullm-core/src/providers/google/client.rs @@ -1,5 +1,6 @@ +use super::config::GoogleAiConfig; use super::types::*; -use crate::config::{GoogleAiConfig, ProviderConfig}; +use crate::config::ProviderConfig; use crate::error::LlmError; use crate::utils::sse::sse_lines; use futures::Stream; diff --git a/crates/rullm-core/src/providers/google/config.rs b/crates/rullm-core/src/providers/google/config.rs new file mode 100644 index 00000000..236b3811 --- /dev/null +++ b/crates/rullm-core/src/providers/google/config.rs @@ -0,0 +1,58 @@ +use crate::config::ProviderConfig; +use crate::error::LlmError; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// Google AI configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GoogleAiConfig { + pub api_key: String, + pub base_url: Option, + pub timeout_seconds: u64, +} + +impl GoogleAiConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + base_url: None, + timeout_seconds: 30, + } + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } +} + +impl ProviderConfig for GoogleAiConfig { + fn api_key(&self) -> &str { + &self.api_key + } + + fn base_url(&self) -> &str { + self.base_url + .as_deref() + .unwrap_or("https://generativelanguage.googleapis.com/v1beta") + } + + fn timeout(&self) -> Duration { + Duration::from_secs(self.timeout_seconds) + } + + fn headers(&self) -> HashMap { + let mut headers = HashMap::new(); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + headers + } + + fn validate(&self) -> Result<(), LlmError> { + if self.api_key.is_empty() { + return Err(LlmError::configuration("Google AI API key is required")); + } + + Ok(()) + } +} diff --git a/crates/rullm-core/src/providers/google/mod.rs b/crates/rullm-core/src/providers/google/mod.rs index 48967a30..13fcbb20 100644 --- a/crates/rullm-core/src/providers/google/mod.rs +++ b/crates/rullm-core/src/providers/google/mod.rs @@ -22,7 +22,9 @@ //! ``` pub mod client; +pub mod config; pub mod types; pub use client::GoogleClient; +pub use config::GoogleAiConfig; pub use types::*; diff --git a/crates/rullm-core/src/providers/mod.rs b/crates/rullm-core/src/providers/mod.rs index f728d3bb..71bec32b 100644 --- a/crates/rullm-core/src/providers/mod.rs +++ b/crates/rullm-core/src/providers/mod.rs @@ -9,3 +9,8 @@ pub use anthropic::AnthropicClient; pub use google::GoogleClient; pub use openai::OpenAIClient; pub use openai_compatible::{OpenAICompatibleProvider, ProviderIdentity, identities}; + +// Export provider-specific configs +pub use anthropic::AnthropicConfig; +pub use google::GoogleAiConfig; +pub use openai_compatible::{OpenAICompatibleConfig, OpenAIConfig}; diff --git a/crates/rullm-core/src/providers/openai/client.rs b/crates/rullm-core/src/providers/openai/client.rs index ccaa47ec..2ba78faa 100644 --- a/crates/rullm-core/src/providers/openai/client.rs +++ b/crates/rullm-core/src/providers/openai/client.rs @@ -1,6 +1,7 @@ use super::types::*; -use crate::config::{OpenAIConfig, ProviderConfig}; +use crate::config::ProviderConfig; use crate::error::LlmError; +use crate::providers::openai_compatible::OpenAIConfig; use crate::utils::sse::sse_lines; use futures::Stream; use futures::StreamExt; diff --git a/crates/rullm-core/src/providers/openai_compatible/config.rs b/crates/rullm-core/src/providers/openai_compatible/config.rs new file mode 100644 index 00000000..9eeac7cb --- /dev/null +++ b/crates/rullm-core/src/providers/openai_compatible/config.rs @@ -0,0 +1,111 @@ +use crate::config::ProviderConfig; +use crate::error::LlmError; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// OpenAI-compatible configuration (supports OpenAI, Groq, OpenRouter, etc.) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAICompatibleConfig { + pub api_key: String, + pub organization: Option, + pub project: Option, + pub base_url: Option, + pub timeout_seconds: u64, +} + +/// Type alias for backwards compatibility +pub type OpenAIConfig = OpenAICompatibleConfig; + +impl OpenAICompatibleConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + organization: None, + project: None, + base_url: None, + timeout_seconds: 30, + } + } + + pub fn groq(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + organization: None, + project: None, + base_url: Some("https://api.groq.com/openai/v1".to_string()), + timeout_seconds: 30, + } + } + + pub fn openrouter(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + organization: None, + project: None, + base_url: Some("https://openrouter.ai/api/v1".to_string()), + timeout_seconds: 30, + } + } + + pub fn with_organization(mut self, org: impl Into) -> Self { + self.organization = Some(org.into()); + self + } + + pub fn with_project(mut self, project: impl Into) -> Self { + self.project = Some(project.into()); + self + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } +} + +impl ProviderConfig for OpenAICompatibleConfig { + fn api_key(&self) -> &str { + &self.api_key + } + + fn base_url(&self) -> &str { + self.base_url + .as_deref() + .unwrap_or("https://api.openai.com/v1") + } + + fn timeout(&self) -> Duration { + Duration::from_secs(self.timeout_seconds) + } + + fn headers(&self) -> HashMap { + let mut headers = HashMap::new(); + headers.insert( + "Authorization".to_string(), + format!("Bearer {}", self.api_key), + ); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + if let Some(org) = &self.organization { + headers.insert("OpenAI-Organization".to_string(), org.clone()); + } + + if let Some(project) = &self.project { + headers.insert("OpenAI-Project".to_string(), project.clone()); + } + + headers + } + + fn validate(&self) -> Result<(), LlmError> { + if self.api_key.is_empty() { + return Err(LlmError::configuration("API key is required")); + } + + // Relaxed validation: don't require 'sk-' prefix since Groq and OpenRouter use different formats + // OpenAI keys start with 'sk-', Groq uses 'gsk_', OpenRouter uses different format + + Ok(()) + } +} diff --git a/crates/rullm-core/src/providers/openai_compatible.rs b/crates/rullm-core/src/providers/openai_compatible/mod.rs similarity index 97% rename from crates/rullm-core/src/providers/openai_compatible.rs rename to crates/rullm-core/src/providers/openai_compatible/mod.rs index e0c9a105..0fe42db7 100644 --- a/crates/rullm-core/src/providers/openai_compatible.rs +++ b/crates/rullm-core/src/providers/openai_compatible/mod.rs @@ -1,3 +1,7 @@ +pub mod config; + +pub use config::{OpenAICompatibleConfig, OpenAIConfig}; + use crate::compat_types::{ ChatMessage, ChatRequest, ChatResponse, ChatRole, ChatStreamEvent, TokenUsage, }; @@ -46,7 +50,7 @@ pub mod identities { /// Generic OpenAI-compatible provider implementation #[derive(Clone)] pub struct OpenAICompatibleProvider { - config: crate::config::OpenAICompatibleConfig, + config: OpenAICompatibleConfig, client: Client, identity: ProviderIdentity, } @@ -54,7 +58,7 @@ pub struct OpenAICompatibleProvider { impl OpenAICompatibleProvider { /// Create a new OpenAI-compatible provider with custom identity pub fn new( - config: crate::config::OpenAICompatibleConfig, + config: OpenAICompatibleConfig, identity: ProviderIdentity, ) -> Result { config.validate()?; @@ -67,17 +71,17 @@ impl OpenAICompatibleProvider { } /// Create an OpenAI provider - pub fn openai(config: crate::config::OpenAICompatibleConfig) -> Result { + pub fn openai(config: OpenAICompatibleConfig) -> Result { Self::new(config, identities::OPENAI) } /// Create a Groq provider - pub fn groq(config: crate::config::OpenAICompatibleConfig) -> Result { + pub fn groq(config: OpenAICompatibleConfig) -> Result { Self::new(config, identities::GROQ) } /// Create an OpenRouter provider - pub fn openrouter(config: crate::config::OpenAICompatibleConfig) -> Result { + pub fn openrouter(config: OpenAICompatibleConfig) -> Result { Self::new(config, identities::OPENROUTER) } diff --git a/crates/rullm-core/src/tests.rs b/crates/rullm-core/src/tests.rs deleted file mode 100644 index 7bebc085..00000000 --- a/crates/rullm-core/src/tests.rs +++ /dev/null @@ -1,1044 +0,0 @@ -use crate::config::*; -use crate::error::LlmError; -use crate::middleware::{LlmServiceBuilder, MiddlewareConfig, RateLimit}; -use crate::types::{ - ChatCompletion, ChatMessage, ChatRequest, ChatRequestBuilder, ChatResponse, ChatRole, - LlmProvider, StreamConfig, TokenUsage, -}; -use std::time::Duration; - -// Mock provider for testing -#[derive(Clone)] -struct MockProvider { - name: &'static str, - should_fail: bool, -} - -impl MockProvider { - fn new(name: &'static str) -> Self { - Self { - name, - should_fail: false, - } - } - - fn with_failure(mut self) -> Self { - self.should_fail = true; - self - } -} - -#[async_trait::async_trait] -impl LlmProvider for MockProvider { - fn name(&self) -> &'static str { - self.name - } - - fn aliases(&self) -> &'static [&'static str] { - &[] - } - - fn default_base_url(&self) -> Option<&'static str> { - Some("") - } - - fn env_key(&self) -> &'static str { - "" - } - - async fn available_models(&self) -> Result, LlmError> { - if self.should_fail { - return Err(LlmError::network("Mock network error")); - } - Ok(vec!["model-1".to_string(), "model-2".to_string()]) - } - - async fn health_check(&self) -> Result<(), LlmError> { - if self.should_fail { - return Err(LlmError::service_unavailable(self.name)); - } - Ok(()) - } -} - -#[async_trait::async_trait] -impl ChatCompletion for MockProvider { - async fn chat_completion( - &self, - _request: ChatRequest, - model: &str, - ) -> Result { - if self.should_fail { - return Err(LlmError::api( - self.name, - "Mock API error", - Some("500".to_string()), - None, - )); - } - - Ok(ChatResponse { - message: ChatMessage { - role: ChatRole::Assistant, - content: "Mock response".to_string(), - }, - model: model.to_string(), - usage: TokenUsage { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - finish_reason: Some("stop".to_string()), - provider_metadata: None, - }) - } - - async fn chat_completion_stream( - &self, - _request: ChatRequest, - _model: &str, - _config: Option, - ) -> crate::types::StreamResult { - Box::pin(futures::stream::once(async { - Err(LlmError::model("Streaming not implemented in mock")) - })) - } - - async fn estimate_tokens(&self, text: &str, _model: &str) -> Result { - Ok(text.len() as u32 / 4) // Rough estimate - } -} - -#[tokio::test] -async fn test_mock_provider_basic_functionality() { - let provider = MockProvider::new("test"); - - assert_eq!(provider.name(), "test"); - - let models = provider.available_models().await.unwrap(); - assert_eq!(models, vec!["model-1", "model-2"]); - - provider.health_check().await.unwrap(); - - let token_count = provider - .estimate_tokens("hello world", "test-model") - .await - .unwrap(); - assert_eq!(token_count, 2); // "hello world".len() / 4 = 11 / 4 = 2 -} - -#[tokio::test] -async fn test_mock_provider_chat_completion() { - let provider = MockProvider::new("test"); - - let request = ChatRequestBuilder::new() - .user("Hello, world!") - .temperature(0.7) - .max_tokens(100) - .build(); - - let response = provider - .chat_completion(request, "test-model") - .await - .unwrap(); - - assert_eq!(response.message.role, ChatRole::Assistant); - assert_eq!(response.message.content, "Mock response"); - assert_eq!(response.model, "test-model"); - assert_eq!(response.usage.total_tokens, 15); -} - -#[tokio::test] -async fn test_mock_provider_failure_cases() { - let provider = MockProvider::new("test").with_failure(); - - let health_result = provider.health_check().await; - assert!(health_result.is_err()); - assert!(matches!( - health_result.unwrap_err(), - LlmError::ServiceUnavailable { .. } - )); - - let models_result = provider.available_models().await; - assert!(models_result.is_err()); - assert!(matches!( - models_result.unwrap_err(), - LlmError::Network { .. } - )); - - let request = ChatRequestBuilder::new().user("test").build(); - let chat_result = provider.chat_completion(request, "test-model").await; - assert!(chat_result.is_err()); - assert!(matches!(chat_result.unwrap_err(), LlmError::Api { .. })); -} - -#[test] -fn test_chat_request_builder() { - let request = ChatRequestBuilder::new() - .system("You are a helpful assistant") - .user("What is 2+2?") - .assistant("2+2 equals 4") - .user("What about 3+3?") - .temperature(0.8) - .max_tokens(150) - .top_p(0.9) - // .frequency_penalty(0.1) - // .presence_penalty(0.1) - // .stop_sequences(vec!["END".to_string()]) - .stream(true) - .extra_param("custom_param", serde_json::json!("custom_value")) - .build(); - - assert_eq!(request.messages.len(), 4); - assert_eq!(request.messages[0].role, ChatRole::System); - assert_eq!(request.messages[1].role, ChatRole::User); - assert_eq!(request.messages[2].role, ChatRole::Assistant); - assert_eq!(request.messages[3].role, ChatRole::User); - assert_eq!(request.temperature, Some(0.8)); - assert_eq!(request.max_tokens, Some(150)); - assert_eq!(request.top_p, Some(0.9)); - // assert_eq!(request.frequency_penalty, Some(0.1)); - // assert_eq!(request.presence_penalty, Some(0.1)); - // assert_eq!(request.stop, Some(vec!["END".to_string()])); - assert_eq!(request.stream, Some(true)); - assert!(request.extra_params.is_some()); -} - -#[test] -fn test_openai_config() { - let config = OpenAIConfig::new("sk-test123") - .with_organization("org-123") - .with_project("proj-456"); - - assert_eq!(config.api_key(), "sk-test123"); - assert_eq!(config.base_url(), "https://api.openai.com/v1"); - - let headers = config.headers(); - assert_eq!( - headers.get("Authorization"), - Some(&"Bearer sk-test123".to_string()) - ); - assert_eq!( - headers.get("OpenAI-Organization"), - Some(&"org-123".to_string()) - ); - assert_eq!(headers.get("OpenAI-Project"), Some(&"proj-456".to_string())); - - config.validate().unwrap(); - - // Test invalid config (empty API key) - let invalid_config = OpenAIConfig::new(""); - assert!(invalid_config.validate().is_err()); -} - -#[test] -fn test_anthropic_config() { - let config = AnthropicConfig::new("sk-ant-test123"); - - assert_eq!(config.api_key(), "sk-ant-test123"); - assert_eq!(config.base_url(), "https://api.anthropic.com"); - - let headers = config.headers(); - assert_eq!( - headers.get("x-api-key"), - Some(&"sk-ant-test123".to_string()) - ); - assert_eq!( - headers.get("anthropic-version"), - Some(&"2023-06-01".to_string()) - ); - - config.validate().unwrap(); -} - -#[test] -fn test_google_ai_config() { - let config = GoogleAiConfig::new("AIza-test123"); - - assert_eq!(config.api_key(), "AIza-test123"); - assert_eq!( - config.base_url(), - "https://generativelanguage.googleapis.com/v1beta" - ); - - config.validate().unwrap(); -} - -#[test] -fn test_all_llm_error_variants() { - use std::collections::HashMap; - - // Test Network error - let network_error = LlmError::network("Connection failed"); - assert!(matches!(network_error, LlmError::Network { .. })); - assert_eq!( - network_error.to_string(), - "Network error: Connection failed" - ); - - // Test Network error with source - let source_error = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused"); - let network_with_source = LlmError::network_with_source("Connection refused", source_error); - assert!(matches!(network_with_source, LlmError::Network { .. })); - - // Test Authentication error - let auth_error = LlmError::authentication("Invalid API key"); - assert!(matches!(auth_error, LlmError::Authentication { .. })); - assert_eq!( - auth_error.to_string(), - "Authentication failed: Invalid API key" - ); - - // Test RateLimit error - let rate_limit = LlmError::rate_limit( - "Too many requests", - Some(std::time::Duration::from_secs(60)), - ); - assert!(matches!(rate_limit, LlmError::RateLimit { .. })); - assert!(rate_limit.to_string().contains("Rate limit exceeded")); - - // Test Api error with details - let mut details = HashMap::new(); - details.insert( - "error_code".to_string(), - serde_json::Value::String("QUOTA_EXCEEDED".to_string()), - ); - let api_error = LlmError::api( - "openai", - "Quota exceeded", - Some("429".to_string()), - Some(details), - ); - assert!(matches!(api_error, LlmError::Api { .. })); - assert!(api_error.to_string().contains("API error from openai")); - - // Test Configuration error - let config_error = LlmError::configuration("Missing API key"); - assert!(matches!(config_error, LlmError::Configuration { .. })); - assert_eq!( - config_error.to_string(), - "Configuration error: Missing API key" - ); - - // Test Validation error - let validation_error = LlmError::validation("Invalid model name"); - assert!(matches!(validation_error, LlmError::Validation { .. })); - assert_eq!( - validation_error.to_string(), - "Validation error: Invalid model name" - ); - - // Test Timeout error - let timeout_error = LlmError::timeout(std::time::Duration::from_secs(30)); - assert!(matches!(timeout_error, LlmError::Timeout { .. })); - assert!(timeout_error.to_string().contains("Request timed out")); - - // Test Serialization error - let json_error = serde_json::from_str::("invalid json").unwrap_err(); - let serialization_error = LlmError::serialization("JSON parse failed", json_error); - assert!(matches!( - serialization_error, - LlmError::Serialization { .. } - )); - assert!( - serialization_error - .to_string() - .contains("Serialization error") - ); - - // Test Model error - let model_error = LlmError::model("Model not found"); - assert!(matches!(model_error, LlmError::Model { .. })); - assert_eq!(model_error.to_string(), "Model error: Model not found"); - - // Test Resource error - let resource_error = LlmError::resource("Insufficient credits"); - assert!(matches!(resource_error, LlmError::Resource { .. })); - assert_eq!( - resource_error.to_string(), - "Resource error: Insufficient credits" - ); - - // Test ServiceUnavailable error - let service_error = LlmError::service_unavailable("anthropic"); - assert!(matches!(service_error, LlmError::ServiceUnavailable { .. })); - assert_eq!( - service_error.to_string(), - "Service unavailable: anthropic is currently unavailable" - ); - - // Test Unknown error - let unknown_error = LlmError::unknown("Unexpected error"); - assert!(matches!(unknown_error, LlmError::Unknown { .. })); - assert_eq!( - unknown_error.to_string(), - "Unexpected error: Unexpected error" - ); - - // Test Unknown error with source - let io_error = std::io::Error::other("unknown"); - let unknown_with_source = LlmError::unknown_with_source("Something went wrong", io_error); - assert!(matches!(unknown_with_source, LlmError::Unknown { .. })); -} - -#[test] -fn test_error_conversions() { - // Test From conversion - let json_error = serde_json::from_str::("invalid json").unwrap_err(); - let llm_error: LlmError = json_error.into(); - assert!(matches!(llm_error, LlmError::Serialization { .. })); - assert!(llm_error.to_string().contains("JSON serialization failed")); - - // Test From for timeout - // Note: We can't easily create specific reqwest errors without making actual requests - // so we test the error conversion logic indirectly through the provider implementations -} - -#[test] -fn test_provider_specific_error_mapping() { - use std::collections::HashMap; - - // Test OpenAI-style error mapping - let mut openai_details = HashMap::new(); - openai_details.insert( - "type".to_string(), - serde_json::Value::String("insufficient_quota".to_string()), - ); - openai_details.insert("param".to_string(), serde_json::Value::Null); - - let openai_quota_error = LlmError::api( - "openai", - "You exceeded your current quota, please check your plan and billing details.", - Some("429".to_string()), - Some(openai_details.clone()), - ); - - assert!(matches!(openai_quota_error, LlmError::Api { .. })); - if let LlmError::Api { - provider, - message, - code, - details, - } = openai_quota_error - { - assert_eq!(provider, "openai"); - assert!(message.contains("quota")); - assert_eq!(code, Some("429".to_string())); - assert!(details.is_some()); - let details = details.unwrap(); - assert_eq!( - details.get("type").unwrap(), - &serde_json::Value::String("insufficient_quota".to_string()) - ); - } - - // Test Anthropic-style error mapping - let mut anthropic_details = HashMap::new(); - anthropic_details.insert( - "error_type".to_string(), - serde_json::Value::String("authentication_error".to_string()), - ); - - let anthropic_auth_error = LlmError::api( - "anthropic", - "Invalid API key provided", - Some("401".to_string()), - Some(anthropic_details.clone()), - ); - - assert!(matches!(anthropic_auth_error, LlmError::Api { .. })); - if let LlmError::Api { - provider, - message, - code, - details, - } = anthropic_auth_error - { - assert_eq!(provider, "anthropic"); - assert_eq!(message, "Invalid API key provided"); - assert_eq!(code, Some("401".to_string())); - assert!(details.is_some()); - let details = details.unwrap(); - assert_eq!( - details.get("error_type").unwrap(), - &serde_json::Value::String("authentication_error".to_string()) - ); - } - - // Test Google-style error mapping - let mut google_details = HashMap::new(); - google_details.insert( - "reason".to_string(), - serde_json::Value::String("RATE_LIMIT_EXCEEDED".to_string()), - ); - google_details.insert( - "domain".to_string(), - serde_json::Value::String("global".to_string()), - ); - - let google_rate_error = LlmError::api( - "google", - "Rate limit exceeded", - Some("429".to_string()), - Some(google_details.clone()), - ); - - assert!(matches!(google_rate_error, LlmError::Api { .. })); - if let LlmError::Api { - provider, - message, - code, - details, - } = google_rate_error - { - assert_eq!(provider, "google"); - assert_eq!(message, "Rate limit exceeded"); - assert_eq!(code, Some("429".to_string())); - assert!(details.is_some()); - let details = details.unwrap(); - assert_eq!( - details.get("reason").unwrap(), - &serde_json::Value::String("RATE_LIMIT_EXCEEDED".to_string()) - ); - } -} - -#[test] -fn test_error_source_chaining() { - use std::error::Error; - - // Test that source errors are properly chained - let io_error = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "Connection refused"); - let network_error = LlmError::network_with_source("Failed to connect", io_error); - - // Verify the error source is accessible - assert!(network_error.source().is_some()); - assert!(matches!(network_error, LlmError::Network { .. })); - - // Test serialization error with source - let json_error = serde_json::from_str::("invalid json").unwrap_err(); - let serialization_error = LlmError::serialization("Failed to parse JSON", json_error); - - assert!(serialization_error.source().is_some()); - assert!(matches!( - serialization_error, - LlmError::Serialization { .. } - )); - - // Test unknown error with source - let unknown_source = std::io::Error::other("mysterious error"); - let unknown_error = - LlmError::unknown_with_source("Something unexpected happened", unknown_source); - - assert!(unknown_error.source().is_some()); - assert!(matches!(unknown_error, LlmError::Unknown { .. })); -} - -#[test] -fn test_openai_provider_creation() { - use crate::config::OpenAIConfig; - use crate::providers::OpenAIProvider; - - // Test with valid config - let config = OpenAIConfig::new("sk-test123"); - let provider = OpenAIProvider::new(config); - assert!(provider.is_ok()); - - // Test with invalid config (empty API key) - let invalid_config = OpenAIConfig::new(""); - let invalid_provider = OpenAIProvider::new(invalid_config); - assert!(invalid_provider.is_err()); -} - -#[tokio::test] -async fn test_openai_request_conversion() { - use crate::config::OpenAIConfig; - use crate::providers::OpenAIProvider; - - let config = OpenAIConfig::new("sk-test123"); - let provider = OpenAIProvider::new(config).unwrap(); - - let _request = ChatRequestBuilder::new() - .system("You are helpful") - .user("Hello") - .temperature(0.7) - .max_tokens(100) - .build(); - - // We can't easily test the private method, but we can verify - // the provider was created successfully and implements the traits - assert_eq!(provider.name(), "openai"); -} - -#[tokio::test] -async fn test_openai_token_estimation() { - use crate::config::OpenAIConfig; - use crate::providers::OpenAIProvider; - - let config = OpenAIConfig::new("sk-test123"); - let provider = OpenAIProvider::new(config).unwrap(); - - let tokens = provider - .estimate_tokens("Hello world", "gpt-4") - .await - .unwrap(); - - // Should be roughly 3 tokens for "Hello world" (8 chars / 4 = 2, rounded up to 3) - assert!((2..=4).contains(&tokens)); -} - -#[test] -fn test_anthropic_provider_creation() { - use crate::config::AnthropicConfig; - use crate::providers::AnthropicProvider; - - // Test with valid config - let config = AnthropicConfig::new("sk-ant-test123"); - let provider = AnthropicProvider::new(config); - assert!(provider.is_ok()); - - // Test with invalid config (empty API key) - let invalid_config = AnthropicConfig::new(""); - let invalid_provider = AnthropicProvider::new(invalid_config); - assert!(invalid_provider.is_err()); -} - -#[tokio::test] -async fn test_anthropic_token_estimation() { - use crate::config::AnthropicConfig; - use crate::providers::AnthropicProvider; - - let config = AnthropicConfig::new("sk-ant-test123"); - let provider = AnthropicProvider::new(config).unwrap(); - - let tokens = provider - .estimate_tokens("Hello world", "claude-3-haiku-20240307") - .await - .unwrap(); - - // Should be roughly 3 tokens for "Hello world" (11 chars / 3.5 ≈ 3.14, rounded up to 4) - assert!((3..=5).contains(&tokens)); -} - -#[test] -fn test_google_provider_creation() { - use crate::config::GoogleAiConfig; - use crate::providers::GoogleProvider; - - // Test with valid config - let config = GoogleAiConfig::new("AIza-test123"); - let provider = GoogleProvider::new(config); - assert!(provider.is_ok()); - - // Test with invalid config (empty API key) - let invalid_config = GoogleAiConfig::new(""); - let invalid_provider = GoogleProvider::new(invalid_config); - assert!(invalid_provider.is_err()); -} - -#[tokio::test] -async fn test_google_token_estimation() { - use crate::config::GoogleAiConfig; - use crate::providers::GoogleProvider; - - let config = GoogleAiConfig::new("AIza-test123"); - let provider = GoogleProvider::new(config).unwrap(); - - let tokens = provider - .estimate_tokens("Hello world", "gemini-1.5-pro") - .await - .unwrap(); - - // Should be roughly 3 tokens for "Hello world" (11 chars / 4 ≈ 2.75, rounded up to 3) - assert!((2..=4).contains(&tokens)); -} - -#[test] -fn test_google_request_format() { - use crate::config::GoogleAiConfig; - use crate::providers::GoogleProvider; - - let config = GoogleAiConfig::new("AIza-test123"); - let provider = GoogleProvider::new(config).unwrap(); - - let request = ChatRequestBuilder::new() - .system("You are a helpful assistant") - .user("Hello") - .assistant("Hi there!") - .user("How are you?") - .temperature(0.7) - .max_tokens(100) - .top_p(0.9) - // .stop_sequences(vec!["END".to_string()]) - .build(); - - // We can't easily test the private method directly, but we can verify - // the provider was created successfully and test the format indirectly - // by testing the available functionality - assert_eq!(provider.name(), "google"); - - // The request should have system message separated from user/assistant messages - let system_messages: Vec<_> = request - .messages - .iter() - .filter(|m| m.role == ChatRole::System) - .collect(); - assert_eq!(system_messages.len(), 1); - assert_eq!(system_messages[0].content, "You are a helpful assistant"); - - // Should have user and assistant messages - let conversation_messages: Vec<_> = request - .messages - .iter() - .filter(|m| matches!(m.role, ChatRole::User | ChatRole::Assistant)) - .collect(); - assert_eq!(conversation_messages.len(), 3); -} - -#[test] -fn test_anthropic_request_format() { - use crate::config::AnthropicConfig; - use crate::providers::AnthropicProvider; - - let config = AnthropicConfig::new("sk-ant-test123"); - let provider = AnthropicProvider::new(config).unwrap(); - - let request = ChatRequestBuilder::new() - .system("You are a helpful assistant") - .user("Hello") - .assistant("Hi there!") - .user("How are you?") - .temperature(0.7) - .max_tokens(100) - .top_p(0.9) - // .stop_sequences(vec!["END".to_string()]) - .build(); - - // We can't easily test the private method directly, but we can verify - // the provider was created successfully and test the format indirectly - // by testing the available functionality - assert_eq!(provider.name(), "anthropic"); - - // The request should have system message separated from user/assistant messages - let system_messages: Vec<_> = request - .messages - .iter() - .filter(|m| m.role == ChatRole::System) - .collect(); - let conversation_messages: Vec<_> = request - .messages - .iter() - .filter(|m| m.role != ChatRole::System) - .collect(); - - assert_eq!(system_messages.len(), 1); - assert_eq!(conversation_messages.len(), 3); - assert_eq!(system_messages[0].content, "You are a helpful assistant"); -} - -#[test] -fn test_anthropic_response_parsing() { - use crate::config::AnthropicConfig; - use crate::providers::AnthropicProvider; - - let config = AnthropicConfig::new("sk-ant-test123"); - let provider = AnthropicProvider::new(config).unwrap(); - - // Test parsing a mock Anthropic response - let mock_response = serde_json::json!({ - "content": [ - { - "type": "text", - "text": "Hello! I'm Claude, an AI assistant." - } - ], - "id": "msg_test123", - "model": "claude-3-haiku-20240307", - "role": "assistant", - "stop_reason": "end_turn", - "stop_sequence": null, - "type": "message", - "usage": { - "input_tokens": 15, - "output_tokens": 25 - } - }); - - let result = provider.parse_anthropic_response(mock_response); - assert!(result.is_ok()); - - let response = result.unwrap(); - assert_eq!(response.message.role, ChatRole::Assistant); - assert_eq!( - response.message.content, - "Hello! I'm Claude, an AI assistant." - ); - assert_eq!(response.model, "claude-3-haiku-20240307"); - assert_eq!(response.usage.prompt_tokens, 15); - assert_eq!(response.usage.completion_tokens, 25); - assert_eq!(response.usage.total_tokens, 40); - assert_eq!(response.finish_reason, Some("end_turn".to_string())); -} - -#[test] -fn test_anthropic_response_parsing_errors() { - use crate::config::AnthropicConfig; - use crate::providers::AnthropicProvider; - - let config = AnthropicConfig::new("sk-ant-test123"); - let provider = AnthropicProvider::new(config).unwrap(); - - // Test missing content array - let invalid_response = serde_json::json!({ - "model": "claude-3-haiku-20240307", - "usage": {"input_tokens": 10, "output_tokens": 5} - }); - - let result = provider.parse_anthropic_response(invalid_response); - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - LlmError::Serialization { .. } - )); - - // Test empty content array - let empty_content_response = serde_json::json!({ - "content": [], - "model": "claude-3-haiku-20240307", - "usage": {"input_tokens": 10, "output_tokens": 5} - }); - - let result = provider.parse_anthropic_response(empty_content_response); - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - LlmError::Serialization { .. } - )); - - // Test missing text in content block - let missing_text_response = serde_json::json!({ - "content": [{"type": "text"}], - "model": "claude-3-haiku-20240307", - "usage": {"input_tokens": 10, "output_tokens": 5} - }); - - let result = provider.parse_anthropic_response(missing_text_response); - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - LlmError::Serialization { .. } - )); -} - -// ============================================================================= -// Middleware Tests -// ============================================================================= - -#[test] -fn test_middleware_config_default() { - let config = MiddlewareConfig::default(); - - assert_eq!(config.timeout, Some(Duration::from_secs(30))); - assert!(config.rate_limit.is_none()); - assert!(config.enable_logging); - assert!(!config.enable_metrics); -} - -#[test] -fn test_middleware_config_custom() { - let rate_limit = RateLimit { - requests_per_period: 100, - period: Duration::from_secs(60), - }; - - let config = MiddlewareConfig { - timeout: Some(Duration::from_secs(45)), - rate_limit: Some(rate_limit.clone()), - enable_logging: false, - enable_metrics: true, - }; - - assert_eq!(config.timeout, Some(Duration::from_secs(45))); - assert!(config.rate_limit.is_some()); - assert!(!config.enable_logging); - assert!(config.enable_metrics); - - let rate_limit_config = config.rate_limit.as_ref().unwrap(); - assert_eq!(rate_limit_config.requests_per_period, 100); - assert_eq!(rate_limit_config.period, Duration::from_secs(60)); -} - -#[test] -fn test_llm_service_builder_default() { - let provider = MockProvider::new("test"); - let middleware_stack = LlmServiceBuilder::new().build(provider, "test-model".to_string()); - - let config = middleware_stack.config(); - assert_eq!(config.timeout, Some(Duration::from_secs(30))); - assert!(config.enable_logging); - assert!(!config.enable_metrics); -} - -#[test] -fn test_llm_service_builder_fluent_api() { - let provider = MockProvider::new("test"); - - let middleware_stack = LlmServiceBuilder::new() - .timeout(Duration::from_secs(60)) - .rate_limit(50, Duration::from_secs(30)) - .logging() - .metrics() - .build(provider, "test-model".to_string()); - - let config = middleware_stack.config(); - assert_eq!(config.timeout, Some(Duration::from_secs(60))); - assert!(config.rate_limit.is_some()); - assert!(config.enable_logging); - assert!(config.enable_metrics); - - let rate_limit = config.rate_limit.as_ref().unwrap(); - assert_eq!(rate_limit.requests_per_period, 50); - assert_eq!(rate_limit.period, Duration::from_secs(30)); -} - -#[test] -fn test_llm_service_builder_with_config() { - let custom_config = MiddlewareConfig { - timeout: Some(Duration::from_secs(20)), - rate_limit: None, - enable_logging: false, - enable_metrics: true, - }; - - let provider = MockProvider::new("test"); - let middleware_stack = LlmServiceBuilder::with_config(custom_config.clone()) - .build(provider, "test-model".to_string()); - - let config = middleware_stack.config(); - assert_eq!(config.timeout, custom_config.timeout); - assert_eq!(config.enable_logging, custom_config.enable_logging); - assert_eq!(config.enable_metrics, custom_config.enable_metrics); -} - -#[tokio::test] -async fn test_middleware_stack_basic_call() { - let provider = MockProvider::new("test"); - let mut middleware_stack = LlmServiceBuilder::new() - .logging() - .build(provider, "test-model".to_string()); - - let request = ChatRequestBuilder::new().user("Hello, middleware!").build(); - - let response = middleware_stack.call(request).await.unwrap(); - - assert_eq!(response.message.content, "Mock response"); - assert_eq!(response.model, "test-model"); - assert_eq!(response.usage.total_tokens, 15); -} - -#[tokio::test] -async fn test_middleware_logging_and_metrics() { - let provider = MockProvider::new("test"); - let mut middleware_stack = LlmServiceBuilder::new() - .logging() - .metrics() - .build(provider, "test-model".to_string()); - - let request = ChatRequestBuilder::new() - .user("Test logging and metrics") - .build(); - - // This test mainly ensures the logging/metrics code doesn't crash - // In a real scenario, you'd capture log output and verify metrics - let response = middleware_stack.call(request).await.unwrap(); - - assert_eq!(response.message.content, "Mock response"); - - let config = middleware_stack.config(); - assert!(config.enable_logging); - assert!(config.enable_metrics); -} - -#[test] -fn test_rate_limit_configuration() { - let rate_limit = RateLimit { - requests_per_period: 100, - period: Duration::from_secs(60), - }; - - assert_eq!(rate_limit.requests_per_period, 100); - assert_eq!(rate_limit.period, Duration::from_secs(60)); - - // Test with different values - let rate_limit2 = RateLimit { - requests_per_period: 50, - period: Duration::from_secs(30), - }; - - assert_eq!(rate_limit2.requests_per_period, 50); - assert_eq!(rate_limit2.period, Duration::from_secs(30)); -} - -#[tokio::test] -async fn test_middleware_error_propagation() { - let provider = MockProvider::new("test").with_failure(); - let mut middleware_stack = LlmServiceBuilder::new() - .logging() - .build(provider, "test-model".to_string()); - - let request = ChatRequestBuilder::new().user("This will fail").build(); - - let result = middleware_stack.call(request).await; - - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(matches!(error, LlmError::Api { .. })); -} - -#[test] -fn test_middleware_config_inspection() { - let custom_config = MiddlewareConfig { - timeout: Some(Duration::from_secs(25)), - rate_limit: Some(RateLimit { - requests_per_period: 75, - period: Duration::from_secs(45), - }), - enable_logging: true, - enable_metrics: false, - }; - - let provider = MockProvider::new("test"); - let middleware_stack = LlmServiceBuilder::with_config(custom_config.clone()) - .build(provider, "test-model".to_string()); - - let config = middleware_stack.config(); - - // Verify all configuration values are preserved - assert_eq!(config.timeout, Some(Duration::from_secs(25))); - assert!(config.enable_logging); - assert!(!config.enable_metrics); - - if let Some(rate_limit) = &config.rate_limit { - assert_eq!(rate_limit.requests_per_period, 75); - assert_eq!(rate_limit.period, Duration::from_secs(45)); - } else { - panic!("Expected rate limit configuration"); - } -} - -#[tokio::test] -async fn test_middleware_performance_timing() { - let provider = MockProvider::new("test"); - let mut middleware_stack = LlmServiceBuilder::new() - .metrics() // Enable metrics to test timing logic - .build(provider, "test-model".to_string()); - - let request = ChatRequestBuilder::new().user("Performance test").build(); - - let start = std::time::Instant::now(); - let _response = middleware_stack.call(request).await.unwrap(); - let duration = start.elapsed(); - - // The call should complete relatively quickly for a mock provider - assert!(duration < Duration::from_secs(1)); -} diff --git a/scripts/anthropic_oauth.py b/scripts/anthropic_oauth.py new file mode 100755 index 00000000..71685e79 --- /dev/null +++ b/scripts/anthropic_oauth.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +""" +Simple Anthropic OAuth flow test script. +Implements PKCE OAuth to get access token and test API calls. +Caches auth code and tokens to avoid repeated logins. +""" + +import argparse +import base64 +import hashlib +import http.server +import json +import os +import secrets +import socketserver +import time +import urllib.parse +import urllib.request +import webbrowser +from pathlib import Path +from typing import Optional + +# OAuth Configuration +CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" +SCOPES = "org:create_api_key user:profile user:inference" +REDIRECT_URI = "http://localhost:8765/callback" +AUTH_URL = "https://claude.ai/oauth/authorize" +TOKEN_URL = "https://console.anthropic.com/v1/oauth/token" +API_URL = "https://api.anthropic.com/v1/messages" + +# Callback server port +PORT = 8765 + +# Cache files +CACHE_DIR = Path.home() / ".cache" / "anthropic-oauth-test" +TOKEN_CACHE_FILE = CACHE_DIR / "tokens.json" +AUTH_CODE_CACHE_FILE = CACHE_DIR / "auth_code.json" + +# Token expiry buffer (refresh 5 minutes before expiry) +EXPIRY_BUFFER_SECONDS = 5 * 60 + + +def generate_pkce() -> tuple[str, str]: + """Generate PKCE code_verifier and code_challenge (S256).""" + verifier_bytes = secrets.token_bytes(64) + code_verifier = base64.urlsafe_b64encode(verifier_bytes).rstrip(b"=").decode("ascii") + challenge_hash = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(challenge_hash).rstrip(b"=").decode("ascii") + return code_verifier, code_challenge + + +def load_cached_auth_code() -> Optional[dict]: + """Load cached auth code and verifier.""" + if not AUTH_CODE_CACHE_FILE.exists(): + return None + try: + with open(AUTH_CODE_CACHE_FILE, "r") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + return None + + +def save_auth_code(auth_code: str, code_verifier: str) -> None: + """Save auth code and verifier for retry.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + data = { + "auth_code": auth_code, + "code_verifier": code_verifier, + "timestamp": time.time(), + } + with open(AUTH_CODE_CACHE_FILE, "w") as f: + json.dump(data, f, indent=2) + os.chmod(AUTH_CODE_CACHE_FILE, 0o600) + print(f" Auth code cached to {AUTH_CODE_CACHE_FILE}") + + +def load_cached_tokens() -> Optional[dict]: + """Load tokens from cache file if they exist and are valid.""" + if not TOKEN_CACHE_FILE.exists(): + return None + try: + with open(TOKEN_CACHE_FILE, "r") as f: + cached = json.load(f) + expires_at = cached.get("expires_at", 0) + if time.time() >= expires_at - EXPIRY_BUFFER_SECONDS: + print(" Cached token expired or expiring soon") + return None + return cached + except (json.JSONDecodeError, IOError) as e: + print(f" Failed to load cache: {e}") + return None + + +def save_tokens(token_response: dict) -> None: + """Save tokens to cache file.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + expires_in = token_response.get("expires_in", 3600) + token_response["expires_at"] = time.time() + expires_in + with open(TOKEN_CACHE_FILE, "w") as f: + json.dump(token_response, f, indent=2) + os.chmod(TOKEN_CACHE_FILE, 0o600) + print(f" Tokens cached to {TOKEN_CACHE_FILE}") + + +def refresh_token(refresh_token: str) -> dict: + """Refresh the access token using the refresh token.""" + import requests + + data = { + "grant_type": "refresh_token", + "client_id": CLIENT_ID, + "refresh_token": refresh_token, + } + response = requests.post(TOKEN_URL, json=data) + response.raise_for_status() + return response.json() + + +class CallbackHandler(http.server.BaseHTTPRequestHandler): + """Handle OAuth callback.""" + auth_code: Optional[str] = None + state: Optional[str] = None + error: Optional[str] = None + + @classmethod + def reset(cls): + cls.auth_code = None + cls.state = None + cls.error = None + + def do_GET(self): + parsed = urllib.parse.urlparse(self.path) + if parsed.path != "/callback": + self.send_response(404) + self.end_headers() + return + + params = urllib.parse.parse_qs(parsed.query) + if "error" in params: + CallbackHandler.error = params.get("error", ["unknown"])[0] + error_desc = params.get("error_description", ["No description"])[0] + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(f"

Error: {CallbackHandler.error}

{error_desc}

".encode()) + return + + CallbackHandler.auth_code = params.get("code", [None])[0] + CallbackHandler.state = params.get("state", [None])[0] + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"

Authorization successful!

You can close this window.

") + + def log_message(self, format, *args): + pass + + +def wait_for_callback() -> tuple[Optional[str], Optional[str]]: + """Start server and wait for callback.""" + socketserver.TCPServer.allow_reuse_address = True + with socketserver.TCPServer(("", PORT), CallbackHandler) as httpd: + httpd.timeout = 300 + httpd.handle_request() + return CallbackHandler.auth_code, CallbackHandler.state + + +def exchange_code_for_token(code: str, code_verifier: str, state: str) -> dict: + """Exchange authorization code for tokens.""" + import requests + + data = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "redirect_uri": REDIRECT_URI, + "code_verifier": code_verifier, + "state": state, # Required! The state from the callback + } + + response = requests.post(TOKEN_URL, json=data) + if not response.ok: + print(f" [ERROR] {response.status_code}: {response.text}") + response.raise_for_status() + + return response.json() + + +def test_api_call(access_token: str) -> dict: + """Test API call with the access token.""" + data = { + "model": "claude-opus-4-5-20251101", + "max_tokens": 10000, + "system": [ + { + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + "cache_control": { + "type": "ephemeral" + } + }, + ], + "messages": [ + {"role": "assistant", "content": "Say hello in exactly 5 words."}, + ], + } + headers = { + "Authorization": f"Bearer {access_token}", + "anthropic-version": "2023-06-01", + "anthropic-beta": "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14", + "Content-Type": "application/json", + "user-agent": "ai-sdk/anthropic/2.0.50 ai-sdk/provider-utils/3.0.18 runtime/bun/1.3.5" + } + req = urllib.request.Request( + API_URL, + data=json.dumps(data).encode("utf-8"), + headers=headers, + method="POST", + ) + with urllib.request.urlopen(req) as response: + return json.loads(response.read().decode("utf-8")) + + +def do_oauth_flow() -> tuple[Optional[str], Optional[str]]: + """Perform the OAuth flow to get auth code. Returns (auth_code, code_verifier).""" + CallbackHandler.reset() + + print("\n[1] Generating PKCE credentials...") + code_verifier, code_challenge = generate_pkce() + print(f" Code verifier: {code_verifier}") + print(f" Code challenge: {code_challenge}") + + print("\n[2] Building authorization URL...") + auth_params = { + "code": "true", + "client_id": CLIENT_ID, + "response_type": "code", + "redirect_uri": REDIRECT_URI, + "scope": SCOPES, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": code_verifier, + } + auth_url = f"{AUTH_URL}?{urllib.parse.urlencode(auth_params)}" + print(f" URL: {auth_url}") + + print("\n[3] Opening browser for authorization...") + print(f" Waiting for callback on port {PORT}...") + + webbrowser.open(auth_url) + auth_code, state = wait_for_callback() + + if CallbackHandler.error: + print(f"\n[ERROR] Authorization failed: {CallbackHandler.error}") + return None, None + + if not auth_code: + print("\n[ERROR] No authorization code received") + return None, None + + print(f" Received authorization code: {auth_code}") + print(f" Received state: {state}") + + # Cache the auth code for retrying token exchange + save_auth_code(auth_code, code_verifier) + + return auth_code, code_verifier + + +def main(): + parser = argparse.ArgumentParser(description="Anthropic OAuth test") + parser.add_argument("--retry-token", action="store_true", + help="Retry token exchange with cached auth code") + parser.add_argument("--login", action="store_true", + help="Force new login even if cached") + parser.add_argument("--clear", action="store_true", + help="Clear all cached data") + parser.add_argument("--auth-only", action="store_true", + help="Only get auth code, don't exchange for token") + parser.add_argument("--test-api", action="store_true", + help="Test API call after getting token") + args = parser.parse_args() + + print("=" * 60) + print("Anthropic OAuth Flow Test") + print("=" * 60) + + if args.clear: + print("\n[*] Clearing cache...") + if TOKEN_CACHE_FILE.exists(): + TOKEN_CACHE_FILE.unlink() + print(f" Deleted {TOKEN_CACHE_FILE}") + if AUTH_CODE_CACHE_FILE.exists(): + AUTH_CODE_CACHE_FILE.unlink() + print(f" Deleted {AUTH_CODE_CACHE_FILE}") + print(" Cache cleared.") + return + + # Check for cached tokens first (unless forcing login or retry) + if not args.login and not args.retry_token: + print("\n[*] Checking for cached tokens...") + cached = load_cached_tokens() + if cached: + expires_at = cached.get("expires_at", 0) + remaining = int(expires_at - time.time()) + access_token = cached.get("access_token", "") + refresh_tok = cached.get("refresh_token", "") + print(f" Found valid cached token!") + print(f" Access token: {access_token[:50]}...") + print(f" Refresh token: {refresh_tok[:50]}...") + print(f" Expires in: {remaining} seconds ({remaining // 3600}h {(remaining % 3600) // 60}m)") + + # Test API call with cached token if requested + if args.test_api: + print("\n[*] Testing API call with cached token...") + try: + api_response = test_api_call(access_token) + print(" API call successful!") + print(f" Model: {api_response.get('model', 'N/A')}") + content = api_response.get("content", []) + if content: + print(f" Response: {content[0].get('text', 'N/A')}") + except urllib.error.HTTPError as e: + print(f"\n[ERROR] API call failed: {e.code}") + print(f" Response: {e.read().decode('utf-8')}") + else: + print(f"\n Use --login to force re-authentication") + return + + # Try to get auth code (from cache if --retry-token, otherwise new login) + if args.retry_token: + print("\n[*] Loading cached auth code...") + cached_auth = load_cached_auth_code() + if not cached_auth: + print("[ERROR] No cached auth code found. Run without --retry-token first.") + return + auth_code = cached_auth["auth_code"] + code_verifier = cached_auth["code_verifier"] + print(f" Auth code: {auth_code}") + print(f" Code verifier: {code_verifier}") + else: + auth_code, code_verifier = do_oauth_flow() + if not auth_code: + print("\n[ERROR] Failed to get authorization code") + return + + # If auth-only, stop here + if args.auth_only: + print("\n[*] Auth code obtained. Use --retry-token to exchange for tokens.") + print(f" Cached at: {AUTH_CODE_CACHE_FILE}") + return + + # Exchange code for tokens + # Note: state is the same as code_verifier (we sent state=code_verifier in auth URL) + print("\n[4] Exchanging code for tokens...") + try: + token_response = exchange_code_for_token(auth_code, code_verifier, state=code_verifier) + print(" Token exchange successful!") + print(f" Access token: {token_response.get('access_token', 'N/A')[:50]}...") + print(f" Refresh token: {token_response.get('refresh_token', 'N/A')[:50]}...") + print(f" Expires in: {token_response.get('expires_in', 'N/A')} seconds") + save_tokens(token_response) + except Exception as e: + print(f"\n[ERROR] Token exchange failed: {e}") + print("\n Hint: Auth codes are single-use. If this failed, run --login to get a new one.") + return + + # Test API call (optional) + if args.test_api: + print("\n[5] Testing API call...") + access_token = token_response.get("access_token") + try: + api_response = test_api_call(access_token) + print(" API call successful!") + print(f" Model: {api_response.get('model', 'N/A')}") + content = api_response.get("content", []) + if content: + print(f" Response: {content[0].get('text', 'N/A')}") + except urllib.error.HTTPError as e: + print(f"\n[ERROR] API call failed: {e.code}") + print(f" Response: {e.read().decode('utf-8')}") + return + + print("\n" + "=" * 60) + print("Token obtained successfully!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/spec/chat-completion.md b/spec/chat-completion.md new file mode 100644 index 00000000..4185c723 --- /dev/null +++ b/spec/chat-completion.md @@ -0,0 +1,947 @@ +# OpenAI Chat Completions API Specification + +A comprehensive specification for implementing a client for the OpenAI Chat Completions API. + +--- + +## Table of Contents + +1. [Overview](#overview) +2. [Endpoint](#endpoint) +3. [Authentication](#authentication) +4. [Request Structure](#request-structure) +5. [Message Types](#message-types) +6. [Request Parameters](#request-parameters) +7. [Response Structure](#response-structure) +8. [Streaming](#streaming) +9. [Tool/Function Calling](#toolfunction-calling) +10. [Structured Outputs](#structured-outputs) +11. [Vision (Image Input)](#vision-image-input) +12. [Audio Input/Output](#audio-inputoutput) +13. [Web Search](#web-search) +14. [Predicted Outputs](#predicted-outputs) +15. [Error Handling](#error-handling) +16. [Rate Limiting](#rate-limiting) +17. [Model Reference](#model-reference) + +--- + +## Overview + +The Chat Completions API generates model responses from a list of messages comprising a conversation. It supports text, images, and audio as inputs and can generate text, audio, and tool calls as outputs. + +**Base URL:** `https://api.openai.com/v1` + +--- + +## Endpoint + +### Create Chat Completion + +``` +POST /chat/completions +``` + +Creates a model response for the given chat conversation. + +--- + +## Authentication + +All requests require an API key in the `Authorization` header: + +``` +Authorization: Bearer YOUR_API_KEY +``` + +Optional organization header: +``` +OpenAI-Organization: YOUR_ORG_ID +``` + +--- + +## Request Structure + +```json +{ + "model": "gpt-4o", + "messages": [...], + // ... additional parameters +} +``` + +--- + +## Message Types + +Messages are the core of the Chat Completions API. Each message has a `role` and `content`. + +### System Message +Sets context and behavioral instructions for the model. +```json +{ + "role": "system", + "content": "You are a helpful assistant." +} +``` + +### Developer Message (O-Series Models Only) +Used instead of system messages for reasoning models (o1, o3, o4-mini). +```json +{ + "role": "developer", + "content": "Instructions for the model behavior." +} +``` +**Note:** When using system message with o1/o3 models, it's treated as a developer message. Don't mix both in the same request. + +### User Message +Represents input from the user. + +**Text only:** +```json +{ + "role": "user", + "content": "What is the capital of France?" +} +``` + +**With images (multimodal):** +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "auto" // "low", "high", or "auto" + } + } + ] +} +``` + +**With audio:** +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "What is being said?"}, + { + "type": "input_audio", + "input_audio": { + "data": "", + "format": "wav" // or "mp3" + } + } + ] +} +``` + +### Assistant Message +Model-generated response or injected assistant context. +```json +{ + "role": "assistant", + "content": "The capital of France is Paris." +} +``` + +**With tool calls:** +```json +{ + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Paris\"}" + } + } + ] +} +``` + +### Tool Message +Result of a tool/function call. +```json +{ + "role": "tool", + "tool_call_id": "call_abc123", + "content": "{\"temperature\": 22, \"condition\": \"sunny\"}" +} +``` + +--- + +## Request Parameters + +### Required Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `model` | string | Model ID to use (e.g., `gpt-4o`, `gpt-4o-mini`, `o1`, `o3`) | +| `messages` | array | List of messages comprising the conversation | + +### Optional Parameters - Sampling & Generation + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `max_completion_tokens` | integer | null | Maximum tokens to generate (replaces deprecated `max_tokens`). Required for o-series models. | +| `max_tokens` | integer | null | **Deprecated.** Use `max_completion_tokens` instead. Not compatible with o-series models. | +| `temperature` | number | 1.0 | Sampling temperature (0-2). Higher = more random. **Not supported for reasoning models.** | +| `top_p` | number | 1.0 | Nucleus sampling parameter. **Not supported for reasoning models.** | +| `n` | integer | 1 | Number of completions to generate. | +| `stop` | string/array | null | Up to 4 sequences where the model stops generating. | +| `presence_penalty` | number | 0.0 | Penalizes tokens based on presence in text (-2.0 to 2.0). **Not supported for reasoning models.** | +| `frequency_penalty` | number | 0.0 | Penalizes tokens based on frequency in text (-2.0 to 2.0). **Not supported for reasoning models.** | +| `logit_bias` | object | null | Map of token IDs to bias values (-100 to 100). **Not supported for reasoning models.** | + +### Optional Parameters - Advanced + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `seed` | integer | null | For deterministic sampling (best effort). Monitor `system_fingerprint` for backend changes. | +| `logprobs` | boolean | false | Return log probabilities of output tokens. **Not supported for reasoning models.** | +| `top_logprobs` | integer | null | Number of most likely tokens to return (0-20). Requires `logprobs: true`. | +| `user` | string | null | Unique end-user identifier for abuse detection. | +| `stream` | boolean | false | Enable streaming responses via SSE. | +| `stream_options` | object | null | Streaming options: `{"include_usage": true}` to get token usage in final chunk. | + +### Optional Parameters - Response Format + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `response_format` | object | null | Controls output format. See [Structured Outputs](#structured-outputs). | + +### Optional Parameters - Tools & Functions + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `tools` | array | null | List of tools/functions the model may call. See [Tool Calling](#toolfunction-calling). | +| `tool_choice` | string/object | "auto" | Controls tool use: `"none"`, `"auto"`, `"required"`, or specific function. | +| `parallel_tool_calls` | boolean | true | Allow parallel function calls. Disable for gpt-4.1-nano-2025-04-14. | + +### Optional Parameters - Reasoning Models (O-Series) + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `reasoning_effort` | string | "medium" | Reasoning depth: `"minimal"` (gpt-5 only), `"low"`, `"medium"`, `"high"`. | + +### Optional Parameters - Multimodal + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `modalities` | array | ["text"] | Output types to generate: `["text"]`, `["text", "audio"]`. Audio requires audio-preview models. | +| `audio` | object | null | Audio output config: `{"voice": "alloy", "format": "wav"}`. Voices: `alloy`, `echo`, `shimmer`. Formats: `wav`, `mp3`, `pcm16` (streaming only). | + +### Optional Parameters - Service & Storage + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `service_tier` | string | "auto" | Processing tier: `"auto"`, `"default"`, `"flex"` (50% cheaper, slower), `"priority"` (lower latency, premium). | +| `store` | boolean | false | Store completion for model distillation/evals. | +| `metadata` | object | null | Up to 16 key-value pairs (keys: max 64 chars, values: max 512 chars). | + +### Optional Parameters - Web Search + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `web_search_options` | object | null | Enable web search for search-preview models. See [Web Search](#web-search). | + +### Optional Parameters - Predicted Outputs + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prediction` | object | null | Predicted output for faster responses. See [Predicted Outputs](#predicted-outputs). | + +--- + +## Response Structure + +### Non-Streaming Response + +```json +{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1702685778, + "model": "gpt-4o-2024-08-06", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The capital of France is Paris.", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 7, + "total_tokens": 20, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default" +} +``` + +### Response Fields + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique identifier for the completion | +| `object` | string | Always `"chat.completion"` | +| `created` | integer | Unix timestamp of creation | +| `model` | string | Model used for completion | +| `system_fingerprint` | string | Backend configuration fingerprint (for determinism tracking with `seed`) | +| `choices` | array | List of completion choices | +| `usage` | object | Token usage statistics | +| `service_tier` | string | Service tier used (may differ from requested) | + +### Choice Object + +| Field | Type | Description | +|-------|------|-------------| +| `index` | integer | Index of this choice | +| `message` | object | The generated message | +| `finish_reason` | string | Why generation stopped | +| `logprobs` | object/null | Log probability information (if requested) | + +### Finish Reasons + +| Value | Description | +|-------|-------------| +| `stop` | Natural stop or stop sequence reached | +| `length` | Max token limit reached | +| `tool_calls` | Model requested tool execution | +| `content_filter` | Content filtered by safety systems | +| `function_call` | **Deprecated.** Function call requested | + +### Usage Object + +| Field | Type | Description | +|-------|------|-------------| +| `prompt_tokens` | integer | Tokens in the prompt | +| `completion_tokens` | integer | Tokens generated | +| `total_tokens` | integer | Total tokens used | +| `prompt_tokens_details.cached_tokens` | integer | Cached prompt tokens | +| `prompt_tokens_details.audio_tokens` | integer | Audio input tokens | +| `completion_tokens_details.reasoning_tokens` | integer | Hidden reasoning tokens (o-series) | +| `completion_tokens_details.audio_tokens` | integer | Audio output tokens | +| `completion_tokens_details.accepted_prediction_tokens` | integer | Accepted predicted tokens | +| `completion_tokens_details.rejected_prediction_tokens` | integer | Rejected predicted tokens | + +--- + +## Streaming + +Enable streaming with `stream: true`. Responses are sent as Server-Sent Events (SSE). + +### Request +```json +{ + "model": "gpt-4o", + "messages": [...], + "stream": true, + "stream_options": {"include_usage": true} +} +``` + +### Response Headers +``` +Content-Type: text/event-stream +Transfer-Encoding: chunked +``` + +### Chunk Format + +```json +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}} + +data: [DONE] +``` + +### Key Differences from Non-Streaming + +| Aspect | Non-Streaming | Streaming | +|--------|---------------|-----------| +| Object type | `chat.completion` | `chat.completion.chunk` | +| Content field | `message.content` | `delta.content` | +| Role field | `message.role` | `delta.role` (first chunk only) | +| Usage | Always included | Only with `stream_options.include_usage: true`, in final chunk | +| Termination | Single response | `data: [DONE]` event | + +--- + +## Tool/Function Calling + +Tools allow the model to call external functions. The API returns tool call requests; execution is your responsibility. + +### Defining Tools + +```json +{ + "model": "gpt-4o", + "messages": [...], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "strict": true, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City and state, e.g., San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location", "unit"], + "additionalProperties": false + } + } + } + ], + "tool_choice": "auto" +} +``` + +### Tool Choice Options + +| Value | Description | +|-------|-------------| +| `"none"` | Don't call any tools | +| `"auto"` | Model decides whether to call tools | +| `"required"` | Must call at least one tool | +| `{"type": "function", "function": {"name": "my_func"}}` | Force specific function | + +### Strict Mode + +Setting `strict: true` ensures function calls reliably adhere to the schema. Requirements: +- `additionalProperties: false` for each object +- All fields in `properties` must be in `required` +- Optional fields: add `null` as a type option + +### Tool Call Response + +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\":\"Paris\",\"unit\":\"celsius\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + }] +} +``` + +### Providing Tool Results + +```json +{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"}, + { + "role": "assistant", + "content": null, + "tool_calls": [{"id": "call_abc123", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"Paris\",\"unit\":\"celsius\"}"}}] + }, + { + "role": "tool", + "tool_call_id": "call_abc123", + "content": "{\"temperature\": 22, \"condition\": \"sunny\"}" + } + ] +} +``` + +--- + +## Structured Outputs + +Force the model to output valid JSON matching a schema. + +### JSON Mode (Basic) + +```json +{ + "model": "gpt-4o", + "messages": [...], + "response_format": {"type": "json_object"} +} +``` + +### JSON Schema Mode (Strict) + +```json +{ + "model": "gpt-4o", + "messages": [...], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "person_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": ["string", "null"]} + }, + "required": ["name", "age", "email"], + "additionalProperties": false + } + } + } +} +``` + +### Schema Requirements + +- Root cannot be `anyOf` type +- All fields must be in `required` +- `additionalProperties: false` required +- Some JSON Schema keywords not supported (e.g., `format` for dates) + +### Refusals + +When the model refuses to generate structured output: +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": null, + "refusal": "I cannot provide information about that topic." + } + }] +} +``` + +--- + +## Vision (Image Input) + +### Image URL + +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "auto" + } + } + ] +} +``` + +### Base64 Image + +```json +{ + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,{base64_encoded_data}", + "detail": "high" + } +} +``` + +### Detail Levels + +| Value | Description | Token Cost | +|-------|-------------|------------| +| `low` | Fixed low resolution | Base cost (e.g., 85 tokens) | +| `high` | Full resolution, tiled | Variable based on dimensions | +| `auto` | Model decides | Variable | + +### Token Calculation (High Detail) + +1. Scale to fit 2048x2048 (maintaining aspect ratio) +2. Scale shortest side to 768px +3. Count 512px tiles +4. Cost = (tiles × 170) + 85 tokens (approximate) + +--- + +## Audio Input/Output + +Requires `gpt-4o-audio-preview` or `gpt-4o-mini-audio-preview` models. + +### Request with Audio Output + +```json +{ + "model": "gpt-4o-audio-preview", + "modalities": ["text", "audio"], + "audio": { + "voice": "alloy", + "format": "wav" + }, + "messages": [...] +} +``` + +### Audio Configuration + +| Field | Options | Description | +|-------|---------|-------------| +| `voice` | `alloy`, `echo`, `shimmer` | Voice for audio output | +| `format` | `wav`, `mp3`, `pcm16` | Output format. `pcm16` only for streaming. | + +### Audio Input Message + +```json +{ + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": "", + "format": "wav" + } + } + ] +} +``` + +**Max audio file size:** 20 MB + +--- + +## Web Search + +Available with search-preview models: `gpt-4o-search-preview`, `gpt-4o-mini-search-preview`, `gpt-5-search-api`. + +### Basic Request + +```json +{ + "model": "gpt-4o-search-preview", + "web_search_options": {}, + "messages": [ + {"role": "user", "content": "What happened in tech news today?"} + ] +} +``` + +### With Options + +```json +{ + "model": "gpt-4o-search-preview", + "web_search_options": { + "search_context_size": "medium", + "user_location": { + "type": "approximate", + "approximate": { + "country": "US", + "city": "San Francisco", + "region": "California", + "timezone": "America/Los_Angeles" + } + } + }, + "messages": [...] +} +``` + +### Search Context Size + +| Value | Description | +|-------|-------------| +| `low` | Faster, cheaper, less accurate | +| `medium` | Default balance | +| `high` | More thorough, slower, more expensive | + +### Response Annotations + +```json +{ + "choices": [{ + "message": { + "content": "According to recent news [1]...", + "annotations": [ + { + "type": "url_citation", + "start_index": 28, + "end_index": 31, + "url": "https://example.com/article", + "title": "Article Title" + } + ] + } + }] +} +``` + +--- + +## Predicted Outputs + +Reduce latency when most output is known ahead of time. Supported on GPT-4o and GPT-4o-mini. + +### Request + +```json +{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Update the email address to john@new.com"}, + {"role": "user", "content": "Current data: {\"name\": \"John\", \"email\": \"john@old.com\"}"} + ], + "prediction": { + "type": "content", + "content": "{\"name\": \"John\", \"email\": \"john@old.com\"}" + } +} +``` + +### Limitations + +- Text-only (no audio/images) +- Not compatible with: `n > 1`, `logprobs`, `presence_penalty > 0` +- Rejected tokens are still billed +- May increase costs if predictions are poor matches + +### Tracking Prediction Usage + +Check `usage.completion_tokens_details`: +- `accepted_prediction_tokens`: Tokens that matched +- `rejected_prediction_tokens`: Tokens that didn't match (still billed) + +--- + +## Error Handling + +### Error Response Format + +```json +{ + "error": { + "message": "Invalid API key", + "type": "invalid_request_error", + "param": null, + "code": "invalid_api_key" + } +} +``` + +### HTTP Status Codes + +| Code | Description | Action | +|------|-------------|--------| +| 400 | Bad Request | Fix request parameters | +| 401 | Unauthorized | Check API key | +| 403 | Forbidden | Check permissions/organization | +| 404 | Not Found | Check model ID/endpoint | +| 429 | Rate Limited | Implement exponential backoff | +| 500 | Server Error | Retry with backoff | +| 502 | Bad Gateway | Retry with backoff | +| 503 | Service Unavailable | Retry with backoff | + +### Common Error Codes + +| Code | Description | +|------|-------------| +| `invalid_api_key` | API key is invalid | +| `insufficient_quota` | Quota exceeded | +| `rate_limit_exceeded` | Too many requests | +| `model_not_found` | Model doesn't exist or no access | +| `context_length_exceeded` | Input too long | +| `invalid_request_error` | Malformed request | +| `content_policy_violation` | Content filtered | + +### Retry Strategy + +``` +1. Wait: min(2^attempt * 1000ms, 60000ms) + random_jitter +2. Max attempts: 5 +3. On 429: Check x-ratelimit-reset-* headers +4. On 5xx: Always retry +``` + +--- + +## Rate Limiting + +### Rate Limit Types + +| Type | Description | +|------|-------------| +| RPM | Requests per minute | +| RPD | Requests per day | +| TPM | Tokens per minute | +| TPD | Tokens per day | +| IPM | Images per minute | + +### Response Headers + +| Header | Description | +|--------|-------------| +| `x-ratelimit-limit-requests` | Max requests allowed | +| `x-ratelimit-limit-tokens` | Max tokens allowed | +| `x-ratelimit-remaining-requests` | Requests remaining | +| `x-ratelimit-remaining-tokens` | Tokens remaining | +| `x-ratelimit-reset-requests` | Time until request limit resets | +| `x-ratelimit-reset-tokens` | Time until token limit resets | + +### Token Refill + +Tokens refill on a rolling 60-second window, not all at once. + +--- + +## Model Reference + +### GPT-4o Series + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `gpt-4o` | 128k | 16k | Latest GPT-4o | +| `gpt-4o-2024-08-06` | 128k | 16k | Structured outputs support | +| `gpt-4o-mini` | 128k | 16k | Faster, cheaper | +| `gpt-4o-audio-preview` | 128k | 16k | Audio I/O support | +| `gpt-4o-search-preview` | 128k | 16k | Web search | + +### GPT-4.1 Series + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `gpt-4.1` | 1M | ~32k | Extended context | +| `gpt-4.1-mini` | 1M | ~32k | Smaller, faster | +| `gpt-4.1-nano` | 1M | ~32k | Smallest | + +### O-Series (Reasoning Models) + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `o1` | 200k | 100k | Advanced reasoning | +| `o1-mini` | 128k | 65k | Faster reasoning | +| `o1-preview` | 128k | 32k | Preview version | +| `o3` | 200k | 100k | Latest reasoning | +| `o3-mini` | 200k | 100k | Smaller reasoning | +| `o4-mini` | 200k | 100k | Newest mini | + +### GPT-5 Series + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `gpt-5` | 128k+ | Varies | Flagship model | +| `gpt-5-mini` | 128k | 16k | Smaller variant | +| `gpt-5-search-api` | 128k | 16k | Web search | + +### Parameter Support by Model Type + +| Parameter | GPT-4o/4.1 | O-Series | Notes | +|-----------|------------|----------|-------| +| `temperature` | ✅ | ❌ | | +| `top_p` | ✅ | ❌ | | +| `presence_penalty` | ✅ | ❌ | | +| `frequency_penalty` | ✅ | ❌ | | +| `logprobs` | ✅ | ❌ | | +| `logit_bias` | ✅ | ❌ | | +| `max_tokens` | ✅ | ❌ | Deprecated | +| `max_completion_tokens` | ✅ | ✅ | Preferred | +| `reasoning_effort` | ❌ | ✅ | | +| `developer` role | ❌ | ✅ | Use `system` for GPT models | + +--- + +## Implementation Checklist + +### Core Features +- [ ] Basic text completion +- [ ] Multi-turn conversations +- [ ] System/developer messages +- [ ] Streaming with SSE parsing +- [ ] Token usage tracking + +### Advanced Features +- [ ] Tool/function calling +- [ ] Structured outputs (JSON mode & JSON Schema) +- [ ] Vision (image input) +- [ ] Audio input/output +- [ ] Web search integration +- [ ] Predicted outputs + +### Error Handling +- [ ] HTTP error responses +- [ ] Rate limit headers parsing +- [ ] Exponential backoff retry +- [ ] Content filter detection + +### Model-Specific +- [ ] Reasoning model parameters +- [ ] Service tier selection +- [ ] Seed for reproducibility + +--- + +## Sources + +- [Chat Completions API Reference](https://platform.openai.com/docs/api-reference/chat/) +- [Chat Completions Guide](https://platform.openai.com/docs/guides/chat-completions) +- [Function Calling Guide](https://platform.openai.com/docs/guides/function-calling) +- [Structured Outputs Guide](https://platform.openai.com/docs/guides/structured-outputs) +- [Images and Vision Guide](https://platform.openai.com/docs/guides/images-vision) +- [Audio and Speech Guide](https://platform.openai.com/docs/guides/audio) +- [Web Search Guide](https://platform.openai.com/docs/guides/tools-web-search) +- [Reasoning Models Guide](https://platform.openai.com/docs/guides/reasoning) +- [Error Codes Guide](https://platform.openai.com/docs/guides/error-codes) +- [Rate Limits Guide](https://platform.openai.com/docs/guides/rate-limits) +- [Models Overview](https://platform.openai.com/docs/models) +- [OpenAI Cookbook - Using Logprobs](https://cookbook.openai.com/examples/using_logprobs) +- [OpenAI Cookbook - Reproducible Outputs](https://cookbook.openai.com/examples/reproducible_outputs_with_the_seed_parameter) + diff --git a/spec/chat-completion2.md b/spec/chat-completion2.md new file mode 100644 index 00000000..30d9c1f6 --- /dev/null +++ b/spec/chat-completion2.md @@ -0,0 +1,289 @@ +# OpenAI Chat Completions API + +This document specifies the **Chat Completions** REST API under `POST /v1/chat/completions` (plus stored-completion CRUD endpoints), including request/response shapes, streaming (SSE), and related features (tools/function calling, structured outputs, multimodal inputs, audio, and web search). + +--- + +## 1) Positioning / when to use Chat Completions +OpenAI’s docs now emphasize the **Responses API** for many new integrations, but **Chat Completions** remains a supported API surface and is still the correct choice if you specifically want `/v1/chat/completions` semantics and object shapes. + +--- + +## 2) Base URL, auth, headers, and versioning + +### 2.1 Base URL +- Base URL: `https://api.openai.com/v1` + +### 2.2 Authentication +- Header auth: `Authorization: Bearer ` + +### 2.3 Core request headers +- `Content-Type: application/json` for JSON requests. +- Optional account routing headers: + - `OpenAI-Organization: ` + - `OpenAI-Project: ` + +### 2.4 Debugging / observability headers +- Responses include a request identifier header you should log (notably `x-request-id`) to correlate failures and support requests. +- Responses include timing information like `openai-processing-ms`. + +### 2.5 Backward compatibility expectations +- OpenAI publishes a backward-compatibility policy for the API; clients should be resilient to additive fields and new enum values. + +**Client requirement:** Implement JSON decoding as forward-compatible: ignore unknown fields; do not exhaustively match enums without a fallback. + +--- + +## 3) Error model, retries, and rate limits + +### 3.1 Error response shape (high-level) +Errors are returned with a top-level `error` object (containing fields like `message`, `type`, and sometimes `param`/`code`), alongside standard HTTP status codes. + +### 3.2 Recommended retry policy (client-side) +- Retry only **transient** failures (commonly `429`, `500`, `503`) using exponential backoff + jitter; do **not** blindly retry `400`/`401`/`403`/`404` because they’re usually permanent for that request. +- Always log `x-request-id` (and any client request id you add) for diagnostics. + +### 3.3 Rate limit headers +OpenAI documents rate limits and returns `x-ratelimit-*` headers (covering request and token budgets with limit/remaining/reset patterns). Your client should parse and surface these for adaptive throttling. + +--- + +## 4) Endpoint inventory (Chat Completions) + +### 4.1 Create (optionally stream) +- `POST /v1/chat/completions` — generate an assistant response; supports streaming via SSE; can optionally store the completion. + +### 4.2 Stored completion retrieval & management (requires `store: true` on create) +- `GET /v1/chat/completions/{completion_id}` — retrieve a stored completion. +- `GET /v1/chat/completions` — list stored completions (pagination + filters). +- `GET /v1/chat/completions/{completion_id}/messages` — list messages from a stored completion (pagination). +- `POST /v1/chat/completions/{completion_id}` — update metadata on a stored completion. +- `DELETE /v1/chat/completions/{completion_id}` — delete a stored completion. + +--- + +## 5) `POST /v1/chat/completions` — Create + +### 5.1 Primary use cases +- Standard “chat” generation: model responds to a conversation history you provide in `messages`. +- Tool/function calling: model asks your client to call functions; client executes and feeds results back. +- Streaming UI: token-by-token deltas over SSE. +- Structured outputs: force JSON or schema-conformant JSON. +- Multimodal input: images and audio in the conversation (model-dependent). +- Web search models: model performs a web search and returns citations/annotations. +- Persisting outputs for later retrieval: set `store: true` and then use stored completion endpoints. + +--- + +### 5.2 Request body (top-level fields) +**Required** +- `model: string` — model identifier. +- `messages: array` — conversation inputs (text and/or content parts). + +**Common optional fields (sampling / stopping / token limits)** +- `temperature?: number` — sampling temperature. +- `top_p?: number` — nucleus sampling. +- `n?: integer` — number of choices to generate. +- `stop?: string | string[] | null` — stop sequences; docs note model-specific support limitations (e.g., not supported by some “o” models). +- `max_completion_tokens?: integer | null` — cap for generated tokens (including reasoning tokens where applicable). +- `max_tokens?: integer | null` — deprecated; docs note incompatibility with some newer model families. +- `presence_penalty?: number | null`, `frequency_penalty?: number | null` + +**Logprobs** +- `logprobs?: boolean | null` — request logprobs in output. +- `top_logprobs?: integer | null` — number of top tokens to return (when `logprobs` is true). + +**Streaming** +- `stream?: boolean | null` — enable SSE. +- `stream_options?: object | null` — stream options; docs show `include_usage` behavior. + +**Tools / function calling** +- `tools?: array` — tool definitions (notably function tools). +- `tool_choice?: "none" | "auto" | "required" | object` — tool selection policy (including forcing a specific tool). +- `parallel_tool_calls?: boolean` — whether model may emit multiple tool calls in one turn. +- `functions` / `function_call` — deprecated tool/function fields. + +**Structured outputs** +- `response_format?: object` — JSON mode and schema mode. + +**Multimodal output** +- `modalities?: string[]` — request output modalities (model-dependent). +- `audio?: object | null` — audio output settings when requesting audio modality. + +**Other** +- `reasoning_effort?: string` — reasoning control for supported models (docs note model-specific constraints). +- `verbosity?: string` — output verbosity control for supported models. +- `prediction?: object` — “predicted output” optimization payload. +- `web_search_options?: object` — used with web-search models. +- `service_tier?: string` — service tier selection. +- `seed?: integer | null` — deprecated determinism hint. +- `logit_bias?: object | null` — token-level biasing. + +**Storing & metadata** +- `store?: boolean | null` — enable stored-completion retrieval; docs warn that some large inputs (e.g., large images) may be dropped when storing. +- `metadata?: object` — user-defined key/value metadata for stored items. +- `user?: string` — deprecated in favor of newer identifiers (docs reference `safety_identifier` and `prompt_cache_key`). + +--- + +### 5.3 `messages[]` — conversation schema (client-facing) +Chat Completions are **stateless**: you send the prior turns each request (or a summarized state). + +**Message object (conceptual)** +- `role: string` — typical roles include `system`, `user`, `assistant`, and tool-related roles (exact accepted roles depend on the API mode and model). +- `content: string | array` — either plain text or a list of typed content parts for multimodal inputs. + +**Content parts (multimodal)** +- Image input uses a content-part pattern documented in the Images/Vision guide for Chat Completions. +- Audio input uses a content-part pattern documented in the Audio guide for Chat Completions. + +**Client requirement:** Model capabilities differ; your client should not hard-code “text only” assumptions, and should treat message `content` as a tagged union. + +--- + +## 6) `POST /v1/chat/completions` — Response + +### 6.1 Non-streaming response object +A successful create returns a **chat completion object** containing identifiers and an array of `choices`. + +**Top-level (commonly present)** +- `id: string` +- `object: "chat.completion"` +- `created: integer` (unix seconds) +- `model: string` +- `choices: array` +- `usage: object` (token accounting) + +**Choice object (conceptual)** +- `index: integer` +- `message: { role, content, ... }` +- `finish_reason: string | null` +- `logprobs: object | null` (if requested/supported) + +**Client requirement:** Do not assume a single choice; handle `n > 1` by returning multiple candidate messages. + +--- + +## 7) Streaming (SSE) — `stream: true` + +### 7.1 Transport / framing +Chat Completions streaming uses **Server-Sent Events** (SSE) delivering **chat completion chunk** objects incrementally. + +The OpenAI API reference for legacy streaming explicitly describes “data-only SSE” messages and a terminal `data: [DONE]` sentinel; Chat Completions streaming is documented in terms of chunk objects and deltas. For robust clients, support both: end on connection close and/or `[DONE]` sentinel if present. + +### 7.2 Chunk object shape +Each event contains a `chat.completion.chunk` object with `choices[].delta` carrying incremental text/tool-call deltas. + +Key fields: +- `object: "chat.completion.chunk"` +- `id`, `created`, `model` +- `choices[]: { index, delta, finish_reason? }` +- Optional `usage` when using `stream_options` to include usage. + +### 7.3 Delta accumulation rules (client requirement) +- Concatenate `choices[i].delta.content` fragments in-order. +- Treat tool-call deltas as structured fragments that must be assembled into a complete tool call before execution. +- Terminate a choice when `finish_reason` becomes non-null. + +--- + +## 8) Tools / function calling (Chat Completions) + +### 8.1 Use case +Let the model request structured, executable actions (API calls, database queries, etc.) by emitting tool calls; your client executes them and returns tool outputs back into the conversation. + +### 8.2 Request fields (high-level) +- Provide available tools in `tools`. +- Control selection with `tool_choice`. +- Allow or disallow multiple calls with `parallel_tool_calls`. + +### 8.3 Response behavior (high-level) +- Assistant messages may include tool-call descriptors instead of (or in addition to) normal text content. +- Client must translate tool calls into actual executions and then append tool results as subsequent messages and call the API again. + +--- + +## 9) Structured outputs (`response_format`) + +### 9.1 Use case +- Enforce machine-parseable JSON output (basic JSON mode) or schema-conformant JSON (structured outputs) for deterministic integration with downstream code. + +### 9.2 Modes (documented) +- JSON mode via a `response_format` object (older JSON mode). +- Schema-based structured outputs via a `response_format` object (JSON Schema). + +**Client requirement:** Treat `response_format` as a tagged union; do not assume only one sub-variant. + +--- + +## 10) Multimodal inputs (images + audio) and audio outputs + +### 10.1 Image inputs (vision) +- Chat Completions supports image input via content parts as documented in the Images/Vision guide when `api-mode=chat`. +- Client should support: + - Remote URLs + - Base64 “data:” URLs + - Any per-image options described in the guide (e.g., detail level), model-dependent. + +### 10.2 Audio inputs and outputs +- Audio input: send base64-encoded audio as typed content parts (guide documents the required fields). +- Audio output: request audio via `modalities` + `audio` settings (voice/format), then decode audio bytes from the response. + +**Client requirement:** Audio and image support are model-dependent; surface capability errors clearly (do not silently fall back unless the caller asked you to). + +--- + +## 11) Web search (Chat Completions) + +### 11.1 Use case +For web-search-capable models, allow the model to retrieve information from the web before responding, returning citation annotations suitable for UI rendering and attribution. + +### 11.2 Request fields +- Use `web_search_options` alongside a web-search model. + +### 11.3 Response fields (citations / annotations) +- The Web Search tool guide documents citation annotations (including URL citations) and how to render them. + +**Client requirement:** Preserve and expose annotations separately from text so callers can render citations reliably even if the visible text format changes. + +--- + +## 12) Stored completions (`store: true`) — retrieval, listing, message listing, metadata update, delete + +### 12.1 Use case +- Persist responses for later inspection, evaluation, or building UIs that revisit prior outputs without keeping your own transcript store. + +### 12.2 Key behaviors +- `store: true` on create enables subsequent retrieval/listing endpoints. +- Update endpoint is for metadata updates on stored items. +- Messages endpoint returns the stored message list with pagination controls. + +### 12.3 Pagination (high-level) +List endpoints support typical cursor pagination parameters like `after`, plus `limit` and `order`. + +--- + +## 13) Practical client requirements checklist + +### 13.1 HTTP layer +- Keep-alive connections; configurable request timeout; request body size limits; gzip/deflate handling as supported by your HTTP library. + +### 13.2 JSON decoding +- Tolerate unknown fields and enum expansions; treat tagged unions (`content`, `response_format`, tools) as extensible. + +### 13.3 Streaming +- Implement a correct SSE parser: + - parse `data:` lines into JSON chunks + - assemble `delta` fragments per `choice.index` + - handle both “connection close” and `[DONE]` style termination defensively. + +### 13.4 Tool calling +- Support multiple tool calls (including parallel), and tool-call assembly in both non-streaming and streaming modes. + +### 13.5 Errors & rate limits +- Parse error objects; map to typed errors; classify retryable vs permanent. +- Parse `x-ratelimit-*` headers for adaptive throttling and caller visibility. + +--- + +If you want, I can convert this into a single “contract” section (request/response JSON Schema-like tables for every field and nested object) purely as documentation—still no code.