Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions datafast/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
top_p: float | None = None,
frequency_penalty: float | None = None,
rpm_limit: int | None = None,
timeout: int | None = None,
):
"""Initialize the LLM provider with common parameters.

Expand All @@ -64,6 +65,9 @@ def __init__(
self.rpm_limit = rpm_limit
self._request_timestamps: list[float] = []

# timeout
self.timeout = timeout

# Configure environment with API key if needed
self._configure_env()

Expand Down Expand Up @@ -249,6 +253,7 @@ def generate(
"max_tokens": self.max_completion_tokens,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"timeout": self.timeout,
}
if response_format is not None:
completion_params["response_format"] = response_format
Expand Down Expand Up @@ -336,6 +341,7 @@ def __init__(
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
timeout: int | None = None,
):
"""Initialize the OpenAI provider.

Expand All @@ -347,6 +353,7 @@ def __init__(
temperature: DEPRECATED - Not supported by responses endpoint
top_p: DEPRECATED - Not supported by responses endpoint
frequency_penalty: DEPRECATED - Not supported by responses endpoint
timeout: Request timeout in seconds
"""
# Warn about deprecated parameters
if temperature is not None:
Expand Down Expand Up @@ -379,6 +386,7 @@ def __init__(
max_completion_tokens=max_completion_tokens,
top_p=None,
frequency_penalty=None,
timeout=timeout,
)

def generate(
Expand Down Expand Up @@ -550,6 +558,7 @@ def __init__(
api_key: str | None = None,
temperature: float | None = None,
max_completion_tokens: int | None = None,
timeout: int | None = None,
# top_p: float | None = None, # Not properly supported by anthropic models 4.5
# frequency_penalty: float | None = None, # Not supported by anthropic models 4.5
):
Expand All @@ -560,13 +569,15 @@ def __init__(
api_key: API key (if None, will get from environment)
temperature: Temperature for generation (0.0 to 1.0)
max_completion_tokens: Maximum tokens to generate
timeout: Request timeout in seconds
top_p: Nucleus sampling parameter (0.0 to 1.0)
"""
super().__init__(
model_id=model_id,
api_key=api_key,
temperature=temperature,
max_completion_tokens=max_completion_tokens,
timeout=timeout,
)


Expand All @@ -590,6 +601,7 @@ def __init__(
top_p: float | None = None,
frequency_penalty: float | None = None,
rpm_limit: int | None = None,
timeout: int | None = None,
):
"""Initialize the Gemini provider.

Expand All @@ -600,6 +612,7 @@ def __init__(
max_completion_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter (0.0 to 1.0)
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
timeout: Request timeout in seconds
"""
super().__init__(
model_id=model_id,
Expand All @@ -609,6 +622,7 @@ def __init__(
top_p=top_p,
frequency_penalty=frequency_penalty,
rpm_limit=rpm_limit,
timeout=timeout,
)


Expand Down Expand Up @@ -643,6 +657,7 @@ def __init__(
frequency_penalty: float | None = None,
api_base: str | None = None,
rpm_limit: int | None = None,
timeout: int | None = None,
):
"""Initialize the Ollama provider.

Expand All @@ -653,6 +668,7 @@ def __init__(
top_p: Nucleus sampling parameter (0.0 to 1.0)
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
api_base: Base URL for Ollama API (e.g., "http://localhost:11434")
timeout: Request timeout in seconds
"""
# Set API base URL if provided
if api_base:
Expand All @@ -666,6 +682,7 @@ def __init__(
top_p=top_p,
frequency_penalty=frequency_penalty,
rpm_limit=rpm_limit,
timeout=timeout,
)


Expand All @@ -688,6 +705,7 @@ def __init__(
max_completion_tokens: int | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
timeout: int | None = None,
):
"""Initialize the OpenRouter provider.

Expand All @@ -698,6 +716,7 @@ def __init__(
max_completion_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter (0.0 to 1.0)
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
timeout: Request timeout in seconds
"""
super().__init__(
model_id = model_id,
Expand All @@ -706,4 +725,5 @@ def __init__(
max_completion_tokens = max_completion_tokens,
top_p = top_p,
frequency_penalty = frequency_penalty,
timeout = timeout,
)
48 changes: 48 additions & 0 deletions tests/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,54 @@ def test_ollama_with_all_parameters():
assert "Paris" in response


@pytest.mark.integration
def test_ollama_timeout():
"""Test that the timeout parameter works correctly with Ollama provider."""
# Create provider with a very short timeout (1 second)
provider = OllamaProvider(
model_id="gemma3:4b",
temperature=0.7,
max_completion_tokens=500,
timeout=1 # 1 second - should timeout
)

# Try to generate a response that would normally take longer
prompt = """Write a detailed essay about the history of artificial intelligence,
covering major milestones from the 1950s to present day. Include information about
key researchers, breakthrough algorithms, and the evolution of neural networks."""

# Expect a RuntimeError (which wraps the timeout/connection error)
with pytest.raises(RuntimeError) as exc_info:
provider.generate(prompt=prompt)

# Verify that the error is related to timeout or connection issues
error_message = str(exc_info.value).lower()
# The error could mention timeout, connection errors, or API errors
assert any(keyword in error_message for keyword in [
"timeout", "timed out", "time out", "deadline",
"apiconnectionerror", "connection", "api"
]), f"Expected timeout/connection error, but got: {exc_info.value}"


@pytest.mark.integration
def test_ollama_with_reasonable_timeout():
"""Test that a reasonable timeout allows successful completion."""
# Create provider with a reasonable timeout (30 seconds)
provider = OllamaProvider(
model_id="gemma3:4b",
temperature=0.7,
max_completion_tokens=50,
timeout=30 # 30 seconds - should be enough
)

# Generate a simple response
prompt = "What is the capital of France? Answer in one word."
response = provider.generate(prompt=prompt)

# Should complete successfully
assert "Paris" in response


@pytest.mark.integration
def test_ollama_structured_landmark_info():
"""Test Ollama with a structured landmark info response."""
Expand Down
Loading