Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions datafast/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# LiteLLM
import litellm
from litellm.utils import ModelResponse
from litellm import batch_completion

# Internal imports
from .llm_utils import get_messages
Expand Down Expand Up @@ -115,6 +114,37 @@ def _respect_rate_limit(self) -> None:
print("Waiting for rate limit...")
time.sleep(sleep_time)

@staticmethod
def _strip_code_fences(content: str) -> str:
"""Strip markdown code fences from content if present.

Args:
content: The content string that may contain code fences

Returns:
Content with code fences removed
"""
if not content:
return content

content = content.strip()

# Check for code fences with optional language identifier
if content.startswith('```'):
# Find the end of the first line (language identifier)
first_newline = content.find('\n')
if first_newline != -1:
content = content[first_newline + 1:]
else:
# No newline after opening fence, remove just the fence
content = content[3:]

# Remove closing fence
if content.endswith('```'):
content = content[:-3]

return content.strip()

def generate(
self,
prompt: str | list[str] | None = None,
Expand Down Expand Up @@ -176,13 +206,29 @@ def generate(
raise ValueError("messages cannot be empty")

try:
# Append JSON formatting instructions if response_format is provided
json_instructions = (
"\nReturn only valid JSON. To do so, don't include ```json ``` markdown "
"or code fences around the JSON. Use double quotes for all keys and values. "
"Escape internal quotes and newlines (use \\n). Do not include trailing commas."
)

# Convert batch prompts to messages if needed
batch_to_send = []
if batch_prompts is not None:
for one_prompt in batch_prompts:
batch_to_send.append(get_messages(one_prompt))
# Append JSON instructions to prompt if response_format is provided
modified_prompt = one_prompt + json_instructions if response_format is not None else one_prompt
batch_to_send.append(get_messages(modified_prompt))
else:
batch_to_send = batch_messages
# Append JSON instructions to the last user message if response_format is provided
if response_format is not None:
for message_list in batch_to_send:
for msg in reversed(message_list):
if msg.get("role") == "user":
msg["content"] += json_instructions
break

# Enforce rate limit per batch
self._respect_rate_limit()
Expand Down Expand Up @@ -211,11 +257,15 @@ def generate(
results = []
for one_response in response:
content = one_response.choices[0].message.content

if response_format is not None:
# Strip code fences before validation
content = self._strip_code_fences(content)
results.append(
response_format.model_validate_json(content))
else:
results.append(content)
# Strip leading/trailing whitespace for text responses
results.append(content.strip() if content else content)

# Return single result for backward compatibility
if single_input and len(results) == 1:
Expand Down Expand Up @@ -286,8 +336,8 @@ def __init__(
api_key: str | None = None,
temperature: float | None = None,
max_completion_tokens: int | None = None,
top_p: float | None = None,
# frequency_penalty: float | None = None, # Not supported by anthropic
# top_p: float | None = None, # Not properly supported by anthropic models 4.5
# frequency_penalty: float | None = None, # Not supported by anthropic models 4.5
):
"""Initialize the Anthropic provider.

Expand All @@ -303,7 +353,6 @@ def __init__(
api_key=api_key,
temperature=temperature,
max_completion_tokens=max_completion_tokens,
top_p=top_p,
)


Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[pytest]
markers =
integration: marks tests that require API connectivity (deselect with '-m "not integration"')
slow: marks tests that are slow to run

# Other pytest configurations
testpaths = tests
Expand All @@ -16,6 +17,8 @@ log_cli_level = INFO
filterwarnings =
# Ignore Pydantic deprecation warnings
ignore::DeprecationWarning:pydantic.*:
# Ignore Pydantic serializer warnings during tests
ignore::UserWarning:pydantic.main
# Ignore LiteLLM deprecation warnings
ignore::DeprecationWarning:litellm.*:
# Ignore HTTPX deprecation warnings
Expand Down
Loading
Loading