From 8a75a6ed0338bf8160cc19e3955fef70a7498d68 Mon Sep 17 00:00:00 2001 From: lambda Date: Thu, 4 Dec 2025 02:09:04 +0530 Subject: [PATCH 01/11] feat(cli): add OAuth authentication support Replace the `keys` command with a unified `auth` system supporting multiple authentication methods per provider (OAuth and API keys). --- Cargo.lock | 707 +++++++++++++++++++++++- crates/rullm-cli/Cargo.toml | 10 + crates/rullm-cli/src/api_keys.rs | 87 --- crates/rullm-cli/src/args.rs | 16 +- crates/rullm-cli/src/auth.rs | 337 +++++++++++ crates/rullm-cli/src/client.rs | 13 +- crates/rullm-cli/src/commands/auth.rs | 276 +++++++++ crates/rullm-cli/src/commands/info.rs | 5 +- crates/rullm-cli/src/commands/keys.rs | 118 ---- crates/rullm-cli/src/commands/mod.rs | 20 +- crates/rullm-cli/src/constants.rs | 1 - crates/rullm-cli/src/main.rs | 7 +- crates/rullm-cli/src/oauth/anthropic.rs | 252 +++++++++ crates/rullm-cli/src/oauth/mod.rs | 12 + crates/rullm-cli/src/oauth/openai.rs | 303 ++++++++++ crates/rullm-cli/src/oauth/pkce.rs | 86 +++ crates/rullm-cli/src/oauth/server.rs | 226 ++++++++ 17 files changed, 2210 insertions(+), 266 deletions(-) delete mode 100644 crates/rullm-cli/src/api_keys.rs create mode 100644 crates/rullm-cli/src/auth.rs create mode 100644 crates/rullm-cli/src/commands/auth.rs delete mode 100644 crates/rullm-cli/src/commands/keys.rs create mode 100644 crates/rullm-cli/src/oauth/anthropic.rs create mode 100644 crates/rullm-cli/src/oauth/mod.rs create mode 100644 crates/rullm-cli/src/oauth/openai.rs create mode 100644 crates/rullm-cli/src/oauth/pkce.rs create mode 100644 crates/rullm-cli/src/oauth/server.rs diff --git a/Cargo.lock b/Cargo.lock index c91c6f95..185c6b1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,12 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "atty" version = "0.2.14" @@ -180,6 +186,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -195,6 +207,15 @@ dependencies = [ "serde", ] +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -216,6 +237,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-if" version = "1.0.1" @@ -234,7 +261,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -295,6 +322,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -305,12 +342,31 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crossterm" version = "0.28.1" @@ -337,6 +393,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -526,6 +602,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -566,7 +652,26 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.4.0", "indexmap", "slab", "tokio", @@ -595,6 +700,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "home" version = "0.5.11" @@ -615,6 +726,16 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -622,7 +743,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.4.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -648,9 +792,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -662,6 +806,44 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http 1.4.0", + "hyper 1.8.1", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -669,12 +851,54 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.32", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.8.1", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "hyper 1.8.1", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2 0.6.0", + "system-configuration", + "tokio", + "tower-service", + "tracing", + "windows-registry", +] + [[package]] name = "iana-time-zone" version = "0.1.63" @@ -833,6 +1057,16 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +[[package]] +name = "iri-string" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f867b9d1d896b67beb18518eda36fdb77a32ea590de864f1325b294a6d14397" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "is_executable" version = "1.0.4" @@ -863,6 +1097,28 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "js-sys" version = "0.3.77" @@ -988,6 +1244,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -1016,6 +1278,31 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c2599ce0ec54857b29ce62166b0ed9b4f6f1a70ccc9a71165b6154caca8c05" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" +dependencies = [ + "bitflags 2.9.1", + "objc2", +] + [[package]] name = "object" version = "0.36.7" @@ -1215,8 +1502,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", ] [[package]] @@ -1226,7 +1523,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -1238,6 +1545,15 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.3", +] + [[package]] name = "redox_syscall" version = "0.5.17" @@ -1317,16 +1633,16 @@ version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", - "hyper-tls", + "h2 0.3.27", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.32", + "hyper-tls 0.5.0", "ipnet", "js-sys", "log", @@ -1339,7 +1655,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "system-configuration", "tokio", "tokio-native-tls", @@ -1353,22 +1669,82 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "h2 0.4.12", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.8.1", + "hyper-rustls", + "hyper-tls 0.6.0", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.2", + "tokio", + "tokio-native-tls", + "tower 0.5.2", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rullm-cli" version = "0.1.0" dependencies = [ "anyhow", "atty", + "base64 0.22.1", "chrono", "clap", "clap_complete", "etcetera", "futures", + "hex", "owo-colors", + "rand 0.9.2", "reedline", + "reqwest 0.12.24", "rullm-core", "serde", "serde_json", + "serde_urlencoded", + "sha2", "strum 0.27.2", "strum_macros 0.27.2", "tempfile", @@ -1376,6 +1752,8 @@ dependencies = [ "toml", "tracing", "tracing-subscriber", + "urlencoding", + "webbrowser", ] [[package]] @@ -1390,8 +1768,8 @@ dependencies = [ "log", "metrics", "once_cell", - "rand", - "reqwest", + "rand 0.8.5", + "reqwest 0.11.27", "serde", "serde_json", "strum 0.27.2", @@ -1401,7 +1779,7 @@ dependencies = [ "tokio", "tokio-test", "toml", - "tower", + "tower 0.4.13", "tower-service", "tracing-subscriber", ] @@ -1438,13 +1816,46 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "rustls" +version = "0.23.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", +] + +[[package]] +name = "rustls-pki-types" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", ] [[package]] @@ -1459,6 +1870,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.27" @@ -1481,7 +1901,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.9.1", - "core-foundation", + "core-foundation 0.9.4", "core-foundation-sys", "libc", "security-framework-sys", @@ -1550,6 +1970,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1685,6 +2116,12 @@ dependencies = [ "syn", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.104" @@ -1702,6 +2139,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + [[package]] name = "synstructure" version = "0.13.2" @@ -1720,7 +2166,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -1847,6 +2293,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -1942,6 +2398,39 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 1.0.2", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" +dependencies = [ + "bitflags 2.9.1", + "bytes", + "futures-util", + "http 1.4.0", + "http-body 1.0.1", + "iri-string", + "pin-project-lite", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -2023,6 +2512,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -2041,6 +2536,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.4" @@ -2052,6 +2553,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -2091,6 +2598,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2209,6 +2726,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f1243ef785213e3a32fa0396093424a3a6ea566f9948497e5a2309261a4c97" +dependencies = [ + "core-foundation 0.10.1", + "jni", + "log", + "ndk-context", + "objc2", + "objc2-foundation", + "url", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2225,6 +2758,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.60.2", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2239,9 +2781,9 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", - "windows-result", - "windows-strings", + "windows-link 0.1.3", + "windows-result 0.3.4", + "windows-strings 0.4.2", ] [[package]] @@ -2272,13 +2814,39 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link 0.2.1", + "windows-result 0.4.1", + "windows-strings 0.5.1", +] + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link 0.2.1", ] [[package]] @@ -2287,7 +2855,25 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link 0.2.1", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", ] [[package]] @@ -2326,6 +2912,21 @@ dependencies = [ "windows-targets 0.53.3", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -2363,7 +2964,7 @@ version = "0.53.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe6031c4041849d7c496a8ded650796e7b6ecc19df1a431c1a363342e5dc91" dependencies = [ - "windows-link", + "windows-link 0.1.3", "windows_aarch64_gnullvm 0.53.0", "windows_aarch64_msvc 0.53.0", "windows_i686_gnu 0.53.0", @@ -2374,6 +2975,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -2392,6 +2999,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -2410,6 +3023,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -2440,6 +3059,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -2458,6 +3083,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -2476,6 +3107,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -2494,6 +3131,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -2611,6 +3254,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.2" diff --git a/crates/rullm-cli/Cargo.toml b/crates/rullm-cli/Cargo.toml index a3606882..90bcc976 100644 --- a/crates/rullm-cli/Cargo.toml +++ b/crates/rullm-cli/Cargo.toml @@ -32,6 +32,16 @@ chrono.workspace = true reedline.workspace = true tempfile.workspace = true +# OAuth dependencies +sha2 = "0.10" +rand = "0.9" +base64 = "0.22" +webbrowser = "1.0" +reqwest = { version = "0.12", features = ["json"] } +serde_urlencoded = "0.7" +urlencoding = "2.1" +hex = "0.4" + [dev-dependencies] tempfile.workspace = true diff --git a/crates/rullm-cli/src/api_keys.rs b/crates/rullm-cli/src/api_keys.rs deleted file mode 100644 index f368bdeb..00000000 --- a/crates/rullm-cli/src/api_keys.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::provider::Provider; -use rullm_core::error::LlmError; -use serde::{Deserialize, Serialize}; -use std::path::Path; - -#[derive(Default, Deserialize, Serialize, Debug, Clone)] -pub struct ApiKeys { - pub openai_api_key: Option, - pub groq_api_key: Option, - pub openrouter_api_key: Option, - pub anthropic_api_key: Option, - pub google_ai_api_key: Option, -} - -impl ApiKeys { - /// Load API keys from a TOML file - pub fn load_from_file>(path: P) -> Result { - let path = path.as_ref(); - - if !path.exists() { - return Ok(Self::default()); - } - - let content = std::fs::read_to_string(path) - .map_err(|e| LlmError::validation(format!("Failed to read API keys config: {e}")))?; - - // Handle empty files gracefully - if content.trim().is_empty() { - return Ok(Self::default()); - } - - toml::from_str(&content) - .map_err(|e| LlmError::validation(format!("Failed to parse API keys config: {e}"))) - } - - /// Save API keys to a TOML file - pub fn save_to_file>(&self, path: P) -> Result<(), LlmError> { - let path = path.as_ref(); - - // Create directory if it doesn't exist - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).map_err(|e| { - LlmError::validation(format!("Failed to create data directory: {e}")) - })?; - } - - let content = toml::to_string_pretty(self).map_err(|e| { - LlmError::validation(format!("Failed to serialize API keys config: {e}")) - })?; - - std::fs::write(path, content) - .map_err(|e| LlmError::validation(format!("Failed to write API keys config: {e}"))) - } - - pub fn get_api_key(provider: &Provider, api_keys: &ApiKeys) -> Option { - let key = match provider { - Provider::OpenAI => api_keys.openai_api_key.as_ref(), - Provider::Groq => api_keys.groq_api_key.as_ref(), - Provider::OpenRouter => api_keys.openrouter_api_key.as_ref(), - Provider::Anthropic => api_keys.anthropic_api_key.as_ref(), - Provider::Google => api_keys.google_ai_api_key.as_ref(), - }; - - key.cloned() - .or_else(|| std::env::var(provider.env_key()).ok()) - } - - pub fn set_api_key_for_provider(provider: &Provider, api_keys: &mut ApiKeys, key: &str) { - match provider { - Provider::OpenAI => api_keys.openai_api_key = Some(key.to_string()), - Provider::Groq => api_keys.groq_api_key = Some(key.to_string()), - Provider::OpenRouter => api_keys.openrouter_api_key = Some(key.to_string()), - Provider::Anthropic => api_keys.anthropic_api_key = Some(key.to_string()), - Provider::Google => api_keys.google_ai_api_key = Some(key.to_string()), - } - } - - pub fn delete_api_key_for_provider(provider: &Provider, api_keys: &mut ApiKeys) { - match provider { - Provider::OpenAI => api_keys.openai_api_key = None, - Provider::Groq => api_keys.groq_api_key = None, - Provider::OpenRouter => api_keys.openrouter_api_key = None, - Provider::Anthropic => api_keys.anthropic_api_key = None, - Provider::Google => api_keys.google_ai_api_key = None, - } - } -} diff --git a/crates/rullm-cli/src/args.rs b/crates/rullm-cli/src/args.rs index 039ab818..36c2ab08 100644 --- a/crates/rullm-cli/src/args.rs +++ b/crates/rullm-cli/src/args.rs @@ -6,11 +6,11 @@ use clap_complete::CompletionCandidate; use clap_complete::engine::ArgValueCompleter; use std::ffi::OsStr; -use crate::api_keys::ApiKeys; +use crate::auth::AuthConfig; use crate::commands::models::load_models_cache; use crate::commands::{Commands, ModelsCache}; use crate::config::{self, Config}; -use crate::constants::{BINARY_NAME, KEYS_CONFIG_FILE}; +use crate::constants::BINARY_NAME; use crate::templates::TemplateStore; // Example strings for after_long_help @@ -85,7 +85,7 @@ pub struct CliConfig { pub data_base_path: PathBuf, pub config: Config, pub models: Models, - pub api_keys: ApiKeys, + pub auth_config: AuthConfig, } impl CliConfig { @@ -97,21 +97,19 @@ impl CliConfig { let config = config::Config::load(&config_base_path).unwrap(); let models = Models::load(&data_base_path).unwrap(); - let api_keys = - ApiKeys::load_from_file(config_base_path.join(KEYS_CONFIG_FILE)).unwrap_or_default(); + let auth_config = AuthConfig::load(&config_base_path).unwrap_or_default(); Self { config_base_path, data_base_path, config, models, - api_keys, + auth_config, } } - pub fn save_api_keys(&self) -> Result<(), rullm_core::error::LlmError> { - let keys_path = self.config_base_path.join(KEYS_CONFIG_FILE); - self.api_keys.save_to_file(&keys_path) + pub fn save_auth_config(&self) -> anyhow::Result<()> { + self.auth_config.save(&self.config_base_path) } } diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs new file mode 100644 index 00000000..b9015773 --- /dev/null +++ b/crates/rullm-cli/src/auth.rs @@ -0,0 +1,337 @@ +//! Authentication credential management for rullm. +//! +//! Supports multiple authentication methods per provider: +//! - OAuth (for Claude Max/Pro, ChatGPT Plus/Pro subscriptions) +//! - API keys (traditional method) + +use crate::provider::Provider; +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, UNIX_EPOCH}; + +/// File name for auth credentials +pub const AUTH_CONFIG_FILE: &str = "auth.toml"; + +/// Buffer time (in ms) before token expiration to trigger refresh +const TOKEN_EXPIRY_BUFFER_MS: u64 = 5 * 60 * 1000; // 5 minutes + +/// A credential for authenticating with an LLM provider. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum Credential { + /// OAuth credential with access/refresh tokens + OAuth { + access_token: String, + refresh_token: String, + /// Unix timestamp in milliseconds when the access token expires + expires_at: u64, + }, + /// API key credential + Api { api_key: String }, +} + +impl Credential { + /// Create a new OAuth credential + pub fn oauth(access_token: String, refresh_token: String, expires_at: u64) -> Self { + Self::OAuth { + access_token, + refresh_token, + expires_at, + } + } + + /// Create a new API key credential + pub fn api(api_key: String) -> Self { + Self::Api { api_key } + } + + /// Get the access token or API key for use in requests + pub fn get_token(&self) -> &str { + match self { + Self::OAuth { access_token, .. } => access_token, + Self::Api { api_key } => api_key, + } + } + + /// Check if this is an OAuth credential + pub fn is_oauth(&self) -> bool { + matches!(self, Self::OAuth { .. }) + } + + /// Check if an OAuth token is expired or about to expire + pub fn is_expired(&self) -> bool { + match self { + Self::OAuth { expires_at, .. } => { + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + *expires_at <= now_ms + TOKEN_EXPIRY_BUFFER_MS + } + Self::Api { .. } => false, // API keys don't expire + } + } + + /// Get the refresh token if this is an OAuth credential + pub fn refresh_token(&self) -> Option<&str> { + match self { + Self::OAuth { refresh_token, .. } => Some(refresh_token), + Self::Api { .. } => None, + } + } + + /// Return a display string for the credential type + pub fn type_display(&self) -> &'static str { + match self { + Self::OAuth { .. } => "oauth", + Self::Api { .. } => "api", + } + } +} + +/// Authentication configuration containing credentials for all providers. +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct AuthConfig { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub anthropic: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub openai: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub groq: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub openrouter: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub google: Option, +} + +impl AuthConfig { + /// Load auth config from the default location + pub fn load(config_base_path: &Path) -> Result { + let path = config_base_path.join(AUTH_CONFIG_FILE); + Self::load_from_file(&path) + } + + /// Load auth config from a specific file path + pub fn load_from_file(path: &Path) -> Result { + if !path.exists() { + return Ok(Self::default()); + } + + let content = fs::read_to_string(path) + .with_context(|| format!("Failed to read auth config from {}", path.display()))?; + + if content.trim().is_empty() { + return Ok(Self::default()); + } + + toml::from_str(&content) + .with_context(|| format!("Failed to parse auth config from {}", path.display())) + } + + /// Save auth config to the default location + pub fn save(&self, config_base_path: &Path) -> Result<()> { + let path = config_base_path.join(AUTH_CONFIG_FILE); + self.save_to_file(&path) + } + + /// Save auth config to a specific file path + pub fn save_to_file(&self, path: &Path) -> Result<()> { + // Create parent directory if needed + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + let content = toml::to_string_pretty(self) + .with_context(|| "Failed to serialize auth config")?; + + fs::write(path, &content) + .with_context(|| format!("Failed to write auth config to {}", path.display()))?; + + // Set file permissions to 0600 (owner read/write only) on Unix + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + fs::set_permissions(path, perms) + .with_context(|| format!("Failed to set permissions on {}", path.display()))?; + } + + Ok(()) + } + + /// Get credential for a provider from config + pub fn get(&self, provider: &Provider) -> Option<&Credential> { + match provider { + Provider::Anthropic => self.anthropic.as_ref(), + Provider::OpenAI => self.openai.as_ref(), + Provider::Groq => self.groq.as_ref(), + Provider::OpenRouter => self.openrouter.as_ref(), + Provider::Google => self.google.as_ref(), + } + } + + /// Get mutable credential for a provider + pub fn get_mut(&mut self, provider: &Provider) -> &mut Option { + match provider { + Provider::Anthropic => &mut self.anthropic, + Provider::OpenAI => &mut self.openai, + Provider::Groq => &mut self.groq, + Provider::OpenRouter => &mut self.openrouter, + Provider::Google => &mut self.google, + } + } + + /// Set credential for a provider + pub fn set(&mut self, provider: &Provider, credential: Credential) { + *self.get_mut(provider) = Some(credential); + } + + /// Remove credential for a provider + pub fn remove(&mut self, provider: &Provider) { + *self.get_mut(provider) = None; + } + + /// Check if a provider has a credential configured + pub fn has(&self, provider: &Provider) -> bool { + self.get(provider).is_some() + } +} + +/// Get the auth config file path +pub fn auth_config_path(config_base_path: &Path) -> PathBuf { + config_base_path.join(AUTH_CONFIG_FILE) +} + +/// Source of a credential (file or environment variable) +#[derive(Debug, Clone, PartialEq)] +pub enum CredentialSource { + File, + Environment(String), +} + +/// Result of credential lookup including the source +#[derive(Debug)] +pub struct CredentialInfo { + pub credential: Credential, + pub source: CredentialSource, +} + +/// Get credential for a provider, checking file first, then environment variable. +/// +/// File credentials take precedence over environment variables. +pub fn get_credential( + provider: &Provider, + auth_config: &AuthConfig, +) -> Option { + // Check file credentials first (higher precedence) + if let Some(cred) = auth_config.get(provider) { + return Some(CredentialInfo { + credential: cred.clone(), + source: CredentialSource::File, + }); + } + + // Fall back to environment variable + let env_key = provider.env_key(); + if let Ok(api_key) = std::env::var(env_key) { + return Some(CredentialInfo { + credential: Credential::api(api_key), + source: CredentialSource::Environment(env_key.to_string()), + }); + } + + None +} + +/// Get just the token string for a provider (for backward compatibility). +pub fn get_token(provider: &Provider, auth_config: &AuthConfig) -> Option { + get_credential(provider, auth_config).map(|info| info.credential.get_token().to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_credential_oauth() { + let cred = Credential::oauth( + "access".to_string(), + "refresh".to_string(), + u64::MAX, // Far future + ); + assert!(cred.is_oauth()); + assert_eq!(cred.get_token(), "access"); + assert_eq!(cred.refresh_token(), Some("refresh")); + assert!(!cred.is_expired()); + } + + #[test] + fn test_credential_api() { + let cred = Credential::api("sk-test-key".to_string()); + assert!(!cred.is_oauth()); + assert_eq!(cred.get_token(), "sk-test-key"); + assert_eq!(cred.refresh_token(), None); + assert!(!cred.is_expired()); + } + + #[test] + fn test_credential_expired() { + let cred = Credential::oauth( + "access".to_string(), + "refresh".to_string(), + 0, // Long expired + ); + assert!(cred.is_expired()); + } + + #[test] + fn test_auth_config_serialization() { + let mut config = AuthConfig::default(); + config.anthropic = Some(Credential::oauth( + "sk-ant-oat01-test".to_string(), + "sk-ant-ort01-test".to_string(), + 1764813330304, + )); + config.openai = Some(Credential::api("sk-proj-test".to_string())); + + let toml_str = toml::to_string_pretty(&config).unwrap(); + + // Verify it can be deserialized back + let parsed: AuthConfig = toml::from_str(&toml_str).unwrap(); + assert!(parsed.anthropic.is_some()); + assert!(parsed.openai.is_some()); + assert!(parsed.groq.is_none()); + } + + #[test] + fn test_auth_config_save_load() { + let temp_dir = TempDir::new().unwrap(); + let config_path = temp_dir.path(); + + let mut config = AuthConfig::default(); + config.groq = Some(Credential::api("test-groq-key".to_string())); + + config.save(config_path).unwrap(); + + let loaded = AuthConfig::load(config_path).unwrap(); + assert_eq!( + loaded.groq.as_ref().map(|c| c.get_token()), + Some("test-groq-key") + ); + } + + #[test] + fn test_get_credential_file_precedence() { + let mut config = AuthConfig::default(); + config.anthropic = Some(Credential::api("file-key".to_string())); + + // File credential should be returned + let info = get_credential(&Provider::Anthropic, &config).unwrap(); + assert_eq!(info.source, CredentialSource::File); + assert_eq!(info.credential.get_token(), "file-key"); + } +} diff --git a/crates/rullm-cli/src/client.rs b/crates/rullm-cli/src/client.rs index 7104ccb4..451158c1 100644 --- a/crates/rullm-cli/src/client.rs +++ b/crates/rullm-cli/src/client.rs @@ -1,7 +1,6 @@ use super::provider::Provider; -use crate::api_keys::ApiKeys; use crate::args::{Cli, CliConfig}; -use crate::constants; +use crate::auth; use anyhow::{Context, Result}; use rullm_core::simple::{SimpleLlmBuilder, SimpleLlmClient, SimpleLlmConfig}; @@ -222,13 +221,13 @@ pub fn from_model(model_str: &str, cli: &Cli, cli_config: &CliConfig) -> Result< .resolve(model_str) .context("Invalid model format")?; - let api_key = ApiKeys::get_api_key(&provider, &cli_config.api_keys).ok_or_else(|| { + let token = auth::get_token(&provider, &cli_config.auth_config).ok_or_else(|| { anyhow::anyhow!( - "API key required. Set {} environment variable or add it to {} in config directory", - provider.env_key(), - constants::CONFIG_FILE_NAME + "Credentials required. Run 'rullm auth login {}' or set {} environment variable", + provider, + provider.env_key() ) })?; - create_client(&provider, &api_key, None, cli, &model_name).map_err(anyhow::Error::from) + create_client(&provider, &token, None, cli, &model_name).map_err(anyhow::Error::from) } diff --git a/crates/rullm-cli/src/commands/auth.rs b/crates/rullm-cli/src/commands/auth.rs new file mode 100644 index 00000000..953641dc --- /dev/null +++ b/crates/rullm-cli/src/commands/auth.rs @@ -0,0 +1,276 @@ +//! Auth command handlers for managing credentials. + +use anyhow::Result; +use clap::{Args, Subcommand}; +use etcetera::BaseStrategy; +use strum::IntoEnumIterator; + +use crate::auth::{self, AuthConfig, Credential}; +use crate::oauth::{anthropic::AnthropicOAuth, openai::OpenAIOAuth}; +use crate::output::OutputLevel; +use crate::provider::Provider; + +#[derive(Args)] +pub struct AuthArgs { + #[command(subcommand)] + pub action: AuthAction, +} + +#[derive(Subcommand)] +pub enum AuthAction { + /// Login to a provider (OAuth or API key) + Login { + /// Provider name (anthropic, openai, groq, openrouter, google) + provider: Option, + }, + /// Logout from a provider (remove stored credentials) + Logout { + /// Provider name (anthropic, openai, groq, openrouter, google) + provider: Option, + }, + /// List all credentials and environment variables + #[command(alias = "ls")] + List, +} + +/// Authentication method selection. +#[derive(Debug, Clone, Copy)] +pub enum AuthMethod { + OAuth, + ApiKey, +} + +impl AuthArgs { + pub async fn run( + &self, + output_level: OutputLevel, + config_base_path: &std::path::Path, + ) -> Result<()> { + match &self.action { + AuthAction::Login { provider } => { + let provider = match provider { + Some(p) => p.clone(), + None => select_provider()?, + }; + + // Determine available auth methods for the provider + let method = select_auth_method(&provider)?; + + match method { + AuthMethod::OAuth => { + let credential = match provider { + Provider::Anthropic => { + let oauth = AnthropicOAuth::new(); + oauth.login().await? + } + Provider::OpenAI => { + let oauth = OpenAIOAuth::new(); + oauth.login().await? + } + _ => { + anyhow::bail!( + "OAuth is not supported for {}. Use API key instead.", + provider + ); + } + }; + + // Save the credential + let mut auth_config = AuthConfig::load(config_base_path)?; + auth_config.set(&provider, credential); + auth_config.save(config_base_path)?; + + crate::output::success( + &format!("Successfully logged in to {provider}"), + output_level, + ); + } + AuthMethod::ApiKey => { + let api_key = prompt_api_key(&provider)?; + + if api_key.is_empty() { + anyhow::bail!("API key cannot be empty"); + } + + let mut auth_config = AuthConfig::load(config_base_path)?; + auth_config.set(&provider, Credential::api(api_key)); + auth_config.save(config_base_path)?; + + crate::output::success( + &format!("API key for {provider} has been saved"), + output_level, + ); + } + } + } + + AuthAction::Logout { provider } => { + let provider = match provider { + Some(p) => p.clone(), + None => select_provider()?, + }; + + let mut auth_config = AuthConfig::load(config_base_path)?; + auth_config.remove(&provider); + auth_config.save(config_base_path)?; + + crate::output::success( + &format!("Logged out from {provider}"), + output_level, + ); + } + + AuthAction::List => { + let auth_config = AuthConfig::load(config_base_path)?; + print_credentials_list(&auth_config, output_level); + } + } + + Ok(()) + } +} + +/// Select a provider interactively. +fn select_provider() -> Result { + use std::io::{self, Write}; + + println!("\n? Select provider"); + let providers: Vec = Provider::iter().collect(); + + for (i, provider) in providers.iter().enumerate() { + println!(" {}) {}", i + 1, format_provider_display(provider)); + } + + print!("\nEnter number (1-{}): ", providers.len()); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + let choice: usize = input.trim().parse().map_err(|_| { + anyhow::anyhow!("Invalid selection") + })?; + + if choice == 0 || choice > providers.len() { + anyhow::bail!("Invalid selection"); + } + + Ok(providers[choice - 1].clone()) +} + +/// Select authentication method for a provider. +fn select_auth_method(provider: &Provider) -> Result { + use std::io::{self, Write}; + + // Check if OAuth is available for this provider + let oauth_available = matches!(provider, Provider::Anthropic | Provider::OpenAI); + + if !oauth_available { + // Only API key available + return Ok(AuthMethod::ApiKey); + } + + println!("\n? Select authentication method"); + println!(" 1) OAuth (subscription-based access)"); + println!(" 2) API Key"); + + print!("\nEnter number (1-2): "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + match input.trim() { + "1" => Ok(AuthMethod::OAuth), + "2" => Ok(AuthMethod::ApiKey), + _ => anyhow::bail!("Invalid selection"), + } +} + +/// Prompt for API key input. +fn prompt_api_key(provider: &Provider) -> Result { + use std::io::{self, Write}; + + print!("Enter API key for {provider}: "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + + Ok(input.trim().to_string()) +} + +/// Format provider name for display. +fn format_provider_display(provider: &Provider) -> &'static str { + match provider { + Provider::Anthropic => "Anthropic", + Provider::OpenAI => "OpenAI", + Provider::Groq => "Groq", + Provider::OpenRouter => "OpenRouter", + Provider::Google => "Google", + } +} + +/// Print the credentials list in a nice format. +fn print_credentials_list(auth_config: &AuthConfig, _output_level: OutputLevel) { + let mut file_creds: Vec<(Provider, String)> = Vec::new(); + let mut env_creds: Vec<(Provider, String)> = Vec::new(); + + for provider in Provider::iter() { + // Check file credentials + if let Some(cred) = auth_config.get(&provider) { + file_creds.push((provider, cred.type_display().to_string())); + } else { + // Check environment variable + let env_key = provider.env_key(); + if std::env::var(env_key).is_ok() { + env_creds.push((provider, env_key.to_string())); + } + } + } + + // Print file credentials section + if !file_creds.is_empty() { + let config_path = auth::auth_config_path( + &etcetera::choose_base_strategy() + .unwrap() + .config_dir() + .join(crate::constants::BINARY_NAME), + ); + println!("\n\u{250c} Credentials {}", config_path.display()); + println!("\u{2502}"); + + for (provider, cred_type) in &file_creds { + println!( + "\u{25cf} {} {}", + format_provider_display(provider), + cred_type + ); + println!("\u{2502}"); + } + + println!("\u{2514} {} credentials", file_creds.len()); + } + + // Print environment variables section + if !env_creds.is_empty() { + println!("\n\u{250c} Environment"); + println!("\u{2502}"); + + for (provider, env_key) in &env_creds { + println!( + "\u{25cf} {} {}", + format_provider_display(provider), + env_key + ); + println!("\u{2502}"); + } + + println!("\u{2514} {} environment variables", env_creds.len()); + } + + if file_creds.is_empty() && env_creds.is_empty() { + println!("\nNo credentials configured."); + println!("Use 'rullm auth login' to add credentials."); + } +} diff --git a/crates/rullm-cli/src/commands/info.rs b/crates/rullm-cli/src/commands/info.rs index a1cb4dcb..cb96f42d 100644 --- a/crates/rullm-cli/src/commands/info.rs +++ b/crates/rullm-cli/src/commands/info.rs @@ -3,6 +3,7 @@ use clap::Args; use crate::{ args::{Cli, CliConfig}, + auth, commands::env_var_status, constants::*, output::OutputLevel, @@ -22,7 +23,7 @@ impl InfoArgs { ) -> Result<()> { let config_path = cli_config.config_base_path.join(CONFIG_FILE_NAME); let models_path = cli_config.data_base_path.join(MODEL_FILE_NAME); - let keys_path = cli_config.config_base_path.join(KEYS_CONFIG_FILE); + let auth_path = auth::auth_config_path(&cli_config.config_base_path); let templates_path = cli_config.config_base_path.join(TEMPLATES_DIR_NAME); // crate::output::heading("Config files:", output_level); @@ -30,7 +31,7 @@ impl InfoArgs { &format!("config file: {}", config_path.display()), output_level, ); - crate::output::note(&format!("keys file: {}", keys_path.display()), output_level); + crate::output::note(&format!("auth file: {}", auth_path.display()), output_level); crate::output::note( &format!("models cache file: {}", models_path.display()), output_level, diff --git a/crates/rullm-cli/src/commands/keys.rs b/crates/rullm-cli/src/commands/keys.rs deleted file mode 100644 index 49696e2f..00000000 --- a/crates/rullm-cli/src/commands/keys.rs +++ /dev/null @@ -1,118 +0,0 @@ -use anyhow::Result; -use clap::{Args, Subcommand}; -use strum::IntoEnumIterator; - -use crate::provider::Provider; -use crate::{ - api_keys::ApiKeys, - args::{Cli, CliConfig}, - output::OutputLevel, -}; - -#[derive(Args)] -pub struct KeysArgs { - #[command(subcommand)] - pub action: KeysAction, -} - -#[derive(Subcommand)] -pub enum KeysAction { - /// Set an API key for a provider - Set { - /// Provider name (openai, anthropic, google) - provider: Provider, - /// API key (if not provided, will read from stdin) - #[arg(short, long)] - key: Option, - }, - /// Delete an API key for a provider - Delete { - /// Provider name (openai, anthropic, google) - provider: Provider, - }, - /// List which providers have API keys set - List, -} - -impl KeysArgs { - pub async fn run( - &self, - output_level: OutputLevel, - cli_config: &mut CliConfig, - _cli: &Cli, - ) -> Result<()> { - match &self.action { - KeysAction::Set { provider, key } => { - let api_key = if let Some(key) = key { - key.clone() - } else { - use std::io::{self, Write}; - print!("Enter API key for {provider}: "); - io::stdout().flush()?; - - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - input.trim().to_string() - }; - - if api_key.is_empty() { - return Err(anyhow::anyhow!("API key cannot be empty")); - } - - let api_keys = &mut cli_config.api_keys; - ApiKeys::set_api_key_for_provider(provider, api_keys, &api_key); - cli_config.save_api_keys()?; - - crate::output::success( - &format!("API key for {provider} has been saved"), - output_level, - ); - } - KeysAction::Delete { provider } => { - let api_keys = &mut cli_config.api_keys; - ApiKeys::delete_api_key_for_provider(provider, api_keys); - cli_config.save_api_keys()?; - - crate::output::success( - &format!("API key for {provider} has been deleted"), - output_level, - ); - } - KeysAction::List => { - let api_keys = cli_config.api_keys.clone(); - - for provider in Provider::iter() { - let has_cli_key = match provider { - Provider::OpenAI => api_keys.openai_api_key.is_some(), - Provider::Groq => api_keys.groq_api_key.is_some(), - Provider::OpenRouter => api_keys.openrouter_api_key.is_some(), - Provider::Anthropic => api_keys.anthropic_api_key.is_some(), - Provider::Google => api_keys.google_ai_api_key.is_some(), - }; - - let has_env_key = std::env::var(provider.env_key()).is_ok(); - - let source_info = if has_cli_key { - Some("cli".to_string()) - } else if has_env_key { - Some(format!("env ({})", provider.env_key())) - } else { - None - }; - - if let Some(source) = source_info { - crate::output::note( - &format!( - "{}: {}", - crate::output::format_provider(&provider.to_string()), - source - ), - output_level, - ); - } - } - } - } - Ok(()) - } -} diff --git a/crates/rullm-cli/src/commands/mod.rs b/crates/rullm-cli/src/commands/mod.rs index bb645f5d..bd168fa3 100644 --- a/crates/rullm-cli/src/commands/mod.rs +++ b/crates/rullm-cli/src/commands/mod.rs @@ -13,20 +13,20 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; pub mod alias; +pub mod auth; pub mod chat; pub mod completions; pub mod info; pub mod templates; -pub mod keys; pub mod models; // Re-export the command args structs pub use alias::AliasArgs; +pub use auth::AuthArgs; pub use chat::ChatArgs; pub use completions::CompletionsArgs; pub use info::InfoArgs; -pub use keys::KeysArgs; pub use models::ModelsArgs; // Example strings for after_long_help @@ -42,11 +42,11 @@ const MODELS_EXAMPLES: &str = r#"EXAMPLES: rullm models default openai/gpt-4o # Set default model rullm models clear # Clear model cache"#; -const KEYS_EXAMPLES: &str = r#"EXAMPLES: - rullm keys set openai # Set OpenAI API key (prompted) - rullm keys set anthropic -k sk-ant-... # Set Anthropic key directly - rullm keys list # Show which providers have keys - rullm keys delete google # Remove Google API key"#; +const AUTH_EXAMPLES: &str = r#"EXAMPLES: + rullm auth login # Login interactively + rullm auth login anthropic # Login to Anthropic (OAuth or API key) + rullm auth logout openai # Logout from OpenAI + rullm auth list # Show all credentials"#; const ALIAS_EXAMPLES: &str = r#"EXAMPLES: rullm alias list # Show all available aliases @@ -73,9 +73,9 @@ pub enum Commands { /// Show configuration and system information #[command(after_long_help = INFO_EXAMPLES)] Info(InfoArgs), - /// Manage API keys - #[command(after_long_help = KEYS_EXAMPLES)] - Keys(KeysArgs), + /// Manage authentication credentials + #[command(after_long_help = AUTH_EXAMPLES)] + Auth(AuthArgs), /// Manage model aliases #[command(after_long_help = ALIAS_EXAMPLES)] Alias(AliasArgs), diff --git a/crates/rullm-cli/src/constants.rs b/crates/rullm-cli/src/constants.rs index b0f35475..3b55f925 100644 --- a/crates/rullm-cli/src/constants.rs +++ b/crates/rullm-cli/src/constants.rs @@ -1,6 +1,5 @@ pub const CONFIG_FILE_NAME: &str = "config.toml"; pub const MODEL_FILE_NAME: &str = "models.json"; pub const ALIASES_CONFIG_FILE: &str = "aliases.toml"; -pub const KEYS_CONFIG_FILE: &str = "keys.toml"; pub const TEMPLATES_DIR_NAME: &str = "templates"; pub const BINARY_NAME: &str = env!("CARGO_BIN_NAME"); diff --git a/crates/rullm-cli/src/main.rs b/crates/rullm-cli/src/main.rs index 8e8babf1..674b462e 100644 --- a/crates/rullm-cli/src/main.rs +++ b/crates/rullm-cli/src/main.rs @@ -1,13 +1,14 @@ // Binary entry point for rullm-cli mod aliases; -mod api_keys; mod args; +mod auth; mod cli_helpers; mod client; mod commands; mod config; mod constants; +mod oauth; mod output; mod provider; mod spinner; @@ -50,7 +51,7 @@ pub async fn run() -> Result<()> { if cli.model.is_some() { match &cli.command { Some(Commands::Info(_)) - | Some(Commands::Keys(_)) + | Some(Commands::Auth(_)) | Some(Commands::Alias(_)) | Some(Commands::Completions(_)) => { use clap::error::ErrorKind; @@ -83,7 +84,7 @@ pub async fn run() -> Result<()> { Some(Commands::Chat(args)) => args.run(output_level, &cli_config, &cli).await?, Some(Commands::Models(args)) => args.run(output_level, &mut cli_config, &cli).await?, Some(Commands::Info(args)) => args.run(output_level, &cli_config, &cli).await?, - Some(Commands::Keys(args)) => args.run(output_level, &mut cli_config, &cli).await?, + Some(Commands::Auth(args)) => args.run(output_level, &cli_config.config_base_path).await?, Some(Commands::Alias(args)) => args.run(output_level, &cli_config, &cli).await?, Some(Commands::Completions(args)) => args.run(output_level, &cli_config, &cli).await?, Some(Commands::Templates(args)) => args.run(output_level, &cli_config, &cli).await?, diff --git a/crates/rullm-cli/src/oauth/anthropic.rs b/crates/rullm-cli/src/oauth/anthropic.rs new file mode 100644 index 00000000..0c2ae43f --- /dev/null +++ b/crates/rullm-cli/src/oauth/anthropic.rs @@ -0,0 +1,252 @@ +//! Anthropic OAuth flow implementation. +//! +//! Supports Claude Max/Pro subscription authentication. + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +use super::{CallbackServer, PkceChallenge}; +use crate::auth::Credential; + +/// Anthropic OAuth configuration. +pub struct AnthropicOAuth { + /// Authorization URL + pub authorization_url: &'static str, + /// Token URL + pub token_url: &'static str, + /// Client ID (Claude Code's public ID) + pub client_id: &'static str, + /// Callback port + pub callback_port: u16, + /// Required scopes + pub scopes: &'static [&'static str], +} + +impl Default for AnthropicOAuth { + fn default() -> Self { + Self { + authorization_url: "https://console.anthropic.com/oauth/authorize", + token_url: "https://api.anthropic.com/oauth/token", + client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", + callback_port: 8765, + scopes: &["org:create_api_key", "user:profile", "user:inference"], + } + } +} + +/// Token response from Anthropic OAuth. +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: String, + expires_in: u64, + #[allow(dead_code)] + token_type: String, +} + +/// Token refresh request body. +#[derive(Debug, Serialize)] +struct RefreshRequest<'a> { + grant_type: &'static str, + client_id: &'a str, + refresh_token: &'a str, +} + +/// Token exchange request body. +#[derive(Debug, Serialize)] +struct TokenRequest<'a> { + grant_type: &'static str, + client_id: &'a str, + code: &'a str, + redirect_uri: &'a str, + code_verifier: &'a str, +} + +impl AnthropicOAuth { + /// Create a new Anthropic OAuth handler with default configuration. + pub fn new() -> Self { + Self::default() + } + + /// Build the authorization URL for the OAuth flow. + pub fn build_authorization_url(&self, pkce: &PkceChallenge, state: &str) -> String { + let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); + let scope = self.scopes.join(" "); + + format!( + "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}", + self.authorization_url, + urlencoding::encode(self.client_id), + urlencoding::encode(&redirect_uri), + urlencoding::encode(&scope), + urlencoding::encode(&pkce.challenge), + pkce.method(), + urlencoding::encode(state) + ) + } + + /// Start the OAuth flow and return the credential on success. + /// + /// This will: + /// 1. Start a local callback server + /// 2. Open the browser to the authorization URL + /// 3. Wait for the callback with the authorization code + /// 4. Exchange the code for tokens + pub async fn login(&self) -> Result { + // Generate PKCE challenge + let pkce = PkceChallenge::generate(); + + // Generate state for CSRF protection + let state = generate_state(); + + // Start callback server + let server = CallbackServer::new(self.callback_port) + .context("Failed to start callback server")?; + + // Build and open authorization URL + let auth_url = self.build_authorization_url(&pkce, &state); + + println!("Opening browser for Anthropic authentication..."); + webbrowser::open(&auth_url).context("Failed to open browser")?; + + println!("Waiting for authentication (timeout: 5 minutes)..."); + + // Wait for callback + let callback = server + .wait_for_callback(Duration::from_secs(300)) + .context("Failed to receive OAuth callback")?; + + // Verify state + if callback.state.as_deref() != Some(&state) { + anyhow::bail!("State mismatch in OAuth callback (possible CSRF attack)"); + } + + // Exchange code for tokens + let credential = self.exchange_code(&callback.code, &pkce.verifier).await?; + + Ok(credential) + } + + /// Exchange authorization code for tokens. + async fn exchange_code(&self, code: &str, code_verifier: &str) -> Result { + let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); + + let request_body = TokenRequest { + grant_type: "authorization_code", + client_id: self.client_id, + code, + redirect_uri: &redirect_uri, + code_verifier, + }; + + let client = reqwest::Client::new(); + let response = client + .post(self.token_url) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(serde_urlencoded::to_string(&request_body)?) + .send() + .await + .context("Failed to send token request")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("Token exchange failed: {} - {}", status, body); + } + + let token_response: TokenResponse = response + .json() + .await + .context("Failed to parse token response")?; + + // Calculate expiration timestamp + let expires_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) + .unwrap_or(0); + + Ok(Credential::oauth( + token_response.access_token, + token_response.refresh_token, + expires_at, + )) + } + + /// Refresh an expired OAuth token. + pub async fn refresh_token(&self, refresh_token: &str) -> Result { + let request_body = RefreshRequest { + grant_type: "refresh_token", + client_id: self.client_id, + refresh_token, + }; + + let client = reqwest::Client::new(); + let response = client + .post(self.token_url) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(serde_urlencoded::to_string(&request_body)?) + .send() + .await + .context("Failed to send refresh request")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("Token refresh failed: {} - {}", status, body); + } + + let token_response: TokenResponse = response + .json() + .await + .context("Failed to parse refresh response")?; + + let expires_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) + .unwrap_or(0); + + Ok(Credential::oauth( + token_response.access_token, + token_response.refresh_token, + expires_at, + )) + } +} + +/// Generate a random state string for CSRF protection. +fn generate_state() -> String { + use rand::RngCore; + let mut bytes = [0u8; 16]; + rand::rng().fill_bytes(&mut bytes); + hex::encode(bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_authorization_url() { + let oauth = AnthropicOAuth::new(); + let pkce = PkceChallenge::generate(); + let state = "test-state"; + + let url = oauth.build_authorization_url(&pkce, state); + + assert!(url.starts_with("https://console.anthropic.com/oauth/authorize")); + assert!(url.contains("response_type=code")); + assert!(url.contains("client_id=")); + assert!(url.contains("redirect_uri=")); + assert!(url.contains("code_challenge=")); + assert!(url.contains("code_challenge_method=S256")); + assert!(url.contains("state=test-state")); + } + + #[test] + fn test_default_config() { + let oauth = AnthropicOAuth::new(); + assert_eq!(oauth.callback_port, 8765); + assert!(oauth.scopes.contains(&"user:inference")); + } +} diff --git a/crates/rullm-cli/src/oauth/mod.rs b/crates/rullm-cli/src/oauth/mod.rs new file mode 100644 index 00000000..93bbe19a --- /dev/null +++ b/crates/rullm-cli/src/oauth/mod.rs @@ -0,0 +1,12 @@ +//! OAuth authentication module for rullm. +//! +//! Provides OAuth 2.0 authentication flows for supported providers. + +mod pkce; +mod server; + +pub mod anthropic; +pub mod openai; + +pub use pkce::PkceChallenge; +pub use server::{CallbackResult, CallbackServer}; diff --git a/crates/rullm-cli/src/oauth/openai.rs b/crates/rullm-cli/src/oauth/openai.rs new file mode 100644 index 00000000..78eef86a --- /dev/null +++ b/crates/rullm-cli/src/oauth/openai.rs @@ -0,0 +1,303 @@ +//! OpenAI OAuth flow implementation. +//! +//! Supports ChatGPT Plus/Pro subscription authentication via OAuth discovery. + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +use super::{CallbackServer, PkceChallenge}; +use crate::auth::Credential; + +/// OpenAI OAuth configuration. +pub struct OpenAIOAuth { + /// Issuer URL for discovery + pub issuer_url: &'static str, + /// Callback port + pub callback_port: u16, +} + +impl Default for OpenAIOAuth { + fn default() -> Self { + Self { + issuer_url: "https://auth.openai.com", + callback_port: 1455, + } + } +} + +/// OAuth authorization server metadata from discovery. +#[derive(Debug, Deserialize)] +struct AuthorizationServerMetadata { + authorization_endpoint: String, + token_endpoint: String, + #[allow(dead_code)] + issuer: String, +} + +/// Token response from OpenAI OAuth. +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: Option, + expires_in: u64, + #[allow(dead_code)] + token_type: String, +} + +/// Token exchange request body. +#[derive(Debug, Serialize)] +struct TokenRequest<'a> { + grant_type: &'static str, + code: &'a str, + redirect_uri: &'a str, + code_verifier: &'a str, +} + +/// Token refresh request body. +#[derive(Debug, Serialize)] +struct RefreshRequest<'a> { + grant_type: &'static str, + refresh_token: &'a str, +} + +impl OpenAIOAuth { + /// Create a new OpenAI OAuth handler with default configuration. + pub fn new() -> Self { + Self::default() + } + + /// Discover OAuth endpoints from the authorization server. + async fn discover(&self) -> Result { + let discovery_url = format!( + "{}/.well-known/oauth-authorization-server", + self.issuer_url + ); + + let client = reqwest::Client::new(); + let response = client + .get(&discovery_url) + .send() + .await + .with_context(|| format!("Failed to fetch OAuth discovery from {}", discovery_url))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("OAuth discovery failed: {} - {}", status, body); + } + + response + .json() + .await + .context("Failed to parse OAuth discovery response") + } + + /// Build the authorization URL for the OAuth flow. + fn build_authorization_url( + &self, + authorization_endpoint: &str, + pkce: &PkceChallenge, + state: &str, + ) -> String { + let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); + + format!( + "{}?response_type=code&redirect_uri={}&code_challenge={}&code_challenge_method={}&state={}", + authorization_endpoint, + urlencoding::encode(&redirect_uri), + urlencoding::encode(&pkce.challenge), + pkce.method(), + urlencoding::encode(state) + ) + } + + /// Start the OAuth flow and return the credential on success. + /// + /// This will: + /// 1. Discover OAuth endpoints + /// 2. Start a local callback server + /// 3. Open the browser to the authorization URL + /// 4. Wait for the callback with the authorization code + /// 5. Exchange the code for tokens + pub async fn login(&self) -> Result { + // Discover OAuth endpoints + println!("Discovering OpenAI OAuth endpoints..."); + let metadata = self.discover().await?; + + // Generate PKCE challenge + let pkce = PkceChallenge::generate(); + + // Generate state for CSRF protection + let state = generate_state(); + + // Start callback server + let server = CallbackServer::new(self.callback_port) + .context("Failed to start callback server")?; + + // Build and open authorization URL + let auth_url = self.build_authorization_url(&metadata.authorization_endpoint, &pkce, &state); + + println!("Opening browser for OpenAI authentication..."); + webbrowser::open(&auth_url).context("Failed to open browser")?; + + println!("Waiting for authentication (timeout: 5 minutes)..."); + + // Wait for callback + let callback = server + .wait_for_callback(Duration::from_secs(300)) + .context("Failed to receive OAuth callback")?; + + // Verify state + if callback.state.as_deref() != Some(&state) { + anyhow::bail!("State mismatch in OAuth callback (possible CSRF attack)"); + } + + // Exchange code for tokens + let credential = self + .exchange_code(&metadata.token_endpoint, &callback.code, &pkce.verifier) + .await?; + + Ok(credential) + } + + /// Exchange authorization code for tokens. + async fn exchange_code( + &self, + token_endpoint: &str, + code: &str, + code_verifier: &str, + ) -> Result { + let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); + + let request_body = TokenRequest { + grant_type: "authorization_code", + code, + redirect_uri: &redirect_uri, + code_verifier, + }; + + let client = reqwest::Client::new(); + let response = client + .post(token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(serde_urlencoded::to_string(&request_body)?) + .send() + .await + .context("Failed to send token request")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("Token exchange failed: {} - {}", status, body); + } + + let token_response: TokenResponse = response + .json() + .await + .context("Failed to parse token response")?; + + // Calculate expiration timestamp + let expires_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) + .unwrap_or(0); + + // OpenAI might not always return a refresh token + let refresh_token = token_response + .refresh_token + .unwrap_or_else(|| String::new()); + + Ok(Credential::oauth( + token_response.access_token, + refresh_token, + expires_at, + )) + } + + /// Refresh an expired OAuth token. + pub async fn refresh_token(&self, refresh_token: &str) -> Result { + // Need to discover endpoints first + let metadata = self.discover().await?; + + let request_body = RefreshRequest { + grant_type: "refresh_token", + refresh_token, + }; + + let client = reqwest::Client::new(); + let response = client + .post(&metadata.token_endpoint) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(serde_urlencoded::to_string(&request_body)?) + .send() + .await + .context("Failed to send refresh request")?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("Token refresh failed: {} - {}", status, body); + } + + let token_response: TokenResponse = response + .json() + .await + .context("Failed to parse refresh response")?; + + let expires_at = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) + .unwrap_or(0); + + let new_refresh_token = token_response + .refresh_token + .unwrap_or_else(|| refresh_token.to_string()); + + Ok(Credential::oauth( + token_response.access_token, + new_refresh_token, + expires_at, + )) + } +} + +/// Generate a random state string for CSRF protection. +fn generate_state() -> String { + use rand::RngCore; + let mut bytes = [0u8; 16]; + rand::rng().fill_bytes(&mut bytes); + hex::encode(bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let oauth = OpenAIOAuth::new(); + assert_eq!(oauth.callback_port, 1455); + assert_eq!(oauth.issuer_url, "https://auth.openai.com"); + } + + #[test] + fn test_build_authorization_url() { + let oauth = OpenAIOAuth::new(); + let pkce = PkceChallenge::generate(); + let state = "test-state"; + + let url = oauth.build_authorization_url( + "https://auth.openai.com/authorize", + &pkce, + state, + ); + + assert!(url.starts_with("https://auth.openai.com/authorize")); + assert!(url.contains("response_type=code")); + assert!(url.contains("redirect_uri=")); + assert!(url.contains("code_challenge=")); + assert!(url.contains("code_challenge_method=S256")); + assert!(url.contains("state=test-state")); + } +} diff --git a/crates/rullm-cli/src/oauth/pkce.rs b/crates/rullm-cli/src/oauth/pkce.rs new file mode 100644 index 00000000..d2a02849 --- /dev/null +++ b/crates/rullm-cli/src/oauth/pkce.rs @@ -0,0 +1,86 @@ +//! PKCE (Proof Key for Code Exchange) implementation for OAuth 2.0. +//! +//! Implements RFC 7636 for secure authorization code flow. + +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use rand::RngCore; +use sha2::{Digest, Sha256}; + +/// PKCE challenge pair consisting of verifier and challenge. +#[derive(Debug, Clone)] +pub struct PkceChallenge { + /// The code verifier (sent with token request) + pub verifier: String, + /// The code challenge (sent with authorization request) + pub challenge: String, +} + +impl PkceChallenge { + /// Generate a new PKCE challenge pair. + /// + /// Creates a 64-byte random code verifier and derives the S256 challenge from it. + pub fn generate() -> Self { + // Generate 64 random bytes for the code verifier + let mut verifier_bytes = [0u8; 64]; + rand::rng().fill_bytes(&mut verifier_bytes); + + // Base64url encode the verifier (no padding) + let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); + + // Create code challenge: base64url(sha256(verifier)) + let mut hasher = Sha256::new(); + hasher.update(verifier.as_bytes()); + let hash = hasher.finalize(); + let challenge = URL_SAFE_NO_PAD.encode(hash); + + Self { verifier, challenge } + } + + /// Get the challenge method (always "S256"). + pub fn method(&self) -> &'static str { + "S256" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pkce_generation() { + let pkce = PkceChallenge::generate(); + + // Verifier should be base64url encoded 64 bytes = 86 chars + assert_eq!(pkce.verifier.len(), 86); + + // Challenge should be base64url encoded SHA256 = 43 chars + assert_eq!(pkce.challenge.len(), 43); + + // Method should be S256 + assert_eq!(pkce.method(), "S256"); + } + + #[test] + fn test_pkce_uniqueness() { + let pkce1 = PkceChallenge::generate(); + let pkce2 = PkceChallenge::generate(); + + // Each generation should produce unique values + assert_ne!(pkce1.verifier, pkce2.verifier); + assert_ne!(pkce1.challenge, pkce2.challenge); + } + + #[test] + fn test_challenge_derivation() { + // Verify that the challenge is correctly derived from verifier + let pkce = PkceChallenge::generate(); + + // Manually compute the expected challenge + let mut hasher = Sha256::new(); + hasher.update(pkce.verifier.as_bytes()); + let hash = hasher.finalize(); + let expected_challenge = URL_SAFE_NO_PAD.encode(hash); + + assert_eq!(pkce.challenge, expected_challenge); + } +} diff --git a/crates/rullm-cli/src/oauth/server.rs b/crates/rullm-cli/src/oauth/server.rs new file mode 100644 index 00000000..325cf124 --- /dev/null +++ b/crates/rullm-cli/src/oauth/server.rs @@ -0,0 +1,226 @@ +//! Local HTTP server for OAuth callback handling. +//! +//! Starts a temporary server to receive the OAuth authorization code. + +use anyhow::{Context, Result, anyhow}; +use std::io::{Read, Write}; +use std::net::TcpListener; +use std::time::Duration; + +/// Result of waiting for an OAuth callback. +#[derive(Debug)] +pub struct CallbackResult { + /// The authorization code received + pub code: String, + /// The state parameter (for CSRF verification) + pub state: Option, +} + +/// Local callback server for OAuth flows. +pub struct CallbackServer { + listener: TcpListener, + port: u16, +} + +impl CallbackServer { + /// Create a new callback server on the specified port. + pub fn new(port: u16) -> Result { + let addr = format!("127.0.0.1:{port}"); + let listener = TcpListener::bind(&addr) + .with_context(|| format!("Failed to bind to {addr}"))?; + + // Get the actual port if 0 was specified + let actual_port = listener.local_addr()?.port(); + + Ok(Self { + listener, + port: actual_port, + }) + } + + /// Get the redirect URI for this callback server. + pub fn redirect_uri(&self) -> String { + format!("http://localhost:{}/callback", self.port) + } + + /// Wait for the OAuth callback and extract the authorization code. + /// + /// This blocks until a request is received or the timeout is reached. + pub fn wait_for_callback(&self, timeout: Duration) -> Result { + // Set non-blocking mode on the listener and poll with timeout + self.listener + .set_nonblocking(true) + .context("Failed to set non-blocking mode")?; + + let start = std::time::Instant::now(); + let poll_interval = Duration::from_millis(100); + + loop { + match self.listener.accept() { + Ok((mut stream, _addr)) => { + // Set the stream back to blocking for read/write + stream.set_nonblocking(false).ok(); + stream.set_read_timeout(Some(Duration::from_secs(5))).ok(); + + // Read the HTTP request + let mut buffer = [0u8; 4096]; + let n = stream + .read(&mut buffer) + .context("Failed to read from connection")?; + + let request = String::from_utf8_lossy(&buffer[..n]); + + // Parse the request to extract code and state + let result = Self::parse_callback_request(&request)?; + + // Send a success response + let response_body = r#" + +Authentication Successful + +
+

Authentication successful!

+

You can close this window and return to the terminal.

+
+ +"#; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + response_body.len(), + response_body + ); + + stream + .write_all(response.as_bytes()) + .context("Failed to send response")?; + + stream.flush().ok(); + + return Ok(result); + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // No connection yet, check timeout + if start.elapsed() >= timeout { + return Err(anyhow!("Timeout waiting for OAuth callback")); + } + std::thread::sleep(poll_interval); + } + Err(e) => { + return Err(e).context("Failed to accept connection"); + } + } + } + } + + /// Parse the callback request to extract code and state parameters. + fn parse_callback_request(request: &str) -> Result { + // Extract the request line (GET /callback?code=xxx&state=yyy HTTP/1.1) + let first_line = request + .lines() + .next() + .ok_or_else(|| anyhow!("Empty request"))?; + + // Extract the path with query string + let parts: Vec<&str> = first_line.split_whitespace().collect(); + if parts.len() < 2 { + return Err(anyhow!("Invalid HTTP request line")); + } + + let path = parts[1]; + + // Check for error response + if let Some(error) = Self::extract_query_param(path, "error") { + let description = Self::extract_query_param(path, "error_description") + .unwrap_or_else(|| "Unknown error".to_string()); + return Err(anyhow!("OAuth error: {} - {}", error, description)); + } + + // Extract the code + let code = Self::extract_query_param(path, "code") + .ok_or_else(|| anyhow!("No authorization code in callback"))?; + + // Extract state (optional) + let state = Self::extract_query_param(path, "state"); + + Ok(CallbackResult { code, state }) + } + + /// Extract a query parameter value from a URL path. + fn extract_query_param(path: &str, param: &str) -> Option { + let query = path.split('?').nth(1)?; + for pair in query.split('&') { + let mut kv = pair.splitn(2, '='); + let key = kv.next()?; + let value = kv.next()?; + if key == param { + // URL decode the value + return Some(urlencoding::decode(value).ok()?.into_owned()); + } + } + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_callback_success() { + let request = "GET /callback?code=abc123&state=xyz789 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let result = CallbackServer::parse_callback_request(request).unwrap(); + assert_eq!(result.code, "abc123"); + assert_eq!(result.state, Some("xyz789".to_string())); + } + + #[test] + fn test_parse_callback_no_state() { + let request = "GET /callback?code=abc123 HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let result = CallbackServer::parse_callback_request(request).unwrap(); + assert_eq!(result.code, "abc123"); + assert_eq!(result.state, None); + } + + #[test] + fn test_parse_callback_error() { + let request = "GET /callback?error=access_denied&error_description=User%20denied%20access HTTP/1.1\r\n"; + let result = CallbackServer::parse_callback_request(request); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("access_denied")); + } + + #[test] + fn test_parse_callback_no_code() { + let request = "GET /callback?state=xyz HTTP/1.1\r\n"; + let result = CallbackServer::parse_callback_request(request); + assert!(result.is_err()); + } + + #[test] + fn test_extract_query_param() { + let path = "/callback?code=abc&state=def&other=ghi"; + assert_eq!( + CallbackServer::extract_query_param(path, "code"), + Some("abc".to_string()) + ); + assert_eq!( + CallbackServer::extract_query_param(path, "state"), + Some("def".to_string()) + ); + assert_eq!( + CallbackServer::extract_query_param(path, "missing"), + None + ); + } + + #[test] + fn test_redirect_uri() { + // Note: This test requires an available port + if let Ok(server) = CallbackServer::new(0) { + let uri = server.redirect_uri(); + assert!(uri.starts_with("http://localhost:")); + assert!(uri.ends_with("/callback")); + } + } +} From 8bb00ff8940040543c581e39df60cb372bc90732 Mon Sep 17 00:00:00 2001 From: lambda Date: Thu, 4 Dec 2025 23:11:47 +0530 Subject: [PATCH 02/11] feat(oauth): add automatic token refresh and fix provider endpoints --- .justfile | 9 +-- crates/rullm-cli/src/args.rs | 4 - crates/rullm-cli/src/auth.rs | 90 +++++++++++++++++------ crates/rullm-cli/src/client.rs | 24 ++++-- crates/rullm-cli/src/commands/auth.rs | 82 +++++++++++++-------- crates/rullm-cli/src/commands/chat/mod.rs | 4 +- crates/rullm-cli/src/commands/models.rs | 2 +- crates/rullm-cli/src/main.rs | 4 +- crates/rullm-cli/src/oauth/anthropic.rs | 56 +++++++++----- crates/rullm-cli/src/oauth/mod.rs | 2 +- crates/rullm-cli/src/oauth/openai.rs | 83 ++++++++++++++------- crates/rullm-cli/src/oauth/pkce.rs | 5 +- crates/rullm-cli/src/oauth/server.rs | 19 +++-- 13 files changed, 262 insertions(+), 122 deletions(-) diff --git a/.justfile b/.justfile index 5ee1b391..61346da2 100644 --- a/.justfile +++ b/.justfile @@ -1,9 +1,6 @@ lint: - cargo fmt --all - cargo clippy --fix + cargo clippy --all-targets --all-features -- -D warnings -fmt: - cargo fmt --check -clippy: - cargo clippy +fmt: + cargo fmt diff --git a/crates/rullm-cli/src/args.rs b/crates/rullm-cli/src/args.rs index 36c2ab08..87dfa3c4 100644 --- a/crates/rullm-cli/src/args.rs +++ b/crates/rullm-cli/src/args.rs @@ -107,10 +107,6 @@ impl CliConfig { auth_config, } } - - pub fn save_auth_config(&self) -> anyhow::Result<()> { - self.auth_config.save(&self.config_base_path) - } } #[derive(Parser)] diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs index b9015773..4bd7935c 100644 --- a/crates/rullm-cli/src/auth.rs +++ b/crates/rullm-cli/src/auth.rs @@ -55,11 +55,6 @@ impl Credential { } } - /// Check if this is an OAuth credential - pub fn is_oauth(&self) -> bool { - matches!(self, Self::OAuth { .. }) - } - /// Check if an OAuth token is expired or about to expire pub fn is_expired(&self) -> bool { match self { @@ -144,8 +139,8 @@ impl AuthConfig { .with_context(|| format!("Failed to create directory {}", parent.display()))?; } - let content = toml::to_string_pretty(self) - .with_context(|| "Failed to serialize auth config")?; + let content = + toml::to_string_pretty(self).with_context(|| "Failed to serialize auth config")?; fs::write(path, &content) .with_context(|| format!("Failed to write auth config to {}", path.display()))?; @@ -193,11 +188,6 @@ impl AuthConfig { pub fn remove(&mut self, provider: &Provider) { *self.get_mut(provider) = None; } - - /// Check if a provider has a credential configured - pub fn has(&self, provider: &Provider) -> bool { - self.get(provider).is_some() - } } /// Get the auth config file path @@ -222,10 +212,7 @@ pub struct CredentialInfo { /// Get credential for a provider, checking file first, then environment variable. /// /// File credentials take precedence over environment variables. -pub fn get_credential( - provider: &Provider, - auth_config: &AuthConfig, -) -> Option { +pub fn get_credential(provider: &Provider, auth_config: &AuthConfig) -> Option { // Check file credentials first (higher precedence) if let Some(cred) = auth_config.get(provider) { return Some(CredentialInfo { @@ -246,9 +233,70 @@ pub fn get_credential( None } -/// Get just the token string for a provider (for backward compatibility). -pub fn get_token(provider: &Provider, auth_config: &AuthConfig) -> Option { - get_credential(provider, auth_config).map(|info| info.credential.get_token().to_string()) +/// Get token for a provider, automatically refreshing OAuth tokens if expired. +/// +/// This is the preferred method for getting tokens as it handles expiration. +/// If the token is refreshed, the new credential is saved to the config file. +pub async fn get_or_refresh_token( + provider: &Provider, + auth_config: &mut AuthConfig, + config_base_path: &Path, +) -> Result { + // Get credential info + let info = get_credential(provider, auth_config) + .ok_or_else(|| anyhow::anyhow!("No credential found for {}", provider))?; + + // If from environment, just return the token (can't refresh env vars) + if matches!(info.source, CredentialSource::Environment(_)) { + return Ok(info.credential.get_token().to_string()); + } + + // Check if OAuth token is expired + if info.credential.is_expired() { + if let Some(refresh_tok) = info.credential.refresh_token() { + // Attempt to refresh + eprintln!("OAuth token expired, refreshing..."); + match refresh_oauth_token(provider, refresh_tok).await { + Ok(new_credential) => { + let token = new_credential.get_token().to_string(); + auth_config.set(provider, new_credential); + auth_config.save(config_base_path)?; + eprintln!("Token refreshed successfully."); + return Ok(token); + } + Err(e) => { + // Refresh failed - user needs to re-authenticate + return Err(anyhow::anyhow!( + "OAuth token expired and refresh failed: {}. Please run 'rullm auth login {}'", + e, + provider + )); + } + } + } + } + + Ok(info.credential.get_token().to_string()) +} + +/// Refresh an OAuth token for a specific provider. +async fn refresh_oauth_token(provider: &Provider, refresh_token: &str) -> Result { + use crate::oauth::{anthropic::AnthropicOAuth, openai::OpenAIOAuth}; + + match provider { + Provider::Anthropic => { + let oauth = AnthropicOAuth::new(); + oauth.refresh_token(refresh_token).await + } + Provider::OpenAI => { + let oauth = OpenAIOAuth::new(); + oauth.refresh_token(refresh_token).await + } + _ => Err(anyhow::anyhow!( + "Provider {} does not support OAuth token refresh", + provider + )), + } } #[cfg(test)] @@ -263,7 +311,7 @@ mod tests { "refresh".to_string(), u64::MAX, // Far future ); - assert!(cred.is_oauth()); + assert!(matches!(cred, Credential::OAuth { .. })); assert_eq!(cred.get_token(), "access"); assert_eq!(cred.refresh_token(), Some("refresh")); assert!(!cred.is_expired()); @@ -272,7 +320,7 @@ mod tests { #[test] fn test_credential_api() { let cred = Credential::api("sk-test-key".to_string()); - assert!(!cred.is_oauth()); + assert!(matches!(cred, Credential::Api { .. })); assert_eq!(cred.get_token(), "sk-test-key"); assert_eq!(cred.refresh_token(), None); assert!(!cred.is_expired()); diff --git a/crates/rullm-cli/src/client.rs b/crates/rullm-cli/src/client.rs index 451158c1..77ba66c5 100644 --- a/crates/rullm-cli/src/client.rs +++ b/crates/rullm-cli/src/client.rs @@ -209,9 +209,15 @@ pub fn create_client( } } -/// Create a SimpleLlmClient from a model string, CLI arguments, and configuration -/// This is the promoted version of the create_client_from_model closure from lib.rs -pub fn from_model(model_str: &str, cli: &Cli, cli_config: &CliConfig) -> Result { +/// Create a SimpleLlmClient from a model string, CLI arguments, and configuration. +/// +/// This function handles OAuth token refresh automatically if the token is expired. +/// The `cli_config` is mutable because refreshing a token requires saving the new credential. +pub async fn from_model( + model_str: &str, + cli: &Cli, + cli_config: &mut CliConfig, +) -> Result { // Use the global alias resolver for CLI functionality let resolver = crate::aliases::get_global_alias_resolver(&cli_config.config_base_path); let resolver = resolver @@ -221,9 +227,17 @@ pub fn from_model(model_str: &str, cli: &Cli, cli_config: &CliConfig) -> Result< .resolve(model_str) .context("Invalid model format")?; - let token = auth::get_token(&provider, &cli_config.auth_config).ok_or_else(|| { + // Get token with automatic refresh for OAuth + let token = auth::get_or_refresh_token( + &provider, + &mut cli_config.auth_config, + &cli_config.config_base_path, + ) + .await + .map_err(|e| { anyhow::anyhow!( - "Credentials required. Run 'rullm auth login {}' or set {} environment variable", + "{}. Run 'rullm auth login {}' or set {} environment variable", + e, provider, provider.env_key() ) diff --git a/crates/rullm-cli/src/commands/auth.rs b/crates/rullm-cli/src/commands/auth.rs index 953641dc..9fb0d630 100644 --- a/crates/rullm-cli/src/commands/auth.rs +++ b/crates/rullm-cli/src/commands/auth.rs @@ -114,10 +114,7 @@ impl AuthArgs { auth_config.remove(&provider); auth_config.save(config_base_path)?; - crate::output::success( - &format!("Logged out from {provider}"), - output_level, - ); + crate::output::success(&format!("Logged out from {provider}"), output_level); } AuthAction::List => { @@ -147,9 +144,10 @@ fn select_provider() -> Result { let mut input = String::new(); io::stdin().read_line(&mut input)?; - let choice: usize = input.trim().parse().map_err(|_| { - anyhow::anyhow!("Invalid selection") - })?; + let choice: usize = input + .trim() + .parse() + .map_err(|_| anyhow::anyhow!("Invalid selection"))?; if choice == 0 || choice > providers.len() { anyhow::bail!("Invalid selection"); @@ -212,7 +210,7 @@ fn format_provider_display(provider: &Provider) -> &'static str { } /// Print the credentials list in a nice format. -fn print_credentials_list(auth_config: &AuthConfig, _output_level: OutputLevel) { +fn print_credentials_list(auth_config: &AuthConfig, output_level: OutputLevel) { let mut file_creds: Vec<(Provider, String)> = Vec::new(); let mut env_creds: Vec<(Provider, String)> = Vec::new(); @@ -231,46 +229,68 @@ fn print_credentials_list(auth_config: &AuthConfig, _output_level: OutputLevel) // Print file credentials section if !file_creds.is_empty() { + let base_strategy = etcetera::choose_base_strategy().unwrap(); let config_path = auth::auth_config_path( - &etcetera::choose_base_strategy() - .unwrap() + &base_strategy .config_dir() .join(crate::constants::BINARY_NAME), ); - println!("\n\u{250c} Credentials {}", config_path.display()); - println!("\u{2502}"); + + // Display path with ~ for home directory + let display_path = config_path + .strip_prefix(base_strategy.home_dir()) + .map(|p| format!("~/{}", p.display())) + .unwrap_or_else(|_| config_path.display().to_string()); + + crate::output::heading(&format!("Credentials ({display_path}):"), output_level); for (provider, cred_type) in &file_creds { - println!( - "\u{25cf} {} {}", - format_provider_display(provider), - cred_type + crate::output::note( + &format!( + " {}: {}", + crate::output::format_provider(format_provider_display(provider)), + cred_type + ), + output_level, ); - println!("\u{2502}"); } - - println!("\u{2514} {} credentials", file_creds.len()); } // Print environment variables section if !env_creds.is_empty() { - println!("\n\u{250c} Environment"); - println!("\u{2502}"); + if !file_creds.is_empty() { + crate::output::note("", output_level); + } + crate::output::heading("Environment variables:", output_level); for (provider, env_key) in &env_creds { - println!( - "\u{25cf} {} {}", - format_provider_display(provider), - env_key + crate::output::note( + &format!( + " {}: {}", + crate::output::format_provider(format_provider_display(provider)), + env_key + ), + output_level, ); - println!("\u{2502}"); } - - println!("\u{2514} {} environment variables", env_creds.len()); } - if file_creds.is_empty() && env_creds.is_empty() { - println!("\nNo credentials configured."); - println!("Use 'rullm auth login' to add credentials."); + // Print summary + let total = file_creds.len() + env_creds.len(); + if total > 0 { + crate::output::note("", output_level); + crate::output::note( + &format!("{} credential(s) configured.", total), + output_level, + ); + } else { + crate::output::note("No credentials configured.", output_level); + crate::output::hint( + &format!( + "Run {} to add credentials", + crate::output::format_command("rullm auth login") + ), + output_level, + ); } } diff --git a/crates/rullm-cli/src/commands/chat/mod.rs b/crates/rullm-cli/src/commands/chat/mod.rs index 897b0115..3c4af395 100644 --- a/crates/rullm-cli/src/commands/chat/mod.rs +++ b/crates/rullm-cli/src/commands/chat/mod.rs @@ -35,11 +35,11 @@ impl ChatArgs { pub async fn run( &self, _output_level: OutputLevel, - cli_config: &CliConfig, + cli_config: &mut CliConfig, cli: &Cli, ) -> Result<()> { let model_str = resolve_model(&cli.model, &self.model, &cli_config.config.default_model)?; - let client = client::from_model(&model_str, cli, cli_config)?; + let client = client::from_model(&model_str, cli, cli_config).await?; run_interactive_chat(&client, None, cli_config, !cli.no_streaming).await?; Ok(()) } diff --git a/crates/rullm-cli/src/commands/models.rs b/crates/rullm-cli/src/commands/models.rs index 19c99ae7..b2cabaec 100644 --- a/crates/rullm-cli/src/commands/models.rs +++ b/crates/rullm-cli/src/commands/models.rs @@ -75,7 +75,7 @@ impl ModelsArgs { let provider = format!("{provider}"); // Try to create a client for this provider let model_hint = format!("{provider}:dummy"); // dummy model name, just to get the client - let client = match client::from_model(&model_hint, cli, cli_config) { + let client = match client::from_model(&model_hint, cli, cli_config).await { Ok(c) => c, Err(_) => { skipped.push(provider); diff --git a/crates/rullm-cli/src/main.rs b/crates/rullm-cli/src/main.rs index 674b462e..dc5bc899 100644 --- a/crates/rullm-cli/src/main.rs +++ b/crates/rullm-cli/src/main.rs @@ -81,7 +81,7 @@ pub async fn run() -> Result<()> { // Handle commands match &cli.command { - Some(Commands::Chat(args)) => args.run(output_level, &cli_config, &cli).await?, + Some(Commands::Chat(args)) => args.run(output_level, &mut cli_config, &cli).await?, Some(Commands::Models(args)) => args.run(output_level, &mut cli_config, &cli).await?, Some(Commands::Info(args)) => args.run(output_level, &cli_config, &cli).await?, Some(Commands::Auth(args)) => args.run(output_level, &cli_config.config_base_path).await?, @@ -92,7 +92,7 @@ pub async fn run() -> Result<()> { if let Some(query) = &cli.query { let model_str = resolve_direct_query_model(&cli.model, &cli_config.config.default_model)?; - let client = client::from_model(&model_str, &cli, &cli_config)?; + let client = client::from_model(&model_str, &cli, &mut cli_config).await?; // Handle template if provided let (system_prompt, final_query) = if let Some(template_name) = &cli.template { diff --git a/crates/rullm-cli/src/oauth/anthropic.rs b/crates/rullm-cli/src/oauth/anthropic.rs index 0c2ae43f..80ffbfef 100644 --- a/crates/rullm-cli/src/oauth/anthropic.rs +++ b/crates/rullm-cli/src/oauth/anthropic.rs @@ -27,7 +27,9 @@ impl Default for AnthropicOAuth { fn default() -> Self { Self { authorization_url: "https://console.anthropic.com/oauth/authorize", - token_url: "https://api.anthropic.com/oauth/token", + // The token endpoint lives on the console domain (not the public API) + // and requires the `/v1` prefix; posting to the API host returns 404. + token_url: "https://console.anthropic.com/v1/oauth/token", client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", callback_port: 8765, scopes: &["org:create_api_key", "user:profile", "user:inference"], @@ -61,6 +63,8 @@ struct TokenRequest<'a> { code: &'a str, redirect_uri: &'a str, code_verifier: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + state: Option<&'a str>, } impl AnthropicOAuth { @@ -70,15 +74,19 @@ impl AnthropicOAuth { } /// Build the authorization URL for the OAuth flow. - pub fn build_authorization_url(&self, pkce: &PkceChallenge, state: &str) -> String { - let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); + fn build_authorization_url( + &self, + pkce: &PkceChallenge, + state: &str, + redirect_uri: &str, + ) -> String { let scope = self.scopes.join(" "); format!( "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}", self.authorization_url, urlencoding::encode(self.client_id), - urlencoding::encode(&redirect_uri), + urlencoding::encode(redirect_uri), urlencoding::encode(&scope), urlencoding::encode(&pkce.challenge), pkce.method(), @@ -101,11 +109,13 @@ impl AnthropicOAuth { let state = generate_state(); // Start callback server - let server = CallbackServer::new(self.callback_port) - .context("Failed to start callback server")?; + let server = + CallbackServer::new(self.callback_port).context("Failed to start callback server")?; + + let redirect_uri = server.redirect_uri(); // Build and open authorization URL - let auth_url = self.build_authorization_url(&pkce, &state); + let auth_url = self.build_authorization_url(&pkce, &state, &redirect_uri); println!("Opening browser for Anthropic authentication..."); webbrowser::open(&auth_url).context("Failed to open browser")?; @@ -123,28 +133,40 @@ impl AnthropicOAuth { } // Exchange code for tokens - let credential = self.exchange_code(&callback.code, &pkce.verifier).await?; + let credential = self + .exchange_code( + &callback.code, + &pkce.verifier, + &redirect_uri, + callback.state.as_deref(), + ) + .await?; Ok(credential) } /// Exchange authorization code for tokens. - async fn exchange_code(&self, code: &str, code_verifier: &str) -> Result { - let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); - + async fn exchange_code( + &self, + code: &str, + code_verifier: &str, + redirect_uri: &str, + state: Option<&str>, + ) -> Result { let request_body = TokenRequest { grant_type: "authorization_code", client_id: self.client_id, code, - redirect_uri: &redirect_uri, + redirect_uri, code_verifier, + state, }; let client = reqwest::Client::new(); let response = client .post(self.token_url) - .header("Content-Type", "application/x-www-form-urlencoded") - .body(serde_urlencoded::to_string(&request_body)?) + // Anthropic expects JSON payloads for the token exchange. + .json(&request_body) .send() .await .context("Failed to send token request")?; @@ -184,8 +206,7 @@ impl AnthropicOAuth { let client = reqwest::Client::new(); let response = client .post(self.token_url) - .header("Content-Type", "application/x-www-form-urlencoded") - .body(serde_urlencoded::to_string(&request_body)?) + .json(&request_body) .send() .await .context("Failed to send refresh request")?; @@ -231,8 +252,9 @@ mod tests { let oauth = AnthropicOAuth::new(); let pkce = PkceChallenge::generate(); let state = "test-state"; + let redirect_uri = "http://localhost:8765/callback"; - let url = oauth.build_authorization_url(&pkce, state); + let url = oauth.build_authorization_url(&pkce, state, redirect_uri); assert!(url.starts_with("https://console.anthropic.com/oauth/authorize")); assert!(url.contains("response_type=code")); diff --git a/crates/rullm-cli/src/oauth/mod.rs b/crates/rullm-cli/src/oauth/mod.rs index 93bbe19a..2c1dc24a 100644 --- a/crates/rullm-cli/src/oauth/mod.rs +++ b/crates/rullm-cli/src/oauth/mod.rs @@ -9,4 +9,4 @@ pub mod anthropic; pub mod openai; pub use pkce::PkceChallenge; -pub use server::{CallbackResult, CallbackServer}; +pub use server::CallbackServer; diff --git a/crates/rullm-cli/src/oauth/openai.rs b/crates/rullm-cli/src/oauth/openai.rs index 78eef86a..8d1d8827 100644 --- a/crates/rullm-cli/src/oauth/openai.rs +++ b/crates/rullm-cli/src/oauth/openai.rs @@ -13,6 +13,10 @@ use crate::auth::Credential; pub struct OpenAIOAuth { /// Issuer URL for discovery pub issuer_url: &'static str, + /// Public client ID for Codex CLI (OpenAI consumer OAuth) + pub client_id: &'static str, + /// Authorization scopes requested + pub scopes: &'static [&'static str], /// Callback port pub callback_port: u16, } @@ -21,17 +25,19 @@ impl Default for OpenAIOAuth { fn default() -> Self { Self { issuer_url: "https://auth.openai.com", + client_id: "app_EMoamEEZ73f0CkXaXp7hrann", + scopes: &["openid", "profile", "email", "offline_access"], callback_port: 1455, } } } -/// OAuth authorization server metadata from discovery. +/// OpenID Connect discovery document (subset used). #[derive(Debug, Deserialize)] -struct AuthorizationServerMetadata { +#[allow(dead_code)] +struct OpenIdConfiguration { authorization_endpoint: String, token_endpoint: String, - #[allow(dead_code)] issuer: String, } @@ -49,6 +55,7 @@ struct TokenResponse { #[derive(Debug, Serialize)] struct TokenRequest<'a> { grant_type: &'static str, + client_id: &'a str, code: &'a str, redirect_uri: &'a str, code_verifier: &'a str, @@ -58,6 +65,7 @@ struct TokenRequest<'a> { #[derive(Debug, Serialize)] struct RefreshRequest<'a> { grant_type: &'static str, + client_id: &'a str, refresh_token: &'a str, } @@ -68,15 +76,13 @@ impl OpenAIOAuth { } /// Discover OAuth endpoints from the authorization server. - async fn discover(&self) -> Result { - let discovery_url = format!( - "{}/.well-known/oauth-authorization-server", - self.issuer_url - ); + async fn discover(&self) -> Result { + let discovery_url = format!("{}/.well-known/openid-configuration", self.issuer_url); let client = reqwest::Client::new(); let response = client .get(&discovery_url) + .header("Accept", "application/json") .send() .await .with_context(|| format!("Failed to fetch OAuth discovery from {}", discovery_url))?; @@ -99,13 +105,15 @@ impl OpenAIOAuth { authorization_endpoint: &str, pkce: &PkceChallenge, state: &str, + redirect_uri: &str, ) -> String { - let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); - + let scope = self.scopes.join(" "); format!( - "{}?response_type=code&redirect_uri={}&code_challenge={}&code_challenge_method={}&state={}", + "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=codex_cli_rs", authorization_endpoint, - urlencoding::encode(&redirect_uri), + urlencoding::encode(self.client_id), + urlencoding::encode(redirect_uri), + urlencoding::encode(&scope), urlencoding::encode(&pkce.challenge), pkce.method(), urlencoding::encode(state) @@ -123,7 +131,7 @@ impl OpenAIOAuth { pub async fn login(&self) -> Result { // Discover OAuth endpoints println!("Discovering OpenAI OAuth endpoints..."); - let metadata = self.discover().await?; + let _metadata = self.discover().await?; // Generate PKCE challenge let pkce = PkceChallenge::generate(); @@ -132,11 +140,19 @@ impl OpenAIOAuth { let state = generate_state(); // Start callback server - let server = CallbackServer::new(self.callback_port) - .context("Failed to start callback server")?; + let server = + CallbackServer::new(self.callback_port).context("Failed to start callback server")?; + + // OpenAI's public client is registered with /auth/callback for localhost. + let redirect_uri = server.redirect_uri_with_path("/auth/callback"); // Build and open authorization URL - let auth_url = self.build_authorization_url(&metadata.authorization_endpoint, &pkce, &state); + // The public web consumer OAuth lives at /oauth/authorize (different from OIDC discovery value). + let authorization_endpoint = format!("{}/oauth/authorize", self.issuer_url); + // The working token endpoint for consumer OAuth lives on auth.openai.com, not the discovery host. + let token_endpoint = format!("{}/oauth/token", self.issuer_url); + let auth_url = + self.build_authorization_url(&authorization_endpoint, &pkce, &state, &redirect_uri); println!("Opening browser for OpenAI authentication..."); webbrowser::open(&auth_url).context("Failed to open browser")?; @@ -155,7 +171,12 @@ impl OpenAIOAuth { // Exchange code for tokens let credential = self - .exchange_code(&metadata.token_endpoint, &callback.code, &pkce.verifier) + .exchange_code( + &token_endpoint, + &callback.code, + &pkce.verifier, + &redirect_uri, + ) .await?; Ok(credential) @@ -167,13 +188,13 @@ impl OpenAIOAuth { token_endpoint: &str, code: &str, code_verifier: &str, + redirect_uri: &str, ) -> Result { - let redirect_uri = format!("http://localhost:{}/callback", self.callback_port); - let request_body = TokenRequest { grant_type: "authorization_code", + client_id: self.client_id, code, - redirect_uri: &redirect_uri, + redirect_uri, code_verifier, }; @@ -217,17 +238,18 @@ impl OpenAIOAuth { /// Refresh an expired OAuth token. pub async fn refresh_token(&self, refresh_token: &str) -> Result { - // Need to discover endpoints first - let metadata = self.discover().await?; + // Use the known consumer token endpoint (discovery token endpoint points to auth0.openai.com and fails). + let token_endpoint = format!("{}/oauth/token", self.issuer_url); let request_body = RefreshRequest { grant_type: "refresh_token", + client_id: self.client_id, refresh_token, }; let client = reqwest::Client::new(); let response = client - .post(&metadata.token_endpoint) + .post(&token_endpoint) .header("Content-Type", "application/x-www-form-urlencoded") .body(serde_urlencoded::to_string(&request_body)?) .send() @@ -279,6 +301,11 @@ mod tests { let oauth = OpenAIOAuth::new(); assert_eq!(oauth.callback_port, 1455); assert_eq!(oauth.issuer_url, "https://auth.openai.com"); + assert_eq!(oauth.client_id, "app_EMoamEEZ73f0CkXaXp7hrann"); + assert_eq!( + oauth.scopes, + ["openid", "profile", "email", "offline_access"] + ); } #[test] @@ -286,18 +313,24 @@ mod tests { let oauth = OpenAIOAuth::new(); let pkce = PkceChallenge::generate(); let state = "test-state"; + let redirect_uri = "http://localhost:1455/auth/callback"; let url = oauth.build_authorization_url( - "https://auth.openai.com/authorize", + "https://auth.openai.com/oauth/authorize", &pkce, state, + redirect_uri, ); - assert!(url.starts_with("https://auth.openai.com/authorize")); + assert!(url.starts_with("https://auth.openai.com/oauth/authorize")); assert!(url.contains("response_type=code")); + assert!(url.contains("client_id=app_EMoamEEZ73f0CkXaXp7hrann")); + assert!(url.contains("scope=openid%20profile%20email%20offline_access")); assert!(url.contains("redirect_uri=")); assert!(url.contains("code_challenge=")); assert!(url.contains("code_challenge_method=S256")); assert!(url.contains("state=test-state")); + assert!(url.contains("codex_cli_simplified_flow=true")); + assert!(url.contains("originator=codex_cli_rs")); } } diff --git a/crates/rullm-cli/src/oauth/pkce.rs b/crates/rullm-cli/src/oauth/pkce.rs index d2a02849..7086212d 100644 --- a/crates/rullm-cli/src/oauth/pkce.rs +++ b/crates/rullm-cli/src/oauth/pkce.rs @@ -33,7 +33,10 @@ impl PkceChallenge { let hash = hasher.finalize(); let challenge = URL_SAFE_NO_PAD.encode(hash); - Self { verifier, challenge } + Self { + verifier, + challenge, + } } /// Get the challenge method (always "S256"). diff --git a/crates/rullm-cli/src/oauth/server.rs b/crates/rullm-cli/src/oauth/server.rs index 325cf124..065ee4f7 100644 --- a/crates/rullm-cli/src/oauth/server.rs +++ b/crates/rullm-cli/src/oauth/server.rs @@ -26,8 +26,8 @@ impl CallbackServer { /// Create a new callback server on the specified port. pub fn new(port: u16) -> Result { let addr = format!("127.0.0.1:{port}"); - let listener = TcpListener::bind(&addr) - .with_context(|| format!("Failed to bind to {addr}"))?; + let listener = + TcpListener::bind(&addr).with_context(|| format!("Failed to bind to {addr}"))?; // Get the actual port if 0 was specified let actual_port = listener.local_addr()?.port(); @@ -43,6 +43,16 @@ impl CallbackServer { format!("http://localhost:{}/callback", self.port) } + /// Build a redirect URI using a custom path (must start with '/'). + pub fn redirect_uri_with_path(&self, path: &str) -> String { + let normalized = if path.starts_with('/') { + path.to_string() + } else { + format!("/{}", path) + }; + format!("http://localhost:{}{}", self.port, normalized) + } + /// Wait for the OAuth callback and extract the authorization code. /// /// This blocks until a request is received or the timeout is reached. @@ -208,10 +218,7 @@ mod tests { CallbackServer::extract_query_param(path, "state"), Some("def".to_string()) ); - assert_eq!( - CallbackServer::extract_query_param(path, "missing"), - None - ); + assert_eq!(CallbackServer::extract_query_param(path, "missing"), None); } #[test] From 631c4cb30173acddc990c23d5670f3d52be08139 Mon Sep 17 00:00:00 2001 From: lambda Date: Thu, 4 Dec 2025 23:15:12 +0530 Subject: [PATCH 03/11] fix lints --- crates/rullm-cli/src/auth.rs | 28 +++++++++++++++++----------- crates/rullm-cli/src/client.rs | 17 ++++++++++------- crates/rullm-cli/src/oauth/openai.rs | 4 +--- crates/rullm-cli/src/oauth/server.rs | 4 +--- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs index 4bd7935c..dcee5b5a 100644 --- a/crates/rullm-cli/src/auth.rs +++ b/crates/rullm-cli/src/auth.rs @@ -338,13 +338,15 @@ mod tests { #[test] fn test_auth_config_serialization() { - let mut config = AuthConfig::default(); - config.anthropic = Some(Credential::oauth( - "sk-ant-oat01-test".to_string(), - "sk-ant-ort01-test".to_string(), - 1764813330304, - )); - config.openai = Some(Credential::api("sk-proj-test".to_string())); + let config = AuthConfig { + anthropic: Some(Credential::oauth( + "sk-ant-oat01-test".to_string(), + "sk-ant-ort01-test".to_string(), + 1764813330304, + )), + openai: Some(Credential::api("sk-proj-test".to_string())), + ..Default::default() + }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -360,8 +362,10 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let config_path = temp_dir.path(); - let mut config = AuthConfig::default(); - config.groq = Some(Credential::api("test-groq-key".to_string())); + let config = AuthConfig { + groq: Some(Credential::api("test-groq-key".to_string())), + ..Default::default() + }; config.save(config_path).unwrap(); @@ -374,8 +378,10 @@ mod tests { #[test] fn test_get_credential_file_precedence() { - let mut config = AuthConfig::default(); - config.anthropic = Some(Credential::api("file-key".to_string())); + let config = AuthConfig { + anthropic: Some(Credential::api("file-key".to_string())), + ..Default::default() + }; // File credential should be returned let info = get_credential(&Provider::Anthropic, &config).unwrap(); diff --git a/crates/rullm-cli/src/client.rs b/crates/rullm-cli/src/client.rs index 77ba66c5..bade2eb5 100644 --- a/crates/rullm-cli/src/client.rs +++ b/crates/rullm-cli/src/client.rs @@ -219,13 +219,16 @@ pub async fn from_model( cli_config: &mut CliConfig, ) -> Result { // Use the global alias resolver for CLI functionality - let resolver = crate::aliases::get_global_alias_resolver(&cli_config.config_base_path); - let resolver = resolver - .read() - .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on global resolver"))?; - let (provider, model_name) = resolver - .resolve(model_str) - .context("Invalid model format")?; + // Resolve provider and model inside a block so the lock is dropped before the await + let (provider, model_name) = { + let resolver = crate::aliases::get_global_alias_resolver(&cli_config.config_base_path); + let resolver = resolver + .read() + .map_err(|_| anyhow::anyhow!("Failed to acquire read lock on global resolver"))?; + resolver + .resolve(model_str) + .context("Invalid model format")? + }; // Get token with automatic refresh for OAuth let token = auth::get_or_refresh_token( diff --git a/crates/rullm-cli/src/oauth/openai.rs b/crates/rullm-cli/src/oauth/openai.rs index 8d1d8827..b9fca518 100644 --- a/crates/rullm-cli/src/oauth/openai.rs +++ b/crates/rullm-cli/src/oauth/openai.rs @@ -225,9 +225,7 @@ impl OpenAIOAuth { .unwrap_or(0); // OpenAI might not always return a refresh token - let refresh_token = token_response - .refresh_token - .unwrap_or_else(|| String::new()); + let refresh_token = token_response.refresh_token.unwrap_or_default(); Ok(Credential::oauth( token_response.access_token, diff --git a/crates/rullm-cli/src/oauth/server.rs b/crates/rullm-cli/src/oauth/server.rs index 065ee4f7..3a8ab07e 100644 --- a/crates/rullm-cli/src/oauth/server.rs +++ b/crates/rullm-cli/src/oauth/server.rs @@ -160,9 +160,7 @@ impl CallbackServer { fn extract_query_param(path: &str, param: &str) -> Option { let query = path.split('?').nth(1)?; for pair in query.split('&') { - let mut kv = pair.splitn(2, '='); - let key = kv.next()?; - let value = kv.next()?; + let (key, value) = pair.split_once('=')?; if key == param { // URL decode the value return Some(urlencoding::decode(value).ok()?.into_owned()); From 0031a7e07f825b1c3b94ce7e9d90947bd8236c62 Mon Sep 17 00:00:00 2001 From: lambda Date: Thu, 4 Dec 2025 23:58:55 +0530 Subject: [PATCH 04/11] move to AGENTS.md and rework the whole file --- AGENTS.md | 55 ++++++++++++++++++ CLAUDE.md | 166 +----------------------------------------------------- 2 files changed, 56 insertions(+), 165 deletions(-) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..7ed07507 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,55 @@ +# rullm + +## Build Commands + +```bash +# Build all crates +cargo build --all + +# Lint (format + clippy) +just lint + +# Format only +just fmt + +# Check examples compile +cargo check --examples + +# Run tests +cargo test +``` + +## Project Structure + +This is a Rust workspace with two crates: + +- **rullm-core** (`crates/rullm-core/`) - Core library for LLM provider interactions +- **rullm-cli** (`crates/rullm-cli/`) - CLI binary for querying LLMs + +### Core Library Architecture + +The core library uses a trait-based provider system with two API levels: + +1. **Simple API** - String-based, minimal configuration +2. **Advanced API** - Full control with `ChatRequestBuilder` + +Key modules: +- `providers/` - Provider implementations (OpenAI, Anthropic, Google, OpenAI-compatible) +- `compat_types.rs` - OpenAI-compatible message/response types used across providers +- `config.rs` - Provider configuration traits and builders +- `error.rs` - `LlmError` enum with comprehensive error variants +- `utils/sse.rs` - Server-sent event parsing for streaming + +### CLI Architecture + +The CLI is organized by commands in `commands/`: +- `auth.rs` - OAuth and API key management +- `chat.rs` - Interactive chat mode with reedline +- `models.rs` - Model listing and updates +- `alias.rs` - User-defined model aliases +- `templates.rs` - TOML template management + +OAuth implementation in `oauth/`: +- `openai.rs`, `anthropic.rs` - Provider-specific OAuth flows +- `server.rs` - Local callback server for OAuth redirects +- `pkce.rs` - PKCE challenge generation diff --git a/CLAUDE.md b/CLAUDE.md index 763940cd..43c994c2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,165 +1 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -`rullm` is a Rust library and CLI for interacting with multiple LLM providers (OpenAI, Anthropic, Google AI, Groq, OpenRouter). The project uses a workspace structure with two main crates: - -- **rullm-core**: Core library implementing provider integrations, middleware (Tower-based), and streaming support -- **rullm-cli**: Command-line interface built on top of rullm-core - -## Architecture - -### Provider System - -All LLM providers implement two core traits defined in `crates/rullm-core/src/types.rs`: -- `LlmProvider`: Base trait with provider metadata (name, aliases, env_key, default_base_url, available_models, health_check) -- `ChatCompletion`: Extends LlmProvider with chat completion methods (blocking and streaming) - -Provider implementations are in `crates/rullm-core/src/providers/`: -- `openai.rs`: OpenAI GPT models -- `anthropic.rs`: Anthropic Claude models -- `google.rs`: Google Gemini models -- `openai_compatible.rs`: Generic provider for OpenAI-compatible APIs -- `groq.rs`: Groq provider (uses `openai_compatible`) -- `openrouter.rs`: OpenRouter provider (uses `openai_compatible`) - -The `openai_compatible` provider is a generic implementation that other providers like Groq and OpenRouter extend. It uses a `ProviderIdentity` struct to define provider-specific metadata. - -### Middleware Stack - -The library uses Tower middleware (see `crates/rullm-core/src/middleware.rs`): -- Rate limiting -- Timeouts -- Connection pooling -- Logging and metrics - -Configuration is done via `MiddlewareConfig` and `LlmServiceBuilder`. - -### Simple API - -`crates/rullm-core/src/simple.rs` provides a simplified string-based API (`SimpleLlmClient`, `SimpleLlmBuilder`) that wraps the advanced provider APIs for ease of use. - -### CLI Architecture - -The CLI entry point is `crates/rullm-cli/src/main.rs`, which: -1. Parses arguments using clap (see `args.rs`) -2. Loads configuration from `~/.config/rullm/` (see `config.rs`) -3. Dispatches to commands in `crates/rullm-cli/src/commands/` - -Key CLI modules: -- `client.rs`: Creates provider clients from model strings (format: `provider:model`) -- `provider.rs`: Resolves provider names and aliases -- `config.rs`: Manages CLI configuration (models list, aliases, default model) -- `api_keys.rs`: Manages API key storage in system keychain -- `templates.rs`: TOML-based prompt templates with `{{input}}` placeholders -- `commands/chat/`: Interactive chat mode using reedline for advanced REPL features - -### Model Format - -Models are specified using the format `provider:model`: -- Example: `openai:gpt-4`, `anthropic:claude-3-opus-20240229`, `groq:llama-3-8b` -- The CLI resolves this via `client::from_model()` which creates the appropriate provider client - -## Common Development Tasks - -### Building and Running - -```bash -# Build everything -cargo build --all - -# Build release binary -cargo build --release - -# Run the CLI (from workspace root) -cargo run -p rullm-cli -- "your query" - -# Or after building -./target/debug/rullm "your query" -./target/release/rullm "your query" -``` - -### Testing - -```bash -# Run all tests (note: some require API keys) -cargo test --all - -# Run tests for specific crate -cargo test -p rullm-core -cargo test -p rullm-cli - -# Run a specific test -cargo test test_name - -# Check examples compile -cargo check --examples -``` - -### Code Quality - -```bash -# Format code -cargo fmt - -# Check formatting -cargo fmt -- --check - -# Run clippy (linter) -cargo clippy --all-targets --all-features -- -D warnings - -# Fix clippy suggestions automatically -cargo clippy --fix --all-targets --all-features -``` - -### Running Examples - -```bash -# Run examples from rullm-core (requires API keys) -cargo run --example openai_simple -cargo run --example anthropic_simple -cargo run --example google_simple -cargo run --example openai_stream # Streaming example -cargo run --example test_all_providers # Test all providers at once -``` - -### Adding a New Provider - -When adding a new provider: - -1. **OpenAI-compatible providers**: Use `OpenAICompatibleProvider` with a `ProviderIdentity` in `providers/openai_compatible.rs`. See `groq.rs` or `openrouter.rs` for examples. - -2. **Non-compatible providers**: Create a new file in `crates/rullm-core/src/providers/`: - - Implement `LlmProvider` and `ChatCompletion` traits - - Add provider config struct in `crates/rullm-core/src/config.rs` - - Export from `providers/mod.rs` and `lib.rs` - - Add client creation logic in `crates/rullm-cli/src/client.rs` - - Update `crates/rullm-cli/src/provider.rs` for CLI support - -3. Update `DEFAULT_MODELS` in `crates/rullm-core/src/simple.rs` if adding default model mappings - -### Streaming Implementation - -All providers should implement `chat_completion_stream()` returning `StreamResult`. The stream emits: -- `ChatStreamEvent::Token(String)`: Each token/chunk -- `ChatStreamEvent::Done`: Completion marker -- `ChatStreamEvent::Error(String)`: Errors during streaming - -See provider implementations for SSE parsing patterns using `utils::sse::sse_lines()`. - -## Configuration Files - -- **User config**: `~/.config/rullm/config.toml` (or system equivalent) - - Stores: default model, model aliases, cached models list -- **Templates**: `~/.config/rullm/templates/*.toml` -- **API keys**: Stored in system keychain via `api_keys.rs` - -## Important Notes - -- The project uses Rust edition 2024 (rust-version 1.85+) -- Model separator changed from `/` to `:` (e.g., `openai:gpt-4` not `openai/gpt-4`) -- Chat history is persisted in `~/.config/rullm/chat_history/` -- The CLI uses `reedline` for advanced REPL features (syntax highlighting, history, multiline editing) -- In chat mode: Alt+Enter for multiline, Ctrl+O for buffer editing, `/edit` to open $EDITOR +@AGENTS.md From 8971be82e50b7ad79e95933d6c12f4890dab0b7c Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 14 Dec 2025 01:14:51 +0530 Subject: [PATCH 05/11] feat(oauth): use bearer auth headers for Anthropic OAuth tokens --- .gitignore | 1 + crates/rullm-cli/src/auth.rs | 53 ++++++++++++++ crates/rullm-cli/src/cli_client.rs | 3 +- crates/rullm-cli/src/client.rs | 9 +-- crates/rullm-cli/src/oauth/anthropic.rs | 94 ++++++++++--------------- crates/rullm-core/src/config.rs | 29 +++++++- 6 files changed, 124 insertions(+), 65 deletions(-) diff --git a/.gitignore b/.gitignore index 41cfe11c..a4132301 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ TODO.md .cursor .taskmaster/ .cursorignore +docs/ diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs index dcee5b5a..ba552a7f 100644 --- a/crates/rullm-cli/src/auth.rs +++ b/crates/rullm-cli/src/auth.rs @@ -237,6 +237,10 @@ pub fn get_credential(provider: &Provider, auth_config: &AuthConfig) -> Option Result<(String, bool)> { + // Get credential info + let info = get_credential(provider, auth_config) + .ok_or_else(|| anyhow::anyhow!("No credential found for {}", provider))?; + + // If from environment, it's always an API key (not OAuth) + if matches!(info.source, CredentialSource::Environment(_)) { + return Ok((info.credential.get_token().to_string(), false)); + } + + // Check if OAuth token is expired and needs refresh + let credential = if info.credential.is_expired() { + if let Some(refresh_tok) = info.credential.refresh_token() { + eprintln!("OAuth token expired, refreshing..."); + match refresh_oauth_token(provider, refresh_tok).await { + Ok(new_credential) => { + auth_config.set(provider, new_credential.clone()); + auth_config.save(config_base_path)?; + eprintln!("Token refreshed successfully."); + new_credential + } + Err(e) => { + return Err(anyhow::anyhow!( + "OAuth token expired and refresh failed: {}. Please run 'rullm auth login {}'", + e, + provider + )); + } + } + } else { + info.credential + } + } else { + info.credential + }; + + let is_oauth = matches!(credential, Credential::OAuth { .. }); + Ok((credential.get_token().to_string(), is_oauth)) +} + /// Refresh an OAuth token for a specific provider. async fn refresh_oauth_token(provider: &Provider, refresh_token: &str) -> Result { use crate::oauth::{anthropic::AnthropicOAuth, openai::OpenAIOAuth}; diff --git a/crates/rullm-cli/src/cli_client.rs b/crates/rullm-cli/src/cli_client.rs index d221be52..8f1373f3 100644 --- a/crates/rullm-cli/src/cli_client.rs +++ b/crates/rullm-cli/src/cli_client.rs @@ -67,8 +67,9 @@ impl CliClient { api_key: impl Into, model: impl Into, config: CliConfig, + use_oauth: bool, ) -> Result { - let client_config = AnthropicConfig::new(api_key); + let client_config = AnthropicConfig::new(api_key).with_oauth(use_oauth); let client = AnthropicClient::new(client_config)?; Ok(Self::Anthropic { client, diff --git a/crates/rullm-cli/src/client.rs b/crates/rullm-cli/src/client.rs index 6a6f670e..98c7e2a5 100644 --- a/crates/rullm-cli/src/client.rs +++ b/crates/rullm-cli/src/client.rs @@ -12,6 +12,7 @@ pub fn create_client( _base_url: Option<&str>, cli: &Cli, model_name: &str, + is_oauth: bool, ) -> Result { // Build CoreCliConfig based on CLI args let mut config = CoreCliConfig::default(); @@ -39,7 +40,7 @@ pub fn create_client( Provider::OpenAI => CliClient::openai(api_key, model_name, config), Provider::Groq => CliClient::groq(api_key, model_name, config), Provider::OpenRouter => CliClient::openrouter(api_key, model_name, config), - Provider::Anthropic => CliClient::anthropic(api_key, model_name, config), + Provider::Anthropic => CliClient::anthropic(api_key, model_name, config, is_oauth), Provider::Google => CliClient::google(api_key, model_name, config), } } @@ -65,8 +66,8 @@ pub async fn from_model( .context("Invalid model format")? }; - // Get token with automatic refresh for OAuth - let token = auth::get_or_refresh_token( + // Get token with automatic refresh for OAuth, including credential type + let (token, is_oauth) = auth::get_token_with_type( &provider, &mut cli_config.auth_config, &cli_config.config_base_path, @@ -81,5 +82,5 @@ pub async fn from_model( ) })?; - create_client(&provider, &token, None, cli, &model_name).map_err(anyhow::Error::from) + create_client(&provider, &token, None, cli, &model_name, is_oauth).map_err(anyhow::Error::from) } diff --git a/crates/rullm-cli/src/oauth/anthropic.rs b/crates/rullm-cli/src/oauth/anthropic.rs index 80ffbfef..19b8513b 100644 --- a/crates/rullm-cli/src/oauth/anthropic.rs +++ b/crates/rullm-cli/src/oauth/anthropic.rs @@ -6,7 +6,8 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use std::time::Duration; -use super::{CallbackServer, PkceChallenge}; +use super::PkceChallenge; +use super::server::CallbackServer; use crate::auth::Credential; /// Anthropic OAuth configuration. @@ -17,22 +18,22 @@ pub struct AnthropicOAuth { pub token_url: &'static str, /// Client ID (Claude Code's public ID) pub client_id: &'static str, - /// Callback port - pub callback_port: u16, /// Required scopes pub scopes: &'static [&'static str], + /// Local callback port + pub callback_port: u16, } impl Default for AnthropicOAuth { fn default() -> Self { Self { - authorization_url: "https://console.anthropic.com/oauth/authorize", + authorization_url: "https://claude.ai/oauth/authorize", // The token endpoint lives on the console domain (not the public API) // and requires the `/v1` prefix; posting to the API host returns 404. token_url: "https://console.anthropic.com/v1/oauth/token", client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", - callback_port: 8765, scopes: &["org:create_api_key", "user:profile", "user:inference"], + callback_port: 8765, } } } @@ -63,8 +64,6 @@ struct TokenRequest<'a> { code: &'a str, redirect_uri: &'a str, code_verifier: &'a str, - #[serde(skip_serializing_if = "Option::is_none")] - state: Option<&'a str>, } impl AnthropicOAuth { @@ -74,74 +73,66 @@ impl AnthropicOAuth { } /// Build the authorization URL for the OAuth flow. - fn build_authorization_url( - &self, - pkce: &PkceChallenge, - state: &str, - redirect_uri: &str, - ) -> String { + fn build_authorization_url(&self, redirect_uri: &str, pkce: &PkceChallenge) -> String { let scope = self.scopes.join(" "); format!( - "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}", + "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}", self.authorization_url, urlencoding::encode(self.client_id), urlencoding::encode(redirect_uri), urlencoding::encode(&scope), urlencoding::encode(&pkce.challenge), pkce.method(), - urlencoding::encode(state) + urlencoding::encode(&pkce.verifier) // Use verifier as state (like OpenCode) ) } /// Start the OAuth flow and return the credential on success. /// - /// This will: - /// 1. Start a local callback server - /// 2. Open the browser to the authorization URL - /// 3. Wait for the callback with the authorization code - /// 4. Exchange the code for tokens + /// This opens a browser for the user to authorize, then captures + /// the callback on a local server. pub async fn login(&self) -> Result { - // Generate PKCE challenge - let pkce = PkceChallenge::generate(); - - // Generate state for CSRF protection - let state = generate_state(); - - // Start callback server + // Start local callback server let server = CallbackServer::new(self.callback_port).context("Failed to start callback server")?; let redirect_uri = server.redirect_uri(); - // Build and open authorization URL - let auth_url = self.build_authorization_url(&pkce, &state, &redirect_uri); + // Generate PKCE challenge + let pkce = PkceChallenge::generate(); + + // Build authorization URL + let auth_url = self.build_authorization_url(&redirect_uri, &pkce); println!("Opening browser for Anthropic authentication..."); - webbrowser::open(&auth_url).context("Failed to open browser")?; - println!("Waiting for authentication (timeout: 5 minutes)..."); + // Try to open browser, but don't fail if it doesn't work + if let Err(e) = webbrowser::open(&auth_url) { + println!("Could not open browser automatically: {}", e); + println!("Please open this URL manually:"); + println!("{}", auth_url); + } - // Wait for callback + // Wait for the callback + println!("Waiting for authorization callback..."); let callback = server .wait_for_callback(Duration::from_secs(300)) .context("Failed to receive OAuth callback")?; - // Verify state - if callback.state.as_deref() != Some(&state) { - anyhow::bail!("State mismatch in OAuth callback (possible CSRF attack)"); + // Verify state matches (we use verifier as state) + if let Some(state) = &callback.state { + if state != &pkce.verifier { + anyhow::bail!("State mismatch in OAuth callback"); + } } // Exchange code for tokens let credential = self - .exchange_code( - &callback.code, - &pkce.verifier, - &redirect_uri, - callback.state.as_deref(), - ) + .exchange_code(&callback.code, &redirect_uri, &pkce.verifier) .await?; + println!("Authentication successful!"); Ok(credential) } @@ -149,9 +140,8 @@ impl AnthropicOAuth { async fn exchange_code( &self, code: &str, - code_verifier: &str, redirect_uri: &str, - state: Option<&str>, + code_verifier: &str, ) -> Result { let request_body = TokenRequest { grant_type: "authorization_code", @@ -159,7 +149,6 @@ impl AnthropicOAuth { code, redirect_uri, code_verifier, - state, }; let client = reqwest::Client::new(); @@ -235,14 +224,6 @@ impl AnthropicOAuth { } } -/// Generate a random state string for CSRF protection. -fn generate_state() -> String { - use rand::RngCore; - let mut bytes = [0u8; 16]; - rand::rng().fill_bytes(&mut bytes); - hex::encode(bytes) -} - #[cfg(test)] mod tests { use super::*; @@ -251,24 +232,21 @@ mod tests { fn test_build_authorization_url() { let oauth = AnthropicOAuth::new(); let pkce = PkceChallenge::generate(); - let state = "test-state"; - let redirect_uri = "http://localhost:8765/callback"; - let url = oauth.build_authorization_url(&pkce, state, redirect_uri); + let url = oauth.build_authorization_url("http://localhost:8765/callback", &pkce); - assert!(url.starts_with("https://console.anthropic.com/oauth/authorize")); + assert!(url.starts_with("https://claude.ai/oauth/authorize")); assert!(url.contains("response_type=code")); assert!(url.contains("client_id=")); assert!(url.contains("redirect_uri=")); assert!(url.contains("code_challenge=")); assert!(url.contains("code_challenge_method=S256")); - assert!(url.contains("state=test-state")); + assert!(url.contains("state=")); } #[test] fn test_default_config() { let oauth = AnthropicOAuth::new(); - assert_eq!(oauth.callback_port, 8765); assert!(oauth.scopes.contains(&"user:inference")); } } diff --git a/crates/rullm-core/src/config.rs b/crates/rullm-core/src/config.rs index be8a3488..d3f2ddc6 100644 --- a/crates/rullm-core/src/config.rs +++ b/crates/rullm-core/src/config.rs @@ -200,6 +200,9 @@ pub struct AnthropicConfig { pub api_key: String, pub base_url: Option, pub timeout_seconds: u64, + /// Whether to use OAuth authentication (Bearer token) instead of API key (x-api-key) + #[serde(default)] + pub use_oauth: bool, } impl AnthropicConfig { @@ -208,6 +211,7 @@ impl AnthropicConfig { api_key: api_key.into(), base_url: None, timeout_seconds: 30, + use_oauth: false, } } @@ -215,6 +219,11 @@ impl AnthropicConfig { self.base_url = Some(base_url.into()); self } + + pub fn with_oauth(mut self, use_oauth: bool) -> Self { + self.use_oauth = use_oauth; + self + } } impl ProviderConfig for AnthropicConfig { @@ -234,9 +243,25 @@ impl ProviderConfig for AnthropicConfig { fn headers(&self) -> HashMap { let mut headers = HashMap::new(); - headers.insert("x-api-key".to_string(), self.api_key.clone()); + + if self.use_oauth { + // OAuth: use Bearer token + required beta headers + // Note: OpenCode doesn't send anthropic-version for OAuth requests + headers.insert( + "Authorization".to_string(), + format!("Bearer {}", self.api_key), + ); + headers.insert( + "anthropic-beta".to_string(), + "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14".to_string(), + ); + } else { + // API key: use x-api-key header + headers.insert("x-api-key".to_string(), self.api_key.clone()); + headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); + } + headers.insert("Content-Type".to_string(), "application/json".to_string()); - headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); headers } From 835eefd410bf60ede9cdd882ef0e597acbb405ab Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 14 Dec 2025 01:24:40 +0530 Subject: [PATCH 06/11] refactor(core): move provider configs to provider modules Move AnthropicConfig, GoogleAiConfig, and OpenAICompatibleConfig from config.rs to their respective provider modules. --- crates/rullm-cli/src/cli_client.rs | 7 +- crates/rullm-core/examples/openai_config.rs | 3 +- .../rullm-core/examples/test_all_providers.rs | 6 +- crates/rullm-core/src/config.rs | 245 +----------------- crates/rullm-core/src/lib.rs | 5 +- .../src/providers/anthropic/client.rs | 3 +- .../src/providers/anthropic/config.rs | 85 ++++++ .../rullm-core/src/providers/anthropic/mod.rs | 2 + .../rullm-core/src/providers/google/client.rs | 3 +- .../rullm-core/src/providers/google/config.rs | 58 +++++ crates/rullm-core/src/providers/google/mod.rs | 2 + crates/rullm-core/src/providers/mod.rs | 5 + .../rullm-core/src/providers/openai/client.rs | 3 +- .../src/providers/openai_compatible/config.rs | 111 ++++++++ .../mod.rs} | 14 +- 15 files changed, 291 insertions(+), 261 deletions(-) create mode 100644 crates/rullm-core/src/providers/anthropic/config.rs create mode 100644 crates/rullm-core/src/providers/google/config.rs create mode 100644 crates/rullm-core/src/providers/openai_compatible/config.rs rename crates/rullm-core/src/providers/{openai_compatible.rs => openai_compatible/mod.rs} (97%) diff --git a/crates/rullm-cli/src/cli_client.rs b/crates/rullm-cli/src/cli_client.rs index 8f1373f3..4fe6b45f 100644 --- a/crates/rullm-cli/src/cli_client.rs +++ b/crates/rullm-cli/src/cli_client.rs @@ -4,9 +4,12 @@ //! basic chat operations without exposing the full complexity of each provider's API. use futures::StreamExt; -use rullm_core::config::{AnthropicConfig, GoogleAiConfig, OpenAICompatibleConfig, OpenAIConfig}; use rullm_core::error::LlmError; -use rullm_core::providers::openai_compatible::{OpenAICompatibleProvider, identities}; +use rullm_core::providers::anthropic::AnthropicConfig; +use rullm_core::providers::google::GoogleAiConfig; +use rullm_core::providers::openai_compatible::{ + OpenAICompatibleConfig, OpenAICompatibleProvider, OpenAIConfig, identities, +}; use rullm_core::providers::{AnthropicClient, GoogleClient, OpenAIClient}; use std::pin::Pin; diff --git a/crates/rullm-core/examples/openai_config.rs b/crates/rullm-core/examples/openai_config.rs index 19279a9f..2ba628d7 100644 --- a/crates/rullm-core/examples/openai_config.rs +++ b/crates/rullm-core/examples/openai_config.rs @@ -1,7 +1,8 @@ -use rullm_core::config::{OpenAIConfig, ProviderConfig}; +use rullm_core::config::ProviderConfig; use rullm_core::providers::openai::{ ChatCompletionRequest, ChatMessage, ContentPart, MessageContent, OpenAIClient, }; +use rullm_core::providers::openai_compatible::OpenAIConfig; // Helper to extract text from MessageContent fn extract_text(content: &MessageContent) -> String { diff --git a/crates/rullm-core/examples/test_all_providers.rs b/crates/rullm-core/examples/test_all_providers.rs index 87041eb6..e6fa52c0 100644 --- a/crates/rullm-core/examples/test_all_providers.rs +++ b/crates/rullm-core/examples/test_all_providers.rs @@ -1,7 +1,7 @@ -use rullm_core::config::{AnthropicConfig, GoogleAiConfig, OpenAIConfig}; -use rullm_core::providers::anthropic::AnthropicClient; -use rullm_core::providers::google::GoogleClient; +use rullm_core::providers::anthropic::{AnthropicClient, AnthropicConfig}; +use rullm_core::providers::google::{GoogleAiConfig, GoogleClient}; use rullm_core::providers::openai::OpenAIClient; +use rullm_core::providers::openai_compatible::OpenAIConfig; use std::env; #[tokio::main] diff --git a/crates/rullm-core/src/config.rs b/crates/rullm-core/src/config.rs index d3f2ddc6..8528f4ba 100644 --- a/crates/rullm-core/src/config.rs +++ b/crates/rullm-core/src/config.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Duration; +use crate::providers::{AnthropicConfig, GoogleAiConfig, OpenAICompatibleConfig, OpenAIConfig}; + /// Configuration trait for LLM providers pub trait ProviderConfig: Send + Sync { /// Get the API key for this provider @@ -88,249 +90,6 @@ impl ProviderConfig for HttpProviderConfig { } } -/// OpenAI-compatible configuration (supports OpenAI, Groq, OpenRouter, etc.) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenAICompatibleConfig { - pub api_key: String, - pub organization: Option, - pub project: Option, - pub base_url: Option, - pub timeout_seconds: u64, -} - -/// Type alias for backwards compatibility -pub type OpenAIConfig = OpenAICompatibleConfig; - -impl OpenAICompatibleConfig { - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - organization: None, - project: None, - base_url: None, - timeout_seconds: 30, - } - } - - pub fn groq(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - organization: None, - project: None, - base_url: Some("https://api.groq.com/openai/v1".to_string()), - timeout_seconds: 30, - } - } - - pub fn openrouter(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - organization: None, - project: None, - base_url: Some("https://openrouter.ai/api/v1".to_string()), - timeout_seconds: 30, - } - } - - pub fn with_organization(mut self, org: impl Into) -> Self { - self.organization = Some(org.into()); - self - } - - pub fn with_project(mut self, project: impl Into) -> Self { - self.project = Some(project.into()); - self - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = Some(base_url.into()); - self - } -} - -impl ProviderConfig for OpenAICompatibleConfig { - fn api_key(&self) -> &str { - &self.api_key - } - - fn base_url(&self) -> &str { - self.base_url - .as_deref() - .unwrap_or("https://api.openai.com/v1") - } - - fn timeout(&self) -> Duration { - Duration::from_secs(self.timeout_seconds) - } - - fn headers(&self) -> HashMap { - let mut headers = HashMap::new(); - headers.insert( - "Authorization".to_string(), - format!("Bearer {}", self.api_key), - ); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - - if let Some(org) = &self.organization { - headers.insert("OpenAI-Organization".to_string(), org.clone()); - } - - if let Some(project) = &self.project { - headers.insert("OpenAI-Project".to_string(), project.clone()); - } - - headers - } - - fn validate(&self) -> Result<(), crate::error::LlmError> { - if self.api_key.is_empty() { - return Err(crate::error::LlmError::configuration("API key is required")); - } - - // Relaxed validation: don't require 'sk-' prefix since Groq and OpenRouter use different formats - // OpenAI keys start with 'sk-', Groq uses 'gsk_', OpenRouter uses different format - - Ok(()) - } -} - -/// Anthropic-specific configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AnthropicConfig { - pub api_key: String, - pub base_url: Option, - pub timeout_seconds: u64, - /// Whether to use OAuth authentication (Bearer token) instead of API key (x-api-key) - #[serde(default)] - pub use_oauth: bool, -} - -impl AnthropicConfig { - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - base_url: None, - timeout_seconds: 30, - use_oauth: false, - } - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = Some(base_url.into()); - self - } - - pub fn with_oauth(mut self, use_oauth: bool) -> Self { - self.use_oauth = use_oauth; - self - } -} - -impl ProviderConfig for AnthropicConfig { - fn api_key(&self) -> &str { - &self.api_key - } - - fn base_url(&self) -> &str { - self.base_url - .as_deref() - .unwrap_or("https://api.anthropic.com") - } - - fn timeout(&self) -> Duration { - Duration::from_secs(self.timeout_seconds) - } - - fn headers(&self) -> HashMap { - let mut headers = HashMap::new(); - - if self.use_oauth { - // OAuth: use Bearer token + required beta headers - // Note: OpenCode doesn't send anthropic-version for OAuth requests - headers.insert( - "Authorization".to_string(), - format!("Bearer {}", self.api_key), - ); - headers.insert( - "anthropic-beta".to_string(), - "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14".to_string(), - ); - } else { - // API key: use x-api-key header - headers.insert("x-api-key".to_string(), self.api_key.clone()); - headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); - } - - headers.insert("Content-Type".to_string(), "application/json".to_string()); - headers - } - - fn validate(&self) -> Result<(), crate::error::LlmError> { - if self.api_key.is_empty() { - return Err(crate::error::LlmError::configuration( - "Anthropic API key is required", - )); - } - - Ok(()) - } -} - -/// Google AI configuration -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GoogleAiConfig { - pub api_key: String, - pub base_url: Option, - pub timeout_seconds: u64, -} - -impl GoogleAiConfig { - pub fn new(api_key: impl Into) -> Self { - Self { - api_key: api_key.into(), - base_url: None, - timeout_seconds: 30, - } - } - - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = Some(base_url.into()); - self - } -} - -impl ProviderConfig for GoogleAiConfig { - fn api_key(&self) -> &str { - &self.api_key - } - - fn base_url(&self) -> &str { - self.base_url - .as_deref() - .unwrap_or("https://generativelanguage.googleapis.com/v1beta") - } - - fn timeout(&self) -> Duration { - Duration::from_secs(self.timeout_seconds) - } - - fn headers(&self) -> HashMap { - let mut headers = HashMap::new(); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - headers - } - - fn validate(&self) -> Result<(), crate::error::LlmError> { - if self.api_key.is_empty() { - return Err(crate::error::LlmError::configuration( - "Google AI API key is required", - )); - } - - Ok(()) - } -} - /// Configuration builder for creating provider configs from environment variables pub struct ConfigBuilder; diff --git a/crates/rullm-core/src/lib.rs b/crates/rullm-core/src/lib.rs index e137a565..0f7700ab 100644 --- a/crates/rullm-core/src/lib.rs +++ b/crates/rullm-core/src/lib.rs @@ -157,10 +157,7 @@ pub mod utils; // Concrete client exports pub use providers::{AnthropicClient, GoogleClient, OpenAIClient, OpenAICompatibleProvider}; -pub use config::{ - AnthropicConfig, ConfigBuilder, GoogleAiConfig, OpenAICompatibleConfig, OpenAIConfig, - ProviderConfig, -}; +pub use config::{ConfigBuilder, HttpProviderConfig, ProviderConfig}; pub use error::LlmError; pub use utils::sse::sse_lines; diff --git a/crates/rullm-core/src/providers/anthropic/client.rs b/crates/rullm-core/src/providers/anthropic/client.rs index 2da1cf5d..4bda4938 100644 --- a/crates/rullm-core/src/providers/anthropic/client.rs +++ b/crates/rullm-core/src/providers/anthropic/client.rs @@ -1,5 +1,6 @@ +use super::config::AnthropicConfig; use super::types::*; -use crate::config::{AnthropicConfig, ProviderConfig}; +use crate::config::ProviderConfig; use crate::error::LlmError; use crate::utils::sse::sse_lines; use futures::Stream; diff --git a/crates/rullm-core/src/providers/anthropic/config.rs b/crates/rullm-core/src/providers/anthropic/config.rs new file mode 100644 index 00000000..40fe80d8 --- /dev/null +++ b/crates/rullm-core/src/providers/anthropic/config.rs @@ -0,0 +1,85 @@ +use crate::config::ProviderConfig; +use crate::error::LlmError; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// Anthropic-specific configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnthropicConfig { + pub api_key: String, + pub base_url: Option, + pub timeout_seconds: u64, + /// Whether to use OAuth authentication (Bearer token) instead of API key (x-api-key) + #[serde(default)] + pub use_oauth: bool, +} + +impl AnthropicConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + base_url: None, + timeout_seconds: 30, + use_oauth: false, + } + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } + + pub fn with_oauth(mut self, use_oauth: bool) -> Self { + self.use_oauth = use_oauth; + self + } +} + +impl ProviderConfig for AnthropicConfig { + fn api_key(&self) -> &str { + &self.api_key + } + + fn base_url(&self) -> &str { + self.base_url + .as_deref() + .unwrap_or("https://api.anthropic.com") + } + + fn timeout(&self) -> Duration { + Duration::from_secs(self.timeout_seconds) + } + + fn headers(&self) -> HashMap { + let mut headers = HashMap::new(); + + if self.use_oauth { + // OAuth: use Bearer token + required beta headers + // Note: OpenCode doesn't send anthropic-version for OAuth requests + headers.insert( + "Authorization".to_string(), + format!("Bearer {}", self.api_key), + ); + headers.insert( + "anthropic-beta".to_string(), + "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14".to_string(), + ); + } else { + // API key: use x-api-key header + headers.insert("x-api-key".to_string(), self.api_key.clone()); + headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); + } + + headers.insert("Content-Type".to_string(), "application/json".to_string()); + headers + } + + fn validate(&self) -> Result<(), LlmError> { + if self.api_key.is_empty() { + return Err(LlmError::configuration("Anthropic API key is required")); + } + + Ok(()) + } +} diff --git a/crates/rullm-core/src/providers/anthropic/mod.rs b/crates/rullm-core/src/providers/anthropic/mod.rs index 794d2e67..081ca01a 100644 --- a/crates/rullm-core/src/providers/anthropic/mod.rs +++ b/crates/rullm-core/src/providers/anthropic/mod.rs @@ -24,7 +24,9 @@ //! ``` pub mod client; +pub mod config; pub mod types; pub use client::AnthropicClient; +pub use config::AnthropicConfig; pub use types::*; diff --git a/crates/rullm-core/src/providers/google/client.rs b/crates/rullm-core/src/providers/google/client.rs index 7c1fe8f6..adbdde1f 100644 --- a/crates/rullm-core/src/providers/google/client.rs +++ b/crates/rullm-core/src/providers/google/client.rs @@ -1,5 +1,6 @@ +use super::config::GoogleAiConfig; use super::types::*; -use crate::config::{GoogleAiConfig, ProviderConfig}; +use crate::config::ProviderConfig; use crate::error::LlmError; use crate::utils::sse::sse_lines; use futures::Stream; diff --git a/crates/rullm-core/src/providers/google/config.rs b/crates/rullm-core/src/providers/google/config.rs new file mode 100644 index 00000000..236b3811 --- /dev/null +++ b/crates/rullm-core/src/providers/google/config.rs @@ -0,0 +1,58 @@ +use crate::config::ProviderConfig; +use crate::error::LlmError; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// Google AI configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GoogleAiConfig { + pub api_key: String, + pub base_url: Option, + pub timeout_seconds: u64, +} + +impl GoogleAiConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + base_url: None, + timeout_seconds: 30, + } + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } +} + +impl ProviderConfig for GoogleAiConfig { + fn api_key(&self) -> &str { + &self.api_key + } + + fn base_url(&self) -> &str { + self.base_url + .as_deref() + .unwrap_or("https://generativelanguage.googleapis.com/v1beta") + } + + fn timeout(&self) -> Duration { + Duration::from_secs(self.timeout_seconds) + } + + fn headers(&self) -> HashMap { + let mut headers = HashMap::new(); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + headers + } + + fn validate(&self) -> Result<(), LlmError> { + if self.api_key.is_empty() { + return Err(LlmError::configuration("Google AI API key is required")); + } + + Ok(()) + } +} diff --git a/crates/rullm-core/src/providers/google/mod.rs b/crates/rullm-core/src/providers/google/mod.rs index 48967a30..13fcbb20 100644 --- a/crates/rullm-core/src/providers/google/mod.rs +++ b/crates/rullm-core/src/providers/google/mod.rs @@ -22,7 +22,9 @@ //! ``` pub mod client; +pub mod config; pub mod types; pub use client::GoogleClient; +pub use config::GoogleAiConfig; pub use types::*; diff --git a/crates/rullm-core/src/providers/mod.rs b/crates/rullm-core/src/providers/mod.rs index f728d3bb..71bec32b 100644 --- a/crates/rullm-core/src/providers/mod.rs +++ b/crates/rullm-core/src/providers/mod.rs @@ -9,3 +9,8 @@ pub use anthropic::AnthropicClient; pub use google::GoogleClient; pub use openai::OpenAIClient; pub use openai_compatible::{OpenAICompatibleProvider, ProviderIdentity, identities}; + +// Export provider-specific configs +pub use anthropic::AnthropicConfig; +pub use google::GoogleAiConfig; +pub use openai_compatible::{OpenAICompatibleConfig, OpenAIConfig}; diff --git a/crates/rullm-core/src/providers/openai/client.rs b/crates/rullm-core/src/providers/openai/client.rs index ccaa47ec..2ba78faa 100644 --- a/crates/rullm-core/src/providers/openai/client.rs +++ b/crates/rullm-core/src/providers/openai/client.rs @@ -1,6 +1,7 @@ use super::types::*; -use crate::config::{OpenAIConfig, ProviderConfig}; +use crate::config::ProviderConfig; use crate::error::LlmError; +use crate::providers::openai_compatible::OpenAIConfig; use crate::utils::sse::sse_lines; use futures::Stream; use futures::StreamExt; diff --git a/crates/rullm-core/src/providers/openai_compatible/config.rs b/crates/rullm-core/src/providers/openai_compatible/config.rs new file mode 100644 index 00000000..9eeac7cb --- /dev/null +++ b/crates/rullm-core/src/providers/openai_compatible/config.rs @@ -0,0 +1,111 @@ +use crate::config::ProviderConfig; +use crate::error::LlmError; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; + +/// OpenAI-compatible configuration (supports OpenAI, Groq, OpenRouter, etc.) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAICompatibleConfig { + pub api_key: String, + pub organization: Option, + pub project: Option, + pub base_url: Option, + pub timeout_seconds: u64, +} + +/// Type alias for backwards compatibility +pub type OpenAIConfig = OpenAICompatibleConfig; + +impl OpenAICompatibleConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + organization: None, + project: None, + base_url: None, + timeout_seconds: 30, + } + } + + pub fn groq(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + organization: None, + project: None, + base_url: Some("https://api.groq.com/openai/v1".to_string()), + timeout_seconds: 30, + } + } + + pub fn openrouter(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + organization: None, + project: None, + base_url: Some("https://openrouter.ai/api/v1".to_string()), + timeout_seconds: 30, + } + } + + pub fn with_organization(mut self, org: impl Into) -> Self { + self.organization = Some(org.into()); + self + } + + pub fn with_project(mut self, project: impl Into) -> Self { + self.project = Some(project.into()); + self + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } +} + +impl ProviderConfig for OpenAICompatibleConfig { + fn api_key(&self) -> &str { + &self.api_key + } + + fn base_url(&self) -> &str { + self.base_url + .as_deref() + .unwrap_or("https://api.openai.com/v1") + } + + fn timeout(&self) -> Duration { + Duration::from_secs(self.timeout_seconds) + } + + fn headers(&self) -> HashMap { + let mut headers = HashMap::new(); + headers.insert( + "Authorization".to_string(), + format!("Bearer {}", self.api_key), + ); + headers.insert("Content-Type".to_string(), "application/json".to_string()); + + if let Some(org) = &self.organization { + headers.insert("OpenAI-Organization".to_string(), org.clone()); + } + + if let Some(project) = &self.project { + headers.insert("OpenAI-Project".to_string(), project.clone()); + } + + headers + } + + fn validate(&self) -> Result<(), LlmError> { + if self.api_key.is_empty() { + return Err(LlmError::configuration("API key is required")); + } + + // Relaxed validation: don't require 'sk-' prefix since Groq and OpenRouter use different formats + // OpenAI keys start with 'sk-', Groq uses 'gsk_', OpenRouter uses different format + + Ok(()) + } +} diff --git a/crates/rullm-core/src/providers/openai_compatible.rs b/crates/rullm-core/src/providers/openai_compatible/mod.rs similarity index 97% rename from crates/rullm-core/src/providers/openai_compatible.rs rename to crates/rullm-core/src/providers/openai_compatible/mod.rs index e0c9a105..0fe42db7 100644 --- a/crates/rullm-core/src/providers/openai_compatible.rs +++ b/crates/rullm-core/src/providers/openai_compatible/mod.rs @@ -1,3 +1,7 @@ +pub mod config; + +pub use config::{OpenAICompatibleConfig, OpenAIConfig}; + use crate::compat_types::{ ChatMessage, ChatRequest, ChatResponse, ChatRole, ChatStreamEvent, TokenUsage, }; @@ -46,7 +50,7 @@ pub mod identities { /// Generic OpenAI-compatible provider implementation #[derive(Clone)] pub struct OpenAICompatibleProvider { - config: crate::config::OpenAICompatibleConfig, + config: OpenAICompatibleConfig, client: Client, identity: ProviderIdentity, } @@ -54,7 +58,7 @@ pub struct OpenAICompatibleProvider { impl OpenAICompatibleProvider { /// Create a new OpenAI-compatible provider with custom identity pub fn new( - config: crate::config::OpenAICompatibleConfig, + config: OpenAICompatibleConfig, identity: ProviderIdentity, ) -> Result { config.validate()?; @@ -67,17 +71,17 @@ impl OpenAICompatibleProvider { } /// Create an OpenAI provider - pub fn openai(config: crate::config::OpenAICompatibleConfig) -> Result { + pub fn openai(config: OpenAICompatibleConfig) -> Result { Self::new(config, identities::OPENAI) } /// Create a Groq provider - pub fn groq(config: crate::config::OpenAICompatibleConfig) -> Result { + pub fn groq(config: OpenAICompatibleConfig) -> Result { Self::new(config, identities::GROQ) } /// Create an OpenRouter provider - pub fn openrouter(config: crate::config::OpenAICompatibleConfig) -> Result { + pub fn openrouter(config: OpenAICompatibleConfig) -> Result { Self::new(config, identities::OPENROUTER) } From 07aa957aff6781e902526e4086fa4fdafde4cfe8 Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 21 Dec 2025 00:15:20 +0530 Subject: [PATCH 07/11] feat(oauth): fix anthropic oauth by spoofing claude code - Prepend Claude Code system block with cache control to OAuth requests - Add required state parameter to token exchange endpoint - Include anthropic-version header for OAuth requests --- crates/rullm-cli/src/cli_client.rs | 35 ++ crates/rullm-cli/src/oauth/anthropic.rs | 13 +- .../src/providers/anthropic/config.rs | 2 +- .../src/providers/anthropic/types.rs | 29 ++ scripts/anthropic_oauth.py | 390 ++++++++++++++++++ scripts/test_token.rs | 58 +++ 6 files changed, 523 insertions(+), 4 deletions(-) create mode 100755 scripts/anthropic_oauth.py create mode 100644 scripts/test_token.rs diff --git a/crates/rullm-cli/src/cli_client.rs b/crates/rullm-cli/src/cli_client.rs index 4fe6b45f..e22a692d 100644 --- a/crates/rullm-cli/src/cli_client.rs +++ b/crates/rullm-cli/src/cli_client.rs @@ -13,6 +13,29 @@ use rullm_core::providers::openai_compatible::{ use rullm_core::providers::{AnthropicClient, GoogleClient, OpenAIClient}; use std::pin::Pin; +/// Claude Code identification text for OAuth requests +const CLAUDE_CODE_SPOOF_TEXT: &str = "You are Claude Code, Anthropic's official CLI for Claude."; + +/// Prepend Claude Code system block to an existing system prompt (for OAuth requests) +fn prepend_claude_code_system( + existing: Option, +) -> rullm_core::providers::anthropic::SystemPrompt { + use rullm_core::providers::anthropic::{SystemBlock, SystemPrompt}; + + let spoof_block = SystemBlock::text_with_cache(CLAUDE_CODE_SPOOF_TEXT); + + match existing { + None => SystemPrompt::Blocks(vec![spoof_block]), + Some(SystemPrompt::Text(text)) => { + SystemPrompt::Blocks(vec![spoof_block, SystemBlock::text(text)]) + } + Some(SystemPrompt::Blocks(mut blocks)) => { + blocks.insert(0, spoof_block); + SystemPrompt::Blocks(blocks) + } + } +} + /// Simple configuration for CLI adapter #[derive(Debug, Clone, Default)] pub struct CliConfig { @@ -31,6 +54,7 @@ pub enum CliClient { client: AnthropicClient, model: String, config: CliConfig, + is_oauth: bool, }, Google { client: GoogleClient, @@ -78,6 +102,7 @@ impl CliClient { client, model: model.into(), config, + is_oauth: use_oauth, }) } @@ -163,6 +188,7 @@ impl CliClient { client, model, config, + is_oauth, } => { use rullm_core::providers::anthropic::{Message, MessagesRequest}; @@ -174,6 +200,10 @@ impl CliClient { request.temperature = Some(temp); } + if *is_oauth { + request.system = Some(prepend_claude_code_system(request.system.take())); + } + let response = client.messages(request).await?; let content = response .content @@ -319,6 +349,7 @@ impl CliClient { client, model, config, + is_oauth, } => { use rullm_core::providers::anthropic::{Message, MessagesRequest}; @@ -339,6 +370,10 @@ impl CliClient { request.temperature = Some(temp); } + if *is_oauth { + request.system = Some(prepend_claude_code_system(request.system.take())); + } + let stream = client.messages_stream(request).await?; Ok(Box::pin(stream.filter_map(|event_result| async move { match event_result { diff --git a/crates/rullm-cli/src/oauth/anthropic.rs b/crates/rullm-cli/src/oauth/anthropic.rs index 19b8513b..3f03d75f 100644 --- a/crates/rullm-cli/src/oauth/anthropic.rs +++ b/crates/rullm-cli/src/oauth/anthropic.rs @@ -64,6 +64,7 @@ struct TokenRequest<'a> { code: &'a str, redirect_uri: &'a str, code_verifier: &'a str, + state: &'a str, } impl AnthropicOAuth { @@ -127,9 +128,14 @@ impl AnthropicOAuth { } } - // Exchange code for tokens + // Exchange code for tokens (state is required by Anthropic's token endpoint) let credential = self - .exchange_code(&callback.code, &redirect_uri, &pkce.verifier) + .exchange_code( + &callback.code, + &redirect_uri, + &pkce.verifier, + &pkce.verifier, + ) .await?; println!("Authentication successful!"); @@ -142,6 +148,7 @@ impl AnthropicOAuth { code: &str, redirect_uri: &str, code_verifier: &str, + state: &str, ) -> Result { let request_body = TokenRequest { grant_type: "authorization_code", @@ -149,12 +156,12 @@ impl AnthropicOAuth { code, redirect_uri, code_verifier, + state, }; let client = reqwest::Client::new(); let response = client .post(self.token_url) - // Anthropic expects JSON payloads for the token exchange. .json(&request_body) .send() .await diff --git a/crates/rullm-core/src/providers/anthropic/config.rs b/crates/rullm-core/src/providers/anthropic/config.rs index 40fe80d8..baf54655 100644 --- a/crates/rullm-core/src/providers/anthropic/config.rs +++ b/crates/rullm-core/src/providers/anthropic/config.rs @@ -56,7 +56,6 @@ impl ProviderConfig for AnthropicConfig { if self.use_oauth { // OAuth: use Bearer token + required beta headers - // Note: OpenCode doesn't send anthropic-version for OAuth requests headers.insert( "Authorization".to_string(), format!("Bearer {}", self.api_key), @@ -65,6 +64,7 @@ impl ProviderConfig for AnthropicConfig { "anthropic-beta".to_string(), "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14".to_string(), ); + headers.insert("anthropic-version".to_string(), "2023-06-01".to_string()); } else { // API key: use x-api-key header headers.insert("x-api-key".to_string(), self.api_key.clone()); diff --git a/crates/rullm-core/src/providers/anthropic/types.rs b/crates/rullm-core/src/providers/anthropic/types.rs index e152315a..f0a9c54f 100644 --- a/crates/rullm-core/src/providers/anthropic/types.rs +++ b/crates/rullm-core/src/providers/anthropic/types.rs @@ -143,6 +143,35 @@ pub struct CacheControl { pub cache_type: String, // "ephemeral" } +impl CacheControl { + /// Create an ephemeral cache control + pub fn ephemeral() -> Self { + Self { + cache_type: "ephemeral".to_string(), + } + } +} + +impl SystemBlock { + /// Create a text system block + pub fn text(text: impl Into) -> Self { + Self { + block_type: "text".to_string(), + text: text.into(), + cache_control: None, + } + } + + /// Create a text system block with ephemeral cache control + pub fn text_with_cache(text: impl Into) -> Self { + Self { + block_type: "text".to_string(), + text: text.into(), + cache_control: Some(CacheControl::ephemeral()), + } + } +} + /// Request metadata #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Metadata { diff --git a/scripts/anthropic_oauth.py b/scripts/anthropic_oauth.py new file mode 100755 index 00000000..71685e79 --- /dev/null +++ b/scripts/anthropic_oauth.py @@ -0,0 +1,390 @@ +#!/usr/bin/env python3 +""" +Simple Anthropic OAuth flow test script. +Implements PKCE OAuth to get access token and test API calls. +Caches auth code and tokens to avoid repeated logins. +""" + +import argparse +import base64 +import hashlib +import http.server +import json +import os +import secrets +import socketserver +import time +import urllib.parse +import urllib.request +import webbrowser +from pathlib import Path +from typing import Optional + +# OAuth Configuration +CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" +SCOPES = "org:create_api_key user:profile user:inference" +REDIRECT_URI = "http://localhost:8765/callback" +AUTH_URL = "https://claude.ai/oauth/authorize" +TOKEN_URL = "https://console.anthropic.com/v1/oauth/token" +API_URL = "https://api.anthropic.com/v1/messages" + +# Callback server port +PORT = 8765 + +# Cache files +CACHE_DIR = Path.home() / ".cache" / "anthropic-oauth-test" +TOKEN_CACHE_FILE = CACHE_DIR / "tokens.json" +AUTH_CODE_CACHE_FILE = CACHE_DIR / "auth_code.json" + +# Token expiry buffer (refresh 5 minutes before expiry) +EXPIRY_BUFFER_SECONDS = 5 * 60 + + +def generate_pkce() -> tuple[str, str]: + """Generate PKCE code_verifier and code_challenge (S256).""" + verifier_bytes = secrets.token_bytes(64) + code_verifier = base64.urlsafe_b64encode(verifier_bytes).rstrip(b"=").decode("ascii") + challenge_hash = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(challenge_hash).rstrip(b"=").decode("ascii") + return code_verifier, code_challenge + + +def load_cached_auth_code() -> Optional[dict]: + """Load cached auth code and verifier.""" + if not AUTH_CODE_CACHE_FILE.exists(): + return None + try: + with open(AUTH_CODE_CACHE_FILE, "r") as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + return None + + +def save_auth_code(auth_code: str, code_verifier: str) -> None: + """Save auth code and verifier for retry.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + data = { + "auth_code": auth_code, + "code_verifier": code_verifier, + "timestamp": time.time(), + } + with open(AUTH_CODE_CACHE_FILE, "w") as f: + json.dump(data, f, indent=2) + os.chmod(AUTH_CODE_CACHE_FILE, 0o600) + print(f" Auth code cached to {AUTH_CODE_CACHE_FILE}") + + +def load_cached_tokens() -> Optional[dict]: + """Load tokens from cache file if they exist and are valid.""" + if not TOKEN_CACHE_FILE.exists(): + return None + try: + with open(TOKEN_CACHE_FILE, "r") as f: + cached = json.load(f) + expires_at = cached.get("expires_at", 0) + if time.time() >= expires_at - EXPIRY_BUFFER_SECONDS: + print(" Cached token expired or expiring soon") + return None + return cached + except (json.JSONDecodeError, IOError) as e: + print(f" Failed to load cache: {e}") + return None + + +def save_tokens(token_response: dict) -> None: + """Save tokens to cache file.""" + CACHE_DIR.mkdir(parents=True, exist_ok=True) + expires_in = token_response.get("expires_in", 3600) + token_response["expires_at"] = time.time() + expires_in + with open(TOKEN_CACHE_FILE, "w") as f: + json.dump(token_response, f, indent=2) + os.chmod(TOKEN_CACHE_FILE, 0o600) + print(f" Tokens cached to {TOKEN_CACHE_FILE}") + + +def refresh_token(refresh_token: str) -> dict: + """Refresh the access token using the refresh token.""" + import requests + + data = { + "grant_type": "refresh_token", + "client_id": CLIENT_ID, + "refresh_token": refresh_token, + } + response = requests.post(TOKEN_URL, json=data) + response.raise_for_status() + return response.json() + + +class CallbackHandler(http.server.BaseHTTPRequestHandler): + """Handle OAuth callback.""" + auth_code: Optional[str] = None + state: Optional[str] = None + error: Optional[str] = None + + @classmethod + def reset(cls): + cls.auth_code = None + cls.state = None + cls.error = None + + def do_GET(self): + parsed = urllib.parse.urlparse(self.path) + if parsed.path != "/callback": + self.send_response(404) + self.end_headers() + return + + params = urllib.parse.parse_qs(parsed.query) + if "error" in params: + CallbackHandler.error = params.get("error", ["unknown"])[0] + error_desc = params.get("error_description", ["No description"])[0] + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(f"

Error: {CallbackHandler.error}

{error_desc}

".encode()) + return + + CallbackHandler.auth_code = params.get("code", [None])[0] + CallbackHandler.state = params.get("state", [None])[0] + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write(b"

Authorization successful!

You can close this window.

") + + def log_message(self, format, *args): + pass + + +def wait_for_callback() -> tuple[Optional[str], Optional[str]]: + """Start server and wait for callback.""" + socketserver.TCPServer.allow_reuse_address = True + with socketserver.TCPServer(("", PORT), CallbackHandler) as httpd: + httpd.timeout = 300 + httpd.handle_request() + return CallbackHandler.auth_code, CallbackHandler.state + + +def exchange_code_for_token(code: str, code_verifier: str, state: str) -> dict: + """Exchange authorization code for tokens.""" + import requests + + data = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "redirect_uri": REDIRECT_URI, + "code_verifier": code_verifier, + "state": state, # Required! The state from the callback + } + + response = requests.post(TOKEN_URL, json=data) + if not response.ok: + print(f" [ERROR] {response.status_code}: {response.text}") + response.raise_for_status() + + return response.json() + + +def test_api_call(access_token: str) -> dict: + """Test API call with the access token.""" + data = { + "model": "claude-opus-4-5-20251101", + "max_tokens": 10000, + "system": [ + { + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + "cache_control": { + "type": "ephemeral" + } + }, + ], + "messages": [ + {"role": "assistant", "content": "Say hello in exactly 5 words."}, + ], + } + headers = { + "Authorization": f"Bearer {access_token}", + "anthropic-version": "2023-06-01", + "anthropic-beta": "oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14", + "Content-Type": "application/json", + "user-agent": "ai-sdk/anthropic/2.0.50 ai-sdk/provider-utils/3.0.18 runtime/bun/1.3.5" + } + req = urllib.request.Request( + API_URL, + data=json.dumps(data).encode("utf-8"), + headers=headers, + method="POST", + ) + with urllib.request.urlopen(req) as response: + return json.loads(response.read().decode("utf-8")) + + +def do_oauth_flow() -> tuple[Optional[str], Optional[str]]: + """Perform the OAuth flow to get auth code. Returns (auth_code, code_verifier).""" + CallbackHandler.reset() + + print("\n[1] Generating PKCE credentials...") + code_verifier, code_challenge = generate_pkce() + print(f" Code verifier: {code_verifier}") + print(f" Code challenge: {code_challenge}") + + print("\n[2] Building authorization URL...") + auth_params = { + "code": "true", + "client_id": CLIENT_ID, + "response_type": "code", + "redirect_uri": REDIRECT_URI, + "scope": SCOPES, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "state": code_verifier, + } + auth_url = f"{AUTH_URL}?{urllib.parse.urlencode(auth_params)}" + print(f" URL: {auth_url}") + + print("\n[3] Opening browser for authorization...") + print(f" Waiting for callback on port {PORT}...") + + webbrowser.open(auth_url) + auth_code, state = wait_for_callback() + + if CallbackHandler.error: + print(f"\n[ERROR] Authorization failed: {CallbackHandler.error}") + return None, None + + if not auth_code: + print("\n[ERROR] No authorization code received") + return None, None + + print(f" Received authorization code: {auth_code}") + print(f" Received state: {state}") + + # Cache the auth code for retrying token exchange + save_auth_code(auth_code, code_verifier) + + return auth_code, code_verifier + + +def main(): + parser = argparse.ArgumentParser(description="Anthropic OAuth test") + parser.add_argument("--retry-token", action="store_true", + help="Retry token exchange with cached auth code") + parser.add_argument("--login", action="store_true", + help="Force new login even if cached") + parser.add_argument("--clear", action="store_true", + help="Clear all cached data") + parser.add_argument("--auth-only", action="store_true", + help="Only get auth code, don't exchange for token") + parser.add_argument("--test-api", action="store_true", + help="Test API call after getting token") + args = parser.parse_args() + + print("=" * 60) + print("Anthropic OAuth Flow Test") + print("=" * 60) + + if args.clear: + print("\n[*] Clearing cache...") + if TOKEN_CACHE_FILE.exists(): + TOKEN_CACHE_FILE.unlink() + print(f" Deleted {TOKEN_CACHE_FILE}") + if AUTH_CODE_CACHE_FILE.exists(): + AUTH_CODE_CACHE_FILE.unlink() + print(f" Deleted {AUTH_CODE_CACHE_FILE}") + print(" Cache cleared.") + return + + # Check for cached tokens first (unless forcing login or retry) + if not args.login and not args.retry_token: + print("\n[*] Checking for cached tokens...") + cached = load_cached_tokens() + if cached: + expires_at = cached.get("expires_at", 0) + remaining = int(expires_at - time.time()) + access_token = cached.get("access_token", "") + refresh_tok = cached.get("refresh_token", "") + print(f" Found valid cached token!") + print(f" Access token: {access_token[:50]}...") + print(f" Refresh token: {refresh_tok[:50]}...") + print(f" Expires in: {remaining} seconds ({remaining // 3600}h {(remaining % 3600) // 60}m)") + + # Test API call with cached token if requested + if args.test_api: + print("\n[*] Testing API call with cached token...") + try: + api_response = test_api_call(access_token) + print(" API call successful!") + print(f" Model: {api_response.get('model', 'N/A')}") + content = api_response.get("content", []) + if content: + print(f" Response: {content[0].get('text', 'N/A')}") + except urllib.error.HTTPError as e: + print(f"\n[ERROR] API call failed: {e.code}") + print(f" Response: {e.read().decode('utf-8')}") + else: + print(f"\n Use --login to force re-authentication") + return + + # Try to get auth code (from cache if --retry-token, otherwise new login) + if args.retry_token: + print("\n[*] Loading cached auth code...") + cached_auth = load_cached_auth_code() + if not cached_auth: + print("[ERROR] No cached auth code found. Run without --retry-token first.") + return + auth_code = cached_auth["auth_code"] + code_verifier = cached_auth["code_verifier"] + print(f" Auth code: {auth_code}") + print(f" Code verifier: {code_verifier}") + else: + auth_code, code_verifier = do_oauth_flow() + if not auth_code: + print("\n[ERROR] Failed to get authorization code") + return + + # If auth-only, stop here + if args.auth_only: + print("\n[*] Auth code obtained. Use --retry-token to exchange for tokens.") + print(f" Cached at: {AUTH_CODE_CACHE_FILE}") + return + + # Exchange code for tokens + # Note: state is the same as code_verifier (we sent state=code_verifier in auth URL) + print("\n[4] Exchanging code for tokens...") + try: + token_response = exchange_code_for_token(auth_code, code_verifier, state=code_verifier) + print(" Token exchange successful!") + print(f" Access token: {token_response.get('access_token', 'N/A')[:50]}...") + print(f" Refresh token: {token_response.get('refresh_token', 'N/A')[:50]}...") + print(f" Expires in: {token_response.get('expires_in', 'N/A')} seconds") + save_tokens(token_response) + except Exception as e: + print(f"\n[ERROR] Token exchange failed: {e}") + print("\n Hint: Auth codes are single-use. If this failed, run --login to get a new one.") + return + + # Test API call (optional) + if args.test_api: + print("\n[5] Testing API call...") + access_token = token_response.get("access_token") + try: + api_response = test_api_call(access_token) + print(" API call successful!") + print(f" Model: {api_response.get('model', 'N/A')}") + content = api_response.get("content", []) + if content: + print(f" Response: {content[0].get('text', 'N/A')}") + except urllib.error.HTTPError as e: + print(f"\n[ERROR] API call failed: {e.code}") + print(f" Response: {e.read().decode('utf-8')}") + return + + print("\n" + "=" * 60) + print("Token obtained successfully!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/scripts/test_token.rs b/scripts/test_token.rs new file mode 100644 index 00000000..0316e458 --- /dev/null +++ b/scripts/test_token.rs @@ -0,0 +1,58 @@ +//! Quick test of token exchange + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +struct TokenRequest<'a> { + grant_type: &'static str, + client_id: &'a str, + code: &'a str, + redirect_uri: &'a str, + code_verifier: &'a str, +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: String, + expires_in: u64, + token_type: String, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Read from auth_code.json + let cache_path = dirs::home_dir() + .unwrap() + .join(".cache/anthropic-oauth-test/auth_code.json"); + + let cached: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(&cache_path)?)?; + + let auth_code = cached["auth_code"].as_str().unwrap(); + let code_verifier = cached["code_verifier"].as_str().unwrap(); + + println!("Auth code: {}", auth_code); + println!("Code verifier: {}", code_verifier); + + let request_body = TokenRequest { + grant_type: "authorization_code", + client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", + code: auth_code, + redirect_uri: "http://localhost:8765/callback", + code_verifier, + }; + + println!("Request JSON: {}", serde_json::to_string_pretty(&request_body)?); + + let client = reqwest::Client::new(); + let response = client + .post("https://console.anthropic.com/v1/oauth/token") + .json(&request_body) + .send() + .await?; + + println!("Response status: {}", response.status()); + println!("Response body: {}", response.text().await?); + + Ok(()) +} From 426124ddb0e84b515d20a6ade453971222ef6055 Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 21 Dec 2025 02:31:31 +0530 Subject: [PATCH 08/11] fix lints --- Cargo.lock | 1 - crates/rullm-cli/Cargo.toml | 1 - crates/rullm-cli/src/auth.rs | 25 +- crates/rullm-cli/src/commands/auth.rs | 57 +++-- crates/rullm-cli/src/oauth/mod.rs | 2 - crates/rullm-cli/src/oauth/openai.rs | 334 -------------------------- crates/rullm-cli/src/oauth/server.rs | 10 - 7 files changed, 56 insertions(+), 374 deletions(-) delete mode 100644 crates/rullm-cli/src/oauth/openai.rs diff --git a/Cargo.lock b/Cargo.lock index 185c6b1a..ff8613cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1743,7 +1743,6 @@ dependencies = [ "rullm-core", "serde", "serde_json", - "serde_urlencoded", "sha2", "strum 0.27.2", "strum_macros 0.27.2", diff --git a/crates/rullm-cli/Cargo.toml b/crates/rullm-cli/Cargo.toml index 90bcc976..96fdf084 100644 --- a/crates/rullm-cli/Cargo.toml +++ b/crates/rullm-cli/Cargo.toml @@ -38,7 +38,6 @@ rand = "0.9" base64 = "0.22" webbrowser = "1.0" reqwest = { version = "0.12", features = ["json"] } -serde_urlencoded = "0.7" urlencoding = "2.1" hex = "0.4" diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs index ba552a7f..27b159d8 100644 --- a/crates/rullm-cli/src/auth.rs +++ b/crates/rullm-cli/src/auth.rs @@ -1,7 +1,7 @@ //! Authentication credential management for rullm. //! //! Supports multiple authentication methods per provider: -//! - OAuth (for Claude Max/Pro, ChatGPT Plus/Pro subscriptions) +//! - OAuth (for Claude Max/Pro subscriptions) //! - API keys (traditional method) use crate::provider::Provider; @@ -334,17 +334,13 @@ pub async fn get_token_with_type( /// Refresh an OAuth token for a specific provider. async fn refresh_oauth_token(provider: &Provider, refresh_token: &str) -> Result { - use crate::oauth::{anthropic::AnthropicOAuth, openai::OpenAIOAuth}; + use crate::oauth::anthropic::AnthropicOAuth; match provider { Provider::Anthropic => { let oauth = AnthropicOAuth::new(); oauth.refresh_token(refresh_token).await } - Provider::OpenAI => { - let oauth = OpenAIOAuth::new(); - oauth.refresh_token(refresh_token).await - } _ => Err(anyhow::anyhow!( "Provider {} does not support OAuth token refresh", provider @@ -441,4 +437,21 @@ mod tests { assert_eq!(info.source, CredentialSource::File); assert_eq!(info.credential.get_token(), "file-key"); } + + #[tokio::test] + async fn test_get_token_with_type_api_is_not_oauth() { + let temp_dir = TempDir::new().unwrap(); + let mut config = AuthConfig { + openai: Some(Credential::api("sk-test".to_string())), + ..Default::default() + }; + + let (token, is_oauth) = + get_token_with_type(&Provider::OpenAI, &mut config, temp_dir.path()) + .await + .unwrap(); + + assert_eq!(token, "sk-test"); + assert!(!is_oauth); + } } diff --git a/crates/rullm-cli/src/commands/auth.rs b/crates/rullm-cli/src/commands/auth.rs index 9fb0d630..377385e1 100644 --- a/crates/rullm-cli/src/commands/auth.rs +++ b/crates/rullm-cli/src/commands/auth.rs @@ -6,7 +6,7 @@ use etcetera::BaseStrategy; use strum::IntoEnumIterator; use crate::auth::{self, AuthConfig, Credential}; -use crate::oauth::{anthropic::AnthropicOAuth, openai::OpenAIOAuth}; +use crate::oauth::anthropic::AnthropicOAuth; use crate::output::OutputLevel; use crate::provider::Provider; @@ -64,8 +64,9 @@ impl AuthArgs { oauth.login().await? } Provider::OpenAI => { - let oauth = OpenAIOAuth::new(); - oauth.login().await? + anyhow::bail!( + "OpenAI OAuth login is not implemented yet. Use API key instead." + ); } _ => { anyhow::bail!( @@ -160,28 +161,44 @@ fn select_provider() -> Result { fn select_auth_method(provider: &Provider) -> Result { use std::io::{self, Write}; - // Check if OAuth is available for this provider - let oauth_available = matches!(provider, Provider::Anthropic | Provider::OpenAI); + match provider { + Provider::Anthropic => { + println!("\n? Select authentication method"); + println!(" 1) OAuth (subscription-based access)"); + println!(" 2) API Key"); - if !oauth_available { - // Only API key available - return Ok(AuthMethod::ApiKey); - } + print!("\nEnter number (1-2): "); + io::stdout().flush()?; - println!("\n? Select authentication method"); - println!(" 1) OAuth (subscription-based access)"); - println!(" 2) API Key"); + let mut input = String::new(); + io::stdin().read_line(&mut input)?; - print!("\nEnter number (1-2): "); - io::stdout().flush()?; + match input.trim() { + "1" => Ok(AuthMethod::OAuth), + "2" => Ok(AuthMethod::ApiKey), + _ => anyhow::bail!("Invalid selection"), + } + } + Provider::OpenAI => { + println!("\n? Select authentication method"); + println!(" 1) OAuth (not implemented)"); + println!(" 2) API Key"); - let mut input = String::new(); - io::stdin().read_line(&mut input)?; + print!("\nEnter number (1-2): "); + io::stdout().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; - match input.trim() { - "1" => Ok(AuthMethod::OAuth), - "2" => Ok(AuthMethod::ApiKey), - _ => anyhow::bail!("Invalid selection"), + match input.trim() { + "1" => { + anyhow::bail!("OpenAI OAuth login is not implemented yet. Use API key instead.") + } + "2" => Ok(AuthMethod::ApiKey), + _ => anyhow::bail!("Invalid selection"), + } + } + _ => Ok(AuthMethod::ApiKey), } } diff --git a/crates/rullm-cli/src/oauth/mod.rs b/crates/rullm-cli/src/oauth/mod.rs index 2c1dc24a..e4115303 100644 --- a/crates/rullm-cli/src/oauth/mod.rs +++ b/crates/rullm-cli/src/oauth/mod.rs @@ -6,7 +6,5 @@ mod pkce; mod server; pub mod anthropic; -pub mod openai; pub use pkce::PkceChallenge; -pub use server::CallbackServer; diff --git a/crates/rullm-cli/src/oauth/openai.rs b/crates/rullm-cli/src/oauth/openai.rs deleted file mode 100644 index b9fca518..00000000 --- a/crates/rullm-cli/src/oauth/openai.rs +++ /dev/null @@ -1,334 +0,0 @@ -//! OpenAI OAuth flow implementation. -//! -//! Supports ChatGPT Plus/Pro subscription authentication via OAuth discovery. - -use anyhow::{Context, Result}; -use serde::{Deserialize, Serialize}; -use std::time::Duration; - -use super::{CallbackServer, PkceChallenge}; -use crate::auth::Credential; - -/// OpenAI OAuth configuration. -pub struct OpenAIOAuth { - /// Issuer URL for discovery - pub issuer_url: &'static str, - /// Public client ID for Codex CLI (OpenAI consumer OAuth) - pub client_id: &'static str, - /// Authorization scopes requested - pub scopes: &'static [&'static str], - /// Callback port - pub callback_port: u16, -} - -impl Default for OpenAIOAuth { - fn default() -> Self { - Self { - issuer_url: "https://auth.openai.com", - client_id: "app_EMoamEEZ73f0CkXaXp7hrann", - scopes: &["openid", "profile", "email", "offline_access"], - callback_port: 1455, - } - } -} - -/// OpenID Connect discovery document (subset used). -#[derive(Debug, Deserialize)] -#[allow(dead_code)] -struct OpenIdConfiguration { - authorization_endpoint: String, - token_endpoint: String, - issuer: String, -} - -/// Token response from OpenAI OAuth. -#[derive(Debug, Deserialize)] -struct TokenResponse { - access_token: String, - refresh_token: Option, - expires_in: u64, - #[allow(dead_code)] - token_type: String, -} - -/// Token exchange request body. -#[derive(Debug, Serialize)] -struct TokenRequest<'a> { - grant_type: &'static str, - client_id: &'a str, - code: &'a str, - redirect_uri: &'a str, - code_verifier: &'a str, -} - -/// Token refresh request body. -#[derive(Debug, Serialize)] -struct RefreshRequest<'a> { - grant_type: &'static str, - client_id: &'a str, - refresh_token: &'a str, -} - -impl OpenAIOAuth { - /// Create a new OpenAI OAuth handler with default configuration. - pub fn new() -> Self { - Self::default() - } - - /// Discover OAuth endpoints from the authorization server. - async fn discover(&self) -> Result { - let discovery_url = format!("{}/.well-known/openid-configuration", self.issuer_url); - - let client = reqwest::Client::new(); - let response = client - .get(&discovery_url) - .header("Accept", "application/json") - .send() - .await - .with_context(|| format!("Failed to fetch OAuth discovery from {}", discovery_url))?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("OAuth discovery failed: {} - {}", status, body); - } - - response - .json() - .await - .context("Failed to parse OAuth discovery response") - } - - /// Build the authorization URL for the OAuth flow. - fn build_authorization_url( - &self, - authorization_endpoint: &str, - pkce: &PkceChallenge, - state: &str, - redirect_uri: &str, - ) -> String { - let scope = self.scopes.join(" "); - format!( - "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method={}&state={}&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=codex_cli_rs", - authorization_endpoint, - urlencoding::encode(self.client_id), - urlencoding::encode(redirect_uri), - urlencoding::encode(&scope), - urlencoding::encode(&pkce.challenge), - pkce.method(), - urlencoding::encode(state) - ) - } - - /// Start the OAuth flow and return the credential on success. - /// - /// This will: - /// 1. Discover OAuth endpoints - /// 2. Start a local callback server - /// 3. Open the browser to the authorization URL - /// 4. Wait for the callback with the authorization code - /// 5. Exchange the code for tokens - pub async fn login(&self) -> Result { - // Discover OAuth endpoints - println!("Discovering OpenAI OAuth endpoints..."); - let _metadata = self.discover().await?; - - // Generate PKCE challenge - let pkce = PkceChallenge::generate(); - - // Generate state for CSRF protection - let state = generate_state(); - - // Start callback server - let server = - CallbackServer::new(self.callback_port).context("Failed to start callback server")?; - - // OpenAI's public client is registered with /auth/callback for localhost. - let redirect_uri = server.redirect_uri_with_path("/auth/callback"); - - // Build and open authorization URL - // The public web consumer OAuth lives at /oauth/authorize (different from OIDC discovery value). - let authorization_endpoint = format!("{}/oauth/authorize", self.issuer_url); - // The working token endpoint for consumer OAuth lives on auth.openai.com, not the discovery host. - let token_endpoint = format!("{}/oauth/token", self.issuer_url); - let auth_url = - self.build_authorization_url(&authorization_endpoint, &pkce, &state, &redirect_uri); - - println!("Opening browser for OpenAI authentication..."); - webbrowser::open(&auth_url).context("Failed to open browser")?; - - println!("Waiting for authentication (timeout: 5 minutes)..."); - - // Wait for callback - let callback = server - .wait_for_callback(Duration::from_secs(300)) - .context("Failed to receive OAuth callback")?; - - // Verify state - if callback.state.as_deref() != Some(&state) { - anyhow::bail!("State mismatch in OAuth callback (possible CSRF attack)"); - } - - // Exchange code for tokens - let credential = self - .exchange_code( - &token_endpoint, - &callback.code, - &pkce.verifier, - &redirect_uri, - ) - .await?; - - Ok(credential) - } - - /// Exchange authorization code for tokens. - async fn exchange_code( - &self, - token_endpoint: &str, - code: &str, - code_verifier: &str, - redirect_uri: &str, - ) -> Result { - let request_body = TokenRequest { - grant_type: "authorization_code", - client_id: self.client_id, - code, - redirect_uri, - code_verifier, - }; - - let client = reqwest::Client::new(); - let response = client - .post(token_endpoint) - .header("Content-Type", "application/x-www-form-urlencoded") - .body(serde_urlencoded::to_string(&request_body)?) - .send() - .await - .context("Failed to send token request")?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("Token exchange failed: {} - {}", status, body); - } - - let token_response: TokenResponse = response - .json() - .await - .context("Failed to parse token response")?; - - // Calculate expiration timestamp - let expires_at = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) - .unwrap_or(0); - - // OpenAI might not always return a refresh token - let refresh_token = token_response.refresh_token.unwrap_or_default(); - - Ok(Credential::oauth( - token_response.access_token, - refresh_token, - expires_at, - )) - } - - /// Refresh an expired OAuth token. - pub async fn refresh_token(&self, refresh_token: &str) -> Result { - // Use the known consumer token endpoint (discovery token endpoint points to auth0.openai.com and fails). - let token_endpoint = format!("{}/oauth/token", self.issuer_url); - - let request_body = RefreshRequest { - grant_type: "refresh_token", - client_id: self.client_id, - refresh_token, - }; - - let client = reqwest::Client::new(); - let response = client - .post(&token_endpoint) - .header("Content-Type", "application/x-www-form-urlencoded") - .body(serde_urlencoded::to_string(&request_body)?) - .send() - .await - .context("Failed to send refresh request")?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("Token refresh failed: {} - {}", status, body); - } - - let token_response: TokenResponse = response - .json() - .await - .context("Failed to parse refresh response")?; - - let expires_at = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64 + token_response.expires_in * 1000) - .unwrap_or(0); - - let new_refresh_token = token_response - .refresh_token - .unwrap_or_else(|| refresh_token.to_string()); - - Ok(Credential::oauth( - token_response.access_token, - new_refresh_token, - expires_at, - )) - } -} - -/// Generate a random state string for CSRF protection. -fn generate_state() -> String { - use rand::RngCore; - let mut bytes = [0u8; 16]; - rand::rng().fill_bytes(&mut bytes); - hex::encode(bytes) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_default_config() { - let oauth = OpenAIOAuth::new(); - assert_eq!(oauth.callback_port, 1455); - assert_eq!(oauth.issuer_url, "https://auth.openai.com"); - assert_eq!(oauth.client_id, "app_EMoamEEZ73f0CkXaXp7hrann"); - assert_eq!( - oauth.scopes, - ["openid", "profile", "email", "offline_access"] - ); - } - - #[test] - fn test_build_authorization_url() { - let oauth = OpenAIOAuth::new(); - let pkce = PkceChallenge::generate(); - let state = "test-state"; - let redirect_uri = "http://localhost:1455/auth/callback"; - - let url = oauth.build_authorization_url( - "https://auth.openai.com/oauth/authorize", - &pkce, - state, - redirect_uri, - ); - - assert!(url.starts_with("https://auth.openai.com/oauth/authorize")); - assert!(url.contains("response_type=code")); - assert!(url.contains("client_id=app_EMoamEEZ73f0CkXaXp7hrann")); - assert!(url.contains("scope=openid%20profile%20email%20offline_access")); - assert!(url.contains("redirect_uri=")); - assert!(url.contains("code_challenge=")); - assert!(url.contains("code_challenge_method=S256")); - assert!(url.contains("state=test-state")); - assert!(url.contains("codex_cli_simplified_flow=true")); - assert!(url.contains("originator=codex_cli_rs")); - } -} diff --git a/crates/rullm-cli/src/oauth/server.rs b/crates/rullm-cli/src/oauth/server.rs index 3a8ab07e..e93a6360 100644 --- a/crates/rullm-cli/src/oauth/server.rs +++ b/crates/rullm-cli/src/oauth/server.rs @@ -43,16 +43,6 @@ impl CallbackServer { format!("http://localhost:{}/callback", self.port) } - /// Build a redirect URI using a custom path (must start with '/'). - pub fn redirect_uri_with_path(&self, path: &str) -> String { - let normalized = if path.starts_with('/') { - path.to_string() - } else { - format!("/{}", path) - }; - format!("http://localhost:{}{}", self.port, normalized) - } - /// Wait for the OAuth callback and extract the authorization code. /// /// This blocks until a request is received or the timeout is reached. From b043ac4be6de533146c41758afc3fcb41af7e520 Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 21 Dec 2025 10:58:39 +0530 Subject: [PATCH 09/11] feat(models): add list_models for Anthropic and implement update_models --- crates/rullm-cli/src/cli_client.rs | 12 ++++ crates/rullm-cli/src/commands/models.rs | 38 +++++++++--- .../rullm-core/examples/test_all_providers.rs | 31 +++++----- .../src/providers/anthropic/client.rs | 61 +++++++++++++++++++ 4 files changed, 117 insertions(+), 25 deletions(-) diff --git a/crates/rullm-cli/src/cli_client.rs b/crates/rullm-cli/src/cli_client.rs index e22a692d..743105fb 100644 --- a/crates/rullm-cli/src/cli_client.rs +++ b/crates/rullm-cli/src/cli_client.rs @@ -486,6 +486,18 @@ impl CliClient { } } + /// Get available models for the provider + pub async fn available_models(&self) -> Result, LlmError> { + match self { + Self::OpenAI { client, .. } => client.list_models().await, + Self::Anthropic { client, .. } => client.list_models().await, + Self::Google { client, .. } => client.list_models().await, + Self::Groq { client, .. } | Self::OpenRouter { client, .. } => { + client.available_models().await + } + } + } + /// Get provider name pub fn provider_name(&self) -> &'static str { match self { diff --git a/crates/rullm-cli/src/commands/models.rs b/crates/rullm-cli/src/commands/models.rs index a226cff0..5954980b 100644 --- a/crates/rullm-cli/src/commands/models.rs +++ b/crates/rullm-cli/src/commands/models.rs @@ -217,7 +217,7 @@ pub fn clear_models_cache(cli_config: &CliConfig, output_level: OutputLevel) -> } pub async fn update_models( - _cli_config: &mut CliConfig, + cli_config: &mut CliConfig, client: &CliClient, output_level: OutputLevel, ) -> Result<(), LlmError> { @@ -229,15 +229,37 @@ pub async fn update_models( output_level, ); - // TODO: Implement models() method on CliClient - // For now, just return an error - crate::output::error( - "Fetching models not yet implemented for new client architecture", + let mut models = client.available_models().await.map_err(|e| { + crate::output::error(&format!("Failed to fetch models: {e}"), output_level); + e + })?; + + if models.is_empty() { + crate::output::error("No models returned by provider", output_level); + return Err(LlmError::model("No models returned by provider".to_string())); + } + + models.sort(); + models.dedup(); + + _cache_models(cli_config, client.provider_name(), &models).map_err(|e| { + crate::output::error( + &format!("Failed to update models cache: {e}"), + output_level, + ); + LlmError::unknown(e.to_string()) + })?; + + crate::output::success( + &format!( + "Updated {} models for {}", + models.len(), + client.provider_name() + ), output_level, ); - Err(LlmError::unknown( - "Models fetching not yet implemented".to_string(), - )) + + Ok(()) } fn _cache_models(cli_config: &CliConfig, provider_name: &str, models: &[String]) -> Result<()> { diff --git a/crates/rullm-core/examples/test_all_providers.rs b/crates/rullm-core/examples/test_all_providers.rs index e6fa52c0..67a948a3 100644 --- a/crates/rullm-core/examples/test_all_providers.rs +++ b/crates/rullm-core/examples/test_all_providers.rs @@ -37,12 +37,18 @@ async fn main() -> Result<(), Box> { // 2. Test Anthropic Provider println!("🔍 Testing Anthropic Provider..."); match test_anthropic_provider().await { - Ok(model_count) => { + Ok(models) => { + println!("✅ Anthropic: Found {} models", models.len()); println!( - "✅ Anthropic: API is working ({} models available)", - model_count + " Models (first 5): {}", + models + .iter() + .take(5) + .map(|s| s.as_str()) + .collect::>() + .join(", ") ); - results.push(("Anthropic", true, model_count)); + results.push(("Anthropic", true, models.len())); } Err(e) => { println!("❌ Anthropic: Failed - {e}"); @@ -133,7 +139,7 @@ async fn test_openai_provider() -> Result, Box Result> { +async fn test_anthropic_provider() -> Result, Box> { let api_key = env::var("ANTHROPIC_API_KEY") .map_err(|_| "ANTHROPIC_API_KEY environment variable not set")?; @@ -146,19 +152,10 @@ async fn test_anthropic_provider() -> Result> Err(e) => println!(" Health check: ⚠️ Warning - {e}"), } - // Anthropic doesn't have a list models endpoint, so we'll just return known models - let known_models = vec![ - "claude-3-5-sonnet-20241022", - "claude-3-opus-20240229", - "claude-3-sonnet-20240229", - "claude-3-haiku-20240307", - ]; - - for model in &known_models { - println!(" Known model: {model}"); - } + // Get available models + let models = client.list_models().await?; - Ok(known_models.len()) + Ok(models) } async fn test_google_provider() -> Result, Box> { diff --git a/crates/rullm-core/src/providers/anthropic/client.rs b/crates/rullm-core/src/providers/anthropic/client.rs index 4bda4938..4b21bad0 100644 --- a/crates/rullm-core/src/providers/anthropic/client.rs +++ b/crates/rullm-core/src/providers/anthropic/client.rs @@ -186,6 +186,67 @@ impl AnthropicClient { Ok(tokens) } + /// List available models + pub async fn list_models(&self) -> Result, LlmError> { + let url = format!("{}/v1/models", self.base_url); + + let mut req = self.client.get(&url); + for (key, value) in self.config.headers() { + req = req.header(key, value); + } + + let response = req.send().await?; + + if !response.status().is_success() { + return Err(LlmError::api( + "anthropic", + "Failed to fetch available models", + Some(response.status().to_string()), + None, + )); + } + + let json: serde_json::Value = response.json().await.map_err(|e| { + LlmError::serialization("Failed to parse models response", Box::new(e)) + })?; + + let models_array = json + .get("data") + .and_then(|d| d.as_array()) + .or_else(|| json.get("models").and_then(|m| m.as_array())) + .ok_or_else(|| { + LlmError::serialization( + "Invalid models response format", + Box::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Missing data array", + )), + ) + })?; + + let models: Vec = models_array + .iter() + .filter_map(|m| { + m.get("id") + .and_then(|id| id.as_str()) + .or_else(|| m.get("name").and_then(|name| name.as_str())) + .or_else(|| m.get("model").and_then(|model| model.as_str())) + .map(|s| s.to_string()) + }) + .collect(); + + if models.is_empty() { + return Err(LlmError::api( + "anthropic", + "No models found in response", + None, + None, + )); + } + + Ok(models) + } + /// Health check pub async fn health_check(&self) -> Result<(), LlmError> { // Anthropic doesn't have a dedicated health endpoint From 3ee53859148dac3d99bbb33a86692ee8d468321e Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 21 Dec 2025 11:35:41 +0530 Subject: [PATCH 10/11] chore: remove unused code and fix formatting Remove dead_code get_or_refresh_token function and delete test script. --- crates/rullm-cli/src/auth.rs | 50 ---------------- crates/rullm-cli/src/commands/models.rs | 9 ++- .../src/providers/anthropic/client.rs | 7 ++- scripts/test_token.rs | 58 ------------------- 4 files changed, 8 insertions(+), 116 deletions(-) delete mode 100644 scripts/test_token.rs diff --git a/crates/rullm-cli/src/auth.rs b/crates/rullm-cli/src/auth.rs index 27b159d8..fe78c7a1 100644 --- a/crates/rullm-cli/src/auth.rs +++ b/crates/rullm-cli/src/auth.rs @@ -233,56 +233,6 @@ pub fn get_credential(provider: &Provider, auth_config: &AuthConfig) -> Option Result { - // Get credential info - let info = get_credential(provider, auth_config) - .ok_or_else(|| anyhow::anyhow!("No credential found for {}", provider))?; - - // If from environment, just return the token (can't refresh env vars) - if matches!(info.source, CredentialSource::Environment(_)) { - return Ok(info.credential.get_token().to_string()); - } - - // Check if OAuth token is expired - if info.credential.is_expired() { - if let Some(refresh_tok) = info.credential.refresh_token() { - // Attempt to refresh - eprintln!("OAuth token expired, refreshing..."); - match refresh_oauth_token(provider, refresh_tok).await { - Ok(new_credential) => { - let token = new_credential.get_token().to_string(); - auth_config.set(provider, new_credential); - auth_config.save(config_base_path)?; - eprintln!("Token refreshed successfully."); - return Ok(token); - } - Err(e) => { - // Refresh failed - user needs to re-authenticate - return Err(anyhow::anyhow!( - "OAuth token expired and refresh failed: {}. Please run 'rullm auth login {}'", - e, - provider - )); - } - } - } - } - - Ok(info.credential.get_token().to_string()) -} - /// Get token and credential type for a provider. /// /// Returns (token, is_oauth) where is_oauth is true if the credential is OAuth. diff --git a/crates/rullm-cli/src/commands/models.rs b/crates/rullm-cli/src/commands/models.rs index 5954980b..f66439ab 100644 --- a/crates/rullm-cli/src/commands/models.rs +++ b/crates/rullm-cli/src/commands/models.rs @@ -236,17 +236,16 @@ pub async fn update_models( if models.is_empty() { crate::output::error("No models returned by provider", output_level); - return Err(LlmError::model("No models returned by provider".to_string())); + return Err(LlmError::model( + "No models returned by provider".to_string(), + )); } models.sort(); models.dedup(); _cache_models(cli_config, client.provider_name(), &models).map_err(|e| { - crate::output::error( - &format!("Failed to update models cache: {e}"), - output_level, - ); + crate::output::error(&format!("Failed to update models cache: {e}"), output_level); LlmError::unknown(e.to_string()) })?; diff --git a/crates/rullm-core/src/providers/anthropic/client.rs b/crates/rullm-core/src/providers/anthropic/client.rs index 4b21bad0..bd08bc21 100644 --- a/crates/rullm-core/src/providers/anthropic/client.rs +++ b/crates/rullm-core/src/providers/anthropic/client.rs @@ -206,9 +206,10 @@ impl AnthropicClient { )); } - let json: serde_json::Value = response.json().await.map_err(|e| { - LlmError::serialization("Failed to parse models response", Box::new(e)) - })?; + let json: serde_json::Value = response + .json() + .await + .map_err(|e| LlmError::serialization("Failed to parse models response", Box::new(e)))?; let models_array = json .get("data") diff --git a/scripts/test_token.rs b/scripts/test_token.rs deleted file mode 100644 index 0316e458..00000000 --- a/scripts/test_token.rs +++ /dev/null @@ -1,58 +0,0 @@ -//! Quick test of token exchange - -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize)] -struct TokenRequest<'a> { - grant_type: &'static str, - client_id: &'a str, - code: &'a str, - redirect_uri: &'a str, - code_verifier: &'a str, -} - -#[derive(Debug, Deserialize)] -struct TokenResponse { - access_token: String, - refresh_token: String, - expires_in: u64, - token_type: String, -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - // Read from auth_code.json - let cache_path = dirs::home_dir() - .unwrap() - .join(".cache/anthropic-oauth-test/auth_code.json"); - - let cached: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(&cache_path)?)?; - - let auth_code = cached["auth_code"].as_str().unwrap(); - let code_verifier = cached["code_verifier"].as_str().unwrap(); - - println!("Auth code: {}", auth_code); - println!("Code verifier: {}", code_verifier); - - let request_body = TokenRequest { - grant_type: "authorization_code", - client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", - code: auth_code, - redirect_uri: "http://localhost:8765/callback", - code_verifier, - }; - - println!("Request JSON: {}", serde_json::to_string_pretty(&request_body)?); - - let client = reqwest::Client::new(); - let response = client - .post("https://console.anthropic.com/v1/oauth/token") - .json(&request_body) - .send() - .await?; - - println!("Response status: {}", response.status()); - println!("Response body: {}", response.text().await?); - - Ok(()) -} From 7c310d1772bcd55019c76a5cad48728818800587 Mon Sep 17 00:00:00 2001 From: lambda Date: Sun, 21 Dec 2025 11:36:14 +0530 Subject: [PATCH 11/11] add chat completion spec --- spec/chat-completion.md | 947 +++++++++++++++++++++++++++++++++++++++ spec/chat-completion2.md | 289 ++++++++++++ 2 files changed, 1236 insertions(+) create mode 100644 spec/chat-completion.md create mode 100644 spec/chat-completion2.md diff --git a/spec/chat-completion.md b/spec/chat-completion.md new file mode 100644 index 00000000..4185c723 --- /dev/null +++ b/spec/chat-completion.md @@ -0,0 +1,947 @@ +# OpenAI Chat Completions API Specification + +A comprehensive specification for implementing a client for the OpenAI Chat Completions API. + +--- + +## Table of Contents + +1. [Overview](#overview) +2. [Endpoint](#endpoint) +3. [Authentication](#authentication) +4. [Request Structure](#request-structure) +5. [Message Types](#message-types) +6. [Request Parameters](#request-parameters) +7. [Response Structure](#response-structure) +8. [Streaming](#streaming) +9. [Tool/Function Calling](#toolfunction-calling) +10. [Structured Outputs](#structured-outputs) +11. [Vision (Image Input)](#vision-image-input) +12. [Audio Input/Output](#audio-inputoutput) +13. [Web Search](#web-search) +14. [Predicted Outputs](#predicted-outputs) +15. [Error Handling](#error-handling) +16. [Rate Limiting](#rate-limiting) +17. [Model Reference](#model-reference) + +--- + +## Overview + +The Chat Completions API generates model responses from a list of messages comprising a conversation. It supports text, images, and audio as inputs and can generate text, audio, and tool calls as outputs. + +**Base URL:** `https://api.openai.com/v1` + +--- + +## Endpoint + +### Create Chat Completion + +``` +POST /chat/completions +``` + +Creates a model response for the given chat conversation. + +--- + +## Authentication + +All requests require an API key in the `Authorization` header: + +``` +Authorization: Bearer YOUR_API_KEY +``` + +Optional organization header: +``` +OpenAI-Organization: YOUR_ORG_ID +``` + +--- + +## Request Structure + +```json +{ + "model": "gpt-4o", + "messages": [...], + // ... additional parameters +} +``` + +--- + +## Message Types + +Messages are the core of the Chat Completions API. Each message has a `role` and `content`. + +### System Message +Sets context and behavioral instructions for the model. +```json +{ + "role": "system", + "content": "You are a helpful assistant." +} +``` + +### Developer Message (O-Series Models Only) +Used instead of system messages for reasoning models (o1, o3, o4-mini). +```json +{ + "role": "developer", + "content": "Instructions for the model behavior." +} +``` +**Note:** When using system message with o1/o3 models, it's treated as a developer message. Don't mix both in the same request. + +### User Message +Represents input from the user. + +**Text only:** +```json +{ + "role": "user", + "content": "What is the capital of France?" +} +``` + +**With images (multimodal):** +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "auto" // "low", "high", or "auto" + } + } + ] +} +``` + +**With audio:** +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "What is being said?"}, + { + "type": "input_audio", + "input_audio": { + "data": "", + "format": "wav" // or "mp3" + } + } + ] +} +``` + +### Assistant Message +Model-generated response or injected assistant context. +```json +{ + "role": "assistant", + "content": "The capital of France is Paris." +} +``` + +**With tool calls:** +```json +{ + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"Paris\"}" + } + } + ] +} +``` + +### Tool Message +Result of a tool/function call. +```json +{ + "role": "tool", + "tool_call_id": "call_abc123", + "content": "{\"temperature\": 22, \"condition\": \"sunny\"}" +} +``` + +--- + +## Request Parameters + +### Required Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `model` | string | Model ID to use (e.g., `gpt-4o`, `gpt-4o-mini`, `o1`, `o3`) | +| `messages` | array | List of messages comprising the conversation | + +### Optional Parameters - Sampling & Generation + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `max_completion_tokens` | integer | null | Maximum tokens to generate (replaces deprecated `max_tokens`). Required for o-series models. | +| `max_tokens` | integer | null | **Deprecated.** Use `max_completion_tokens` instead. Not compatible with o-series models. | +| `temperature` | number | 1.0 | Sampling temperature (0-2). Higher = more random. **Not supported for reasoning models.** | +| `top_p` | number | 1.0 | Nucleus sampling parameter. **Not supported for reasoning models.** | +| `n` | integer | 1 | Number of completions to generate. | +| `stop` | string/array | null | Up to 4 sequences where the model stops generating. | +| `presence_penalty` | number | 0.0 | Penalizes tokens based on presence in text (-2.0 to 2.0). **Not supported for reasoning models.** | +| `frequency_penalty` | number | 0.0 | Penalizes tokens based on frequency in text (-2.0 to 2.0). **Not supported for reasoning models.** | +| `logit_bias` | object | null | Map of token IDs to bias values (-100 to 100). **Not supported for reasoning models.** | + +### Optional Parameters - Advanced + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `seed` | integer | null | For deterministic sampling (best effort). Monitor `system_fingerprint` for backend changes. | +| `logprobs` | boolean | false | Return log probabilities of output tokens. **Not supported for reasoning models.** | +| `top_logprobs` | integer | null | Number of most likely tokens to return (0-20). Requires `logprobs: true`. | +| `user` | string | null | Unique end-user identifier for abuse detection. | +| `stream` | boolean | false | Enable streaming responses via SSE. | +| `stream_options` | object | null | Streaming options: `{"include_usage": true}` to get token usage in final chunk. | + +### Optional Parameters - Response Format + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `response_format` | object | null | Controls output format. See [Structured Outputs](#structured-outputs). | + +### Optional Parameters - Tools & Functions + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `tools` | array | null | List of tools/functions the model may call. See [Tool Calling](#toolfunction-calling). | +| `tool_choice` | string/object | "auto" | Controls tool use: `"none"`, `"auto"`, `"required"`, or specific function. | +| `parallel_tool_calls` | boolean | true | Allow parallel function calls. Disable for gpt-4.1-nano-2025-04-14. | + +### Optional Parameters - Reasoning Models (O-Series) + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `reasoning_effort` | string | "medium" | Reasoning depth: `"minimal"` (gpt-5 only), `"low"`, `"medium"`, `"high"`. | + +### Optional Parameters - Multimodal + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `modalities` | array | ["text"] | Output types to generate: `["text"]`, `["text", "audio"]`. Audio requires audio-preview models. | +| `audio` | object | null | Audio output config: `{"voice": "alloy", "format": "wav"}`. Voices: `alloy`, `echo`, `shimmer`. Formats: `wav`, `mp3`, `pcm16` (streaming only). | + +### Optional Parameters - Service & Storage + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `service_tier` | string | "auto" | Processing tier: `"auto"`, `"default"`, `"flex"` (50% cheaper, slower), `"priority"` (lower latency, premium). | +| `store` | boolean | false | Store completion for model distillation/evals. | +| `metadata` | object | null | Up to 16 key-value pairs (keys: max 64 chars, values: max 512 chars). | + +### Optional Parameters - Web Search + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `web_search_options` | object | null | Enable web search for search-preview models. See [Web Search](#web-search). | + +### Optional Parameters - Predicted Outputs + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prediction` | object | null | Predicted output for faster responses. See [Predicted Outputs](#predicted-outputs). | + +--- + +## Response Structure + +### Non-Streaming Response + +```json +{ + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1702685778, + "model": "gpt-4o-2024-08-06", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The capital of France is Paris.", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 13, + "completion_tokens": 7, + "total_tokens": 20, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0 + } + }, + "service_tier": "default" +} +``` + +### Response Fields + +| Field | Type | Description | +|-------|------|-------------| +| `id` | string | Unique identifier for the completion | +| `object` | string | Always `"chat.completion"` | +| `created` | integer | Unix timestamp of creation | +| `model` | string | Model used for completion | +| `system_fingerprint` | string | Backend configuration fingerprint (for determinism tracking with `seed`) | +| `choices` | array | List of completion choices | +| `usage` | object | Token usage statistics | +| `service_tier` | string | Service tier used (may differ from requested) | + +### Choice Object + +| Field | Type | Description | +|-------|------|-------------| +| `index` | integer | Index of this choice | +| `message` | object | The generated message | +| `finish_reason` | string | Why generation stopped | +| `logprobs` | object/null | Log probability information (if requested) | + +### Finish Reasons + +| Value | Description | +|-------|-------------| +| `stop` | Natural stop or stop sequence reached | +| `length` | Max token limit reached | +| `tool_calls` | Model requested tool execution | +| `content_filter` | Content filtered by safety systems | +| `function_call` | **Deprecated.** Function call requested | + +### Usage Object + +| Field | Type | Description | +|-------|------|-------------| +| `prompt_tokens` | integer | Tokens in the prompt | +| `completion_tokens` | integer | Tokens generated | +| `total_tokens` | integer | Total tokens used | +| `prompt_tokens_details.cached_tokens` | integer | Cached prompt tokens | +| `prompt_tokens_details.audio_tokens` | integer | Audio input tokens | +| `completion_tokens_details.reasoning_tokens` | integer | Hidden reasoning tokens (o-series) | +| `completion_tokens_details.audio_tokens` | integer | Audio output tokens | +| `completion_tokens_details.accepted_prediction_tokens` | integer | Accepted predicted tokens | +| `completion_tokens_details.rejected_prediction_tokens` | integer | Rejected predicted tokens | + +--- + +## Streaming + +Enable streaming with `stream: true`. Responses are sent as Server-Sent Events (SSE). + +### Request +```json +{ + "model": "gpt-4o", + "messages": [...], + "stream": true, + "stream_options": {"include_usage": true} +} +``` + +### Response Headers +``` +Content-Type: text/event-stream +Transfer-Encoding: chunked +``` + +### Chunk Format + +```json +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","system_fingerprint":"fp_44709d6fcb","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + +data: {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":12,"total_tokens":21}} + +data: [DONE] +``` + +### Key Differences from Non-Streaming + +| Aspect | Non-Streaming | Streaming | +|--------|---------------|-----------| +| Object type | `chat.completion` | `chat.completion.chunk` | +| Content field | `message.content` | `delta.content` | +| Role field | `message.role` | `delta.role` (first chunk only) | +| Usage | Always included | Only with `stream_options.include_usage: true`, in final chunk | +| Termination | Single response | `data: [DONE]` event | + +--- + +## Tool/Function Calling + +Tools allow the model to call external functions. The API returns tool call requests; execution is your responsibility. + +### Defining Tools + +```json +{ + "model": "gpt-4o", + "messages": [...], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "strict": true, + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City and state, e.g., San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location", "unit"], + "additionalProperties": false + } + } + } + ], + "tool_choice": "auto" +} +``` + +### Tool Choice Options + +| Value | Description | +|-------|-------------| +| `"none"` | Don't call any tools | +| `"auto"` | Model decides whether to call tools | +| `"required"` | Must call at least one tool | +| `{"type": "function", "function": {"name": "my_func"}}` | Force specific function | + +### Strict Mode + +Setting `strict: true` ensures function calls reliably adhere to the schema. Requirements: +- `additionalProperties: false` for each object +- All fields in `properties` must be in `required` +- Optional fields: add `null` as a type option + +### Tool Call Response + +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\":\"Paris\",\"unit\":\"celsius\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + }] +} +``` + +### Providing Tool Results + +```json +{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"}, + { + "role": "assistant", + "content": null, + "tool_calls": [{"id": "call_abc123", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"Paris\",\"unit\":\"celsius\"}"}}] + }, + { + "role": "tool", + "tool_call_id": "call_abc123", + "content": "{\"temperature\": 22, \"condition\": \"sunny\"}" + } + ] +} +``` + +--- + +## Structured Outputs + +Force the model to output valid JSON matching a schema. + +### JSON Mode (Basic) + +```json +{ + "model": "gpt-4o", + "messages": [...], + "response_format": {"type": "json_object"} +} +``` + +### JSON Schema Mode (Strict) + +```json +{ + "model": "gpt-4o", + "messages": [...], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "person_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "email": {"type": ["string", "null"]} + }, + "required": ["name", "age", "email"], + "additionalProperties": false + } + } + } +} +``` + +### Schema Requirements + +- Root cannot be `anyOf` type +- All fields must be in `required` +- `additionalProperties: false` required +- Some JSON Schema keywords not supported (e.g., `format` for dates) + +### Refusals + +When the model refuses to generate structured output: +```json +{ + "choices": [{ + "message": { + "role": "assistant", + "content": null, + "refusal": "I cannot provide information about that topic." + } + }] +} +``` + +--- + +## Vision (Image Input) + +### Image URL + +```json +{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "auto" + } + } + ] +} +``` + +### Base64 Image + +```json +{ + "type": "image_url", + "image_url": { + "url": "data:image/jpeg;base64,{base64_encoded_data}", + "detail": "high" + } +} +``` + +### Detail Levels + +| Value | Description | Token Cost | +|-------|-------------|------------| +| `low` | Fixed low resolution | Base cost (e.g., 85 tokens) | +| `high` | Full resolution, tiled | Variable based on dimensions | +| `auto` | Model decides | Variable | + +### Token Calculation (High Detail) + +1. Scale to fit 2048x2048 (maintaining aspect ratio) +2. Scale shortest side to 768px +3. Count 512px tiles +4. Cost = (tiles × 170) + 85 tokens (approximate) + +--- + +## Audio Input/Output + +Requires `gpt-4o-audio-preview` or `gpt-4o-mini-audio-preview` models. + +### Request with Audio Output + +```json +{ + "model": "gpt-4o-audio-preview", + "modalities": ["text", "audio"], + "audio": { + "voice": "alloy", + "format": "wav" + }, + "messages": [...] +} +``` + +### Audio Configuration + +| Field | Options | Description | +|-------|---------|-------------| +| `voice` | `alloy`, `echo`, `shimmer` | Voice for audio output | +| `format` | `wav`, `mp3`, `pcm16` | Output format. `pcm16` only for streaming. | + +### Audio Input Message + +```json +{ + "role": "user", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": "", + "format": "wav" + } + } + ] +} +``` + +**Max audio file size:** 20 MB + +--- + +## Web Search + +Available with search-preview models: `gpt-4o-search-preview`, `gpt-4o-mini-search-preview`, `gpt-5-search-api`. + +### Basic Request + +```json +{ + "model": "gpt-4o-search-preview", + "web_search_options": {}, + "messages": [ + {"role": "user", "content": "What happened in tech news today?"} + ] +} +``` + +### With Options + +```json +{ + "model": "gpt-4o-search-preview", + "web_search_options": { + "search_context_size": "medium", + "user_location": { + "type": "approximate", + "approximate": { + "country": "US", + "city": "San Francisco", + "region": "California", + "timezone": "America/Los_Angeles" + } + } + }, + "messages": [...] +} +``` + +### Search Context Size + +| Value | Description | +|-------|-------------| +| `low` | Faster, cheaper, less accurate | +| `medium` | Default balance | +| `high` | More thorough, slower, more expensive | + +### Response Annotations + +```json +{ + "choices": [{ + "message": { + "content": "According to recent news [1]...", + "annotations": [ + { + "type": "url_citation", + "start_index": 28, + "end_index": 31, + "url": "https://example.com/article", + "title": "Article Title" + } + ] + } + }] +} +``` + +--- + +## Predicted Outputs + +Reduce latency when most output is known ahead of time. Supported on GPT-4o and GPT-4o-mini. + +### Request + +```json +{ + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Update the email address to john@new.com"}, + {"role": "user", "content": "Current data: {\"name\": \"John\", \"email\": \"john@old.com\"}"} + ], + "prediction": { + "type": "content", + "content": "{\"name\": \"John\", \"email\": \"john@old.com\"}" + } +} +``` + +### Limitations + +- Text-only (no audio/images) +- Not compatible with: `n > 1`, `logprobs`, `presence_penalty > 0` +- Rejected tokens are still billed +- May increase costs if predictions are poor matches + +### Tracking Prediction Usage + +Check `usage.completion_tokens_details`: +- `accepted_prediction_tokens`: Tokens that matched +- `rejected_prediction_tokens`: Tokens that didn't match (still billed) + +--- + +## Error Handling + +### Error Response Format + +```json +{ + "error": { + "message": "Invalid API key", + "type": "invalid_request_error", + "param": null, + "code": "invalid_api_key" + } +} +``` + +### HTTP Status Codes + +| Code | Description | Action | +|------|-------------|--------| +| 400 | Bad Request | Fix request parameters | +| 401 | Unauthorized | Check API key | +| 403 | Forbidden | Check permissions/organization | +| 404 | Not Found | Check model ID/endpoint | +| 429 | Rate Limited | Implement exponential backoff | +| 500 | Server Error | Retry with backoff | +| 502 | Bad Gateway | Retry with backoff | +| 503 | Service Unavailable | Retry with backoff | + +### Common Error Codes + +| Code | Description | +|------|-------------| +| `invalid_api_key` | API key is invalid | +| `insufficient_quota` | Quota exceeded | +| `rate_limit_exceeded` | Too many requests | +| `model_not_found` | Model doesn't exist or no access | +| `context_length_exceeded` | Input too long | +| `invalid_request_error` | Malformed request | +| `content_policy_violation` | Content filtered | + +### Retry Strategy + +``` +1. Wait: min(2^attempt * 1000ms, 60000ms) + random_jitter +2. Max attempts: 5 +3. On 429: Check x-ratelimit-reset-* headers +4. On 5xx: Always retry +``` + +--- + +## Rate Limiting + +### Rate Limit Types + +| Type | Description | +|------|-------------| +| RPM | Requests per minute | +| RPD | Requests per day | +| TPM | Tokens per minute | +| TPD | Tokens per day | +| IPM | Images per minute | + +### Response Headers + +| Header | Description | +|--------|-------------| +| `x-ratelimit-limit-requests` | Max requests allowed | +| `x-ratelimit-limit-tokens` | Max tokens allowed | +| `x-ratelimit-remaining-requests` | Requests remaining | +| `x-ratelimit-remaining-tokens` | Tokens remaining | +| `x-ratelimit-reset-requests` | Time until request limit resets | +| `x-ratelimit-reset-tokens` | Time until token limit resets | + +### Token Refill + +Tokens refill on a rolling 60-second window, not all at once. + +--- + +## Model Reference + +### GPT-4o Series + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `gpt-4o` | 128k | 16k | Latest GPT-4o | +| `gpt-4o-2024-08-06` | 128k | 16k | Structured outputs support | +| `gpt-4o-mini` | 128k | 16k | Faster, cheaper | +| `gpt-4o-audio-preview` | 128k | 16k | Audio I/O support | +| `gpt-4o-search-preview` | 128k | 16k | Web search | + +### GPT-4.1 Series + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `gpt-4.1` | 1M | ~32k | Extended context | +| `gpt-4.1-mini` | 1M | ~32k | Smaller, faster | +| `gpt-4.1-nano` | 1M | ~32k | Smallest | + +### O-Series (Reasoning Models) + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `o1` | 200k | 100k | Advanced reasoning | +| `o1-mini` | 128k | 65k | Faster reasoning | +| `o1-preview` | 128k | 32k | Preview version | +| `o3` | 200k | 100k | Latest reasoning | +| `o3-mini` | 200k | 100k | Smaller reasoning | +| `o4-mini` | 200k | 100k | Newest mini | + +### GPT-5 Series + +| Model | Context Window | Max Output | Notes | +|-------|----------------|------------|-------| +| `gpt-5` | 128k+ | Varies | Flagship model | +| `gpt-5-mini` | 128k | 16k | Smaller variant | +| `gpt-5-search-api` | 128k | 16k | Web search | + +### Parameter Support by Model Type + +| Parameter | GPT-4o/4.1 | O-Series | Notes | +|-----------|------------|----------|-------| +| `temperature` | ✅ | ❌ | | +| `top_p` | ✅ | ❌ | | +| `presence_penalty` | ✅ | ❌ | | +| `frequency_penalty` | ✅ | ❌ | | +| `logprobs` | ✅ | ❌ | | +| `logit_bias` | ✅ | ❌ | | +| `max_tokens` | ✅ | ❌ | Deprecated | +| `max_completion_tokens` | ✅ | ✅ | Preferred | +| `reasoning_effort` | ❌ | ✅ | | +| `developer` role | ❌ | ✅ | Use `system` for GPT models | + +--- + +## Implementation Checklist + +### Core Features +- [ ] Basic text completion +- [ ] Multi-turn conversations +- [ ] System/developer messages +- [ ] Streaming with SSE parsing +- [ ] Token usage tracking + +### Advanced Features +- [ ] Tool/function calling +- [ ] Structured outputs (JSON mode & JSON Schema) +- [ ] Vision (image input) +- [ ] Audio input/output +- [ ] Web search integration +- [ ] Predicted outputs + +### Error Handling +- [ ] HTTP error responses +- [ ] Rate limit headers parsing +- [ ] Exponential backoff retry +- [ ] Content filter detection + +### Model-Specific +- [ ] Reasoning model parameters +- [ ] Service tier selection +- [ ] Seed for reproducibility + +--- + +## Sources + +- [Chat Completions API Reference](https://platform.openai.com/docs/api-reference/chat/) +- [Chat Completions Guide](https://platform.openai.com/docs/guides/chat-completions) +- [Function Calling Guide](https://platform.openai.com/docs/guides/function-calling) +- [Structured Outputs Guide](https://platform.openai.com/docs/guides/structured-outputs) +- [Images and Vision Guide](https://platform.openai.com/docs/guides/images-vision) +- [Audio and Speech Guide](https://platform.openai.com/docs/guides/audio) +- [Web Search Guide](https://platform.openai.com/docs/guides/tools-web-search) +- [Reasoning Models Guide](https://platform.openai.com/docs/guides/reasoning) +- [Error Codes Guide](https://platform.openai.com/docs/guides/error-codes) +- [Rate Limits Guide](https://platform.openai.com/docs/guides/rate-limits) +- [Models Overview](https://platform.openai.com/docs/models) +- [OpenAI Cookbook - Using Logprobs](https://cookbook.openai.com/examples/using_logprobs) +- [OpenAI Cookbook - Reproducible Outputs](https://cookbook.openai.com/examples/reproducible_outputs_with_the_seed_parameter) + diff --git a/spec/chat-completion2.md b/spec/chat-completion2.md new file mode 100644 index 00000000..30d9c1f6 --- /dev/null +++ b/spec/chat-completion2.md @@ -0,0 +1,289 @@ +# OpenAI Chat Completions API + +This document specifies the **Chat Completions** REST API under `POST /v1/chat/completions` (plus stored-completion CRUD endpoints), including request/response shapes, streaming (SSE), and related features (tools/function calling, structured outputs, multimodal inputs, audio, and web search). + +--- + +## 1) Positioning / when to use Chat Completions +OpenAI’s docs now emphasize the **Responses API** for many new integrations, but **Chat Completions** remains a supported API surface and is still the correct choice if you specifically want `/v1/chat/completions` semantics and object shapes. + +--- + +## 2) Base URL, auth, headers, and versioning + +### 2.1 Base URL +- Base URL: `https://api.openai.com/v1` + +### 2.2 Authentication +- Header auth: `Authorization: Bearer ` + +### 2.3 Core request headers +- `Content-Type: application/json` for JSON requests. +- Optional account routing headers: + - `OpenAI-Organization: ` + - `OpenAI-Project: ` + +### 2.4 Debugging / observability headers +- Responses include a request identifier header you should log (notably `x-request-id`) to correlate failures and support requests. +- Responses include timing information like `openai-processing-ms`. + +### 2.5 Backward compatibility expectations +- OpenAI publishes a backward-compatibility policy for the API; clients should be resilient to additive fields and new enum values. + +**Client requirement:** Implement JSON decoding as forward-compatible: ignore unknown fields; do not exhaustively match enums without a fallback. + +--- + +## 3) Error model, retries, and rate limits + +### 3.1 Error response shape (high-level) +Errors are returned with a top-level `error` object (containing fields like `message`, `type`, and sometimes `param`/`code`), alongside standard HTTP status codes. + +### 3.2 Recommended retry policy (client-side) +- Retry only **transient** failures (commonly `429`, `500`, `503`) using exponential backoff + jitter; do **not** blindly retry `400`/`401`/`403`/`404` because they’re usually permanent for that request. +- Always log `x-request-id` (and any client request id you add) for diagnostics. + +### 3.3 Rate limit headers +OpenAI documents rate limits and returns `x-ratelimit-*` headers (covering request and token budgets with limit/remaining/reset patterns). Your client should parse and surface these for adaptive throttling. + +--- + +## 4) Endpoint inventory (Chat Completions) + +### 4.1 Create (optionally stream) +- `POST /v1/chat/completions` — generate an assistant response; supports streaming via SSE; can optionally store the completion. + +### 4.2 Stored completion retrieval & management (requires `store: true` on create) +- `GET /v1/chat/completions/{completion_id}` — retrieve a stored completion. +- `GET /v1/chat/completions` — list stored completions (pagination + filters). +- `GET /v1/chat/completions/{completion_id}/messages` — list messages from a stored completion (pagination). +- `POST /v1/chat/completions/{completion_id}` — update metadata on a stored completion. +- `DELETE /v1/chat/completions/{completion_id}` — delete a stored completion. + +--- + +## 5) `POST /v1/chat/completions` — Create + +### 5.1 Primary use cases +- Standard “chat” generation: model responds to a conversation history you provide in `messages`. +- Tool/function calling: model asks your client to call functions; client executes and feeds results back. +- Streaming UI: token-by-token deltas over SSE. +- Structured outputs: force JSON or schema-conformant JSON. +- Multimodal input: images and audio in the conversation (model-dependent). +- Web search models: model performs a web search and returns citations/annotations. +- Persisting outputs for later retrieval: set `store: true` and then use stored completion endpoints. + +--- + +### 5.2 Request body (top-level fields) +**Required** +- `model: string` — model identifier. +- `messages: array` — conversation inputs (text and/or content parts). + +**Common optional fields (sampling / stopping / token limits)** +- `temperature?: number` — sampling temperature. +- `top_p?: number` — nucleus sampling. +- `n?: integer` — number of choices to generate. +- `stop?: string | string[] | null` — stop sequences; docs note model-specific support limitations (e.g., not supported by some “o” models). +- `max_completion_tokens?: integer | null` — cap for generated tokens (including reasoning tokens where applicable). +- `max_tokens?: integer | null` — deprecated; docs note incompatibility with some newer model families. +- `presence_penalty?: number | null`, `frequency_penalty?: number | null` + +**Logprobs** +- `logprobs?: boolean | null` — request logprobs in output. +- `top_logprobs?: integer | null` — number of top tokens to return (when `logprobs` is true). + +**Streaming** +- `stream?: boolean | null` — enable SSE. +- `stream_options?: object | null` — stream options; docs show `include_usage` behavior. + +**Tools / function calling** +- `tools?: array` — tool definitions (notably function tools). +- `tool_choice?: "none" | "auto" | "required" | object` — tool selection policy (including forcing a specific tool). +- `parallel_tool_calls?: boolean` — whether model may emit multiple tool calls in one turn. +- `functions` / `function_call` — deprecated tool/function fields. + +**Structured outputs** +- `response_format?: object` — JSON mode and schema mode. + +**Multimodal output** +- `modalities?: string[]` — request output modalities (model-dependent). +- `audio?: object | null` — audio output settings when requesting audio modality. + +**Other** +- `reasoning_effort?: string` — reasoning control for supported models (docs note model-specific constraints). +- `verbosity?: string` — output verbosity control for supported models. +- `prediction?: object` — “predicted output” optimization payload. +- `web_search_options?: object` — used with web-search models. +- `service_tier?: string` — service tier selection. +- `seed?: integer | null` — deprecated determinism hint. +- `logit_bias?: object | null` — token-level biasing. + +**Storing & metadata** +- `store?: boolean | null` — enable stored-completion retrieval; docs warn that some large inputs (e.g., large images) may be dropped when storing. +- `metadata?: object` — user-defined key/value metadata for stored items. +- `user?: string` — deprecated in favor of newer identifiers (docs reference `safety_identifier` and `prompt_cache_key`). + +--- + +### 5.3 `messages[]` — conversation schema (client-facing) +Chat Completions are **stateless**: you send the prior turns each request (or a summarized state). + +**Message object (conceptual)** +- `role: string` — typical roles include `system`, `user`, `assistant`, and tool-related roles (exact accepted roles depend on the API mode and model). +- `content: string | array` — either plain text or a list of typed content parts for multimodal inputs. + +**Content parts (multimodal)** +- Image input uses a content-part pattern documented in the Images/Vision guide for Chat Completions. +- Audio input uses a content-part pattern documented in the Audio guide for Chat Completions. + +**Client requirement:** Model capabilities differ; your client should not hard-code “text only” assumptions, and should treat message `content` as a tagged union. + +--- + +## 6) `POST /v1/chat/completions` — Response + +### 6.1 Non-streaming response object +A successful create returns a **chat completion object** containing identifiers and an array of `choices`. + +**Top-level (commonly present)** +- `id: string` +- `object: "chat.completion"` +- `created: integer` (unix seconds) +- `model: string` +- `choices: array` +- `usage: object` (token accounting) + +**Choice object (conceptual)** +- `index: integer` +- `message: { role, content, ... }` +- `finish_reason: string | null` +- `logprobs: object | null` (if requested/supported) + +**Client requirement:** Do not assume a single choice; handle `n > 1` by returning multiple candidate messages. + +--- + +## 7) Streaming (SSE) — `stream: true` + +### 7.1 Transport / framing +Chat Completions streaming uses **Server-Sent Events** (SSE) delivering **chat completion chunk** objects incrementally. + +The OpenAI API reference for legacy streaming explicitly describes “data-only SSE” messages and a terminal `data: [DONE]` sentinel; Chat Completions streaming is documented in terms of chunk objects and deltas. For robust clients, support both: end on connection close and/or `[DONE]` sentinel if present. + +### 7.2 Chunk object shape +Each event contains a `chat.completion.chunk` object with `choices[].delta` carrying incremental text/tool-call deltas. + +Key fields: +- `object: "chat.completion.chunk"` +- `id`, `created`, `model` +- `choices[]: { index, delta, finish_reason? }` +- Optional `usage` when using `stream_options` to include usage. + +### 7.3 Delta accumulation rules (client requirement) +- Concatenate `choices[i].delta.content` fragments in-order. +- Treat tool-call deltas as structured fragments that must be assembled into a complete tool call before execution. +- Terminate a choice when `finish_reason` becomes non-null. + +--- + +## 8) Tools / function calling (Chat Completions) + +### 8.1 Use case +Let the model request structured, executable actions (API calls, database queries, etc.) by emitting tool calls; your client executes them and returns tool outputs back into the conversation. + +### 8.2 Request fields (high-level) +- Provide available tools in `tools`. +- Control selection with `tool_choice`. +- Allow or disallow multiple calls with `parallel_tool_calls`. + +### 8.3 Response behavior (high-level) +- Assistant messages may include tool-call descriptors instead of (or in addition to) normal text content. +- Client must translate tool calls into actual executions and then append tool results as subsequent messages and call the API again. + +--- + +## 9) Structured outputs (`response_format`) + +### 9.1 Use case +- Enforce machine-parseable JSON output (basic JSON mode) or schema-conformant JSON (structured outputs) for deterministic integration with downstream code. + +### 9.2 Modes (documented) +- JSON mode via a `response_format` object (older JSON mode). +- Schema-based structured outputs via a `response_format` object (JSON Schema). + +**Client requirement:** Treat `response_format` as a tagged union; do not assume only one sub-variant. + +--- + +## 10) Multimodal inputs (images + audio) and audio outputs + +### 10.1 Image inputs (vision) +- Chat Completions supports image input via content parts as documented in the Images/Vision guide when `api-mode=chat`. +- Client should support: + - Remote URLs + - Base64 “data:” URLs + - Any per-image options described in the guide (e.g., detail level), model-dependent. + +### 10.2 Audio inputs and outputs +- Audio input: send base64-encoded audio as typed content parts (guide documents the required fields). +- Audio output: request audio via `modalities` + `audio` settings (voice/format), then decode audio bytes from the response. + +**Client requirement:** Audio and image support are model-dependent; surface capability errors clearly (do not silently fall back unless the caller asked you to). + +--- + +## 11) Web search (Chat Completions) + +### 11.1 Use case +For web-search-capable models, allow the model to retrieve information from the web before responding, returning citation annotations suitable for UI rendering and attribution. + +### 11.2 Request fields +- Use `web_search_options` alongside a web-search model. + +### 11.3 Response fields (citations / annotations) +- The Web Search tool guide documents citation annotations (including URL citations) and how to render them. + +**Client requirement:** Preserve and expose annotations separately from text so callers can render citations reliably even if the visible text format changes. + +--- + +## 12) Stored completions (`store: true`) — retrieval, listing, message listing, metadata update, delete + +### 12.1 Use case +- Persist responses for later inspection, evaluation, or building UIs that revisit prior outputs without keeping your own transcript store. + +### 12.2 Key behaviors +- `store: true` on create enables subsequent retrieval/listing endpoints. +- Update endpoint is for metadata updates on stored items. +- Messages endpoint returns the stored message list with pagination controls. + +### 12.3 Pagination (high-level) +List endpoints support typical cursor pagination parameters like `after`, plus `limit` and `order`. + +--- + +## 13) Practical client requirements checklist + +### 13.1 HTTP layer +- Keep-alive connections; configurable request timeout; request body size limits; gzip/deflate handling as supported by your HTTP library. + +### 13.2 JSON decoding +- Tolerate unknown fields and enum expansions; treat tagged unions (`content`, `response_format`, tools) as extensible. + +### 13.3 Streaming +- Implement a correct SSE parser: + - parse `data:` lines into JSON chunks + - assemble `delta` fragments per `choice.index` + - handle both “connection close” and `[DONE]` style termination defensively. + +### 13.4 Tool calling +- Support multiple tool calls (including parallel), and tool-call assembly in both non-streaming and streaming modes. + +### 13.5 Errors & rate limits +- Parse error objects; map to typed errors; classify retryable vs permanent. +- Parse `x-ratelimit-*` headers for adaptive throttling and caller visibility. + +--- + +If you want, I can convert this into a single “contract” section (request/response JSON Schema-like tables for every field and nested object) purely as documentation—still no code.