Skip to content

Commit c377553

Browse files
Merge pull request #138 from patrickfleith/feat/add-llm-timeout
Adding timeout parameter support to LLM providers
2 parents f9b5366 + 7559efe commit c377553

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

datafast/llms.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
top_p: float | None = None,
4141
frequency_penalty: float | None = None,
4242
rpm_limit: int | None = None,
43+
timeout: int | None = None,
4344
):
4445
"""Initialize the LLM provider with common parameters.
4546
@@ -64,6 +65,9 @@ def __init__(
6465
self.rpm_limit = rpm_limit
6566
self._request_timestamps: list[float] = []
6667

68+
# timeout
69+
self.timeout = timeout
70+
6771
# Configure environment with API key if needed
6872
self._configure_env()
6973

@@ -249,6 +253,7 @@ def generate(
249253
"max_tokens": self.max_completion_tokens,
250254
"top_p": self.top_p,
251255
"frequency_penalty": self.frequency_penalty,
256+
"timeout": self.timeout,
252257
}
253258
if response_format is not None:
254259
completion_params["response_format"] = response_format
@@ -336,6 +341,7 @@ def __init__(
336341
temperature: float | None = None,
337342
top_p: float | None = None,
338343
frequency_penalty: float | None = None,
344+
timeout: int | None = None,
339345
):
340346
"""Initialize the OpenAI provider.
341347
@@ -347,6 +353,7 @@ def __init__(
347353
temperature: DEPRECATED - Not supported by responses endpoint
348354
top_p: DEPRECATED - Not supported by responses endpoint
349355
frequency_penalty: DEPRECATED - Not supported by responses endpoint
356+
timeout: Request timeout in seconds
350357
"""
351358
# Warn about deprecated parameters
352359
if temperature is not None:
@@ -379,6 +386,7 @@ def __init__(
379386
max_completion_tokens=max_completion_tokens,
380387
top_p=None,
381388
frequency_penalty=None,
389+
timeout=timeout,
382390
)
383391

384392
def generate(
@@ -550,6 +558,7 @@ def __init__(
550558
api_key: str | None = None,
551559
temperature: float | None = None,
552560
max_completion_tokens: int | None = None,
561+
timeout: int | None = None,
553562
# top_p: float | None = None, # Not properly supported by anthropic models 4.5
554563
# frequency_penalty: float | None = None, # Not supported by anthropic models 4.5
555564
):
@@ -560,13 +569,15 @@ def __init__(
560569
api_key: API key (if None, will get from environment)
561570
temperature: Temperature for generation (0.0 to 1.0)
562571
max_completion_tokens: Maximum tokens to generate
572+
timeout: Request timeout in seconds
563573
top_p: Nucleus sampling parameter (0.0 to 1.0)
564574
"""
565575
super().__init__(
566576
model_id=model_id,
567577
api_key=api_key,
568578
temperature=temperature,
569579
max_completion_tokens=max_completion_tokens,
580+
timeout=timeout,
570581
)
571582

572583

@@ -590,6 +601,7 @@ def __init__(
590601
top_p: float | None = None,
591602
frequency_penalty: float | None = None,
592603
rpm_limit: int | None = None,
604+
timeout: int | None = None,
593605
):
594606
"""Initialize the Gemini provider.
595607
@@ -600,6 +612,7 @@ def __init__(
600612
max_completion_tokens: Maximum tokens to generate
601613
top_p: Nucleus sampling parameter (0.0 to 1.0)
602614
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
615+
timeout: Request timeout in seconds
603616
"""
604617
super().__init__(
605618
model_id=model_id,
@@ -609,6 +622,7 @@ def __init__(
609622
top_p=top_p,
610623
frequency_penalty=frequency_penalty,
611624
rpm_limit=rpm_limit,
625+
timeout=timeout,
612626
)
613627

614628

@@ -643,6 +657,7 @@ def __init__(
643657
frequency_penalty: float | None = None,
644658
api_base: str | None = None,
645659
rpm_limit: int | None = None,
660+
timeout: int | None = None,
646661
):
647662
"""Initialize the Ollama provider.
648663
@@ -653,6 +668,7 @@ def __init__(
653668
top_p: Nucleus sampling parameter (0.0 to 1.0)
654669
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
655670
api_base: Base URL for Ollama API (e.g., "http://localhost:11434")
671+
timeout: Request timeout in seconds
656672
"""
657673
# Set API base URL if provided
658674
if api_base:
@@ -666,6 +682,7 @@ def __init__(
666682
top_p=top_p,
667683
frequency_penalty=frequency_penalty,
668684
rpm_limit=rpm_limit,
685+
timeout=timeout,
669686
)
670687

671688

@@ -688,6 +705,7 @@ def __init__(
688705
max_completion_tokens: int | None = None,
689706
top_p: float | None = None,
690707
frequency_penalty: float | None = None,
708+
timeout: int | None = None,
691709
):
692710
"""Initialize the OpenRouter provider.
693711
@@ -698,6 +716,7 @@ def __init__(
698716
max_completion_tokens: Maximum tokens to generate
699717
top_p: Nucleus sampling parameter (0.0 to 1.0)
700718
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
719+
timeout: Request timeout in seconds
701720
"""
702721
super().__init__(
703722
model_id = model_id,
@@ -706,4 +725,5 @@ def __init__(
706725
max_completion_tokens = max_completion_tokens,
707726
top_p = top_p,
708727
frequency_penalty = frequency_penalty,
728+
timeout = timeout,
709729
)

tests/test_ollama.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,54 @@ def test_ollama_with_all_parameters():
9191
assert "Paris" in response
9292

9393

94+
@pytest.mark.integration
95+
def test_ollama_timeout():
96+
"""Test that the timeout parameter works correctly with Ollama provider."""
97+
# Create provider with a very short timeout (1 second)
98+
provider = OllamaProvider(
99+
model_id="gemma3:4b",
100+
temperature=0.7,
101+
max_completion_tokens=500,
102+
timeout=1 # 1 second - should timeout
103+
)
104+
105+
# Try to generate a response that would normally take longer
106+
prompt = """Write a detailed essay about the history of artificial intelligence,
107+
covering major milestones from the 1950s to present day. Include information about
108+
key researchers, breakthrough algorithms, and the evolution of neural networks."""
109+
110+
# Expect a RuntimeError (which wraps the timeout/connection error)
111+
with pytest.raises(RuntimeError) as exc_info:
112+
provider.generate(prompt=prompt)
113+
114+
# Verify that the error is related to timeout or connection issues
115+
error_message = str(exc_info.value).lower()
116+
# The error could mention timeout, connection errors, or API errors
117+
assert any(keyword in error_message for keyword in [
118+
"timeout", "timed out", "time out", "deadline",
119+
"apiconnectionerror", "connection", "api"
120+
]), f"Expected timeout/connection error, but got: {exc_info.value}"
121+
122+
123+
@pytest.mark.integration
124+
def test_ollama_with_reasonable_timeout():
125+
"""Test that a reasonable timeout allows successful completion."""
126+
# Create provider with a reasonable timeout (30 seconds)
127+
provider = OllamaProvider(
128+
model_id="gemma3:4b",
129+
temperature=0.7,
130+
max_completion_tokens=50,
131+
timeout=30 # 30 seconds - should be enough
132+
)
133+
134+
# Generate a simple response
135+
prompt = "What is the capital of France? Answer in one word."
136+
response = provider.generate(prompt=prompt)
137+
138+
# Should complete successfully
139+
assert "Paris" in response
140+
141+
94142
@pytest.mark.integration
95143
def test_ollama_structured_landmark_info():
96144
"""Test Ollama with a structured landmark info response."""

0 commit comments

Comments
 (0)