Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 102 additions & 58 deletions Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ public struct OllamaLanguageModel: LanguageModel {
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
// For now, only String is supported
guard type == String.self else {
fatalError("OllamaLanguageModel only supports generating String content")
}

let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
let (ollamaText, ollamaImages) = convertSegmentsToOllama(userSegments)
let messages = [
Expand All @@ -90,14 +85,22 @@ public struct OllamaLanguageModel: LanguageModel {
let ollamaTools = try session.tools.map { tool in
try convertToolToOllamaFormat(tool)
}
let ollamaFormat: JSONValue?
if type == String.self {
ollamaFormat = nil
} else {
let schema = try convertSchemaToOllamaFormat(type.generationSchema)
ollamaFormat = try JSONValue(schema)
}

let params = try createChatParams(
model: model,
messages: messages,
tools: ollamaTools.isEmpty ? nil : ollamaTools,
options: ollamaOptions,
stream: false,
images: ollamaImages.isEmpty ? nil : ollamaImages
images: ollamaImages.isEmpty ? nil : ollamaImages,
format: ollamaFormat
)

let url = baseURL.appendingPathComponent("api/chat")
Expand Down Expand Up @@ -134,9 +137,19 @@ public struct OllamaLanguageModel: LanguageModel {
}

let text = chatResponse.message.content ?? ""
if type == String.self {
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
transcriptEntries: ArraySlice(entries)
)
}

let generatedContent = try GeneratedContent(json: text)
let content = try type.init(generatedContent)
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
content: content,
rawContent: generatedContent,
transcriptEntries: ArraySlice(entries)
)
}
Expand All @@ -148,72 +161,92 @@ public struct OllamaLanguageModel: LanguageModel {
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
// For now, only String is supported
guard type == String.self else {
fatalError("OllamaLanguageModel only supports generating String content")
}

let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
let (ollamaText, ollamaImages) = convertSegmentsToOllama(userSegments)
let messages = [
OllamaMessage(role: .user, content: ollamaText)
]
let ollamaOptions = convertOptions(options)
let ollamaTools = try? session.tools.map { tool in
try convertToolToOllamaFormat(tool)
}

let params = try? createChatParams(
model: model,
messages: messages,
tools: (ollamaTools?.isEmpty == false) ? ollamaTools : nil,
options: ollamaOptions,
stream: true,
images: (ollamaImages.isEmpty ? nil : ollamaImages)
)

let url = baseURL.appendingPathComponent("api/chat")
let body = (try? JSONEncoder().encode(params)) ?? Data()

// Transform the newline-delimited JSON stream from Ollama into ResponseStream snapshots
let stream: AsyncThrowingStream<LanguageModelSession.ResponseStream<Content>.Snapshot, any Error> =
AsyncThrowingStream { continuation in
let task = Task {
do {
do {
let ollamaTools = try session.tools.map { tool in
try convertToolToOllamaFormat(tool)
}
let ollamaFormat: JSONValue?
if type == String.self {
ollamaFormat = nil
} else {
let schema = try convertSchemaToOllamaFormat(type.generationSchema)
ollamaFormat = try JSONValue(schema)
}

let params = try createChatParams(
model: model,
messages: messages,
tools: ollamaTools.isEmpty ? nil : ollamaTools,
options: ollamaOptions,
stream: true,
images: (ollamaImages.isEmpty ? nil : ollamaImages),
format: ollamaFormat
)
let body = try JSONEncoder().encode(params)

let task = Task {
// Reuse ChatResponse as each streamed line shares the same shape
let chunks =
urlSession.fetchStream(
.post,
url: url,
body: body,
dateDecodingStrategy: .iso8601WithFractionalSeconds
) as AsyncThrowingStream<ChatResponse, any Error>

var partialText = ""

for try await chunk in chunks {
if let piece = chunk.message.content {
partialText += piece
let snapshot = LanguageModelSession.ResponseStream<Content>.Snapshot(
content: (partialText as! Content).asPartiallyGenerated(),
rawContent: GeneratedContent(partialText)
)
continuation.yield(snapshot)
do {
let chunks =
urlSession.fetchStream(
.post,
url: url,
body: body,
dateDecodingStrategy: .iso8601WithFractionalSeconds
) as AsyncThrowingStream<ChatResponse, any Error>

var partialText = ""

for try await chunk in chunks {
if let piece = chunk.message.content {
partialText += piece
if type == String.self {
let snapshot = LanguageModelSession.ResponseStream<Content>.Snapshot(
content: (partialText as! Content).asPartiallyGenerated(),
rawContent: GeneratedContent(partialText)
)
continuation.yield(snapshot)
} else if let raw = try? GeneratedContent(json: partialText),
let parsed = try? type.init(raw)
{
let snapshot = LanguageModelSession.ResponseStream<Content>.Snapshot(
content: parsed.asPartiallyGenerated(),
rawContent: raw
)
continuation.yield(snapshot)
} else {
// Structured responses can stream as incomplete JSON fragments.
// Skip snapshots until the accumulated JSON parses cleanly.
}
}

if chunk.done {
break
}
}

if chunk.done {
break
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}

continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}

continuation.onTermination = { _ in
task.cancel()
continuation.onTermination = { _ in
task.cancel()
}
} catch {
continuation.finish(throwing: error)
}
}

Expand Down Expand Up @@ -402,6 +435,12 @@ private func convertToolToOllamaFormat(_ tool: any Tool) throws -> [String: JSON
]
}

private func convertSchemaToOllamaFormat(_ schema: GenerationSchema) throws -> JSONSchema {
let resolvedSchema = schema.withResolvedRoot() ?? schema
let data = try JSONEncoder().encode(resolvedSchema)
return try JSONDecoder().decode(JSONSchema.self, from: data)
}

private func toGeneratedContent(_ value: JSONValue?) throws -> GeneratedContent {
guard let value else { return GeneratedContent(properties: [:]) }
let data = try JSONEncoder().encode(value)
Expand All @@ -415,7 +454,8 @@ private func createChatParams(
tools: [[String: JSONValue]]?,
options: [String: JSONValue]?,
stream: Bool,
images: [String]?
images: [String]?,
format: JSONValue?
) throws -> [String: JSONValue] {
var params: [String: JSONValue] = [
"model": .string(model),
Expand All @@ -435,6 +475,10 @@ private func createChatParams(
params["images"] = .array(images.map { .string($0) })
}

if let format {
params["format"] = format
}

return params
}

Expand Down
36 changes: 36 additions & 0 deletions Tests/AnyLanguageModelTests/OllamaLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ import Testing

@testable import AnyLanguageModel

@Generable
private struct OllamaStructuredForecast {
var summary: String
var temperatureCelsius: Int
}

@Suite(
"OllamaLanguageModel",
.serialized,
Expand Down Expand Up @@ -61,6 +67,36 @@ struct OllamaLanguageModelTests {
#expect(!snapshots.last!.rawContent.jsonString.isEmpty)
}

@Test func structuredResponse() async throws {
let session = LanguageModelSession(model: model)

let response = try await session.respond(
to: "Summarize the weather with a short summary and a celsius temperature.",
generating: OllamaStructuredForecast.self
)

#expect(!response.content.summary.isEmpty)
#expect(response.rawContent.jsonString.contains("summary"))
}

@Test func streamingStructured() async throws {
let session = LanguageModelSession(model: model)

let stream = session.streamResponse(
to: "Provide a short weather forecast summary and a celsius temperature.",
generating: OllamaStructuredForecast.self
)

var snapshots: [LanguageModelSession.ResponseStream<OllamaStructuredForecast>.Snapshot] = []
for try await snapshot in stream {
snapshots.append(snapshot)
}

#expect(!snapshots.isEmpty)
#expect(!snapshots.last!.rawContent.jsonString.isEmpty)
#expect(!(snapshots.last!.content.summary ?? "").isEmpty)
}

@Test func withGenerationOptions() async throws {
let session = LanguageModelSession(model: model)

Expand Down