diff --git a/Cargo.lock b/Cargo.lock index 6c5b682..518fb4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -576,6 +576,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + [[package]] name = "futures-macro" version = "0.3.32" @@ -612,8 +618,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", + "futures-io", "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "slab", ] @@ -1044,6 +1053,7 @@ dependencies = [ "cliclack", "console", "ctrlc", + "futures-util", "httpmock", "mime_guess", "reqwest", @@ -1530,12 +1540,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -2396,6 +2408,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.94" diff --git a/Cargo.toml b/Cargo.toml index 78f55dc..acd84ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,8 @@ clap_complete = "4.6.5" cliclack = "0.5.4" console = "0.16.3" ctrlc = "3.5.2" -reqwest = { version = "0.12.15", default-features = false, features = ["brotli", "deflate", "gzip", "json", "multipart", "rustls-tls"] } +reqwest = { version = "0.12.15", default-features = false, features = ["brotli", "deflate", "gzip", "json", "multipart", "rustls-tls", "stream"] } +futures-util = "0.3.31" scraper = "0.27.0" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.150" diff --git a/src/api.rs b/src/api.rs index a139b2c..c18d5a0 100644 --- a/src/api.rs +++ b/src/api.rs @@ -4,6 +4,7 @@ use std::future::Future; use std::path::{Path, PathBuf}; use std::time::Duration; +use futures_util::StreamExt; use reqwest::multipart; use reqwest::{Client, StatusCode, Url, header}; use scraper::{Html, Selector}; @@ -21,28 +22,28 @@ use crate::error::KagiError; use crate::http::{self, map_transport_error}; use crate::local; use crate::parser::{ - parse_assistant_profile_form, parse_assistant_profile_list, parse_assistant_thread_list, - parse_custom_bang_form, parse_custom_bang_list, parse_lens_form, parse_lens_list, - parse_redirect_form, parse_redirect_list, + parse_assistant_model_catalog, parse_assistant_profile_form, parse_assistant_profile_list, + parse_assistant_thread_list, parse_custom_bang_form, parse_custom_bang_list, parse_lens_form, + parse_lens_list, parse_redirect_form, parse_redirect_list, }; #[cfg(test)] use crate::types::ApiMeta; use crate::types::{ AlternativeTranslationsResponse, AskPageRequest, AskPageResponse, AskPageSource, - AssistantMessage, AssistantMeta, AssistantProfileCreateRequest, AssistantProfileDetails, - AssistantProfileSummary, AssistantProfileUpdateRequest, AssistantPromptRequest, - AssistantPromptResponse, AssistantThread, AssistantThreadDeleteResponse, - AssistantThreadExportResponse, AssistantThreadListResponse, AssistantThreadOpenResponse, - AssistantThreadPagination, CustomBangCreateRequest, CustomBangDetails, CustomBangSummary, - CustomBangUpdateRequest, DeletedResourceResponse, EnrichResponse, ExtractPageInput, - ExtractRequest, ExtractResponse, FastGptRequest, FastGptResponse, LensCreateRequest, - LensDetails, LensSummary, LensUpdateRequest, NewsBatchCategories, NewsBatchCategory, - NewsCategoriesResponse, NewsCategoryMetadata, NewsCategoryMetadataList, NewsChaos, - NewsChaosResponse, NewsContentFilterSummary, NewsFilterPresetListEntry, - NewsFilterPresetListResponse, NewsLatestBatch, NewsResolvedCategory, NewsStoriesPayload, - NewsStoriesResponse, NewsStoryContentFilterSummary, RedirectRuleCreateRequest, - RedirectRuleDetails, RedirectRuleSummary, RedirectRuleUpdateRequest, SmallWebFeed, - SubscriberSummarization, SubscriberSummarizeMeta, SubscriberSummarizeRequest, + AssistantMessage, AssistantMeta, AssistantModelCatalog, AssistantProfileCreateRequest, + AssistantProfileDetails, AssistantProfileSummary, AssistantProfileUpdateRequest, + AssistantPromptRequest, AssistantPromptResponse, AssistantPromptStreamEvent, AssistantThread, + AssistantThreadDeleteResponse, AssistantThreadExportResponse, AssistantThreadListResponse, + AssistantThreadOpenResponse, AssistantThreadPagination, CustomBangCreateRequest, + CustomBangDetails, CustomBangSummary, CustomBangUpdateRequest, DeletedResourceResponse, + EnrichResponse, ExtractPageInput, ExtractRequest, ExtractResponse, FastGptRequest, + FastGptResponse, LensCreateRequest, LensDetails, LensSummary, LensUpdateRequest, + NewsBatchCategories, NewsBatchCategory, NewsCategoriesResponse, NewsCategoryMetadata, + NewsCategoryMetadataList, NewsChaos, NewsChaosResponse, NewsContentFilterSummary, + NewsFilterPresetListEntry, NewsFilterPresetListResponse, NewsLatestBatch, NewsResolvedCategory, + NewsStoriesPayload, NewsStoriesResponse, NewsStoryContentFilterSummary, + RedirectRuleCreateRequest, RedirectRuleDetails, RedirectRuleSummary, RedirectRuleUpdateRequest, + SmallWebFeed, SubscriberSummarization, SubscriberSummarizeMeta, SubscriberSummarizeRequest, SubscriberSummarizeResponse, SummarizeRequest, SummarizeResponse, TextAlignmentsResponse, ToggleResourceResponse, TranslateBootstrapMetadata, TranslateCommandRequest, TranslateDetectedLanguage, TranslateOptionState, TranslateResponse, TranslateTextResponse, @@ -781,6 +782,53 @@ pub async fn execute_assistant_prompt( parse_assistant_prompt_stream(&body) } +/// Sends a prompt to Kagi Assistant and calls `on_event` for every message update. +/// +/// The returned value is the same final response produced by [`execute_assistant_prompt`]. +pub async fn execute_assistant_prompt_stream( + request: &AssistantPromptRequest, + token: &str, + mut on_event: F, +) -> Result +where + F: FnMut(&AssistantPromptStreamEvent) -> Result<(), KagiError>, +{ + let response = match build_assistant_prompt_payload(request)? { + AssistantPromptPayload::Json(state) => { + send_assistant_stream_request( + &http::kagi_url(KAGI_ASSISTANT_PROMPT_PATH), + &state, + token, + ) + .await? + } + AssistantPromptPayload::Multipart { state, attachments } => { + send_assistant_multipart_stream_request( + &http::kagi_url(KAGI_ASSISTANT_PROMPT_PATH), + &state, + &attachments, + token, + ) + .await? + } + }; + + handle_assistant_prompt_stream_response(response, "Assistant prompt", &mut on_event).await +} + +/// Lists Assistant base models exposed by the custom assistant form. +pub async fn execute_assistant_model_catalog( + token: &str, +) -> Result { + let html = fetch_authenticated_html( + &http::kagi_url(KAGI_SETTINGS_CUSTOM_ASSISTANT_PATH), + token, + "custom assistant form", + ) + .await?; + parse_assistant_model_catalog(&html) +} + /// Lists all Kagi Assistant threads for the authenticated user. /// /// # Arguments @@ -3437,6 +3485,15 @@ async fn execute_assistant_stream( token: &str, surface: &str, ) -> Result { + let response = send_assistant_stream_request(url, payload, token).await?; + handle_assistant_stream_response(response, surface).await +} + +async fn send_assistant_stream_request( + url: &str, + payload: &Value, + token: &str, +) -> Result { if token.trim().is_empty() { return Err(KagiError::Auth( "missing Kagi session token (expected KAGI_SESSION_TOKEN)".to_string(), @@ -3444,7 +3501,7 @@ async fn execute_assistant_stream( } let client = http::client_assistant_stream()?; - let response = client + client .post(url) .header(header::COOKIE, format!("kagi_session={token}")) .header(header::CONTENT_TYPE, "application/json") @@ -3452,9 +3509,7 @@ async fn execute_assistant_stream( .json(payload) .send() .await - .map_err(map_transport_error)?; - - handle_assistant_stream_response(response, surface).await + .map_err(map_transport_error) } async fn execute_assistant_multipart_stream( @@ -3464,6 +3519,16 @@ async fn execute_assistant_multipart_stream( token: &str, surface: &str, ) -> Result { + let response = send_assistant_multipart_stream_request(url, state, attachments, token).await?; + handle_assistant_stream_response(response, surface).await +} + +async fn send_assistant_multipart_stream_request( + url: &str, + state: &Value, + attachments: &[AssistantAttachmentPayload], + token: &str, +) -> Result { if token.trim().is_empty() { return Err(KagiError::Auth( "missing Kagi session token (expected KAGI_SESSION_TOKEN)".to_string(), @@ -3498,16 +3563,14 @@ async fn execute_assistant_multipart_stream( form = form.part("file", file_part); } - let response = client + client .post(url) .header(header::COOKIE, format!("kagi_session={token}")) .header(header::ACCEPT, "application/vnd.kagi.stream") .multipart(form) .send() .await - .map_err(map_transport_error)?; - - handle_assistant_stream_response(response, surface).await + .map_err(map_transport_error) } async fn handle_assistant_stream_response( @@ -3569,14 +3632,105 @@ async fn handle_assistant_stream_response( } } +async fn handle_assistant_prompt_stream_response( + response: reqwest::Response, + surface: &str, + on_event: &mut F, +) -> Result +where + F: FnMut(&AssistantPromptStreamEvent) -> Result<(), KagiError>, +{ + match response.status() { + StatusCode::OK => { + let mut parser = AssistantPromptStreamParser::default(); + let mut pending = String::new(); + let mut stream = response.bytes_stream(); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|error| { + KagiError::Network(format!("failed to read {surface} response body: {error}")) + })?; + pending.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(index) = pending.find("\0\n") { + let frame = pending[..index].to_string(); + pending.drain(..index + 2); + if let Some(event) = parser.process_frame(&frame)? { + on_event(&event)?; + } + } + } + + if looks_like_html_document(&pending) { + return Err(KagiError::Auth( + "invalid or expired Kagi session token".to_string(), + )); + } + + if !pending.trim().is_empty() + && let Some(event) = parser.process_frame(&pending)? + { + on_event(&event)?; + } + + parser.finish() + } + status @ (StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN) => { + let body = http::read_error_body(response, surface).await; + Err(KagiError::Auth(format!( + "invalid or expired Kagi session token for {surface}: HTTP {status}{}", + format_client_error_suffix(&body) + ))) + } + status if status.is_client_error() => { + let body = http::read_error_body(response, surface).await; + Err(KagiError::Config(format!( + "Kagi {surface} request rejected: HTTP {status}{}", + format_client_error_suffix(&body) + ))) + } + status if status.is_server_error() => Err(KagiError::Network(format!( + "Kagi {surface} server error: HTTP {status}{}", + { + let body = http::read_error_body(response, surface).await; + if body.trim().is_empty() { + String::new() + } else { + format_client_error_suffix(&body) + } + } + ))), + status => Err(KagiError::Network(format!( + "unexpected Kagi {surface} response status: HTTP {status}" + ))), + } +} + fn parse_assistant_prompt_stream(body: &str) -> Result { - let mut meta = AssistantMeta::default(); - let mut thread = None; - let mut message = None; + let mut parser = AssistantPromptStreamParser::default(); for frame in body.split("\0\n").filter(|frame| !frame.trim().is_empty()) { + parser.process_frame(frame)?; + } + + parser.finish() +} + +#[derive(Default)] +struct AssistantPromptStreamParser { + meta: AssistantMeta, + thread: Option, + message: Option, + previous_markdown: String, +} + +impl AssistantPromptStreamParser { + fn process_frame( + &mut self, + frame: &str, + ) -> Result, KagiError> { let Some((tag, payload)) = frame.split_once(':') else { - continue; + return Ok(None); }; match tag { @@ -3584,15 +3738,17 @@ fn parse_assistant_prompt_stream(body: &str) -> Result { let payload: AssistantThreadPayload = serde_json::from_str(payload).map_err(|error| { KagiError::Parse(format!("failed to parse assistant thread frame: {error}")) })?; - thread = Some(AssistantThread::from(payload)); + self.thread = Some(AssistantThread::from(payload)); + Ok(None) } "new_message.json" => { let payload: AssistantMessagePayload = @@ -3601,50 +3757,67 @@ fn parse_assistant_prompt_stream(body: &str) -> Result { let detail = strip_html_to_text(payload); - return Err(KagiError::Config(if detail.is_empty() { + Err(KagiError::Config(if detail.is_empty() { "Kagi Assistant rate limited this request".to_string() } else { detail - })); - } - "unauthorized" => { - return Err(KagiError::Auth( - "invalid or expired Kagi session token".to_string(), - )); + })) } + "unauthorized" => Err(KagiError::Auth( + "invalid or expired Kagi session token".to_string(), + )), _ => { debug!(tag, "ignoring unknown assistant prompt stream frame"); + Ok(None) } } } - let thread = thread.ok_or_else(|| { - KagiError::Parse("assistant response did not include a thread.json frame".to_string()) - })?; - let message = message.ok_or_else(|| { - KagiError::Parse("assistant response did not include a new_message.json frame".to_string()) - })?; + fn finish(self) -> Result { + let thread = self.thread.ok_or_else(|| { + KagiError::Parse("assistant response did not include a thread.json frame".to_string()) + })?; + let message = self.message.ok_or_else(|| { + KagiError::Parse( + "assistant response did not include a new_message.json frame".to_string(), + ) + })?; - if message.state == "error" { - return Err(KagiError::Network( - message - .markdown - .as_deref() - .or(message.reply_html.as_deref()) - .unwrap_or("Kagi Assistant returned an error state") - .to_string(), - )); - } + if message.state == "error" { + return Err(KagiError::Network( + message + .markdown + .as_deref() + .or(message.reply_html.as_deref()) + .unwrap_or("Kagi Assistant returned an error state") + .to_string(), + )); + } - Ok(AssistantPromptResponse { - meta, - thread, - message, - }) + Ok(AssistantPromptResponse { + meta: self.meta, + thread, + message, + }) + } } fn parse_assistant_thread_open_stream( @@ -4697,8 +4870,8 @@ pub struct KagiEnvelope { #[cfg(test)] mod tests { use super::{ - ApiErrorBody, AssistantPromptPayload, KagiEnvelope, NewsFilterRequest, - TRANSLATE_BOOTSTRAP_MISSING_COOKIE_ERROR, TranslateSuggestionContext, + ApiErrorBody, AssistantPromptPayload, AssistantPromptStreamParser, KagiEnvelope, + NewsFilterRequest, TRANSLATE_BOOTSTRAP_MISSING_COOKIE_ERROR, TranslateSuggestionContext, apply_news_content_filters, build_ask_page_prompt, build_assistant_prompt_payload, build_translate_option_state, build_translate_payload, build_translate_suggestions_payload, build_translate_word_insights_payload, capture_optional_translate_section, @@ -5190,6 +5363,34 @@ mod tests { assert_eq!(parsed.message.trace_id.as_deref(), Some("trace-message-1")); } + #[test] + fn assistant_prompt_stream_events_include_markdown_delta() { + let mut parser = AssistantPromptStreamParser::default(); + parser + .process_frame("hi:{\"v\":\"test\",\"trace\":\"trace-stream\"}") + .expect("hello should parse"); + parser + .process_frame("thread.json:{\"id\":\"thread-1\",\"title\":\"Greeting\",\"ack\":\"2026-03-16T06:19:07Z\",\"created_at\":\"2026-03-16T06:19:07Z\",\"saved\":false,\"shared\":false,\"branch_id\":\"00000000-0000-4000-0000-000000000000\",\"tag_ids\":[]}") + .expect("thread should parse"); + + let first = parser + .process_frame("new_message.json:{\"id\":\"msg-1\",\"thread_id\":\"thread-1\",\"created_at\":\"2026-03-16T06:19:07Z\",\"state\":\"streaming\",\"prompt\":\"Hello\",\"md\":\"Hel\",\"documents\":[]}") + .expect("first message should parse") + .expect("first message should emit"); + let second = parser + .process_frame("new_message.json:{\"id\":\"msg-1\",\"thread_id\":\"thread-1\",\"created_at\":\"2026-03-16T06:19:07Z\",\"state\":\"done\",\"prompt\":\"Hello\",\"md\":\"Hello\",\"documents\":[]}") + .expect("second message should parse") + .expect("second message should emit"); + + assert_eq!(first.md_delta, "Hel"); + assert_eq!(second.md_delta, "lo"); + assert_eq!(second.meta.trace.as_deref(), Some("trace-stream")); + assert_eq!( + second.thread.as_ref().map(|thread| thread.id.as_str()), + Some("thread-1") + ); + } + #[test] fn parses_assistant_prompt_stream_without_expires_at() { let raw = concat!( diff --git a/src/cli.rs b/src/cli.rs index 2e81f45..f37c2d8 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -749,6 +749,10 @@ pub struct AssistantArgs { #[arg(long, value_name = "FORMAT", default_value_t = AssistantOutputFormat::Json)] pub format: AssistantOutputFormat, + /// Emit prompt updates as newline-delimited JSON + #[arg(long, conflicts_with = "export")] + pub stream: bool, + /// Disable colored terminal output (only affects pretty format) #[arg(long)] pub no_color: bool, @@ -757,6 +761,10 @@ pub struct AssistantArgs { #[arg(long, value_name = "MODEL")] pub model: Option, + /// Create a temporary custom assistant for --model and delete it after this prompt + #[arg(long, requires = "model")] + pub once: bool, + /// Override the Assistant lens id for this prompt #[arg(long, value_name = "LENS_ID")] pub lens: Option, @@ -787,6 +795,8 @@ pub struct AssistantArgs { pub enum AssistantSubcommand { /// Manage Assistant threads Thread(AssistantThreadArgs), + /// List Assistant base-model slugs available to custom assistants + Models, /// Manage custom assistants Custom(AssistantCustomArgs), /// Start an interactive Assistant REPL with automatic thread continuity diff --git a/src/main.rs b/src/main.rs index af16a08..4616806 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,8 @@ use clap::{CommandFactory, Parser}; use clap_complete::{generate, shells}; use crate::api::{ - NewsFilterRequest, execute_ask_page, execute_assistant_prompt, execute_assistant_thread_delete, + NewsFilterRequest, execute_ask_page, execute_assistant_model_catalog, execute_assistant_prompt, + execute_assistant_prompt_stream, execute_assistant_thread_delete, execute_assistant_thread_export, execute_assistant_thread_get, execute_assistant_thread_list, execute_custom_assistant_create, execute_custom_assistant_delete, execute_custom_assistant_get, execute_custom_assistant_list, execute_custom_assistant_update, execute_custom_bang_create, @@ -60,7 +61,7 @@ use std::fs; use std::future::Future; use std::io::{self, BufRead, Read, Write}; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use tokio::sync::Semaphore; use tracing::error; use tracing_subscriber::EnvFilter; @@ -321,6 +322,10 @@ async fn run() -> Result<(), KagiError> { } }, }, + AssistantSubcommand::Models => { + let response = execute_assistant_model_catalog(&token).await?; + print_json(&response) + } AssistantSubcommand::Repl(repl_args) => { run_assistant_repl(repl_args, &token).await } @@ -411,8 +416,24 @@ async fn run() -> Result<(), KagiError> { _ => None, }, }; - let response = execute_assistant_prompt(&request, &token).await?; - print_assistant_response(&response, args.format, !args.no_color) + if args.once { + let response = + execute_once_assistant_prompt(&request, args.stream, &token).await?; + if args.stream { + Ok(()) + } else { + print_assistant_response(&response, args.format, !args.no_color) + } + } else if args.stream { + execute_assistant_prompt_stream(&request, &token, |event| { + print_compact_json(event) + }) + .await?; + Ok(()) + } else { + let response = execute_assistant_prompt(&request, &token).await?; + print_assistant_response(&response, args.format, !args.no_color) + } } } Commands::AskPage(args) => { @@ -1045,6 +1066,68 @@ fn print_json(value: &T) -> Result<(), KagiError> { Ok(()) } +async fn execute_once_assistant_prompt( + request: &AssistantPromptRequest, + stream: bool, + token: &str, +) -> Result { + let model = request + .model + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .ok_or_else(|| KagiError::Config("--once requires --model".to_string()))?; + if request + .profile_id + .as_deref() + .is_some_and(|value| !value.trim().is_empty()) + { + return Err(KagiError::Config( + "--once cannot be combined with --assistant".to_string(), + )); + } + + let created = execute_custom_assistant_create( + &AssistantProfileCreateRequest { + name: temporary_assistant_name(), + bang_trigger: None, + internet_access: request.internet_access, + selected_lens: request.lens_id.map(|lens_id| lens_id.to_string()), + personalizations: request.personalizations, + base_model: Some(model.to_string()), + custom_instructions: None, + }, + token, + ) + .await?; + + let delete_target = created.profile_id.clone().unwrap_or(created.name.clone()); + let mut prompt_request = request.clone(); + prompt_request.profile_id = Some(delete_target.clone()); + prompt_request.model = None; + + let prompt_result = if stream { + execute_assistant_prompt_stream(&prompt_request, token, print_compact_json).await + } else { + execute_assistant_prompt(&prompt_request, token).await + }; + + let delete_result = execute_custom_assistant_delete(&delete_target, token).await; + match (prompt_result, delete_result) { + (Ok(response), Ok(_)) => Ok(response), + (Err(error), _) => Err(error), + (Ok(_), Err(error)) => Err(error), + } +} + +fn temporary_assistant_name() -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_millis()) + .unwrap_or_default(); + format!("kagi-cli-once-{millis}-{}", std::process::id()) +} + fn print_compact_json(value: &T) -> Result<(), KagiError> { let output = serde_json::to_string(value) .map_err(|error| KagiError::Parse(format!("failed to serialize JSON output: {error}")))?; diff --git a/src/parser.rs b/src/parser.rs index 3849eb8..6ded9f7 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -10,9 +10,9 @@ use scraper::{Html, Selector}; use crate::error::KagiError; use crate::types::{ - AssistantProfileDetails, AssistantProfileSummary, AssistantThreadSummary, CustomBangDetails, - CustomBangSummary, LensDetails, LensSummary, NewsSearchCluster, NewsSearchResult, - RedirectRuleDetails, RedirectRuleSummary, SearchResult, + AssistantModelCatalog, AssistantModelOption, AssistantProfileDetails, AssistantProfileSummary, + AssistantThreadSummary, CustomBangDetails, CustomBangSummary, LensDetails, LensSummary, + NewsSearchCluster, NewsSearchResult, RedirectRuleDetails, RedirectRuleSummary, SearchResult, }; /// Parse Kagi search results from HTML. @@ -380,6 +380,46 @@ pub fn parse_assistant_profile_form(html: &str) -> Result Result { + let document = Html::parse_document(html); + let selector = selector(r#"input[type="radio"][name="base_model"]"#)?; + let models = document + .select(&selector) + .filter_map(|node| { + let id = node.value().attr("value")?.trim(); + if id.is_empty() { + return None; + } + + let label = node + .value() + .attr("aria-label") + .or_else(|| node.value().attr("title")) + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or(id); + + Some(AssistantModelOption { + id: id.to_string(), + label: label.to_string(), + selected: node.value().attr("checked").is_some(), + }) + }) + .collect(); + + Ok(AssistantModelCatalog { models }) +} + /// Parses a list of Kagi lenses from the settings HTML. /// /// # Arguments @@ -769,9 +809,10 @@ fn parse_query_value(href: &str, key: &str) -> Option { #[cfg(test)] mod tests { use super::{ - parse_assistant_profile_form, parse_assistant_profile_list, parse_assistant_thread_list, - parse_custom_bang_form, parse_custom_bang_list, parse_lens_form, parse_lens_list, - parse_news_search_results, parse_redirect_form, parse_redirect_list, parse_search_results, + parse_assistant_model_catalog, parse_assistant_profile_form, parse_assistant_profile_list, + parse_assistant_thread_list, parse_custom_bang_form, parse_custom_bang_list, + parse_lens_form, parse_lens_list, parse_news_search_results, parse_redirect_form, + parse_redirect_list, parse_search_results, }; use crate::error::KagiError; @@ -973,6 +1014,27 @@ mod tests { assert_eq!(details.selected_lens, "0"); } + #[test] + fn parses_assistant_model_catalog_from_base_model_radios() { + let html = r#" +
+ + + +
+ "#; + + let catalog = parse_assistant_model_catalog(html).expect("catalog should parse"); + + assert_eq!(catalog.models.len(), 2); + assert_eq!(catalog.models[0].id, "gpt-5-5"); + assert_eq!(catalog.models[0].label, "GPT 5.5"); + assert!(catalog.models[0].selected); + assert_eq!(catalog.models[1].id, "claude-4-7-opus"); + assert_eq!(catalog.models[1].label, "Claude Opus"); + assert!(!catalog.models[1].selected); + } + #[test] fn parses_lens_list_items() { let html = r#" diff --git a/src/types.rs b/src/types.rs index 2915a86..44fdca3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -456,6 +456,30 @@ pub struct AssistantPromptResponse { pub message: AssistantMessage, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +/// One incremental Assistant message update emitted by `assistant --stream`. +pub struct AssistantPromptStreamEvent { + pub meta: AssistantMeta, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub thread: Option, + pub message: AssistantMessage, + pub md_delta: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +/// One model option from the Assistant custom-profile form. +pub struct AssistantModelOption { + pub id: String, + pub label: String, + pub selected: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +/// Stable JSON shape for Assistant model catalog output. +pub struct AssistantModelCatalog { + pub models: Vec, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] /// Source information for an "ask about a page" query. pub struct AskPageSource { diff --git a/tests/integration-cli.rs b/tests/integration-cli.rs index 06d6ac4..576af19 100644 --- a/tests/integration-cli.rs +++ b/tests/integration-cli.rs @@ -110,6 +110,49 @@ fn api_meta() -> Value { }) } +fn assistant_form_html(profile_id: &str, name: &str) -> String { + format!( + r#" +
+ + + + + + + + + + + +
+
+ "# + ) +} + +fn assistant_list_html() -> &'static str { + r#" +
+
    +
  • +
    + Once +
    +
    +
    Model:
    GPT 5 Mini
    +
    +
    Internet Access:
    On
    +
    +
    + Edit +
    +
  • +
+
+ "# +} + fn search_payload(title: &str, url: &str, snippet: &str) -> Value { json!({ "meta": api_meta(), @@ -1088,6 +1131,133 @@ fn assistant_thread_list_paginates_with_cursor_id() { assert_eq!(body["pagination"]["total_counts"]["all"], 2); } +#[test] +fn assistant_models_prints_json_catalog() { + let server = MockServer::start(); + let _form = server.mock(|when, then| { + when.method(GET) + .path("/settings/custom_assistant") + .header("cookie", "kagi_session=test-session"); + then.status(200) + .body(assistant_form_html("profile-once", "Once")); + }); + + let tempdir = TempDir::new().expect("tempdir"); + let env = session_env(&server); + let output = run_kagi(&["assistant", "models"], &env_refs(&env), tempdir.path()); + + assert_success(&output); + let body: Value = serde_json::from_slice(&output.stdout).expect("json output should parse"); + assert_eq!(body["models"][0]["id"], "gpt-5-mini"); + assert_eq!(body["models"][0]["label"], "GPT 5 Mini"); + assert_eq!(body["models"][0]["selected"], true); + assert_eq!(body["models"][1]["id"], "claude-4-7-opus"); +} + +#[test] +fn assistant_stream_prints_ndjson_updates() { + let server = MockServer::start(); + let _prompt = server.mock(|when, then| { + when.method(POST) + .path("/assistant/prompt") + .header("cookie", "kagi_session=test-session") + .header("accept", "application/vnd.kagi.stream") + .header("content-type", "application/json"); + then.status(200) + .header("content-type", "application/vnd.kagi.stream") + .body(concat!( + "hi:{\"v\":\"test\",\"trace\":\"trace-stream\"}\0\n", + "thread.json:{\"id\":\"thread-1\",\"title\":\"Greeting\",\"ack\":\"2026-03-16T06:19:07Z\",\"created_at\":\"2026-03-16T06:19:07Z\",\"saved\":false,\"shared\":false,\"branch_id\":\"00000000-0000-4000-0000-000000000000\",\"tag_ids\":[]}\0\n", + "new_message.json:{\"id\":\"msg-1\",\"thread_id\":\"thread-1\",\"created_at\":\"2026-03-16T06:19:07Z\",\"state\":\"streaming\",\"prompt\":\"Hello\",\"md\":\"Hel\",\"documents\":[]}\0\n", + "new_message.json:{\"id\":\"msg-1\",\"thread_id\":\"thread-1\",\"created_at\":\"2026-03-16T06:19:07Z\",\"state\":\"done\",\"prompt\":\"Hello\",\"md\":\"Hello\",\"documents\":[]}\0\n" + )); + }); + + let tempdir = TempDir::new().expect("tempdir"); + let env = session_env(&server); + let output = run_kagi( + &["assistant", "--stream", "Hello"], + &env_refs(&env), + tempdir.path(), + ); + + assert_success(&output); + let lines = String::from_utf8_lossy(&output.stdout) + .lines() + .map(|line| serde_json::from_str::(line).expect("line should parse as json")) + .collect::>(); + assert_eq!(lines.len(), 2); + assert_eq!(lines[0]["md_delta"], "Hel"); + assert_eq!(lines[1]["md_delta"], "lo"); + assert_eq!(lines[1]["message"]["state"], "done"); +} + +#[test] +fn assistant_once_creates_prompts_and_deletes_temporary_profile() { + let server = MockServer::start(); + let _new_form = server.mock(|when, then| { + when.method(GET) + .path("/settings/custom_assistant") + .header("cookie", "kagi_session=test-session"); + then.status(200) + .body(assistant_form_html("profile-once", "Once")); + }); + let _create = server.mock(|when, then| { + when.method(POST) + .path("/settings/ast/profiles/update") + .header("cookie", "kagi_session=test-session") + .body_includes("base_model=gpt-5-mini"); + then.status(303) + .header("location", "/settings/custom_assistant?id=profile-once"); + }); + let _list = server.mock(|when, then| { + when.method(GET) + .path("/html/settings/assistant") + .header("cookie", "kagi_session=test-session"); + then.status(200).body(assistant_list_html()); + }); + let _edit_form = server.mock(|when, then| { + when.method(GET) + .path("/settings/custom_assistant") + .query_param("id", "profile-once") + .header("cookie", "kagi_session=test-session"); + then.status(200) + .body(assistant_form_html("profile-once", "Once")); + }); + let _prompt = server.mock(|when, then| { + when.method(POST) + .path("/assistant/prompt") + .header("cookie", "kagi_session=test-session") + .header("accept", "application/vnd.kagi.stream"); + then.status(200) + .header("content-type", "application/vnd.kagi.stream") + .body(concat!( + "hi:{\"v\":\"test\",\"trace\":\"trace-once\"}\0\n", + "thread.json:{\"id\":\"thread-once\",\"title\":\"Once\",\"ack\":\"2026-03-16T06:19:07Z\",\"created_at\":\"2026-03-16T06:19:07Z\",\"saved\":false,\"shared\":false,\"branch_id\":\"00000000-0000-4000-0000-000000000000\",\"tag_ids\":[]}\0\n", + "new_message.json:{\"id\":\"msg-once\",\"thread_id\":\"thread-once\",\"created_at\":\"2026-03-16T06:19:07Z\",\"state\":\"done\",\"prompt\":\"Hi\",\"md\":\"ok\",\"documents\":[]}\0\n" + )); + }); + let _delete = server.mock(|when, then| { + when.method(POST) + .path("/settings/ast/profiles/delete") + .header("cookie", "kagi_session=test-session") + .body_includes("profile_id=profile-once"); + then.status(200).body(""); + }); + + let tempdir = TempDir::new().expect("tempdir"); + let env = session_env(&server); + let output = run_kagi( + &["assistant", "--once", "--model", "gpt-5-mini", "Hi"], + &env_refs(&env), + tempdir.path(), + ); + + assert_success(&output); + let body: Value = serde_json::from_slice(&output.stdout).expect("json output should parse"); + assert_eq!(body["message"]["markdown"], "ok"); +} + #[test] fn batch_command_reads_queries_from_stdin() { let server = MockServer::start();