Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions datafast/examples/mcq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
def main():
# 1. Define the configuration
config = MCQDatasetConfig(
# hf_dataset_name="patrickfleith/space_engineering_environment_effects_texts",
hf_dataset_name="patrickfleith/space_engineering_environment_effects_texts",
# local_file_path="datafast/examples/data/mcq/sample.csv",
# local_file_path="datafast/examples/data/mcq/sample.txt",
local_file_path="datafast/examples/data/mcq/sample.jsonl",
#local_file_path="datafast/examples/data/mcq/sample.jsonl",
text_column="text", # Column containing the text to generate questions from
sample_count=3, # Process only 3 samples for testing
sample_count=2, # Process only 3 samples for testing
num_samples_per_prompt=2,# Generate 2 questions per document
min_document_length=100, # Skip documents shorter than 100 chars
max_document_length=20000,# Skip documents longer than 20000 chars
Expand All @@ -32,7 +32,7 @@ def main():

# 3. Generate the dataset
dataset = MCQDataset(config)
num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=3)
num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=2)
print(f"\nExpected number of rows: {num_expected_rows}")
dataset.generate(providers)

Expand All @@ -55,5 +55,5 @@ def main():
if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv("secrets.env")
load_dotenv()
main()
191 changes: 180 additions & 11 deletions datafast/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import time
import traceback
import warnings

# Pydantic
from pydantic import BaseModel
Expand Down Expand Up @@ -249,9 +250,11 @@ def generate(
response: list[ModelResponse] = litellm.batch_completion(
**completion_params)

# Record timestamp for rate limiting
# Record timestamp for rate limiting (one timestamp per batch item)
if self.rpm_limit is not None:
self._request_timestamps.append(time.monotonic())
current_time = time.monotonic()
for _ in range(len(batch_to_send)):
self._request_timestamps.append(current_time)

# Extract content from each response
results = []
Expand Down Expand Up @@ -280,7 +283,15 @@ def generate(


class OpenAIProvider(LLMProvider):
"""OpenAI provider using litellm."""
"""OpenAI provider using litellm.responses endpoint.

Note: This provider uses the new responses endpoint which has different
parameter support compared to the standard completion endpoint:
- temperature, top_p, and frequency_penalty are not supported
- Uses text_format instead of response_format
- Supports reasoning parameter for controlling reasoning effort
- Does not support batch operations (will process sequentially with warning)
"""

@property
def provider_name(self) -> str:
Expand All @@ -294,29 +305,187 @@ def __init__(
self,
model_id: str = "gpt-5-mini-2025-08-07",
api_key: str | None = None,
temperature: float | None = None,
max_completion_tokens: int | None = None,
reasoning_effort: str = "low",
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
):
"""Initialize the OpenAI provider.

Args:
model_id: The model ID (defaults to gpt-5-mini-2025-08-07)
model_id: The model ID (defaults to gpt-5-mini)
api_key: API key (if None, will get from environment)
temperature: The sampling temperature to be used, between 0 and 2. Higher values like 0.8 produce more random outputs, while lower values like 0.2 make outputs more focused and deterministic
max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
top_p: Nucleus sampling parameter (0.0 to 1.0)
frequency_penalty: Penalty for token frequency (-2.0 to 2.0)
reasoning_effort: Reasoning effort level - "low", "medium", or "high" (defaults to "low")
temperature: DEPRECATED - Not supported by responses endpoint
top_p: DEPRECATED - Not supported by responses endpoint
frequency_penalty: DEPRECATED - Not supported by responses endpoint
"""
# Warn about deprecated parameters
if temperature is not None:
warnings.warn(
"temperature parameter is not supported by OpenAI responses endpoint and will be ignored",
UserWarning,
stacklevel=2
)
if top_p is not None:
warnings.warn(
"top_p parameter is not supported by OpenAI responses endpoint and will be ignored",
UserWarning,
stacklevel=2
)
if frequency_penalty is not None:
warnings.warn(
"frequency_penalty parameter is not supported by OpenAI responses endpoint and will be ignored",
UserWarning,
stacklevel=2
)

# Store reasoning effort
self.reasoning_effort = reasoning_effort

# Call parent init with None for unsupported params
super().__init__(
model_id=model_id,
api_key=api_key,
temperature=temperature,
temperature=None,
max_completion_tokens=max_completion_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
top_p=None,
frequency_penalty=None,
)

def generate(
self,
prompt: str | list[str] | None = None,
messages: list[Messages] | Messages | None = None,
response_format: Type[T] | None = None,
) -> str | list[str] | T | list[T]:
"""
Generate responses from the LLM using the responses endpoint.

Note: Batch operations are processed sequentially as the responses endpoint
does not support native batching.

Args:
prompt: Single text prompt (str) or list of text prompts for batch processing
messages: Single message list or list of message lists for batch processing
response_format: Optional Pydantic model class for structured output

Returns:
Single string/model or list of strings/models depending on input type.

Raises:
ValueError: If neither prompt nor messages is provided, or if both are provided.
RuntimeError: If there's an error during generation.
"""
# Validate inputs
if prompt is None and messages is None:
raise ValueError("Either prompts or messages must be provided")
if prompt is not None and messages is not None:
raise ValueError("Provide either prompts or messages, not both")

# Determine if this is a single input or batch input
single_input = False
batch_prompts = None
batch_messages = None

if prompt is not None:
if isinstance(prompt, str):
# Single prompt - convert to batch
batch_prompts = [prompt]
single_input = True
elif isinstance(prompt, list):
# Already a list of prompts
batch_prompts = prompt
single_input = False
else:
raise ValueError("prompt must be a string or list of strings")

if messages is not None:
if isinstance(messages, list) and len(messages) > 0:
# Check if it's a single message list or batch
if isinstance(messages[0], dict):
# Single message list - convert to batch
batch_messages = [messages]
single_input = True
elif isinstance(messages[0], list):
# Already a batch of message lists
batch_messages = messages
single_input = False
else:
raise ValueError("Invalid messages format")
else:
raise ValueError("messages cannot be empty")

try:
# Convert batch prompts to messages if needed
batch_to_send = []
if batch_prompts is not None:
for one_prompt in batch_prompts:
batch_to_send.append([{"role": "user", "content": one_prompt}])
else:
batch_to_send = batch_messages

# Warn if batch processing is being used
if len(batch_to_send) > 1:
warnings.warn(
f"OpenAI responses endpoint does not support batch operations. "
f"Processing {len(batch_to_send)} requests sequentially.",
UserWarning,
stacklevel=2
)

# Process each request sequentially
results = []
for message_list in batch_to_send:
# Enforce rate limit per request
self._respect_rate_limit()

# Prepare completion parameters
completion_params = {
"model": self._get_model_string(),
"input": message_list,
"reasoning": {"effort": self.reasoning_effort},
}

# Add max_output_tokens if specified
if self.max_completion_tokens is not None:
completion_params["max_output_tokens"] = self.max_completion_tokens

# Add text_format if response_format is provided
if response_format is not None:
completion_params["text_format"] = response_format

# Call LiteLLM responses endpoint
response = litellm.responses(**completion_params)

# Record timestamp for rate limiting
if self.rpm_limit is not None:
self._request_timestamps.append(time.monotonic())

# Extract content from response
# Response structure: response.output[1].content[0].text
content = response.output[1].content[0].text

if response_format is not None:
# Strip code fences before validation
content = self._strip_code_fences(content)
results.append(response_format.model_validate_json(content))
else:
# Strip leading/trailing whitespace for text responses
results.append(content.strip() if content else content)

# Return single result for backward compatibility
if single_input and len(results) == 1:
return results[0]
return results

except Exception as e:
error_trace = traceback.format_exc()
raise RuntimeError(
f"Error generating response with {self.provider_name}:\n{error_trace}"
)


class AnthropicProvider(LLMProvider):
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ config = ClassificationDatasetConfig(
providers = [
OpenAIProvider(model_id="gpt-5-mini-2025-08-07"),
AnthropicProvider(model_id="claude-haiku-4-5-20251001"),
GeminiProvider(model_id="gemini-2.5-flash"),
GeminiProvider(model_id="gemini-2.0-flash"),
OpenRouterProvider(model_id="z-ai/glm-4.6")
]
```
Expand Down
36 changes: 31 additions & 5 deletions docs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ gemini_llm = GeminiProvider()
# Ollama (default: gemma3:4b)
ollama_llm = OllamaProvider()

# OpenRouter (default: openai/gpt-4.1-mini)
# OpenRouter (default: openai/gpt-5-mini)
openrouter_llm = OpenRouterProvider()
```

Expand All @@ -42,12 +42,38 @@ openrouter_llm = OpenRouterProvider()
```python
openai_llm = OpenAIProvider(
model_id="gpt-5-mini-2025-08-07", # Custom model
temperature=0.2, # Lower temperature for more deterministic outputs
max_completion_tokens=100, # Limit token generation
top_p=0.9, # Nucleus sampling parameter
frequency_penalty=0.1 # Penalty for frequent tokens
max_completion_tokens=1000, # Limit token generation (don't set this too low for reasoning models)
reasoning_effort="medium" # Reasoning effort: "low", "medium", or "high"
)
```

!!! warning "OpenAI Provider Changes"
`OpenAIProvider` now uses the `responses` endpoint. The following parameters are **deprecated** and will trigger warnings:
- `temperature`
- `top_p`
- `frequency_penalty`

Use `reasoning_effort` ("low", "medium", "high") instead to control generation behavior.

```python
# Anthropic with custom parameters
anthropic_llm = AnthropicProvider(
model_id="claude-haiku-4-5-20251001",
temperature=0.7,
max_completion_tokens=1000
)
```

!!! warning "Anthropic Provider Limitations"
`AnthropicProvider` only supports the following parameters:
- `temperature` (0.0 to 1.0)
- `max_completion_tokens`

The following parameters are **not supported** by Anthropic Claude 4.5 models:
- `top_p`
- `frequency_penalty`

```python
# Ollama with custom API endpoint
ollama_llm = OllamaProvider(
model_id="llama3.2:latest",
Expand Down
Loading
Loading