diff --git a/datafast/examples/mcq_example.py b/datafast/examples/mcq_example.py index dba76f8..0a12312 100644 --- a/datafast/examples/mcq_example.py +++ b/datafast/examples/mcq_example.py @@ -11,12 +11,12 @@ def main(): # 1. Define the configuration config = MCQDatasetConfig( - # hf_dataset_name="patrickfleith/space_engineering_environment_effects_texts", + hf_dataset_name="patrickfleith/space_engineering_environment_effects_texts", # local_file_path="datafast/examples/data/mcq/sample.csv", # local_file_path="datafast/examples/data/mcq/sample.txt", - local_file_path="datafast/examples/data/mcq/sample.jsonl", + #local_file_path="datafast/examples/data/mcq/sample.jsonl", text_column="text", # Column containing the text to generate questions from - sample_count=3, # Process only 3 samples for testing + sample_count=2, # Process only 3 samples for testing num_samples_per_prompt=2,# Generate 2 questions per document min_document_length=100, # Skip documents shorter than 100 chars max_document_length=20000,# Skip documents longer than 20000 chars @@ -32,7 +32,7 @@ def main(): # 3. Generate the dataset dataset = MCQDataset(config) - num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=3) + num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=2) print(f"\nExpected number of rows: {num_expected_rows}") dataset.generate(providers) @@ -55,5 +55,5 @@ def main(): if __name__ == "__main__": from dotenv import load_dotenv - load_dotenv("secrets.env") + load_dotenv() main() diff --git a/datafast/llms.py b/datafast/llms.py index dc4e46b..6220c05 100644 --- a/datafast/llms.py +++ b/datafast/llms.py @@ -9,6 +9,7 @@ import os import time import traceback +import warnings # Pydantic from pydantic import BaseModel @@ -249,9 +250,11 @@ def generate( response: list[ModelResponse] = litellm.batch_completion( **completion_params) - # Record timestamp for rate limiting + # Record timestamp for rate limiting (one timestamp per batch item) if self.rpm_limit is not None: - self._request_timestamps.append(time.monotonic()) + current_time = time.monotonic() + for _ in range(len(batch_to_send)): + self._request_timestamps.append(current_time) # Extract content from each response results = [] @@ -280,7 +283,15 @@ def generate( class OpenAIProvider(LLMProvider): - """OpenAI provider using litellm.""" + """OpenAI provider using litellm.responses endpoint. + + Note: This provider uses the new responses endpoint which has different + parameter support compared to the standard completion endpoint: + - temperature, top_p, and frequency_penalty are not supported + - Uses text_format instead of response_format + - Supports reasoning parameter for controlling reasoning effort + - Does not support batch operations (will process sequentially with warning) + """ @property def provider_name(self) -> str: @@ -294,29 +305,187 @@ def __init__( self, model_id: str = "gpt-5-mini-2025-08-07", api_key: str | None = None, - temperature: float | None = None, max_completion_tokens: int | None = None, + reasoning_effort: str = "low", + temperature: float | None = None, top_p: float | None = None, frequency_penalty: float | None = None, ): """Initialize the OpenAI provider. Args: - model_id: The model ID (defaults to gpt-5-mini-2025-08-07) + model_id: The model ID (defaults to gpt-5-mini) api_key: API key (if None, will get from environment) - temperature: The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens. - top_p: Nucleus sampling parameter (0.0 to 1.0) - frequency_penalty: Penalty for token frequency (-2.0 to 2.0) + reasoning_effort: Reasoning effort level - "low", "medium", or "high" (defaults to "low") + temperature: DEPRECATED - Not supported by responses endpoint + top_p: DEPRECATED - Not supported by responses endpoint + frequency_penalty: DEPRECATED - Not supported by responses endpoint """ + # Warn about deprecated parameters + if temperature is not None: + warnings.warn( + "temperature parameter is not supported by OpenAI responses endpoint and will be ignored", + UserWarning, + stacklevel=2 + ) + if top_p is not None: + warnings.warn( + "top_p parameter is not supported by OpenAI responses endpoint and will be ignored", + UserWarning, + stacklevel=2 + ) + if frequency_penalty is not None: + warnings.warn( + "frequency_penalty parameter is not supported by OpenAI responses endpoint and will be ignored", + UserWarning, + stacklevel=2 + ) + + # Store reasoning effort + self.reasoning_effort = reasoning_effort + + # Call parent init with None for unsupported params super().__init__( model_id=model_id, api_key=api_key, - temperature=temperature, + temperature=None, max_completion_tokens=max_completion_tokens, - top_p=top_p, - frequency_penalty=frequency_penalty, + top_p=None, + frequency_penalty=None, ) + + def generate( + self, + prompt: str | list[str] | None = None, + messages: list[Messages] | Messages | None = None, + response_format: Type[T] | None = None, + ) -> str | list[str] | T | list[T]: + """ + Generate responses from the LLM using the responses endpoint. + + Note: Batch operations are processed sequentially as the responses endpoint + does not support native batching. + + Args: + prompt: Single text prompt (str) or list of text prompts for batch processing + messages: Single message list or list of message lists for batch processing + response_format: Optional Pydantic model class for structured output + + Returns: + Single string/model or list of strings/models depending on input type. + + Raises: + ValueError: If neither prompt nor messages is provided, or if both are provided. + RuntimeError: If there's an error during generation. + """ + # Validate inputs + if prompt is None and messages is None: + raise ValueError("Either prompts or messages must be provided") + if prompt is not None and messages is not None: + raise ValueError("Provide either prompts or messages, not both") + + # Determine if this is a single input or batch input + single_input = False + batch_prompts = None + batch_messages = None + + if prompt is not None: + if isinstance(prompt, str): + # Single prompt - convert to batch + batch_prompts = [prompt] + single_input = True + elif isinstance(prompt, list): + # Already a list of prompts + batch_prompts = prompt + single_input = False + else: + raise ValueError("prompt must be a string or list of strings") + + if messages is not None: + if isinstance(messages, list) and len(messages) > 0: + # Check if it's a single message list or batch + if isinstance(messages[0], dict): + # Single message list - convert to batch + batch_messages = [messages] + single_input = True + elif isinstance(messages[0], list): + # Already a batch of message lists + batch_messages = messages + single_input = False + else: + raise ValueError("Invalid messages format") + else: + raise ValueError("messages cannot be empty") + + try: + # Convert batch prompts to messages if needed + batch_to_send = [] + if batch_prompts is not None: + for one_prompt in batch_prompts: + batch_to_send.append([{"role": "user", "content": one_prompt}]) + else: + batch_to_send = batch_messages + + # Warn if batch processing is being used + if len(batch_to_send) > 1: + warnings.warn( + f"OpenAI responses endpoint does not support batch operations. " + f"Processing {len(batch_to_send)} requests sequentially.", + UserWarning, + stacklevel=2 + ) + + # Process each request sequentially + results = [] + for message_list in batch_to_send: + # Enforce rate limit per request + self._respect_rate_limit() + + # Prepare completion parameters + completion_params = { + "model": self._get_model_string(), + "input": message_list, + "reasoning": {"effort": self.reasoning_effort}, + } + + # Add max_output_tokens if specified + if self.max_completion_tokens is not None: + completion_params["max_output_tokens"] = self.max_completion_tokens + + # Add text_format if response_format is provided + if response_format is not None: + completion_params["text_format"] = response_format + + # Call LiteLLM responses endpoint + response = litellm.responses(**completion_params) + + # Record timestamp for rate limiting + if self.rpm_limit is not None: + self._request_timestamps.append(time.monotonic()) + + # Extract content from response + # Response structure: response.output[1].content[0].text + content = response.output[1].content[0].text + + if response_format is not None: + # Strip code fences before validation + content = self._strip_code_fences(content) + results.append(response_format.model_validate_json(content)) + else: + # Strip leading/trailing whitespace for text responses + results.append(content.strip() if content else content) + + # Return single result for backward compatibility + if single_input and len(results) == 1: + return results[0] + return results + + except Exception as e: + error_trace = traceback.format_exc() + raise RuntimeError( + f"Error generating response with {self.provider_name}:\n{error_trace}" + ) class AnthropicProvider(LLMProvider): diff --git a/docs/index.md b/docs/index.md index ef53fc7..c11ac6d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -95,7 +95,7 @@ config = ClassificationDatasetConfig( providers = [ OpenAIProvider(model_id="gpt-5-mini-2025-08-07"), AnthropicProvider(model_id="claude-haiku-4-5-20251001"), - GeminiProvider(model_id="gemini-2.5-flash"), + GeminiProvider(model_id="gemini-2.0-flash"), OpenRouterProvider(model_id="z-ai/glm-4.6") ] ``` diff --git a/docs/llms.md b/docs/llms.md index e5d1939..f07534a 100644 --- a/docs/llms.md +++ b/docs/llms.md @@ -33,7 +33,7 @@ gemini_llm = GeminiProvider() # Ollama (default: gemma3:4b) ollama_llm = OllamaProvider() -# OpenRouter (default: openai/gpt-4.1-mini) +# OpenRouter (default: openai/gpt-5-mini) openrouter_llm = OpenRouterProvider() ``` @@ -42,12 +42,38 @@ openrouter_llm = OpenRouterProvider() ```python openai_llm = OpenAIProvider( model_id="gpt-5-mini-2025-08-07", # Custom model - temperature=0.2, # Lower temperature for more deterministic outputs - max_completion_tokens=100, # Limit token generation - top_p=0.9, # Nucleus sampling parameter - frequency_penalty=0.1 # Penalty for frequent tokens + max_completion_tokens=1000, # Limit token generation (don't set this too low for reasoning models) + reasoning_effort="medium" # Reasoning effort: "low", "medium", or "high" ) +``` + +!!! warning "OpenAI Provider Changes" + `OpenAIProvider` now uses the `responses` endpoint. The following parameters are **deprecated** and will trigger warnings: + - `temperature` + - `top_p` + - `frequency_penalty` + + Use `reasoning_effort` ("low", "medium", "high") instead to control generation behavior. +```python +# Anthropic with custom parameters +anthropic_llm = AnthropicProvider( + model_id="claude-haiku-4-5-20251001", + temperature=0.7, + max_completion_tokens=1000 +) +``` + +!!! warning "Anthropic Provider Limitations" + `AnthropicProvider` only supports the following parameters: + - `temperature` (0.0 to 1.0) + - `max_completion_tokens` + + The following parameters are **not supported** by Anthropic Claude 4.5 models: + - `top_p` + - `frequency_penalty` + +```python # Ollama with custom API endpoint ollama_llm = OllamaProvider( model_id="llama3.2:latest", diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000..eded220 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,114 @@ +# Models + +Datafast supports multiple LLM providers through a unified interface. Since model evolve fast, it is not uncommon for things to break. +Please find below a list of my favoriate models to use in `datafast` for each LLMProvider which provide a balance of cost, performance and stability. + +See [LLM Providers](llms.md) for more details about supported arguments for each provider. + +## Recommended Models by Provider + +### OpenAI + +**Default**: `gpt-5-mini-2025-08-07` + +**Recommended Models**: + +- **gpt-5-2025-08-07** - Most intelligent, capable, but also expensive. Only use for the most complex tasks. + *Pricing: \$1.25/\$10 per million I/O token* + +- **gpt-5-mini-2025-08-07** - Intelligent, capable, and affordable. + *Pricing: \$0.25/\$2 per million I/O token* + +- **gpt-5-nano-2025-08-07** - Tiny and cheap. Only use for simple tasks or testing. + *Pricing: \$0.05/\$0.4 per million I/O token* + +```python +from datafast.llms import OpenAIProvider + +# Using default model +llm = OpenAIProvider() + +# Using a specific model +llm = OpenAIProvider(model_id="gpt-5-2025-08-07") +``` + +### Anthropic + +**Default**: `claude-haiku-4-5-20251001` + +**Recommended Models**: + +- **claude-haiku-4-5-20251001** - Fast, efficient for most tasks. + *Pricing: \$1/\$5 per million I/O token* + +- **claude-sonnet-4-5-20250929** - Most powerful model, but also most expensive. + *Pricing: \$3/\$15 per million I/O token* + +```python +from datafast.llms import AnthropicProvider + +# Using default model +llm = AnthropicProvider() + +# Using a specific model +llm = AnthropicProvider(model_id="claude-sonnet-4-5-20251001") +``` + +### Google Gemini + +**Recommended and default**: `gemini-2.5-flash-lite` + +```python +from datafast.llms import GeminiProvider + +# Using default model +llm = GeminiProvider() + +# Using a specific model +llm = GeminiProvider(model_id="gemini-2.5-flash-lite") +``` + +### Ollama (Local Models) + +**Recommended**: `gemma3:27b-it-qat` + +Fast, capable, reliable, and does not take up too much vRAM. + +```python +from datafast.llms import OllamaProvider + +# Using recommended model +llm = OllamaProvider(model_id="gemma3:27b-it-qat") + +# Custom API endpoint +llm = OllamaProvider( + model_id="gemma3:27b-it-qat", + api_base="http://localhost:11434" +) +``` + +### OpenRouter + +There are many models available on OpenRouter, but here are some of our favorites: + +- **qwen/qwen3-next-80b-a3b-instruct** - High capability +- **deepseek/deepseek-r1-0528** - Strong reasoning, cost-effective +- **z-ai/glm-4.6** - Balanced performance +- **meta-llama/llama-3.3-70b-instruct** - Versatile, open-source + +```python +from datafast.llms import OpenRouterProvider + +# Using a specific model +llm = OpenRouterProvider(model_id="deepseek/deepseek-r1-0528") + +# Another example +llm = OpenRouterProvider(model_id="qwen/qwen3-next-80b-a3b-instruct") +``` + +!!! warning + Avoid using `gpt-oss:20b` or `gpt-oss:120b` as they do not work well with structured output. + +## More Details + +For comprehensive information about LLM providers, API keys, generation methods, and advanced usage, see the [LLM Providers](llms.md) page. diff --git a/mkdocs.yml b/mkdocs.yml index 4a0def5..3b5912f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,6 +59,9 @@ markdown_extensions: # Navigation structure nav: - Home: index.md + - Models: + - Recommended Models: models.md + - LLM Providers: llms.md # - Getting Started: getting_started.md - Guides: - guides/index.md @@ -71,7 +74,6 @@ nav: - Concepts: - Core Concepts: concepts.md - Prompt Expansion: guides/prompt_expansion.md - - LLM Providers: llms.md - API: api.md # Plugins diff --git a/tests/test_anthropic.py b/tests/test_anthropic.py index ff219c0..c5d55dc 100644 --- a/tests/test_anthropic.py +++ b/tests/test_anthropic.py @@ -1,115 +1,17 @@ from datafast.llms import AnthropicProvider from dotenv import load_dotenv import pytest -from typing import List, Optional -from pydantic import BaseModel, Field, field_validator +from tests.test_schemas import ( + SimpleResponse, + LandmarkInfo, + PersonaContent, + QASet, + MCQSet, +) load_dotenv() -class SimpleResponse(BaseModel): - """Simple response model for testing structured output.""" - answer: str = Field(description="The answer to the question") - reasoning: str = Field(description="The reasoning behind the answer") - - -class Attribute(BaseModel): - """Attribute of a landmark with value and importance.""" - name: str = Field(description="Name of the attribute") - value: str = Field(description="Value of the attribute") - importance: float = Field(description="Importance score between 0 and 1") - - @field_validator('importance') - @classmethod - def check_importance(cls, v: float) -> float: - """Validate importance is between 0 and 1.""" - if not 0 <= v <= 1: - raise ValueError("Importance must be between 0 and 1") - return v - - -class LandmarkInfo(BaseModel): - """Information about a landmark with attributes.""" - name: str = Field(description="The name of the landmark") - location: str = Field(description="Where the landmark is located") - description: str = Field(description="A brief description of the landmark") - year_built: Optional[int] = Field( - None, description="Year when the landmark was built") - attributes: List[Attribute] = Field( - description="List of attributes about the landmark") - visitor_rating: float = Field( - description="Average visitor rating from 0 to 5") - - @field_validator('visitor_rating') - @classmethod - def check_rating(cls, v: float) -> float: - """Validate rating is between 0 and 5.""" - if not 0 <= v <= 5: - raise ValueError("Rating must be between 0 and 5") - return v - - -class PersonaContent(BaseModel): - """Generated content for a persona including tweets and bio.""" - tweets: List[str] = Field(description="List of 5 tweets for the persona") - bio: str = Field(description="Biography for the persona") - - @field_validator('tweets') - @classmethod - def check_tweets_count(cls, v: List[str]) -> List[str]: - """Validate that exactly 5 tweets are provided.""" - if len(v) != 5: - raise ValueError("Must provide exactly 5 tweets") - return v - - -class QAItem(BaseModel): - """Question and answer pair.""" - question: str = Field(description="The question") - answer: str = Field(description="The correct answer") - - -class QASet(BaseModel): - """Set of questions and answers.""" - questions: List[QAItem] = Field(description="List of question-answer pairs") - - @field_validator('questions') - @classmethod - def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: - """Validate that exactly 5 Q&A pairs are provided.""" - if len(v) != 5: - raise ValueError("Must provide exactly 5 question-answer pairs") - return v - - -class MCQQuestion(BaseModel): - """Multiple choice question with one correct and three incorrect answers.""" - question: str = Field(description="The question") - correct_answer: str = Field(description="The correct answer") - incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") - - @field_validator('incorrect_answers') - @classmethod - def check_incorrect_count(cls, v: List[str]) -> List[str]: - """Validate that exactly 3 incorrect answers are provided.""" - if len(v) != 3: - raise ValueError("Must provide exactly 3 incorrect answers") - return v - - -class MCQSet(BaseModel): - """Set of multiple choice questions.""" - questions: List[MCQQuestion] = Field(description="List of MCQ questions") - - @field_validator('questions') - @classmethod - def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: - """Validate that exactly 3 questions are provided.""" - if len(v) != 3: - raise ValueError("Must provide exactly 3 questions") - return v - - @pytest.mark.integration class TestAnthropicSonnet45: """Anthropic tests for claude-sonnet-4-5-20250929.""" diff --git a/tests/test_gemini.py b/tests/test_gemini.py new file mode 100644 index 0000000..52e2f0b --- /dev/null +++ b/tests/test_gemini.py @@ -0,0 +1,351 @@ +from datafast.llms import GeminiProvider +from dotenv import load_dotenv +import pytest +from tests.test_schemas import ( + SimpleResponse, + LandmarkInfo, + PersonaContent, + QASet, + MCQSet, +) +import time + +load_dotenv() + +@pytest.mark.integration +def test_gemini_provider(): + """Test the Gemini provider with text response.""" + provider = GeminiProvider() + response = provider.generate( + prompt="What is the capital of France? Answer in one word.") + assert "Paris" in response + + +@pytest.mark.slow +@pytest.mark.integration +def test_gemini_rpm_limit_real(): + """Test GeminiProvider RPM limit (15 requests/minute) is enforced with real waiting.""" + import time + prompts_count = 17 + rpm = 15 + provider = GeminiProvider( + model_id="gemini-2.5-flash-lite-preview-06-17", rpm_limit=rpm) + prompt = [f"Test request {i}" for i in range(prompts_count)] + start = time.monotonic() + for prompt in prompt: + provider.generate(prompt=prompt) + elapsed = time.monotonic() - start + # 17 requests, rpm=15, donc on doit attendre au moins ~60s pour les 2 requêtes au-delà de la limite + assert elapsed >= 59, f"Elapsed time too short for RPM limit: {elapsed:.2f}s for {prompts_count} requests with rpm={rpm}" + + +@pytest.mark.integration +def test_gemini_structured_output(): + """Test the Gemini provider with structured output.""" + provider = GeminiProvider() + prompt = """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields.""" + + response = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + +@pytest.mark.integration +def test_gemini_with_messages(): + """Test Gemini provider with messages input instead of prompt.""" + provider = GeminiProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of France? Answer in one word."} + ] + + response = provider.generate(messages=messages) + assert "Paris" in response + +@pytest.mark.integration +def test_gemini_messages_with_structured_output(): + """Test the Gemini provider with messages input and structured output.""" + provider = GeminiProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields."""} + ] + + response = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + +@pytest.mark.integration +def test_gemini_with_all_parameters(): + """Test Gemini provider with all optional parameters specified.""" + provider = GeminiProvider( + model_id="gemini-2.0-flash", + temperature=0.4, + max_completion_tokens=150, + top_p=0.85, + frequency_penalty=0.15 + ) + + prompt = "What is the capital of France? Answer in one word." + response = provider.generate(prompt=prompt) + + assert "Paris" in response + + +@pytest.mark.integration +def test_gemini_structured_landmark_info(): + """Test Gemini with a structured landmark info response.""" + provider = GeminiProvider(temperature=0.1, max_completion_tokens=800) + + prompt = """ + Provide detailed information about the Great Wall of China. + + Return your response as a structured JSON object with the following elements: + - name: The name of the landmark (Great Wall of China) + - location: Where it's located (Northern China) + - description: A brief description of the landmark (2-3 sentences) + - year_built: The year when construction began (as a number) + - attributes: A list of at least 3 attribute objects, each containing: + - name: The name of the attribute (e.g., "length", "material", "dynasties") + - value: The value of the attribute (e.g., "13,171 miles", "stone, brick, wood, etc.", "multiple including Qin, Han, Ming") + - importance: An importance score between 0 and 1 + - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.7) + + Make sure your response is properly structured and can be parsed as valid JSON. + """ + + response = provider.generate(prompt=prompt, response_format=LandmarkInfo) + + # Verify the structure was correctly generated and parsed + assert isinstance(response, LandmarkInfo) + assert "Great Wall" in response.name + assert "China" in response.location + assert len(response.description) > 20 + assert response.year_built is not None + assert len(response.attributes) >= 3 + + # Verify nested objects + for attr in response.attributes: + assert 0 <= attr.importance <= 1 + assert len(attr.name) > 0 + assert len(attr.value) > 0 + + # Verify rating field + assert 0 <= response.visitor_rating <= 5 + + +@pytest.mark.integration +def test_gemini_batch_prompts(): + """Test the Gemini provider with batch prompts.""" + provider = GeminiProvider() + prompt = [ + "What is 2+2? Answer with just the number.", + "What is 3+3? Answer with just the number.", + "What is 4+4? Answer with just the number." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 3 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "4" in responses[0] + assert "6" in responses[1] + assert "8" in responses[2] + + +@pytest.mark.integration +def test_gemini_batch_messages(): + """Test Gemini provider with batch messages.""" + provider = GeminiProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is 5+5? Just the number."} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is 7+3? Just the number."} + ] + ] + + responses = provider.generate(messages=messages) + + assert len(responses) == 2 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "10" in responses[0] + assert "10" in responses[1] + + +@pytest.mark.integration +def test_gemini_batch_structured_output(): + """Test Gemini provider with batch structured output.""" + provider = GeminiProvider() + prompt = [ + """What is 8*3? Provide the answer and show your work. + Format as JSON with 'answer' and 'reasoning' fields.""", + """What is 9*4? Provide the answer and show your work. + Format as JSON with 'answer' and 'reasoning' fields.""" + ] + + responses = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "24" in responses[0].answer + assert "36" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + +@pytest.mark.integration +def test_gemini_batch_messages_with_structured_output(): + """Test Gemini provider with batch messages and structured output.""" + provider = GeminiProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is 12/3? Provide the answer and show your work. + Format as JSON with 'answer' and 'reasoning' fields."""} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is 15/5? Provide the answer and show your work. + Format as JSON with 'answer' and 'reasoning' fields."""} + ] + ] + + responses = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "4" in responses[0].answer + assert "3" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + +@pytest.mark.integration +def test_gemini_batch_with_all_parameters(): + """Test Gemini provider with batch processing and all optional parameters.""" + provider = GeminiProvider( + model_id="gemini-2.0-flash", + temperature=0.1, + max_completion_tokens=50, + top_p=0.9, + frequency_penalty=0.1 + ) + + prompt = [ + "What is the capital of Belgium? Answer in one word.", + "What is the capital of Netherlands? Answer in one word." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 2 + assert "Brussels" in responses[0] + assert "Amsterdam" in responses[1] + + +@pytest.mark.integration +def test_gemini_persona_content_generation(): + """Test generating tweets and bio for a persona using Gemini.""" + provider = GeminiProvider( + temperature=0.7, + max_completion_tokens=1000 + ) + + prompt = """ + Generate social media content for the following persona: + + Persona: A passionate environmental scientist who loves hiking and photography, + advocates for climate action, and enjoys sharing nature facts with humor. + + Create exactly 5 tweets and 1 bio for this persona. + """ + time.sleep(60) + response = provider.generate(prompt=prompt, response_format=PersonaContent) + + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + +@pytest.mark.integration +def test_gemini_qa_generation(): + """Test generating Q&A pairs on machine learning using Gemini.""" + provider = GeminiProvider( + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: supervised learning, neural networks, overfitting, gradient descent, and cross-validation. + + Each question should be clear and the answer should be concise but complete. + """ + + response = provider.generate(prompt=prompt, response_format=QASet) + + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + +@pytest.mark.integration +def test_gemini_mcq_generation(): + """Test generating multiple choice questions using Gemini.""" + provider = GeminiProvider( + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: neural networks, decision trees, and ensemble methods. + """ + + response = provider.generate(prompt=prompt, response_format=MCQSet) + + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) + diff --git a/tests/test_llms.py b/tests/test_llms.py deleted file mode 100644 index 5271296..0000000 --- a/tests/test_llms.py +++ /dev/null @@ -1,645 +0,0 @@ -from datafast.llms import OpenAIProvider, GeminiProvider -from dotenv import load_dotenv -import pytest -from typing import List, Optional -from pydantic import BaseModel, Field, field_validator - -load_dotenv() - - -class SimpleResponse(BaseModel): - """Simple response model for testing structured output.""" - answer: str = Field(description="The answer to the question") - reasoning: str = Field(description="The reasoning behind the answer") - - -class Attribute(BaseModel): - """Attribute of a landmark with value and importance.""" - name: str = Field(description="Name of the attribute") - value: str = Field(description="Value of the attribute") - importance: float = Field(description="Importance score between 0 and 1") - - @field_validator('importance') - @classmethod - def check_importance(cls, v: float) -> float: - """Validate importance is between 0 and 1.""" - if not 0 <= v <= 1: - raise ValueError("Importance must be between 0 and 1") - return v - - -class LandmarkInfo(BaseModel): - """Information about a landmark with attributes.""" - name: str = Field(description="The name of the landmark") - location: str = Field(description="Where the landmark is located") - description: str = Field(description="A brief description of the landmark") - year_built: Optional[int] = Field( - None, description="Year when the landmark was built") - attributes: List[Attribute] = Field( - description="List of attributes about the landmark") - visitor_rating: float = Field( - description="Average visitor rating from 0 to 5") - - @field_validator('visitor_rating') - @classmethod - def check_rating(cls, v: float) -> float: - """Validate rating is between 0 and 5.""" - if not 0 <= v <= 5: - raise ValueError("Rating must be between 0 and 5") - return v - - -@pytest.mark.integration -def test_openai_provider(): - """Test the OpenAI provider with text response.""" - provider = OpenAIProvider() - response = provider.generate( - prompt="What is the capital of France? Answer in one word.") - assert "Paris" in response - - -@pytest.mark.integration -def test_gemini_provider(): - """Test the Gemini provider with text response.""" - provider = GeminiProvider() - response = provider.generate( - prompt="What is the capital of France? Answer in one word.") - assert "Paris" in response - - -@pytest.mark.slow -@pytest.mark.integration -def test_gemini_rpm_limit_real(): - """Test GeminiProvider RPM limit (15 requests/minute) is enforced with real waiting.""" - import time - prompts_count = 17 - rpm = 15 - provider = GeminiProvider( - model_id="gemini-2.5-flash-lite-preview-06-17", rpm_limit=rpm) - prompt = [f"Test request {i}" for i in range(prompts_count)] - start = time.monotonic() - for prompt in prompt: - provider.generate(prompt=prompt) - elapsed = time.monotonic() - start - # 17 requests, rpm=15, donc on doit attendre au moins ~60s pour les 2 requêtes au-delà de la limite - assert elapsed >= 59, f"Elapsed time too short for RPM limit: {elapsed:.2f}s for {prompts_count} requests with rpm={rpm}" - - -@pytest.mark.integration -def test_openai_structured_output(): - """Test the OpenAI provider with structured output.""" - provider = OpenAIProvider() - prompt = """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields.""" - - response = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - # Make sure we have some reasoning text - assert len(response.reasoning) > 10 - - - - - -@pytest.mark.integration -def test_gemini_structured_output(): - """Test the Gemini provider with structured output.""" - provider = GeminiProvider() - prompt = """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields.""" - - response = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 - - - -@pytest.mark.integration -def test_openai_with_messages(): - """Test OpenAI provider with messages input instead of prompt.""" - provider = OpenAIProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of France? Answer in one word."} - ] - - response = provider.generate(messages=messages) - assert "Paris" in response - - -@pytest.mark.integration -def test_gemini_with_messages(): - """Test Gemini provider with messages input instead of prompt.""" - provider = GeminiProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of France? Answer in one word."} - ] - - response = provider.generate(messages=messages) - assert "Paris" in response - - - -@pytest.mark.integration -def test_openai_messages_with_structured_output(): - """Test OpenAI provider with messages input and structured output.""" - provider = OpenAIProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields."""} - ] - - response = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 - - - -@pytest.mark.integration -def test_openai_with_all_parameters(): - """Test OpenAI provider with all optional parameters specified.""" - provider = OpenAIProvider( - model_id="gpt-5-mini-2025-08-07", - temperature=0.2, - max_completion_tokens=100, - top_p=0.9, - frequency_penalty=0.1 - ) - - prompt = "What is the capital of France? Answer in one word." - response = provider.generate(prompt=prompt) - - assert "Paris" in response - - -@pytest.mark.integration -def test_gemini_messages_with_structured_output(): - """Test the Gemini provider with messages input and structured output.""" - provider = GeminiProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields."""} - ] - - response = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 - - -@pytest.mark.integration -def test_gemini_with_all_parameters(): - """Test Gemini provider with all optional parameters specified.""" - provider = GeminiProvider( - model_id="gemini-2.0-flash", - temperature=0.4, - max_completion_tokens=150, - top_p=0.85, - frequency_penalty=0.15 - ) - - prompt = "What is the capital of France? Answer in one word." - response = provider.generate(prompt=prompt) - - assert "Paris" in response - - - -@pytest.mark.integration -def test_openai_structured_landmark_info(): - """Test OpenAI with a structured landmark info response.""" - provider = OpenAIProvider(temperature=0.1, max_completion_tokens=800) - - prompt = """ - Provide detailed information about the Eiffel Tower in Paris. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Eiffel Tower) - - location: Where it's located (Paris, France) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when it was built (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "height", "material", "architect") - - value: The value of the attribute (e.g., "330 meters", "wrought iron", "Gustave Eiffel") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.5) - - Make sure your response is properly structured and can be parsed as valid JSON. - """ - - response = provider.generate(prompt=prompt, response_format=LandmarkInfo) - - # Verify the structure was correctly generated and parsed - assert isinstance(response, LandmarkInfo) - assert "Eiffel Tower" in response.name - assert "Paris" in response.location - assert len(response.description) > 20 - assert response.year_built is not None and response.year_built > 1800 - assert len(response.attributes) >= 3 - - # Verify nested objects - for attr in response.attributes: - assert 0 <= attr.importance <= 1 - assert len(attr.name) > 0 - assert len(attr.value) > 0 - - # Verify rating field - assert 0 <= response.visitor_rating <= 5 - - - - - -@pytest.mark.integration -def test_gemini_structured_landmark_info(): - """Test Gemini with a structured landmark info response.""" - provider = GeminiProvider(temperature=0.1, max_completion_tokens=800) - - prompt = """ - Provide detailed information about the Great Wall of China. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Great Wall of China) - - location: Where it's located (Northern China) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when construction began (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "length", "material", "dynasties") - - value: The value of the attribute (e.g., "13,171 miles", "stone, brick, wood, etc.", "multiple including Qin, Han, Ming") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.7) - - Make sure your response is properly structured and can be parsed as valid JSON. - """ - - response = provider.generate(prompt=prompt, response_format=LandmarkInfo) - - # Verify the structure was correctly generated and parsed - assert isinstance(response, LandmarkInfo) - assert "Great Wall" in response.name - assert "China" in response.location - assert len(response.description) > 20 - assert response.year_built is not None - assert len(response.attributes) >= 3 - - # Verify nested objects - for attr in response.attributes: - assert 0 <= attr.importance <= 1 - assert len(attr.name) > 0 - assert len(attr.value) > 0 - - # Verify rating field - assert 0 <= response.visitor_rating <= 5 - - -"******* Batch Inference Tests *******" -"Similar to previous tests but for batch inputs" - -# Batch tests to add to your existing test file - - -@pytest.mark.integration -def test_openai_batch_prompts(): - """Test the OpenAI provider with batch prompts.""" - provider = OpenAIProvider() - prompt = [ - "What is the capital of France? Answer in one word.", - "What is the capital of Germany? Answer in one word.", - "What is the capital of Italy? Answer in one word." - ] - - responses = provider.generate(prompt=prompt) - - assert len(responses) == 3 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "Paris" in responses[0] - assert "Berlin" in responses[1] - assert "Rome" in responses[2] - - -@pytest.mark.integration -def test_gemini_batch_prompts(): - """Test the Gemini provider with batch prompts.""" - provider = GeminiProvider() - prompt = [ - "What is 2+2? Answer with just the number.", - "What is 3+3? Answer with just the number.", - "What is 4+4? Answer with just the number." - ] - - responses = provider.generate(prompt=prompt) - - assert len(responses) == 3 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "4" in responses[0] - assert "6" in responses[1] - assert "8" in responses[2] - - -@pytest.mark.integration -def test_openai_batch_messages(): - """Test OpenAI provider with batch messages.""" - provider = OpenAIProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of France? One word."} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of Japan? One word."} - ] - ] - - responses = provider.generate(messages=messages) - - assert len(responses) == 2 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "Paris" in responses[0] - assert "Tokyo" in responses[1] - - -@pytest.mark.integration -def test_gemini_batch_messages(): - """Test Gemini provider with batch messages.""" - provider = GeminiProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is 5+5? Just the number."} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is 7+3? Just the number."} - ] - ] - - responses = provider.generate(messages=messages) - - assert len(responses) == 2 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "10" in responses[0] - assert "10" in responses[1] - - -@pytest.mark.integration -def test_openai_batch_structured_output(): - """Test OpenAI provider with batch structured output.""" - provider = OpenAIProvider() - prompt = [ - """What is the capital of France? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields.""", - """What is the capital of Japan? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields.""" - ] - - responses = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "Paris" in responses[0].answer - assert "Tokyo" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 - - -@pytest.mark.integration -def test_gemini_batch_structured_output(): - """Test Gemini provider with batch structured output.""" - provider = GeminiProvider() - prompt = [ - """What is 8*3? Provide the answer and show your work. - Format as JSON with 'answer' and 'reasoning' fields.""", - """What is 9*4? Provide the answer and show your work. - Format as JSON with 'answer' and 'reasoning' fields.""" - ] - - responses = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "24" in responses[0].answer - assert "36" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 - - -@pytest.mark.integration -def test_openai_batch_messages_with_structured_output(): - """Test OpenAI provider with batch messages and structured output.""" - provider = OpenAIProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of Brazil? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields."""} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of Argentina? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields."""} - ] - ] - - responses = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "Brasília" in responses[0].answer or "Brasilia" in responses[0].answer - assert "Buenos Aires" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 - - -@pytest.mark.integration -def test_gemini_batch_messages_with_structured_output(): - """Test Gemini provider with batch messages and structured output.""" - provider = GeminiProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is 12/3? Provide the answer and show your work. - Format as JSON with 'answer' and 'reasoning' fields."""} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is 15/5? Provide the answer and show your work. - Format as JSON with 'answer' and 'reasoning' fields."""} - ] - ] - - responses = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "4" in responses[0].answer - assert "3" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 - - -@pytest.mark.integration -def test_openai_batch_with_all_parameters(): - """Test OpenAI provider with batch processing and all optional parameters.""" - provider = OpenAIProvider( - model_id="gpt-5-mini-2025-08-07", - temperature=0.1, - max_completion_tokens=50, - top_p=0.9, - frequency_penalty=0.1 - ) - - prompt = [ - "What is the capital of Sweden? Answer in one word.", - "What is the capital of Norway? Answer in one word." - ] - - responses = provider.generate(prompt=prompt) - - assert len(responses) == 2 - assert "Stockholm" in responses[0] - assert "Oslo" in responses[1] - - - - - -@pytest.mark.integration -def test_gemini_batch_with_all_parameters(): - """Test Gemini provider with batch processing and all optional parameters.""" - provider = GeminiProvider( - model_id="gemini-2.0-flash", - temperature=0.1, - max_completion_tokens=50, - top_p=0.9, - frequency_penalty=0.1 - ) - - prompt = [ - "What is the capital of Belgium? Answer in one word.", - "What is the capital of Netherlands? Answer in one word." - ] - - responses = provider.generate(prompt=prompt) - - assert len(responses) == 2 - assert "Brussels" in responses[0] - assert "Amsterdam" in responses[1] - - -@pytest.mark.integration -def test_openai_batch_landmark_info(): - """Test OpenAI with batch structured landmark info responses.""" - provider = OpenAIProvider(temperature=0.1, max_completion_tokens=800) - - prompt = [ - """ - Provide detailed information about the Statue of Liberty. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Statue of Liberty) - - location: Where it's located (New York, USA) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when it was completed (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "height", "material", "sculptor") - - value: The value of the attribute (e.g., "93 meters", "copper", "Frédéric Auguste Bartholdi") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.6) - - Make sure your response is properly structured and can be parsed as valid JSON. - """, - """ - Provide detailed information about Big Ben in London. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Big Ben) - - location: Where it's located (London, UK) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when it was completed (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "height", "clock", "architect") - - value: The value of the attribute (e.g., "96 meters", "Great Clock", "Augustus Pugin") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.4) - - Make sure your response is properly structured and can be parsed as valid JSON. - """ - ] - - responses = provider.generate( - prompt=prompt, response_format=LandmarkInfo) - - # Verify we got 2 responses - assert len(responses) == 2 - assert all(isinstance(r, LandmarkInfo) for r in responses) - - # Verify first response (Statue of Liberty) - assert "Statue of Liberty" in responses[0].name - assert "New York" in responses[0].location - assert len(responses[0].description) > 20 - assert responses[0].year_built is not None and responses[0].year_built > 1800 - assert len(responses[0].attributes) >= 3 - - # Verify second response (Big Ben) - assert "Big Ben" in responses[1].name - assert "London" in responses[1].location - assert len(responses[1].description) > 20 - assert responses[1].year_built is not None and responses[1].year_built > 1800 - assert len(responses[1].attributes) >= 3 - - # Verify nested objects for both responses - for response in responses: - for attr in response.attributes: - assert 0 <= attr.importance <= 1 - assert len(attr.name) > 0 - assert len(attr.value) > 0 - assert 0 <= response.visitor_rating <= 5 - - diff --git a/tests/test_ollama.py b/tests/test_ollama.py index cade168..ecf26af 100644 --- a/tests/test_ollama.py +++ b/tests/test_ollama.py @@ -1,115 +1,17 @@ from datafast.llms import OllamaProvider from dotenv import load_dotenv import pytest -from typing import List, Optional -from pydantic import BaseModel, Field, field_validator +from tests.test_schemas import ( + SimpleResponse, + LandmarkInfo, + PersonaContent, + QASet, + MCQSet, +) load_dotenv() -class SimpleResponse(BaseModel): - """Simple response model for testing structured output.""" - answer: str = Field(description="The answer to the question") - reasoning: str = Field(description="The reasoning behind the answer") - - -class Attribute(BaseModel): - """Attribute of a landmark with value and importance.""" - name: str = Field(description="Name of the attribute") - value: str = Field(description="Value of the attribute") - importance: float = Field(description="Importance score between 0 and 1") - - @field_validator('importance') - @classmethod - def check_importance(cls, v: float) -> float: - """Validate importance is between 0 and 1.""" - if not 0 <= v <= 1: - raise ValueError("Importance must be between 0 and 1") - return v - - -class LandmarkInfo(BaseModel): - """Information about a landmark with attributes.""" - name: str = Field(description="The name of the landmark") - location: str = Field(description="Where the landmark is located") - description: str = Field(description="A brief description of the landmark") - year_built: Optional[int] = Field( - None, description="Year when the landmark was built") - attributes: List[Attribute] = Field( - description="List of attributes about the landmark") - visitor_rating: float = Field( - description="Average visitor rating from 0 to 5") - - @field_validator('visitor_rating') - @classmethod - def check_rating(cls, v: float) -> float: - """Validate rating is between 0 and 5.""" - if not 0 <= v <= 5: - raise ValueError("Rating must be between 0 and 5") - return v - - -class PersonaContent(BaseModel): - """Generated content for a persona including tweets and bio.""" - tweets: List[str] = Field(description="List of 5 tweets for the persona") - bio: str = Field(description="Biography for the persona") - - @field_validator('tweets') - @classmethod - def check_tweets_count(cls, v: List[str]) -> List[str]: - """Validate that exactly 5 tweets are provided.""" - if len(v) != 5: - raise ValueError("Must provide exactly 5 tweets") - return v - - -class QAItem(BaseModel): - """Question and answer pair.""" - question: str = Field(description="The question") - answer: str = Field(description="The correct answer") - - -class QASet(BaseModel): - """Set of questions and answers.""" - questions: List[QAItem] = Field(description="List of question-answer pairs") - - @field_validator('questions') - @classmethod - def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: - """Validate that exactly 5 Q&A pairs are provided.""" - if len(v) != 5: - raise ValueError("Must provide exactly 5 question-answer pairs") - return v - - -class MCQQuestion(BaseModel): - """Multiple choice question with one correct and three incorrect answers.""" - question: str = Field(description="The question") - correct_answer: str = Field(description="The correct answer") - incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") - - @field_validator('incorrect_answers') - @classmethod - def check_incorrect_count(cls, v: List[str]) -> List[str]: - """Validate that exactly 3 incorrect answers are provided.""" - if len(v) != 3: - raise ValueError("Must provide exactly 3 incorrect answers") - return v - - -class MCQSet(BaseModel): - """Set of multiple choice questions.""" - questions: List[MCQQuestion] = Field(description="List of MCQ questions") - - @field_validator('questions') - @classmethod - def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: - """Validate that exactly 3 questions are provided.""" - if len(v) != 3: - raise ValueError("Must provide exactly 3 questions") - return v - - @pytest.mark.integration def test_ollama_provider(): """Test the Ollama provider with text response.""" diff --git a/tests/test_openai.py b/tests/test_openai.py new file mode 100644 index 0000000..54c157c --- /dev/null +++ b/tests/test_openai.py @@ -0,0 +1,363 @@ +from datafast.llms import OpenAIProvider +from dotenv import load_dotenv +import pytest +from tests.test_schemas import ( + SimpleResponse, + LandmarkInfo, + PersonaContent, + QASet, + MCQSet, +) + +load_dotenv() + + +@pytest.mark.integration +class TestOpenAIProvider: + """OpenAI provider tests using the default model gpt-5-mini-2025-08-07.""" + + def test_basic_text_response(self): + provider = OpenAIProvider() + response = provider.generate( + prompt="What is the capital of France? Answer in one word.") + assert "Paris" in response + + def test_structured_output(self): + provider = OpenAIProvider() + prompt = """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields.""" + + response = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + def test_with_messages(self): + provider = OpenAIProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of France? Answer in one word."} + ] + + response = provider.generate(messages=messages) + assert "Paris" in response + + def test_messages_with_structured_output(self): + provider = OpenAIProvider() + messages = [ + {"role": "user", "content": """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields."""} + ] + + response = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + def test_with_all_parameters(self): + provider = OpenAIProvider( + model_id="gpt-5-mini-2025-08-07", + max_completion_tokens=1000, + reasoning_effort="low" + ) + + prompt = "What is the capital of France? Answer in one word." + response = provider.generate(prompt=prompt) + + assert "Paris" in response + + def test_structured_landmark_info(self): + provider = OpenAIProvider(max_completion_tokens=1000) + + prompt = """ + Provide detailed information about the Eiffel Tower in Paris. + + Return your response as a structured JSON object with the following elements: + - name: The name of the landmark (Eiffel Tower) + - location: Where it's located (Paris, France) + - description: A brief description of the landmark (2-3 sentences) + - year_built: The year when it was built (as a number) + - attributes: A list of at least 3 attribute objects, each containing: + - name: The name of the attribute (e.g., "height", "material", "architect") + - value: The value of the attribute (e.g., "330 meters", "wrought iron", "Gustave Eiffel") + - importance: An importance score between 0 and 1 + - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.5) + + Make sure your response is properly structured and can be parsed as valid JSON. + """ + + response = provider.generate(prompt=prompt, response_format=LandmarkInfo) + + assert isinstance(response, LandmarkInfo) + assert "Eiffel Tower" in response.name + assert "Paris" in response.location + assert len(response.description) > 20 + assert response.year_built is not None and response.year_built > 1800 + assert len(response.attributes) >= 3 + + for attr in response.attributes: + assert 0 <= attr.importance <= 1 + assert len(attr.name) > 0 + assert len(attr.value) > 0 + + assert 0 <= response.visitor_rating <= 5 + + def test_batch_prompts(self): + provider = OpenAIProvider() + prompt = [ + "What is the capital of France? Answer in one word.", + "What is the capital of Germany? Answer in one word.", + "What is the capital of Italy? Answer in one word." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 3 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "Paris" in responses[0] + assert "Berlin" in responses[1] + assert "Rome" in responses[2] + + def test_batch_messages(self): + provider = OpenAIProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of France? One word."} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of Japan? One word."} + ] + ] + + responses = provider.generate(messages=messages) + + assert len(responses) == 2 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "Paris" in responses[0] + assert "Tokyo" in responses[1] + + def test_batch_structured_output(self): + provider = OpenAIProvider() + prompt = [ + """What is the capital of France? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields.""", + """What is the capital of Japan? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields.""" + ] + + responses = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "Paris" in responses[0].answer + assert "Tokyo" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + def test_batch_messages_with_structured_output(self): + provider = OpenAIProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of Brazil? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields."""} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of Argentina? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields."""} + ] + ] + + responses = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "Brasília" in responses[0].answer or "Brasilia" in responses[0].answer + assert "Buenos Aires" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + def test_batch_with_all_parameters(self): + provider = OpenAIProvider( + model_id="gpt-5-mini-2025-08-07", + max_completion_tokens=1000, + reasoning_effort="low" + ) + + prompt = [ + "What is the capital of Sweden? Answer in one word.", + "What is the capital of Norway? Answer in one word." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 2 + assert "Stockholm" in responses[0] + assert "Oslo" in responses[1] + + def test_batch_landmark_info(self): + provider = OpenAIProvider(max_completion_tokens=1000) + + prompt = [ + """ + Provide detailed information about the Statue of Liberty. + + Return your response as a structured JSON object with the following elements: + - name: The name of the landmark (Statue of Liberty) + - location: Where it's located (New York, USA) + - description: A brief description of the landmark (2-3 sentences) + - year_built: The year when it was completed (as a number) + - attributes: A list of at least 3 attribute objects, each containing: + - name: The name of the attribute (e.g., "height", "material", "sculptor") + - value: The value of the attribute (e.g., "93 meters", "copper", "Frédéric Auguste Bartholdi") + - importance: An importance score between 0 and 1 + - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.6) + + Make sure your response is properly structured and can be parsed as valid JSON. + """, + """ + Provide detailed information about Big Ben in London. + + Return your response as a structured JSON object with the following elements: + - name: The name of the landmark (Big Ben) + - location: Where it's located (London, UK) + - description: A brief description of the landmark (2-3 sentences) + - year_built: The year when it was completed (as a number) + - attributes: A list of at least 3 attribute objects, each containing: + - name: The name of the attribute (e.g., "height", "clock", "architect") + - value: The value of the attribute (e.g., "96 meters", "Great Clock", "Augustus Pugin") + - importance: An importance score between 0 and 1 + - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.4) + + Make sure your response is properly structured and can be parsed as valid JSON. + """ + ] + + responses = provider.generate( + prompt=prompt, + response_format=LandmarkInfo + ) + + assert len(responses) == 2 + assert all(isinstance(r, LandmarkInfo) for r in responses) + + assert "Statue of Liberty" in responses[0].name + assert "New York" in responses[0].location + assert len(responses[0].description) > 20 + assert responses[0].year_built is not None and responses[0].year_built > 1800 + assert len(responses[0].attributes) >= 3 + + assert "Big Ben" in responses[1].name + assert "London" in responses[1].location + assert len(responses[1].description) > 20 + assert responses[1].year_built is not None and responses[1].year_built > 1800 + assert len(responses[1].attributes) >= 3 + + for response in responses: + for attr in response.attributes: + assert 0 <= attr.importance <= 1 + assert len(attr.name) > 0 + assert len(attr.value) > 0 + assert 0 <= response.visitor_rating <= 5 + + def test_batch_validation_errors(self): + provider = OpenAIProvider() + + with pytest.raises(ValueError, match="Either prompts or messages must be provided"): + provider.generate() + + with pytest.raises(ValueError, match="Provide either prompts or messages, not both"): + provider.generate( + prompt=["test"], + messages=[[{"role": "user", "content": "test"}]] + ) + + def test_persona_content_generation(self): + """Test generating tweets and bio for a persona using OpenAI.""" + provider = OpenAIProvider(max_completion_tokens=1000) + + prompt = """ + Generate social media content for the following persona: + + Persona: A passionate environmental scientist who loves hiking and photography, + advocates for climate action, and enjoys sharing nature facts with humor. + + Create exactly 5 tweets and 1 bio for this persona. + """ + + response = provider.generate(prompt=prompt, response_format=PersonaContent) + + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + """Test generating Q&A pairs on machine learning using OpenAI.""" + provider = OpenAIProvider(max_completion_tokens=1500) + + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: supervised learning, neural networks, overfitting, gradient descent, and cross-validation. + + Each question should be clear and the answer should be concise but complete. + """ + + response = provider.generate(prompt=prompt, response_format=QASet) + + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + """Test generating multiple choice questions using OpenAI.""" + provider = OpenAIProvider(max_completion_tokens=1500) + + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: neural networks, decision trees, and ensemble methods. + """ + + response = provider.generate(prompt=prompt, response_format=MCQSet) + + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) diff --git a/tests/test_openrouter.py b/tests/test_openrouter.py index a537b76..84127c8 100644 --- a/tests/test_openrouter.py +++ b/tests/test_openrouter.py @@ -1,109 +1,17 @@ from datafast.llms import OpenRouterProvider from dotenv import load_dotenv import pytest -from typing import List, Optional -from pydantic import BaseModel, Field, field_validator +from tests.test_schemas import ( + SimpleResponse, + LandmarkInfo, + PersonaContent, + QASet, + MCQSet, +) load_dotenv() -class SimpleResponse(BaseModel): - """Simple response model for testing structured output.""" - answer: str = Field(description="The answer to the question") - reasoning: str = Field(description="The reasoning behind the answer") - - -class Attribute(BaseModel): - """Attribute of a landmark with value and importance.""" - name: str = Field(description="Name of the attribute") - value: str = Field(description="Value of the attribute") - importance: float = Field(description="Importance score between 0 and 1") - - @field_validator('importance') - def check_importance(cls, v: float) -> float: - """Validate importance is between 0 and 1.""" - if not 0 <= v <= 1: - raise ValueError("Importance must be between 0 and 1") - return v - - -class LandmarkInfo(BaseModel): - """Information about a landmark with attributes.""" - name: str = Field(description="The name of the landmark") - location: str = Field(description="Where the landmark is located") - description: str = Field(description="A brief description of the landmark") - year_built: Optional[int] = Field( - None, description="Year when the landmark was built") - attributes: List[Attribute] = Field( - description="List of attributes about the landmark") - visitor_rating: float = Field( - description="Average visitor rating from 0 to 5") - - @field_validator('visitor_rating') - def check_rating(cls, v: float) -> float: - """Validate rating is between 0 and 5.""" - if not 0 <= v <= 5: - raise ValueError("Rating must be between 0 and 5") - return v - - -class PersonaContent(BaseModel): - """Generated content for a persona including tweets and bio.""" - tweets: List[str] = Field(description="List of 5 tweets for the persona") - bio: str = Field(description="Biography for the persona") - - @field_validator('tweets') - def check_tweets_count(cls, v: List[str]) -> List[str]: - """Validate that exactly 5 tweets are provided.""" - if len(v) != 5: - raise ValueError("Must provide exactly 5 tweets") - return v - - -class QAItem(BaseModel): - """Question and answer pair.""" - question: str = Field(description="The question") - answer: str = Field(description="The correct answer") - - -class QASet(BaseModel): - """Set of questions and answers.""" - questions: List[QAItem] = Field(description="List of question-answer pairs") - - @field_validator('questions') - def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: - """Validate that exactly 5 Q&A pairs are provided.""" - if len(v) != 5: - raise ValueError("Must provide exactly 5 question-answer pairs") - return v - - -class MCQQuestion(BaseModel): - """Multiple choice question with one correct and three incorrect answers.""" - question: str = Field(description="The question") - correct_answer: str = Field(description="The correct answer") - incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") - - @field_validator('incorrect_answers') - def check_incorrect_count(cls, v: List[str]) -> List[str]: - """Validate that exactly 3 incorrect answers are provided.""" - if len(v) != 3: - raise ValueError("Must provide exactly 3 incorrect answers") - return v - - -class MCQSet(BaseModel): - """Set of multiple choice questions.""" - questions: List[MCQQuestion] = Field(description="List of MCQ questions") - - @field_validator('questions') - def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: - """Validate that exactly 3 questions are provided.""" - if len(v) != 3: - raise ValueError("Must provide exactly 3 questions") - return v - - @pytest.mark.integration class TestOpenRouterProvider: """Test suite for OpenRouter provider with various input types and configurations.""" diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..fc96ee5 --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,106 @@ +"""Shared Pydantic schemas for LLM provider tests.""" +from typing import List +from pydantic import BaseModel, Field, field_validator + + +class SimpleResponse(BaseModel): + """Simple response model for testing structured output.""" + answer: str = Field(description="The answer to the question") + reasoning: str = Field(description="The reasoning behind the answer") + + +class Attribute(BaseModel): + """Attribute of a landmark with value and importance.""" + name: str = Field(description="Name of the attribute") + value: str = Field(description="Value of the attribute") + importance: float = Field(description="Importance score between 0 and 1") + + @field_validator('importance') + @classmethod + def check_importance(cls, v: float) -> float: + """Validate importance is between 0 and 1.""" + if not 0 <= v <= 1: + raise ValueError("Importance must be between 0 and 1") + return v + + +class LandmarkInfo(BaseModel): + """Information about a landmark with attributes.""" + name: str = Field(description="The name of the landmark") + location: str = Field(description="Where the landmark is located") + description: str = Field(description="A brief description of the landmark") + year_built: int | None = Field( + None, description="Year when the landmark was built") + attributes: List[Attribute] = Field( + description="List of attributes about the landmark") + visitor_rating: float = Field( + description="Average visitor rating from 0 to 5") + + @field_validator('visitor_rating') + @classmethod + def check_rating(cls, v: float) -> float: + """Validate rating is between 0 and 5.""" + if not 0 <= v <= 5: + raise ValueError("Rating must be between 0 and 5") + return v + + +class PersonaContent(BaseModel): + """Generated content for a persona including tweets and bio.""" + tweets: List[str] = Field(description="List of 5 tweets for the persona") + bio: str = Field(description="Biography for the persona") + + @field_validator('tweets') + @classmethod + def check_tweets_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 5 tweets are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 tweets") + return v + + +class QAItem(BaseModel): + """Question and answer pair.""" + question: str = Field(description="The question") + answer: str = Field(description="The correct answer") + + +class QASet(BaseModel): + """Set of questions and answers.""" + questions: List[QAItem] = Field(description="List of question-answer pairs") + + @field_validator('questions') + @classmethod + def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: + """Validate that exactly 5 Q&A pairs are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 question-answer pairs") + return v + + +class MCQQuestion(BaseModel): + """Multiple choice question with one correct and three incorrect answers.""" + question: str = Field(description="The question") + correct_answer: str = Field(description="The correct answer") + incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") + + @field_validator('incorrect_answers') + @classmethod + def check_incorrect_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 3 incorrect answers are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 incorrect answers") + return v + + +class MCQSet(BaseModel): + """Set of multiple choice questions.""" + questions: List[MCQQuestion] = Field(description="List of MCQ questions") + + @field_validator('questions') + @classmethod + def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: + """Validate that exactly 3 questions are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 questions") + return v