diff --git a/datafast/llms.py b/datafast/llms.py index 8949686..dc4e46b 100644 --- a/datafast/llms.py +++ b/datafast/llms.py @@ -16,7 +16,6 @@ # LiteLLM import litellm from litellm.utils import ModelResponse -from litellm import batch_completion # Internal imports from .llm_utils import get_messages @@ -115,6 +114,37 @@ def _respect_rate_limit(self) -> None: print("Waiting for rate limit...") time.sleep(sleep_time) + @staticmethod + def _strip_code_fences(content: str) -> str: + """Strip markdown code fences from content if present. + + Args: + content: The content string that may contain code fences + + Returns: + Content with code fences removed + """ + if not content: + return content + + content = content.strip() + + # Check for code fences with optional language identifier + if content.startswith('```'): + # Find the end of the first line (language identifier) + first_newline = content.find('\n') + if first_newline != -1: + content = content[first_newline + 1:] + else: + # No newline after opening fence, remove just the fence + content = content[3:] + + # Remove closing fence + if content.endswith('```'): + content = content[:-3] + + return content.strip() + def generate( self, prompt: str | list[str] | None = None, @@ -176,13 +206,29 @@ def generate( raise ValueError("messages cannot be empty") try: + # Append JSON formatting instructions if response_format is provided + json_instructions = ( + "\nReturn only valid JSON. To do so, don't include ```json ``` markdown " + "or code fences around the JSON. Use double quotes for all keys and values. " + "Escape internal quotes and newlines (use \\n). Do not include trailing commas." + ) + # Convert batch prompts to messages if needed batch_to_send = [] if batch_prompts is not None: for one_prompt in batch_prompts: - batch_to_send.append(get_messages(one_prompt)) + # Append JSON instructions to prompt if response_format is provided + modified_prompt = one_prompt + json_instructions if response_format is not None else one_prompt + batch_to_send.append(get_messages(modified_prompt)) else: batch_to_send = batch_messages + # Append JSON instructions to the last user message if response_format is provided + if response_format is not None: + for message_list in batch_to_send: + for msg in reversed(message_list): + if msg.get("role") == "user": + msg["content"] += json_instructions + break # Enforce rate limit per batch self._respect_rate_limit() @@ -211,11 +257,15 @@ def generate( results = [] for one_response in response: content = one_response.choices[0].message.content + if response_format is not None: + # Strip code fences before validation + content = self._strip_code_fences(content) results.append( response_format.model_validate_json(content)) else: - results.append(content) + # Strip leading/trailing whitespace for text responses + results.append(content.strip() if content else content) # Return single result for backward compatibility if single_input and len(results) == 1: @@ -286,8 +336,8 @@ def __init__( api_key: str | None = None, temperature: float | None = None, max_completion_tokens: int | None = None, - top_p: float | None = None, - # frequency_penalty: float | None = None, # Not supported by anthropic + # top_p: float | None = None, # Not properly supported by anthropic models 4.5 + # frequency_penalty: float | None = None, # Not supported by anthropic models 4.5 ): """Initialize the Anthropic provider. @@ -303,7 +353,6 @@ def __init__( api_key=api_key, temperature=temperature, max_completion_tokens=max_completion_tokens, - top_p=top_p, ) diff --git a/pytest.ini b/pytest.ini index 366dd81..798f789 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,7 @@ [pytest] markers = integration: marks tests that require API connectivity (deselect with '-m "not integration"') + slow: marks tests that are slow to run # Other pytest configurations testpaths = tests @@ -16,6 +17,8 @@ log_cli_level = INFO filterwarnings = # Ignore Pydantic deprecation warnings ignore::DeprecationWarning:pydantic.*: + # Ignore Pydantic serializer warnings during tests + ignore::UserWarning:pydantic.main # Ignore LiteLLM deprecation warnings ignore::DeprecationWarning:litellm.*: # Ignore HTTPX deprecation warnings diff --git a/tests/test_anthropic.py b/tests/test_anthropic.py new file mode 100644 index 0000000..ff219c0 --- /dev/null +++ b/tests/test_anthropic.py @@ -0,0 +1,479 @@ +from datafast.llms import AnthropicProvider +from dotenv import load_dotenv +import pytest +from typing import List, Optional +from pydantic import BaseModel, Field, field_validator + +load_dotenv() + + +class SimpleResponse(BaseModel): + """Simple response model for testing structured output.""" + answer: str = Field(description="The answer to the question") + reasoning: str = Field(description="The reasoning behind the answer") + + +class Attribute(BaseModel): + """Attribute of a landmark with value and importance.""" + name: str = Field(description="Name of the attribute") + value: str = Field(description="Value of the attribute") + importance: float = Field(description="Importance score between 0 and 1") + + @field_validator('importance') + @classmethod + def check_importance(cls, v: float) -> float: + """Validate importance is between 0 and 1.""" + if not 0 <= v <= 1: + raise ValueError("Importance must be between 0 and 1") + return v + + +class LandmarkInfo(BaseModel): + """Information about a landmark with attributes.""" + name: str = Field(description="The name of the landmark") + location: str = Field(description="Where the landmark is located") + description: str = Field(description="A brief description of the landmark") + year_built: Optional[int] = Field( + None, description="Year when the landmark was built") + attributes: List[Attribute] = Field( + description="List of attributes about the landmark") + visitor_rating: float = Field( + description="Average visitor rating from 0 to 5") + + @field_validator('visitor_rating') + @classmethod + def check_rating(cls, v: float) -> float: + """Validate rating is between 0 and 5.""" + if not 0 <= v <= 5: + raise ValueError("Rating must be between 0 and 5") + return v + + +class PersonaContent(BaseModel): + """Generated content for a persona including tweets and bio.""" + tweets: List[str] = Field(description="List of 5 tweets for the persona") + bio: str = Field(description="Biography for the persona") + + @field_validator('tweets') + @classmethod + def check_tweets_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 5 tweets are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 tweets") + return v + + +class QAItem(BaseModel): + """Question and answer pair.""" + question: str = Field(description="The question") + answer: str = Field(description="The correct answer") + + +class QASet(BaseModel): + """Set of questions and answers.""" + questions: List[QAItem] = Field(description="List of question-answer pairs") + + @field_validator('questions') + @classmethod + def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: + """Validate that exactly 5 Q&A pairs are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 question-answer pairs") + return v + + +class MCQQuestion(BaseModel): + """Multiple choice question with one correct and three incorrect answers.""" + question: str = Field(description="The question") + correct_answer: str = Field(description="The correct answer") + incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") + + @field_validator('incorrect_answers') + @classmethod + def check_incorrect_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 3 incorrect answers are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 incorrect answers") + return v + + +class MCQSet(BaseModel): + """Set of multiple choice questions.""" + questions: List[MCQQuestion] = Field(description="List of MCQ questions") + + @field_validator('questions') + @classmethod + def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: + """Validate that exactly 3 questions are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 questions") + return v + + +@pytest.mark.integration +class TestAnthropicSonnet45: + """Anthropic tests for claude-sonnet-4-5-20250929.""" + + def test_persona_content_generation(self): + provider = AnthropicProvider( + model_id="claude-sonnet-4-5-20250929", + temperature=0.5, + max_completion_tokens=2000, + ) + prompt = """ + Generate social media content for the following persona: + + Persona: A tech entrepreneur who is passionate about AI ethics, loves reading sci-fi novels, + practices meditation, and frequently shares insights about startup culture. + + Create exactly 5 tweets and 1 bio for this persona. + """ + response = provider.generate(prompt=prompt, response_format=PersonaContent) + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + provider = AnthropicProvider( + model_id="claude-sonnet-4-5-20250929", + temperature=0.5, + max_completion_tokens=1500, + ) + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: reinforcement learning, convolutional neural networks, regularization, + backpropagation, and feature engineering. + + Each question should be clear and the answer should be concise but complete. + """ + response = provider.generate(prompt=prompt, response_format=QASet) + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + provider = AnthropicProvider( + model_id="claude-sonnet-4-5-20250929", + temperature=0.5, + max_completion_tokens=1500, + ) + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: recurrent neural networks, k-means clustering, and support vector machines. + """ + response = provider.generate(prompt=prompt, response_format=MCQSet) + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) + + +@pytest.mark.integration +class TestAnthropicHaiku45: + """Anthropic tests for claude-haiku-4-5-20251001.""" + + def test_persona_content_generation(self): + provider = AnthropicProvider( + model_id="claude-haiku-4-5-20251001", + temperature=0.5, + max_completion_tokens=2000, + ) + prompt = """ + Generate social media content for the following persona: + + Persona: A tech entrepreneur who is passionate about AI ethics, loves reading sci-fi novels, + practices meditation, and frequently shares insights about startup culture. + + Create exactly 5 tweets and 1 bio for this persona. + """ + response = provider.generate(prompt=prompt, response_format=PersonaContent) + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + provider = AnthropicProvider( + model_id="claude-haiku-4-5-20251001", + temperature=0.5, + max_completion_tokens=1500, + ) + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: reinforcement learning, convolutional neural networks, regularization, + backpropagation, and feature engineering. + + Each question should be clear and the answer should be concise but complete. + """ + response = provider.generate(prompt=prompt, response_format=QASet) + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + provider = AnthropicProvider( + model_id="claude-haiku-4-5-20251001", + temperature=0.5, + max_completion_tokens=1500, + ) + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: transformers, random forests, and principal component analysis. + """ + response = provider.generate(prompt=prompt, response_format=MCQSet) + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) + + +@pytest.mark.integration +class TestAnthropicProvider: + """General Anthropic provider tests mirroring OpenRouter structure.""" + + def test_basic_text_response(self): + provider = AnthropicProvider() + response = provider.generate( + prompt="What is the capital of France? Answer in one word.") + assert "Paris" in response + + def test_structured_output(self): + provider = AnthropicProvider() + prompt = """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields.""" + + response = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + def test_with_messages(self): + provider = AnthropicProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of France? Answer in one word."} + ] + + response = provider.generate(messages=messages) + assert "Paris" in response + + def test_messages_with_structured_output(self): + provider = AnthropicProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields."""} + ] + + response = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + def test_with_all_parameters(self): + provider = AnthropicProvider( + model_id="claude-haiku-4-5-20251001", + temperature=0.3, + max_completion_tokens=100, + ) + + prompt = "What is the capital of France? Answer in one word" + response = provider.generate(prompt=prompt) + + assert "Paris" in response + + def test_structured_landmark_info(self): + provider = AnthropicProvider(temperature=0.1, max_completion_tokens=800) + + prompt = """ + Provide detailed information about the Golden Gate Bridge in San Francisco. + + Return your response as a structured JSON object with the following elements: + - name: The name of the landmark (Golden Gate Bridge) + - location: Where it's located (San Francisco, USA) + - description: A brief description of the landmark (2-3 sentences) + - year_built: The year when it was built (as a number) + - attributes: A list of at least 3 attribute objects, each containing: + - name: The name of the attribute (e.g., "length", "color", "architect") + - value: The value of the attribute (e.g., "1.7 miles", "International Orange", "Joseph Strauss") + - importance: An importance score between 0 and 1 + - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.8) + + Make sure your response is properly structured and can be parsed as valid JSON. + """ + + response = provider.generate(prompt=prompt, response_format=LandmarkInfo) + + # Verify the structure was correctly generated and parsed + assert isinstance(response, LandmarkInfo) + assert "Golden Gate Bridge" in response.name + assert "Francisco" in response.location + assert len(response.description) > 20 + assert response.year_built is not None and response.year_built > 1900 + assert len(response.attributes) >= 3 + + # Verify nested objects + for attr in response.attributes: + assert 0 <= attr.importance <= 1 + assert len(attr.name) > 0 + assert len(attr.value) > 0 + + # Verify rating field + assert 0 <= response.visitor_rating <= 5 + + def test_batch_prompts(self): + provider = AnthropicProvider() + prompt = [ + "What is the capital of France? Answer in one word.", + "What is the capital of Spain? Answer in one word.", + "What is the capital of Portugal? Answer in one word." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 3 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "Paris" in responses[0] + assert "Madrid" in responses[1] + assert "Lisbon" in responses[2] + + def test_batch_messages(self): + provider = AnthropicProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of Canada? One word."} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of Australia? One word."} + ] + ] + + responses = provider.generate(messages=messages) + + assert len(responses) == 2 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "Ottawa" in responses[0] + assert "Canberra" in responses[1] + + def test_batch_structured_output(self): + provider = AnthropicProvider() + prompt = [ + """What is the capital of Germany? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields.""", + """What is the capital of Italy? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields.""" + ] + + responses = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "Berlin" in responses[0].answer + assert "Rome" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + def test_batch_messages_with_structured_output(self): + provider = AnthropicProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of Egypt? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields."""} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of Morocco? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields."""} + ] + ] + + responses = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "Cairo" in responses[0].answer + assert "Rabat" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + def test_batch_with_all_parameters(self): + provider = AnthropicProvider( + model_id="claude-haiku-4-5-20251001", + temperature=0.1, + max_completion_tokens=50 + ) + + prompt = [ + "What is the capital of Denmark? Answer in one word.", + "What is the capital of Finland? Answer in one word." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 2 + assert "Copenhagen" in responses[0] + assert "Helsinki" in responses[1] + + def test_batch_validation_errors(self): + provider = AnthropicProvider() + + # Test no inputs provided + with pytest.raises(ValueError, match="Either prompts or messages must be provided"): + provider.generate() + + # Test both inputs provided + with pytest.raises(ValueError, match="Provide either prompts or messages, not both"): + provider.generate( + prompt=["test"], + messages=[[{"role": "user", "content": "test"}]] + ) \ No newline at end of file diff --git a/tests/test_llms.py b/tests/test_llms.py index 93e01ae..5271296 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -1,10 +1,10 @@ -from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider, OllamaProvider, OpenRouterProvider +from datafast.llms import OpenAIProvider, GeminiProvider from dotenv import load_dotenv import pytest from typing import List, Optional from pydantic import BaseModel, Field, field_validator -load_dotenv('secrets.env') +load_dotenv() class SimpleResponse(BaseModel): @@ -58,15 +58,6 @@ def test_openai_provider(): assert "Paris" in response -@pytest.mark.integration -def test_anthropic_provider(): - """Test the Anthropic provider with text response.""" - provider = AnthropicProvider() - response = provider.generate( - prompt="What is the capital of France? Answer in one word.") - assert "Paris" in response - - @pytest.mark.integration def test_gemini_provider(): """Test the Gemini provider with text response.""" @@ -75,12 +66,6 @@ def test_gemini_provider(): prompt="What is the capital of France? Answer in one word.") assert "Paris" in response -@pytest.mark.integration -def test_openrouter_provider(): - """Test the OpenRouter provider with text response.""" - provider = OpenRouterProvider() - response = provider.generate(prompt="What is the capital of France? Answer in one word.") - assert "Paris" in response @pytest.mark.slow @pytest.mark.integration @@ -119,22 +104,7 @@ def test_openai_structured_output(): assert len(response.reasoning) > 10 -@pytest.mark.integration -def test_anthropic_structured_output(): - """Test the Anthropic provider with structured output.""" - provider = AnthropicProvider() - prompt = """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields.""" - - response = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 @pytest.mark.integration @@ -154,22 +124,6 @@ def test_gemini_structured_output(): assert "Paris" in response.answer assert len(response.reasoning) > 10 -@pytest.mark.integration -def test_openrouter_structured_output(): - """Test the OpenRouter provider with structured output.""" - provider = OpenRouterProvider() - prompt = """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields.""" - - response = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 @pytest.mark.integration @@ -185,19 +139,6 @@ def test_openai_with_messages(): assert "Paris" in response -@pytest.mark.integration -def test_anthropic_with_messages(): - """Test Anthropic provider with messages input instead of prompt.""" - provider = AnthropicProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of France? Answer in one word."} - ] - - response = provider.generate(messages=messages) - assert "Paris" in response - - @pytest.mark.integration def test_gemini_with_messages(): """Test Gemini provider with messages input instead of prompt.""" @@ -210,17 +151,6 @@ def test_gemini_with_messages(): response = provider.generate(messages=messages) assert "Paris" in response -@pytest.mark.integration -def test_openrouter_with_messages(): - """Test OpenRouter provider with messages input instead of prompt.""" - provider = OpenRouterProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of France? Answer in one word."} - ] - - response = provider.generate(messages=messages) - assert "Paris" in response @pytest.mark.integration @@ -243,25 +173,6 @@ def test_openai_messages_with_structured_output(): assert "Paris" in response.answer assert len(response.reasoning) > 10 -@pytest.mark.integration -def test_openrouter_messages_with_structured_output(): - """Test OpenRouter provider with messages input and structured output.""" - provider = OpenRouterProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields."""} - ] - - response = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 @pytest.mark.integration @@ -281,27 +192,6 @@ def test_openai_with_all_parameters(): assert "Paris" in response -@pytest.mark.integration -def test_anthropic_messages_with_structured_output(): - """Test the Anthropic provider with messages input and structured output.""" - provider = AnthropicProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields."""} - ] - - response = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 - - @pytest.mark.integration def test_gemini_messages_with_structured_output(): """Test the Gemini provider with messages input and structured output.""" @@ -323,22 +213,6 @@ def test_gemini_messages_with_structured_output(): assert len(response.reasoning) > 10 -@pytest.mark.integration -def test_anthropic_with_all_parameters(): - """Test Anthropic provider with all optional parameters specified.""" - provider = AnthropicProvider( - model_id="claude-haiku-4-5-20251001", - temperature=0.3, - max_completion_tokens=200, - top_p=0.95, - ) - - prompt = "What is the capital of France? Answer in one word." - response = provider.generate(prompt=prompt) - - assert "Paris" in response - - @pytest.mark.integration def test_gemini_with_all_parameters(): """Test Gemini provider with all optional parameters specified.""" @@ -355,21 +229,6 @@ def test_gemini_with_all_parameters(): assert "Paris" in response -@pytest.mark.integration -def test_openrouter_with_all_parameters(): - """Test OpenRouter provider with all optional parameters specified.""" - provider = OpenRouterProvider( - model_id="openai/gpt-3.5-turbo", - temperature=0.4, - max_completion_tokens=150, - top_p=0.85, - frequency_penalty=0.15 - ) - - prompt = "What is the capital of France? Answer in one word." - response = provider.generate(prompt=prompt) - - assert "Paris" in response @pytest.mark.integration @@ -414,46 +273,7 @@ def test_openai_structured_landmark_info(): assert 0 <= response.visitor_rating <= 5 -@pytest.mark.integration -def test_anthropic_structured_landmark_info(): - """Test Anthropic with a structured landmark info response.""" - provider = AnthropicProvider(temperature=0.1, max_completion_tokens=800) - - prompt = """ - Provide detailed information about the Golden Gate Bridge in San Francisco. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Golden Gate Bridge) - - location: Where it's located (San Francisco, USA) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when it was built (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "length", "color", "architect") - - value: The value of the attribute (e.g., "1.7 miles", "International Orange", "Joseph Strauss") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.8) - - Make sure your response is properly structured and can be parsed as valid JSON. - """ - - response = provider.generate(prompt=prompt, response_format=LandmarkInfo) - - # Verify the structure was correctly generated and parsed - assert isinstance(response, LandmarkInfo) - assert "Golden Gate Bridge" in response.name - assert "Francisco" in response.location - assert len(response.description) > 20 - assert response.year_built is not None and response.year_built > 1900 - assert len(response.attributes) >= 3 - - # Verify nested objects - for attr in response.attributes: - assert 0 <= attr.importance <= 1 - assert len(attr.name) > 0 - assert len(attr.value) > 0 - # Verify rating field - assert 0 <= response.visitor_rating <= 5 @pytest.mark.integration @@ -497,172 +317,6 @@ def test_gemini_structured_landmark_info(): # Verify rating field assert 0 <= response.visitor_rating <= 5 -# import litellm -# litellm._turn_on_debug() # turn on debug to see the request - -@pytest.mark.integration -def test_openrouter_structured_landmark_info(): - """Test OpenRouter with a structured landmark info response.""" - provider = OpenRouterProvider(temperature=0.1, max_completion_tokens=800) - - prompt = """ - Provide detailed information about the Great Wall of China. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Great Wall of China) - - location: Where it's located (Northern China) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when construction began (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "length", "material", "dynasties") - - value: The value of the attribute (e.g., "13,171 miles", "stone, brick, wood, etc.", "multiple including Qin, Han, Ming") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.7) - - Make sure your response is properly structured and can be parsed as valid JSON. - """ - - response = provider.generate(prompt=prompt, response_format=LandmarkInfo) - - # Verify the structure was correctly generated and parsed - assert isinstance(response, LandmarkInfo) - assert "Great Wall" in response.name - assert "China" in response.location - assert len(response.description) > 20 - assert response.year_built is not None - assert len(response.attributes) >= 3 - - # Verify nested objects - for attr in response.attributes: - assert 0 <= attr.importance <= 1 - assert len(attr.name) > 0 - assert len(attr.value) > 0 - - # Verify rating field - assert 0 <= response.visitor_rating <= 5 - - - -@pytest.mark.integration -def test_ollama_provider(): - """Test the Ollama provider with text response.""" - provider = OllamaProvider(model_id="gemma3:4b") - response = provider.generate( - prompt="What is the capital of France? Answer in one word.") - assert "Paris" in response - - -@pytest.mark.integration -def test_ollama_structured_output(): - """Test the Ollama provider with structured output.""" - provider = OllamaProvider(model_id="gemma3:4b") - prompt = """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields.""" - - response = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 - - -@pytest.mark.integration -def test_ollama_with_messages(): - """Test Ollama provider with messages input instead of prompt.""" - provider = OllamaProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of France? Answer in one word."} - ] - - response = provider.generate(messages=messages) - assert "Paris" in response - - -@pytest.mark.integration -def test_ollama_messages_with_structured_output(): - """Test the Ollama provider with messages input and structured output.""" - provider = OllamaProvider() - messages = [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of France? - Provide a short answer and a brief explanation of why Paris is the capital. - Format your response as JSON with 'answer' and 'reasoning' fields."""} - ] - - response = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert isinstance(response, SimpleResponse) - assert "Paris" in response.answer - assert len(response.reasoning) > 10 - - -@pytest.mark.integration -def test_ollama_with_all_parameters(): - """Test Ollama provider with all optional parameters specified.""" - provider = OllamaProvider( - model_id="gemma3:4b", - temperature=0.4, - max_completion_tokens=150, - top_p=0.85, - frequency_penalty=0.15, - api_base="http://localhost:11434" - ) - - prompt = "What is the capital of France? Answer in one word." - response = provider.generate(prompt=prompt) - - assert "Paris" in response - - -@pytest.mark.integration -def test_ollama_structured_landmark_info(): - """Test Ollama with a structured landmark info response.""" - provider = OllamaProvider(temperature=0.1, max_completion_tokens=800) - - prompt = """ - Provide detailed information about the Sydney Opera House. - - Return your response as a structured JSON object with the following elements: - - name: The name of the landmark (Sydney Opera House) - - location: Where it's located (Sydney, Australia) - - description: A brief description of the landmark (2-3 sentences) - - year_built: The year when it was completed (as a number) - - attributes: A list of at least 3 attribute objects, each containing: - - name: The name of the attribute (e.g., "architect", "style", "height") - - value: The value of the attribute (e.g., "Jørn Utzon", "Expressionist", "65 meters") - - importance: An importance score between 0 and 1 - - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.9) - - Make sure your response is properly structured and can be parsed as valid JSON. - """ - - response = provider.generate(prompt=prompt, response_format=LandmarkInfo) - - # Verify the structure was correctly generated and parsed - assert isinstance(response, LandmarkInfo) - assert "Opera House" in response.name - assert "Sydney" in response.location - assert len(response.description) > 20 - assert response.year_built is not None and response.year_built > 1900 - assert len(response.attributes) >= 3 - - # Verify nested objects - for attr in response.attributes: - assert 0 <= attr.importance <= 1 - assert len(attr.name) > 0 - assert len(attr.value) > 0 - - # Verify rating field - assert 0 <= response.visitor_rating <= 5 - "******* Batch Inference Tests *******" "Similar to previous tests but for batch inputs" @@ -690,26 +344,6 @@ def test_openai_batch_prompts(): assert "Rome" in responses[2] -@pytest.mark.integration -def test_anthropic_batch_prompts(): - """Test the Anthropic provider with batch prompts.""" - provider = AnthropicProvider() - prompt = [ - "What is the capital of France? Answer in one word.", - "What is the capital of Spain? Answer in one word.", - "What is the capital of Portugal? Answer in one word." - ] - - responses = provider.generate(prompt=prompt) - - assert len(responses) == 3 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "Paris" in responses[0] - assert "Madrid" in responses[1] - assert "Lisbon" in responses[2] - - @pytest.mark.integration def test_gemini_batch_prompts(): """Test the Gemini provider with batch prompts.""" @@ -754,30 +388,6 @@ def test_openai_batch_messages(): assert "Tokyo" in responses[1] -@pytest.mark.integration -def test_anthropic_batch_messages(): - """Test Anthropic provider with batch messages.""" - provider = AnthropicProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of Canada? One word."} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is the capital of Australia? One word."} - ] - ] - - responses = provider.generate(messages=messages) - - assert len(responses) == 2 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "Ottawa" in responses[0] - assert "Canberra" in responses[1] - - @pytest.mark.integration def test_gemini_batch_messages(): """Test Gemini provider with batch messages.""" @@ -828,32 +438,6 @@ def test_openai_batch_structured_output(): assert len(responses[1].reasoning) > 5 -@pytest.mark.integration -def test_anthropic_batch_structured_output(): - """Test Anthropic provider with batch structured output.""" - provider = AnthropicProvider() - prompt = [ - """What is the capital of Germany? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields.""", - """What is the capital of Italy? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields.""" - ] - - responses = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "Berlin" in responses[0].answer - assert "Rome" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 - - @pytest.mark.integration def test_gemini_batch_structured_output(): """Test Gemini provider with batch structured output.""" @@ -910,38 +494,6 @@ def test_openai_batch_messages_with_structured_output(): assert len(responses[1].reasoning) > 5 -@pytest.mark.integration -def test_anthropic_batch_messages_with_structured_output(): - """Test Anthropic provider with batch messages and structured output.""" - provider = AnthropicProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of Egypt? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields."""} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, - {"role": "user", "content": """What is the capital of Morocco? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields."""} - ] - ] - - responses = provider.generate( - messages=messages, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "Cairo" in responses[0].answer - assert "Rabat" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 - - @pytest.mark.integration def test_gemini_batch_messages_with_structured_output(): """Test Gemini provider with batch messages and structured output.""" @@ -995,26 +547,7 @@ def test_openai_batch_with_all_parameters(): assert "Oslo" in responses[1] -@pytest.mark.integration -def test_anthropic_batch_with_all_parameters(): - """Test Anthropic provider with batch processing and all optional parameters.""" - provider = AnthropicProvider( - model_id="claude-haiku-4-5-20251001", - temperature=0.1, - max_completion_tokens=50, - top_p=0.9 - ) - - prompt = [ - "What is the capital of Denmark? Answer in one word.", - "What is the capital of Finland? Answer in one word." - ] - - responses = provider.generate(prompt=prompt) - assert len(responses) == 2 - assert "Copenhagen" in responses[0] - assert "Helsinki" in responses[1] @pytest.mark.integration @@ -1040,23 +573,6 @@ def test_gemini_batch_with_all_parameters(): assert "Amsterdam" in responses[1] -@pytest.mark.integration -def test_batch_validation_errors(): - """Test that batch generate properly validates inputs.""" - provider = AnthropicProvider() - - # Test no inputs provided - with pytest.raises(ValueError, match="Either prompts or messages must be provided"): - provider.generate() - - # Test both inputs provided - with pytest.raises(ValueError, match="Provide either prompts or messages, not both"): - provider.generate( - prompt=["test"], - messages=[[{"role": "user", "content": "test"}]] - ) - - @pytest.mark.integration def test_openai_batch_landmark_info(): """Test OpenAI with batch structured landmark info responses.""" @@ -1127,69 +643,3 @@ def test_openai_batch_landmark_info(): assert 0 <= response.visitor_rating <= 5 -@pytest.mark.integration -def test_ollama_batch_prompts(): - """Test Ollama provider with batch prompts.""" - provider = OllamaProvider(model_id="gemma3:4b") - prompt = [ - "What is the capital of France? Answer in one word.", - "What is the capital of Germany? Answer in one word." - ] - - responses = provider.generate(prompt=prompt) - - assert len(responses) == 2 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "Paris" in responses[0] - assert "Berlin" in responses[1] - - -@pytest.mark.integration -def test_ollama_batch_messages(): - """Test Ollama provider with batch messages.""" - provider = OllamaProvider() - messages = [ - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is 6+4? Just the number."} - ], - [ - {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, - {"role": "user", "content": "What is 8+2? Just the number."} - ] - ] - - responses = provider.generate(messages=messages) - - assert len(responses) == 2 - assert isinstance(responses, list) - assert all(isinstance(r, str) for r in responses) - assert "10" in responses[0] - assert "10" in responses[1] - - -@pytest.mark.integration -def test_ollama_batch_structured_output(): - """Test Ollama provider with batch structured output.""" - provider = OllamaProvider() - prompt = [ - """What is the capital of Spain? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields.""", - """What is the capital of Portugal? - Provide a short answer and brief reasoning. - Format as JSON with 'answer' and 'reasoning' fields.""" - ] - - responses = provider.generate( - prompt=prompt, - response_format=SimpleResponse - ) - - assert len(responses) == 2 - assert all(isinstance(r, SimpleResponse) for r in responses) - assert "Madrid" in responses[0].answer - assert "Lisbon" in responses[1].answer - assert len(responses[0].reasoning) > 5 - assert len(responses[1].reasoning) > 5 diff --git a/tests/test_ollama.py b/tests/test_ollama.py new file mode 100644 index 0000000..cade168 --- /dev/null +++ b/tests/test_ollama.py @@ -0,0 +1,361 @@ +from datafast.llms import OllamaProvider +from dotenv import load_dotenv +import pytest +from typing import List, Optional +from pydantic import BaseModel, Field, field_validator + +load_dotenv() + + +class SimpleResponse(BaseModel): + """Simple response model for testing structured output.""" + answer: str = Field(description="The answer to the question") + reasoning: str = Field(description="The reasoning behind the answer") + + +class Attribute(BaseModel): + """Attribute of a landmark with value and importance.""" + name: str = Field(description="Name of the attribute") + value: str = Field(description="Value of the attribute") + importance: float = Field(description="Importance score between 0 and 1") + + @field_validator('importance') + @classmethod + def check_importance(cls, v: float) -> float: + """Validate importance is between 0 and 1.""" + if not 0 <= v <= 1: + raise ValueError("Importance must be between 0 and 1") + return v + + +class LandmarkInfo(BaseModel): + """Information about a landmark with attributes.""" + name: str = Field(description="The name of the landmark") + location: str = Field(description="Where the landmark is located") + description: str = Field(description="A brief description of the landmark") + year_built: Optional[int] = Field( + None, description="Year when the landmark was built") + attributes: List[Attribute] = Field( + description="List of attributes about the landmark") + visitor_rating: float = Field( + description="Average visitor rating from 0 to 5") + + @field_validator('visitor_rating') + @classmethod + def check_rating(cls, v: float) -> float: + """Validate rating is between 0 and 5.""" + if not 0 <= v <= 5: + raise ValueError("Rating must be between 0 and 5") + return v + + +class PersonaContent(BaseModel): + """Generated content for a persona including tweets and bio.""" + tweets: List[str] = Field(description="List of 5 tweets for the persona") + bio: str = Field(description="Biography for the persona") + + @field_validator('tweets') + @classmethod + def check_tweets_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 5 tweets are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 tweets") + return v + + +class QAItem(BaseModel): + """Question and answer pair.""" + question: str = Field(description="The question") + answer: str = Field(description="The correct answer") + + +class QASet(BaseModel): + """Set of questions and answers.""" + questions: List[QAItem] = Field(description="List of question-answer pairs") + + @field_validator('questions') + @classmethod + def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: + """Validate that exactly 5 Q&A pairs are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 question-answer pairs") + return v + + +class MCQQuestion(BaseModel): + """Multiple choice question with one correct and three incorrect answers.""" + question: str = Field(description="The question") + correct_answer: str = Field(description="The correct answer") + incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") + + @field_validator('incorrect_answers') + @classmethod + def check_incorrect_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 3 incorrect answers are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 incorrect answers") + return v + + +class MCQSet(BaseModel): + """Set of multiple choice questions.""" + questions: List[MCQQuestion] = Field(description="List of MCQ questions") + + @field_validator('questions') + @classmethod + def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: + """Validate that exactly 3 questions are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 questions") + return v + + +@pytest.mark.integration +def test_ollama_provider(): + """Test the Ollama provider with text response.""" + provider = OllamaProvider(model_id="gemma3:4b") + response = provider.generate( + prompt="What is the capital of France? Answer in one word.") + assert "Paris" in response + + +@pytest.mark.integration +def test_ollama_structured_output(): + """Test the Ollama provider with structured output.""" + provider = OllamaProvider(model_id="gemma3:4b") + prompt = """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields.""" + + response = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + +@pytest.mark.integration +def test_ollama_with_messages(): + """Test Ollama provider with messages input instead of prompt.""" + provider = OllamaProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of France? Answer in one word."} + ] + + response = provider.generate(messages=messages) + assert "Paris" in response + + +@pytest.mark.integration +def test_ollama_messages_with_structured_output(): + """Test the Ollama provider with messages input and structured output.""" + provider = OllamaProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides answers in JSON format."}, + {"role": "user", "content": """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital. + Format your response as JSON with 'answer' and 'reasoning' fields."""} + ] + + response = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + +@pytest.mark.integration +def test_ollama_with_all_parameters(): + """Test Ollama provider with all optional parameters specified.""" + provider = OllamaProvider( + model_id="gemma3:4b", + temperature=0.4, + max_completion_tokens=150, + top_p=0.85, + frequency_penalty=0.15, + api_base="http://localhost:11434" + ) + + prompt = "What is the capital of France? Answer in one word." + response = provider.generate(prompt=prompt) + + assert "Paris" in response + + +@pytest.mark.integration +def test_ollama_structured_landmark_info(): + """Test Ollama with a structured landmark info response.""" + provider = OllamaProvider(temperature=0.1, max_completion_tokens=800) + + prompt = """ + Provide detailed information about the Sydney Opera House. + + Return your response as a structured JSON object with the following elements: + - name: The name of the landmark (Sydney Opera House) + - location: Where it's located (Sydney, Australia) + - description: A brief description of the landmark (2-3 sentences) + - year_built: The year when it was completed (as a number) + - attributes: A list of at least 3 attribute objects, each containing: + - name: The name of the attribute (e.g., "architect", "style", "height") + - value: The value of the attribute (e.g., "Jørn Utzon", "Expressionist", "65 meters") + - importance: An importance score between 0 and 1 + - visitor_rating: Average visitor rating from 0 to 5 (e.g., 4.9) + + Make sure your response is properly structured and can be parsed as valid JSON. + """ + + response = provider.generate(prompt=prompt, response_format=LandmarkInfo) + + # Verify the structure was correctly generated and parsed + assert isinstance(response, LandmarkInfo) + assert "Opera House" in response.name + assert "Sydney" in response.location + assert len(response.description) > 20 + assert response.year_built is not None and response.year_built > 1900 + assert len(response.attributes) >= 3 + + # Verify nested objects + for attr in response.attributes: + assert 0 <= attr.importance <= 1 + assert len(attr.name) > 0 + assert len(attr.value) > 0 + + # Verify rating field + assert 0 <= response.visitor_rating <= 5 + + +@pytest.mark.integration +def test_ollama_batch_prompts(): + """Test Ollama provider with batch prompts.""" + provider = OllamaProvider(model_id="gemma3:4b") + prompt = [ + "What is the capital of France? Answer in one word.", + "What is the capital of Germany? Answer in one word." + ] + + responses = provider.generate(prompt=prompt) + + assert len(responses) == 2 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "Paris" in responses[0] + assert "Berlin" in responses[1] + + +@pytest.mark.integration +def test_ollama_batch_messages(): + """Test Ollama provider with batch messages.""" + provider = OllamaProvider() + messages = [ + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is 6+4? Just the number."} + ], + [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is 8+2? Just the number."} + ] + ] + + responses = provider.generate(messages=messages) + + assert len(responses) == 2 + assert isinstance(responses, list) + assert all(isinstance(r, str) for r in responses) + assert "10" in responses[0] + assert "10" in responses[1] + + +@pytest.mark.integration +def test_ollama_batch_structured_output(): + """Test Ollama provider with batch structured output.""" + provider = OllamaProvider() + prompt = [ + """What is the capital of Spain? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields.""", + """What is the capital of Portugal? + Provide a short answer and brief reasoning. + Format as JSON with 'answer' and 'reasoning' fields.""" + ] + + responses = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert len(responses) == 2 + assert all(isinstance(r, SimpleResponse) for r in responses) + assert "Madrid" in responses[0].answer + assert "Lisbon" in responses[1].answer + assert len(responses[0].reasoning) > 5 + assert len(responses[1].reasoning) > 5 + + +@pytest.mark.integration +class TestOllama: + """Tests mirroring OpenRouter structure using Ollama with structured outputs.""" + + def test_persona_content_generation(self): + """Generate tweets and a bio for a persona using Ollama.""" + provider = OllamaProvider(model_id="gemma3:4b", temperature=0.5, max_completion_tokens=2000) + prompt = """ + Generate social media content for the following persona: + + Persona: A tech entrepreneur who is passionate about AI ethics, loves reading sci-fi novels, + practices meditation, and frequently shares insights about startup culture. + + Create exactly 5 tweets and 1 bio for this persona. + """ + response = provider.generate(prompt=prompt, response_format=PersonaContent) + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + """Generate 5 Q&A pairs on machine learning topics using Ollama.""" + provider = OllamaProvider(model_id="gemma3:4b", temperature=0.5, max_completion_tokens=1500) + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: reinforcement learning, convolutional neural networks, regularization, + backpropagation, and feature engineering. + + Each question should be clear and the answer should be concise but complete. + """ + response = provider.generate(prompt=prompt, response_format=QASet) + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + """Generate 3 MCQs on ML topics using Ollama.""" + provider = OllamaProvider(model_id="gemma3:4b", temperature=0.5, max_completion_tokens=1500) + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: recurrent neural networks, k-means clustering, and support vector machines. + """ + response = provider.generate(prompt=prompt, response_format=MCQSet) + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) diff --git a/tests/test_openrouter.py b/tests/test_openrouter.py new file mode 100644 index 0000000..a537b76 --- /dev/null +++ b/tests/test_openrouter.py @@ -0,0 +1,538 @@ +from datafast.llms import OpenRouterProvider +from dotenv import load_dotenv +import pytest +from typing import List, Optional +from pydantic import BaseModel, Field, field_validator + +load_dotenv() + + +class SimpleResponse(BaseModel): + """Simple response model for testing structured output.""" + answer: str = Field(description="The answer to the question") + reasoning: str = Field(description="The reasoning behind the answer") + + +class Attribute(BaseModel): + """Attribute of a landmark with value and importance.""" + name: str = Field(description="Name of the attribute") + value: str = Field(description="Value of the attribute") + importance: float = Field(description="Importance score between 0 and 1") + + @field_validator('importance') + def check_importance(cls, v: float) -> float: + """Validate importance is between 0 and 1.""" + if not 0 <= v <= 1: + raise ValueError("Importance must be between 0 and 1") + return v + + +class LandmarkInfo(BaseModel): + """Information about a landmark with attributes.""" + name: str = Field(description="The name of the landmark") + location: str = Field(description="Where the landmark is located") + description: str = Field(description="A brief description of the landmark") + year_built: Optional[int] = Field( + None, description="Year when the landmark was built") + attributes: List[Attribute] = Field( + description="List of attributes about the landmark") + visitor_rating: float = Field( + description="Average visitor rating from 0 to 5") + + @field_validator('visitor_rating') + def check_rating(cls, v: float) -> float: + """Validate rating is between 0 and 5.""" + if not 0 <= v <= 5: + raise ValueError("Rating must be between 0 and 5") + return v + + +class PersonaContent(BaseModel): + """Generated content for a persona including tweets and bio.""" + tweets: List[str] = Field(description="List of 5 tweets for the persona") + bio: str = Field(description="Biography for the persona") + + @field_validator('tweets') + def check_tweets_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 5 tweets are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 tweets") + return v + + +class QAItem(BaseModel): + """Question and answer pair.""" + question: str = Field(description="The question") + answer: str = Field(description="The correct answer") + + +class QASet(BaseModel): + """Set of questions and answers.""" + questions: List[QAItem] = Field(description="List of question-answer pairs") + + @field_validator('questions') + def check_qa_count(cls, v: List[QAItem]) -> List[QAItem]: + """Validate that exactly 5 Q&A pairs are provided.""" + if len(v) != 5: + raise ValueError("Must provide exactly 5 question-answer pairs") + return v + + +class MCQQuestion(BaseModel): + """Multiple choice question with one correct and three incorrect answers.""" + question: str = Field(description="The question") + correct_answer: str = Field(description="The correct answer") + incorrect_answers: List[str] = Field(description="List of 3 incorrect answers") + + @field_validator('incorrect_answers') + def check_incorrect_count(cls, v: List[str]) -> List[str]: + """Validate that exactly 3 incorrect answers are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 incorrect answers") + return v + + +class MCQSet(BaseModel): + """Set of multiple choice questions.""" + questions: List[MCQQuestion] = Field(description="List of MCQ questions") + + @field_validator('questions') + def check_questions_count(cls, v: List[MCQQuestion]) -> List[MCQQuestion]: + """Validate that exactly 3 questions are provided.""" + if len(v) != 3: + raise ValueError("Must provide exactly 3 questions") + return v + + +@pytest.mark.integration +class TestOpenRouterProvider: + """Test suite for OpenRouter provider with various input types and configurations.""" + + def test_basic_text_response(self): + """Test the OpenRouter provider with text response.""" + provider = OpenRouterProvider() + response = provider.generate(prompt="What is the capital of France? Answer in one word.") + assert "Paris" in response + + def test_structured_output(self): + """Test the OpenRouter provider with structured output.""" + provider = OpenRouterProvider() + prompt = """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital.""" + + response = provider.generate( + prompt=prompt, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + def test_with_messages(self): + """Test OpenRouter provider with messages input instead of prompt.""" + provider = OpenRouterProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": "What is the capital of France? Answer in one word."} + ] + + response = provider.generate(messages=messages) + assert "Paris" in response + + def test_messages_with_structured_output(self): + """Test OpenRouter provider with messages input and structured output.""" + provider = OpenRouterProvider() + messages = [ + {"role": "system", "content": "You are a helpful assistant that provides brief, accurate answers."}, + {"role": "user", "content": """What is the capital of France? + Provide a short answer and a brief explanation of why Paris is the capital."""} + ] + + response = provider.generate( + messages=messages, + response_format=SimpleResponse + ) + + assert isinstance(response, SimpleResponse) + assert "Paris" in response.answer + assert len(response.reasoning) > 10 + + def test_with_all_parameters(self): + """Test OpenRouter provider with all optional parameters specified.""" + provider = OpenRouterProvider( + model_id="meta-llama/llama-3.3-70b-instruct", + max_completion_tokens=300, + top_p=0.85, + ) + + response = provider.generate(prompt="What is the capital of France? Answer in one word.") + + assert "Paris" in response + + def test_structured_landmark_info(self): + """Test OpenRouter with a structured landmark info response.""" + provider = OpenRouterProvider(temperature=0.6, max_completion_tokens=2000) + + prompt = """ + Extract structured landmark details about the Great Wall of China from the passage below. + + Passage: + "The Great Wall of China stands across northern China, originally begun in 220 BCE to guard imperial borders. + Spanning roughly 13,171 miles, it threads over mountains and deserts, symbolising centuries of engineering prowess and cultural unity. + Construction and major reinforcement during the Ming dynasty in the 14th century gave the wall its iconic form, using stone and brick to fortify older earthen ramparts. + Key attributes include: overall length of about 13,171 miles (importance 0.9), primary materials of stone and brick with tamped earth cores (importance 0.7), and critical Ming dynasty stewardship that restored and expanded the fortifications (importance 0.8). + Today's visitors typically rate the experience around 4.6 out of 5, citing sweeping views and the wall's historical resonance." + """ + + response = provider.generate(prompt=prompt, response_format=LandmarkInfo) + + + # Verify the structure was correctly generated and parsed + assert isinstance(response, LandmarkInfo) + assert "Great Wall" in response.name + assert "China" in response.location + assert len(response.description) > 20 + assert response.year_built is not None + assert len(response.attributes) >= 3 + + # Verify nested objects + for attr in response.attributes: + assert 0 <= attr.importance <= 1 + assert len(attr.name) > 0 + assert len(attr.value) > 0 + + # Verify rating field + assert 0 <= response.visitor_rating <= 5 + + +@pytest.mark.integration +class TestOpenRouterGLM46: + """Test suite for z-ai/glm-4.6 model via OpenRouter.""" + + def test_persona_content_generation(self): + """Test generating tweets and bio for a persona using GLM-4.6.""" + provider = OpenRouterProvider( + model_id="z-ai/glm-4.6", + temperature=0.5, + max_completion_tokens=2000 + ) + + prompt = """ + Generate social media content for the following persona: + + Persona: A passionate environmental scientist who loves hiking and photography, + advocates for climate action, and enjoys sharing nature facts with humor. + + Create exactly 5 tweets and 1 bio for this persona. + """ + + response = provider.generate(prompt=prompt, response_format=PersonaContent) + + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + """Test generating Q&A pairs on machine learning using GLM-4.6.""" + provider = OpenRouterProvider( + model_id="z-ai/glm-4.6", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: supervised learning, neural networks, overfitting, gradient descent, and cross-validation. + + Each question should be clear and the answer should be concise but complete. + """ + + response = provider.generate(prompt=prompt, response_format=QASet) + + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + """Test generating multiple choice questions using GLM-4.6.""" + provider = OpenRouterProvider( + model_id="z-ai/glm-4.6", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: neural networks, decision trees, and ensemble methods. + """ + + response = provider.generate(prompt=prompt, response_format=MCQSet) + + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) + + +@pytest.mark.integration +class TestOpenRouterQwen3: + """Test suite for qwen/qwen3-next-80b-a3b-instruct model via OpenRouter.""" + + def test_persona_content_generation(self): + """Test generating tweets and bio for a persona using Qwen3.""" + provider = OpenRouterProvider( + model_id="qwen/qwen3-next-80b-a3b-instruct", + temperature=0.5, + max_completion_tokens=2000 + ) + + prompt = """ + Generate social media content for the following persona: + + Persona: A tech entrepreneur who is passionate about AI ethics, loves reading sci-fi novels, + practices meditation, and frequently shares insights about startup culture. + + Create exactly 5 tweets and 1 bio for this persona. + """ + + response = provider.generate(prompt=prompt, response_format=PersonaContent) + + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + """Test generating Q&A pairs on machine learning using Qwen3.""" + provider = OpenRouterProvider( + model_id="qwen/qwen3-next-80b-a3b-instruct", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: reinforcement learning, convolutional neural networks, regularization, + backpropagation, and feature engineering. + + Each question should be clear and the answer should be concise but complete. + """ + + response = provider.generate(prompt=prompt, response_format=QASet) + + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + """Test generating multiple choice questions using Qwen3.""" + provider = OpenRouterProvider( + model_id="qwen/qwen3-next-80b-a3b-instruct", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: recurrent neural networks, k-means clustering, and support vector machines. + """ + + response = provider.generate(prompt=prompt, response_format=MCQSet) + + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) + + +@pytest.mark.integration +class TestOpenRouterLlama33: + """Test suite for meta-llama/llama-3.3-70b-instruct model via OpenRouter.""" + + def test_persona_content_generation(self): + """Test generating tweets and bio for a persona using Llama 3.3.""" + provider = OpenRouterProvider( + model_id="meta-llama/llama-3.3-70b-instruct", + temperature=0.7, + max_completion_tokens=1000 + ) + + prompt = """ + Generate social media content for the following persona: + + Persona: A professional chef who specializes in fusion cuisine, loves traveling to discover + new ingredients, teaches cooking classes, and shares culinary tips with enthusiasm. + + Create exactly 5 tweets and 1 bio for this persona. + """ + + response = provider.generate(prompt=prompt, response_format=PersonaContent) + + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + """Test generating Q&A pairs on machine learning using Llama 3.3.""" + provider = OpenRouterProvider( + model_id="meta-llama/llama-3.3-70b-instruct", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: transfer learning, attention mechanisms, batch normalization, + dropout, and hyperparameter tuning. + + Each question should be clear and the answer should be concise but complete. + """ + + response = provider.generate(prompt=prompt, response_format=QASet) + + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + """Test generating multiple choice questions using Llama 3.3.""" + provider = OpenRouterProvider( + model_id="meta-llama/llama-3.3-70b-instruct", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: transformers, random forests, and principal component analysis. + """ + + response = provider.generate(prompt=prompt, response_format=MCQSet) + + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) + + +@pytest.mark.integration +class TestOpenRouterGemini25Flash: + """Test suite for google/gemini-2.5-flash model via OpenRouter.""" + + def test_persona_content_generation(self): + """Test generating tweets and bio for a persona using Gemini 2.5 Flash.""" + provider = OpenRouterProvider( + model_id="google/gemini-2.5-flash", + temperature=0.7, + max_completion_tokens=1000 + ) + + prompt = """ + Generate social media content for the following persona: + + Persona: A data scientist who is passionate about open source, enjoys playing chess, + contributes to educational content, and advocates for diversity in tech. + + Create exactly 5 tweets and 1 bio for this persona. + """ + + response = provider.generate(prompt=prompt, response_format=PersonaContent) + + assert isinstance(response, PersonaContent) + assert len(response.tweets) == 5 + assert all(len(tweet) > 0 for tweet in response.tweets) + assert len(response.bio) > 20 + + def test_qa_generation(self): + """Test generating Q&A pairs on machine learning using Gemini 2.5 Flash.""" + provider = OpenRouterProvider( + model_id="google/gemini-2.5-flash", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 5 questions and their correct answers about machine learning topics. + + Topics to cover: generative adversarial networks, autoencoders, dimensionality reduction, + bias-variance tradeoff, and model evaluation metrics. + + Each question should be clear and the answer should be concise but complete. + """ + + response = provider.generate(prompt=prompt, response_format=QASet) + + assert isinstance(response, QASet) + assert len(response.questions) == 5 + for qa in response.questions: + assert len(qa.question) > 10 + assert len(qa.answer) > 10 + + def test_mcq_generation(self): + """Test generating multiple choice questions using Gemini 2.5 Flash.""" + provider = OpenRouterProvider( + model_id="google/gemini-2.5-flash", + temperature=0.5, + max_completion_tokens=1500 + ) + + prompt = """ + Generate exactly 3 multiple choice questions about machine learning. + + For each question, provide: + - The question itself + - One correct answer + - Three plausible but incorrect answers + + Topics: LSTM networks, gradient boosting, and model interpretability. + """ + + response = provider.generate(prompt=prompt, response_format=MCQSet) + + assert isinstance(response, MCQSet) + assert len(response.questions) == 3 + for mcq in response.questions: + assert len(mcq.question) > 10 + assert len(mcq.correct_answer) > 0 + assert len(mcq.incorrect_answers) == 3 + assert all(len(ans) > 0 for ans in mcq.incorrect_answers) +