diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 4f49d13..6be5d02 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -76,11 +76,6 @@ public struct OllamaLanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("OllamaLanguageModel only supports generating String content") - } - let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description) let (ollamaText, ollamaImages) = convertSegmentsToOllama(userSegments) let messages = [ @@ -90,6 +85,13 @@ public struct OllamaLanguageModel: LanguageModel { let ollamaTools = try session.tools.map { tool in try convertToolToOllamaFormat(tool) } + let ollamaFormat: JSONValue? + if type == String.self { + ollamaFormat = nil + } else { + let schema = try convertSchemaToOllamaFormat(type.generationSchema) + ollamaFormat = try JSONValue(schema) + } let params = try createChatParams( model: model, @@ -97,7 +99,8 @@ public struct OllamaLanguageModel: LanguageModel { tools: ollamaTools.isEmpty ? nil : ollamaTools, options: ollamaOptions, stream: false, - images: ollamaImages.isEmpty ? nil : ollamaImages + images: ollamaImages.isEmpty ? nil : ollamaImages, + format: ollamaFormat ) let url = baseURL.appendingPathComponent("api/chat") @@ -134,9 +137,19 @@ public struct OllamaLanguageModel: LanguageModel { } let text = chatResponse.message.content ?? "" + if type == String.self { + return LanguageModelSession.Response( + content: text as! Content, + rawContent: GeneratedContent(text), + transcriptEntries: ArraySlice(entries) + ) + } + + let generatedContent = try GeneratedContent(json: text) + let content = try type.init(generatedContent) return LanguageModelSession.Response( - content: text as! Content, - rawContent: GeneratedContent(text), + content: content, + rawContent: generatedContent, transcriptEntries: ArraySlice(entries) ) } @@ -148,72 +161,92 @@ public struct OllamaLanguageModel: LanguageModel { includeSchemaInPrompt: Bool, options: GenerationOptions ) -> sending LanguageModelSession.ResponseStream where Content: Generable { - // For now, only String is supported - guard type == String.self else { - fatalError("OllamaLanguageModel only supports generating String content") - } - let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description) let (ollamaText, ollamaImages) = convertSegmentsToOllama(userSegments) let messages = [ OllamaMessage(role: .user, content: ollamaText) ] let ollamaOptions = convertOptions(options) - let ollamaTools = try? session.tools.map { tool in - try convertToolToOllamaFormat(tool) - } - - let params = try? createChatParams( - model: model, - messages: messages, - tools: (ollamaTools?.isEmpty == false) ? ollamaTools : nil, - options: ollamaOptions, - stream: true, - images: (ollamaImages.isEmpty ? nil : ollamaImages) - ) - let url = baseURL.appendingPathComponent("api/chat") - let body = (try? JSONEncoder().encode(params)) ?? Data() // Transform the newline-delimited JSON stream from Ollama into ResponseStream snapshots let stream: AsyncThrowingStream.Snapshot, any Error> = AsyncThrowingStream { continuation in - let task = Task { - do { + do { + let ollamaTools = try session.tools.map { tool in + try convertToolToOllamaFormat(tool) + } + let ollamaFormat: JSONValue? + if type == String.self { + ollamaFormat = nil + } else { + let schema = try convertSchemaToOllamaFormat(type.generationSchema) + ollamaFormat = try JSONValue(schema) + } + + let params = try createChatParams( + model: model, + messages: messages, + tools: ollamaTools.isEmpty ? nil : ollamaTools, + options: ollamaOptions, + stream: true, + images: (ollamaImages.isEmpty ? nil : ollamaImages), + format: ollamaFormat + ) + let body = try JSONEncoder().encode(params) + + let task = Task { // Reuse ChatResponse as each streamed line shares the same shape - let chunks = - urlSession.fetchStream( - .post, - url: url, - body: body, - dateDecodingStrategy: .iso8601WithFractionalSeconds - ) as AsyncThrowingStream - - var partialText = "" - - for try await chunk in chunks { - if let piece = chunk.message.content { - partialText += piece - let snapshot = LanguageModelSession.ResponseStream.Snapshot( - content: (partialText as! Content).asPartiallyGenerated(), - rawContent: GeneratedContent(partialText) - ) - continuation.yield(snapshot) + do { + let chunks = + urlSession.fetchStream( + .post, + url: url, + body: body, + dateDecodingStrategy: .iso8601WithFractionalSeconds + ) as AsyncThrowingStream + + var partialText = "" + + for try await chunk in chunks { + if let piece = chunk.message.content { + partialText += piece + if type == String.self { + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: (partialText as! Content).asPartiallyGenerated(), + rawContent: GeneratedContent(partialText) + ) + continuation.yield(snapshot) + } else if let raw = try? GeneratedContent(json: partialText), + let parsed = try? type.init(raw) + { + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: parsed.asPartiallyGenerated(), + rawContent: raw + ) + continuation.yield(snapshot) + } else { + // Structured responses can stream as incomplete JSON fragments. + // Skip snapshots until the accumulated JSON parses cleanly. + } + } + + if chunk.done { + break + } } - if chunk.done { - break - } + continuation.finish() + } catch { + continuation.finish(throwing: error) } - - continuation.finish() - } catch { - continuation.finish(throwing: error) } - } - continuation.onTermination = { _ in - task.cancel() + continuation.onTermination = { _ in + task.cancel() + } + } catch { + continuation.finish(throwing: error) } } @@ -402,6 +435,12 @@ private func convertToolToOllamaFormat(_ tool: any Tool) throws -> [String: JSON ] } +private func convertSchemaToOllamaFormat(_ schema: GenerationSchema) throws -> JSONSchema { + let resolvedSchema = schema.withResolvedRoot() ?? schema + let data = try JSONEncoder().encode(resolvedSchema) + return try JSONDecoder().decode(JSONSchema.self, from: data) +} + private func toGeneratedContent(_ value: JSONValue?) throws -> GeneratedContent { guard let value else { return GeneratedContent(properties: [:]) } let data = try JSONEncoder().encode(value) @@ -415,7 +454,8 @@ private func createChatParams( tools: [[String: JSONValue]]?, options: [String: JSONValue]?, stream: Bool, - images: [String]? + images: [String]?, + format: JSONValue? ) throws -> [String: JSONValue] { var params: [String: JSONValue] = [ "model": .string(model), @@ -435,6 +475,10 @@ private func createChatParams( params["images"] = .array(images.map { .string($0) }) } + if let format { + params["format"] = format + } + return params } diff --git a/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift b/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift index 234d17e..5a8911f 100644 --- a/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift @@ -3,6 +3,12 @@ import Testing @testable import AnyLanguageModel +@Generable +private struct OllamaStructuredForecast { + var summary: String + var temperatureCelsius: Int +} + @Suite( "OllamaLanguageModel", .serialized, @@ -61,6 +67,36 @@ struct OllamaLanguageModelTests { #expect(!snapshots.last!.rawContent.jsonString.isEmpty) } + @Test func structuredResponse() async throws { + let session = LanguageModelSession(model: model) + + let response = try await session.respond( + to: "Summarize the weather with a short summary and a celsius temperature.", + generating: OllamaStructuredForecast.self + ) + + #expect(!response.content.summary.isEmpty) + #expect(response.rawContent.jsonString.contains("summary")) + } + + @Test func streamingStructured() async throws { + let session = LanguageModelSession(model: model) + + let stream = session.streamResponse( + to: "Provide a short weather forecast summary and a celsius temperature.", + generating: OllamaStructuredForecast.self + ) + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + #expect(!snapshots.last!.rawContent.jsonString.isEmpty) + #expect(!(snapshots.last!.content.summary ?? "").isEmpty) + } + @Test func withGenerationOptions() async throws { let session = LanguageModelSession(model: model)