Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/ollama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic"]
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic", "tenacity>=8.2.3"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for >=8.2.3 for tenacity?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the initial commit, I did not include this, so the Python 3.13 check built with an older version of tenacity that did not include either one this retry, retry_if_exception, or wait_exponential, which caused the verification checks to fail. Because of this, I had to explicitly specify them.


[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,44 @@
)
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from pydantic.json_schema import JsonSchemaValue
from tenacity import RetryCallState, retry, retry_if_exception, wait_exponential

from ollama import AsyncClient, ChatResponse, Client
from ollama import AsyncClient, ChatResponse, Client, ResponseError

FINISH_REASON_MAPPING: dict[str, FinishReason] = {
"stop": "stop",
"tool_calls": "tool_calls",
# we skip load and unload reasons
}

HTTP_STATUS_TOO_MANY_REQUESTS = 429
HTTP_STATUS_SERVER_ERROR_MIN = 500
HTTP_STATUS_SERVER_ERROR_MAX_EXCLUSIVE = 600


def _stop_after_instance_max_retries(retry_state: RetryCallState) -> bool:
"""
Stop retries after `self.max_retries + 1` attempts.
"""
instance = retry_state.args[0]
return retry_state.attempt_number >= instance.max_retries + 1


def _is_retryable_exception(exc: BaseException) -> bool:
"""
Return True for transient failures that should be retried.

Retries are attempted for:
- HTTP 429 responses
- HTTP 5xx responses
- transport-level connection/timeout errors
"""
if isinstance(exc, ResponseError):
return exc.status_code == HTTP_STATUS_TOO_MANY_REQUESTS or (
HTTP_STATUS_SERVER_ERROR_MIN <= exc.status_code < HTTP_STATUS_SERVER_ERROR_MAX_EXCLUSIVE
)
return isinstance(exc, (ConnectionError, TimeoutError))


def _convert_chatmessage_to_ollama_format(message: ChatMessage) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -216,6 +245,7 @@ def __init__(
url: str = "http://localhost:11434",
generation_kwargs: dict[str, Any] | None = None,
timeout: int = 120,
max_retries: int = 0,
keep_alive: float | str | None = None,
streaming_callback: Callable[[StreamingChunk], None] | None = None,
tools: ToolsType | None = None,
Expand All @@ -233,6 +263,9 @@ def __init__(
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
:param max_retries:
Maximum number of retries to attempt for failed requests (HTTP 429, 5xx, connection/timeout errors).
Uses exponential backoff between attempts. Set to 0 (default) to disable retries.
:param think:
If True, the model will "think" before producing a response.
Only [thinking models](https://ollama.com/search?c=thinking) support this feature.
Expand Down Expand Up @@ -268,6 +301,7 @@ def __init__(
self.url = url
self.generation_kwargs = generation_kwargs or {}
self.timeout = timeout
self.max_retries = max_retries
self.keep_alive = keep_alive
self.streaming_callback = streaming_callback
self.tools = tools # Store original tools for serialization
Expand All @@ -292,6 +326,7 @@ def to_dict(self) -> dict[str, Any]:
url=self.url,
generation_kwargs=self.generation_kwargs,
timeout=self.timeout,
max_retries=self.max_retries,
keep_alive=self.keep_alive,
streaming_callback=callback_name,
tools=serialize_tools_or_toolset(self.tools),
Expand Down Expand Up @@ -469,6 +504,56 @@ async def _handle_streaming_response_async(

return {"replies": [reply]}

@retry(
reraise=True,
stop=_stop_after_instance_max_retries,
retry=retry_if_exception(_is_retryable_exception),
wait=wait_exponential(),
)
def _chat(
self,
*,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
is_stream: bool,
generation_kwargs: dict[str, Any],
) -> ChatResponse | Iterator[ChatResponse]:
return self._client.chat(
model=self.model,
messages=messages,
tools=tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
)

@retry(
reraise=True,
stop=_stop_after_instance_max_retries,
retry=retry_if_exception(_is_retryable_exception),
wait=wait_exponential(),
)
async def _chat_async(
self,
*,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
is_stream: bool,
generation_kwargs: dict[str, Any],
) -> ChatResponse | AsyncIterator[ChatResponse]:
return await self._async_client.chat(
model=self.model,
messages=messages,
tools=tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
)

@component.output_types(replies=list[ChatMessage])
def run(
self,
Expand Down Expand Up @@ -518,15 +603,8 @@ def run(

ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]

response = self._client.chat(
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
response = self._chat(
messages=ollama_messages, tools=ollama_tools, is_stream=is_stream, generation_kwargs=generation_kwargs
)

if isinstance(response, Iterator):
Expand Down Expand Up @@ -579,15 +657,8 @@ async def run_async(

ollama_messages = [_convert_chatmessage_to_ollama_format(m) for m in messages]

response = await self._async_client.chat(
model=self.model,
messages=ollama_messages,
tools=ollama_tools,
stream=is_stream, # type: ignore[call-overload] # Ollama expects Literal[True] or Literal[False], not bool
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
think=self.think,
response = await self._chat_async(
messages=ollama_messages, tools=ollama_tools, is_stream=is_stream, generation_kwargs=generation_kwargs
)

if isinstance(response, AsyncIterator):
Expand Down
99 changes: 98 additions & 1 deletion integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Annotated
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch

import pytest
from haystack.components.generators.utils import print_streaming_chunk
Expand Down Expand Up @@ -518,6 +518,7 @@ def test_init_default(self):
assert component.url == "http://localhost:11434"
assert component.generation_kwargs == {}
assert component.timeout == 120
assert component.max_retries == 0
assert component.streaming_callback is None
assert component.tools is None
assert component.keep_alive is None
Expand All @@ -529,6 +530,7 @@ def test_init(self, tools):
url="http://my-custom-endpoint:11434",
generation_kwargs={"temperature": 0.5},
timeout=5,
max_retries=2,
keep_alive="10m",
streaming_callback=print_streaming_chunk,
tools=tools,
Expand All @@ -539,6 +541,7 @@ def test_init(self, tools):
assert component.url == "http://my-custom-endpoint:11434"
assert component.generation_kwargs == {"temperature": 0.5}
assert component.timeout == 5
assert component.max_retries == 2
assert component.keep_alive == "10m"
assert component.streaming_callback is print_streaming_chunk
assert component.tools == tools
Expand Down Expand Up @@ -603,6 +606,7 @@ def test_to_dict(self):
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"max_retries": 0,
"model": "llama2",
"url": "custom_url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
Expand Down Expand Up @@ -650,6 +654,7 @@ def test_from_dict(self):
"type": "haystack_integrations.components.generators.ollama.chat.chat_generator.OllamaChatGenerator",
"init_parameters": {
"timeout": 120,
"max_retries": 0,
"model": "llama2",
"url": "custom_url",
"keep_alive": "5m",
Expand Down Expand Up @@ -689,6 +694,7 @@ def test_from_dict(self):
"some_test_param": "test-params",
}
assert component.timeout == 120
assert component.max_retries == 0
assert component.tools == [tool]
assert component.response_format == {
"type": "object",
Expand Down Expand Up @@ -790,6 +796,97 @@ def test_run(self, mock_client):
assert result["replies"][0].text == "Fine. How can I help you today?"
assert result["replies"][0].role == "assistant"

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_retries_after_failure(self, mock_client):
generator = OllamaChatGenerator(max_retries=1)

mock_response = ChatResponse(
model="qwen3:0.6b",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Recovered after retry"},
done=True,
prompt_eval_count=1,
eval_count=2,
)

mock_client_instance = mock_client.return_value
mock_client_instance.chat.side_effect = [ResponseError("temporary failure", status_code=500), mock_response]

result = generator.run(messages=[ChatMessage.from_user("Hello!")])

assert mock_client_instance.chat.call_count == 2
assert result["replies"][0].text == "Recovered after retry"

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_raises_after_retry_exhausted(self, mock_client):
generator = OllamaChatGenerator(max_retries=1)
mock_client_instance = mock_client.return_value
mock_client_instance.chat.side_effect = ResponseError("persistent failure", status_code=503)

with pytest.raises(ResponseError, match="persistent failure"):
generator.run(messages=[ChatMessage.from_user("Hello!")])

assert mock_client_instance.chat.call_count == 2

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_does_not_retry_non_retryable_error(self, mock_client):
generator = OllamaChatGenerator(max_retries=2)
mock_client_instance = mock_client.return_value
mock_client_instance.chat.side_effect = ResponseError("bad request", status_code=400)

with pytest.raises(ResponseError, match="bad request"):
generator.run(messages=[ChatMessage.from_user("Hello!")])

assert mock_client_instance.chat.call_count == 1

@pytest.mark.asyncio
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
async def test_run_async_does_not_retry_non_retryable_error(self, mock_async_client):
generator = OllamaChatGenerator(max_retries=2)
mock_async_client_instance = mock_async_client.return_value
mock_async_client_instance.chat = AsyncMock(side_effect=ResponseError("bad request", status_code=400))

with pytest.raises(ResponseError, match="bad request"):
await generator.run_async(messages=[ChatMessage.from_user("Hello!")])

assert mock_async_client_instance.chat.call_count == 1

@pytest.mark.asyncio
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
async def test_run_async_retries_after_failure(self, mock_async_client):
generator = OllamaChatGenerator(max_retries=1)

mock_response = ChatResponse(
model="qwen3:0.6b",
created_at="2023-12-12T14:13:43.416799Z",
message={"role": "assistant", "content": "Recovered after retry"},
done=True,
prompt_eval_count=1,
eval_count=2,
)

mock_async_client_instance = mock_async_client.return_value
mock_async_client_instance.chat = AsyncMock(
side_effect=[ResponseError("temporary failure", status_code=500), mock_response]
)

result = await generator.run_async(messages=[ChatMessage.from_user("Hello!")])

assert mock_async_client_instance.chat.call_count == 2
assert result["replies"][0].text == "Recovered after retry"

@pytest.mark.asyncio
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.AsyncClient")
async def test_run_async_raises_after_retry_exhausted(self, mock_async_client):
generator = OllamaChatGenerator(max_retries=1)
mock_async_client_instance = mock_async_client.return_value
mock_async_client_instance.chat = AsyncMock(side_effect=ResponseError("persistent failure", status_code=503))

with pytest.raises(ResponseError, match="persistent failure"):
await generator.run_async(messages=[ChatMessage.from_user("Hello!")])

assert mock_async_client_instance.chat.call_count == 2

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run_streaming(self, mock_client):
collected_chunks = []
Expand Down