diff --git a/datafast/llms.py b/datafast/llms.py index 7d0bcee..72f309e 100644 --- a/datafast/llms.py +++ b/datafast/llms.py @@ -40,6 +40,7 @@ def __init__( top_p: float | None = None, frequency_penalty: float | None = None, rpm_limit: int | None = None, + timeout: int | None = None, ): """Initialize the LLM provider with common parameters. @@ -64,6 +65,9 @@ def __init__( self.rpm_limit = rpm_limit self._request_timestamps: list[float] = [] + # timeout + self.timeout = timeout + # Configure environment with API key if needed self._configure_env() @@ -249,6 +253,7 @@ def generate( "max_tokens": self.max_completion_tokens, "top_p": self.top_p, "frequency_penalty": self.frequency_penalty, + "timeout": self.timeout, } if response_format is not None: completion_params["response_format"] = response_format @@ -336,6 +341,7 @@ def __init__( temperature: float | None = None, top_p: float | None = None, frequency_penalty: float | None = None, + timeout: int | None = None, ): """Initialize the OpenAI provider. @@ -347,6 +353,7 @@ def __init__( temperature: DEPRECATED - Not supported by responses endpoint top_p: DEPRECATED - Not supported by responses endpoint frequency_penalty: DEPRECATED - Not supported by responses endpoint + timeout: Request timeout in seconds """ # Warn about deprecated parameters if temperature is not None: @@ -379,6 +386,7 @@ def __init__( max_completion_tokens=max_completion_tokens, top_p=None, frequency_penalty=None, + timeout=timeout, ) def generate( @@ -550,6 +558,7 @@ def __init__( api_key: str | None = None, temperature: float | None = None, max_completion_tokens: int | None = None, + timeout: int | None = None, # top_p: float | None = None, # Not properly supported by anthropic models 4.5 # frequency_penalty: float | None = None, # Not supported by anthropic models 4.5 ): @@ -560,6 +569,7 @@ def __init__( api_key: API key (if None, will get from environment) temperature: Temperature for generation (0.0 to 1.0) max_completion_tokens: Maximum tokens to generate + timeout: Request timeout in seconds top_p: Nucleus sampling parameter (0.0 to 1.0) """ super().__init__( @@ -567,6 +577,7 @@ def __init__( api_key=api_key, temperature=temperature, max_completion_tokens=max_completion_tokens, + timeout=timeout, ) @@ -590,6 +601,7 @@ def __init__( top_p: float | None = None, frequency_penalty: float | None = None, rpm_limit: int | None = None, + timeout: int | None = None, ): """Initialize the Gemini provider. @@ -600,6 +612,7 @@ def __init__( max_completion_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter (0.0 to 1.0) frequency_penalty: Penalty for token frequency (-2.0 to 2.0) + timeout: Request timeout in seconds """ super().__init__( model_id=model_id, @@ -609,6 +622,7 @@ def __init__( top_p=top_p, frequency_penalty=frequency_penalty, rpm_limit=rpm_limit, + timeout=timeout, ) @@ -643,6 +657,7 @@ def __init__( frequency_penalty: float | None = None, api_base: str | None = None, rpm_limit: int | None = None, + timeout: int | None = None, ): """Initialize the Ollama provider. @@ -653,6 +668,7 @@ def __init__( top_p: Nucleus sampling parameter (0.0 to 1.0) frequency_penalty: Penalty for token frequency (-2.0 to 2.0) api_base: Base URL for Ollama API (e.g., "http://localhost:11434") + timeout: Request timeout in seconds """ # Set API base URL if provided if api_base: @@ -666,6 +682,7 @@ def __init__( top_p=top_p, frequency_penalty=frequency_penalty, rpm_limit=rpm_limit, + timeout=timeout, ) @@ -688,6 +705,7 @@ def __init__( max_completion_tokens: int | None = None, top_p: float | None = None, frequency_penalty: float | None = None, + timeout: int | None = None, ): """Initialize the OpenRouter provider. @@ -698,6 +716,7 @@ def __init__( max_completion_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter (0.0 to 1.0) frequency_penalty: Penalty for token frequency (-2.0 to 2.0) + timeout: Request timeout in seconds """ super().__init__( model_id = model_id, @@ -706,4 +725,5 @@ def __init__( max_completion_tokens = max_completion_tokens, top_p = top_p, frequency_penalty = frequency_penalty, + timeout = timeout, ) \ No newline at end of file diff --git a/tests/test_ollama.py b/tests/test_ollama.py index ecf26af..cfae548 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -91,6 +91,54 @@ def test_ollama_with_all_parameters(): assert "Paris" in response +@pytest.mark.integration +def test_ollama_timeout(): + """Test that the timeout parameter works correctly with Ollama provider.""" + # Create provider with a very short timeout (1 second) + provider = OllamaProvider( + model_id="gemma3:4b", + temperature=0.7, + max_completion_tokens=500, + timeout=1 # 1 second - should timeout + ) + + # Try to generate a response that would normally take longer + prompt = """Write a detailed essay about the history of artificial intelligence, + covering major milestones from the 1950s to present day. Include information about + key researchers, breakthrough algorithms, and the evolution of neural networks.""" + + # Expect a RuntimeError (which wraps the timeout/connection error) + with pytest.raises(RuntimeError) as exc_info: + provider.generate(prompt=prompt) + + # Verify that the error is related to timeout or connection issues + error_message = str(exc_info.value).lower() + # The error could mention timeout, connection errors, or API errors + assert any(keyword in error_message for keyword in [ + "timeout", "timed out", "time out", "deadline", + "apiconnectionerror", "connection", "api" + ]), f"Expected timeout/connection error, but got: {exc_info.value}" + + +@pytest.mark.integration +def test_ollama_with_reasonable_timeout(): + """Test that a reasonable timeout allows successful completion.""" + # Create provider with a reasonable timeout (30 seconds) + provider = OllamaProvider( + model_id="gemma3:4b", + temperature=0.7, + max_completion_tokens=50, + timeout=30 # 30 seconds - should be enough + ) + + # Generate a simple response + prompt = "What is the capital of France? Answer in one word." + response = provider.generate(prompt=prompt) + + # Should complete successfully + assert "Paris" in response + + @pytest.mark.integration def test_ollama_structured_landmark_info(): """Test Ollama with a structured landmark info response."""