Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pip install datafast

### 1. Environment Setup

Make sure you have created a `secrets.env` file with your API keys.
Make sure you have created a `.env` file with your API keys.
HF token is needed if you want to push the dataset to your HF hub.
Other keys depends on which LLM providers you use.
```
Expand All @@ -64,7 +64,7 @@ from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider
from dotenv import load_dotenv

# Load environment variables
load_dotenv("secrets.env") # <--- your API keys
load_dotenv() # <--- your API keys
```

### 3. Configure Dataset
Expand Down
3 changes: 3 additions & 0 deletions datafast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Datafast - A Python package for synthetic text dataset generation"""

import importlib.metadata
from datafast.logger_config import configure_logger

try:
__version__ = importlib.metadata.version("datafast")
Expand All @@ -11,3 +12,5 @@
def get_version():
"""Return the current version of the datafast package."""
return __version__

__all__ = ["configure_logger", "get_version"]
282 changes: 252 additions & 30 deletions datafast/datasets.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion datafast/examples/classification_trail_conditions_example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from datafast.datasets import ClassificationDataset
from datafast.schema.config import ClassificationDatasetConfig, PromptExpansionConfig
from datafast.llms import OpenAIProvider, AnthropicProvider
from datafast.logger_config import configure_logger
from dotenv import load_dotenv

# Load API keys
load_dotenv("secrets.env")
load_dotenv()

# Configure logger
configure_logger()

# Configure dataset
config = ClassificationDatasetConfig(
Expand Down
7 changes: 4 additions & 3 deletions datafast/examples/generic_pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from datafast.schema.config import GenericPipelineDatasetConfig
from datafast.datasets import GenericPipelineDataset
from datafast.llms import OpenAIProvider, GeminiProvider, OllamaProvider
from datafast.logger_config import configure_logger
from dotenv import load_dotenv


PROMPT_TEMPLATE = """I will give you a persona.
Expand Down Expand Up @@ -80,7 +82,6 @@ def main():


if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv("secrets.env")
load_dotenv()
configure_logger()
main()
4 changes: 4 additions & 0 deletions datafast/examples/generic_pipeline_response_format_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from datafast.schema.config import GenericPipelineDatasetConfig
from datafast.utils import create_response_model
from datafast.logger_config import configure_logger

# Configure logger
configure_logger()

# Test with multiple columns and num_samples_per_prompt = 3
config = GenericPipelineDatasetConfig(
Expand Down
4 changes: 4 additions & 0 deletions datafast/examples/generic_pipeline_row_model_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from datafast.schema.config import GenericPipelineDatasetConfig
from datafast.utils import create_generic_pipeline_row_model
from datafast.logger_config import configure_logger

# Configure logger
configure_logger()

# Test with multiple input, forward, and output columns
config = GenericPipelineDatasetConfig(
Expand Down
10 changes: 7 additions & 3 deletions datafast/examples/inspect_dataset_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@
python -m datafast.examples.inspect_dataset_example

Requires:
- OpenAI API key in secrets.env or environment
- OpenAI API key in .env or environment
- gradio package (pip install gradio)
"""
from datafast.datasets import ClassificationDataset
from datafast.schema.config import ClassificationDatasetConfig, PromptExpansionConfig
from datafast.logger_config import configure_logger
from dotenv import load_dotenv

# Load API keys from environment or secrets.env
load_dotenv("secrets.env")
# Load API keys from environment or .env
load_dotenv()

# Configure logger
configure_logger()

# Configure the dataset generation
config = ClassificationDatasetConfig(
Expand Down
9 changes: 6 additions & 3 deletions datafast/examples/keywords_extraction_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
from datafast.schema.config import GenericPipelineDatasetConfig
from datafast.datasets import GenericPipelineDataset
from datafast.llms import OpenAIProvider, GeminiProvider, OllamaProvider, OpenRouterProvider
from datafast.logger_config import configure_logger
from dotenv import load_dotenv

# Load environment variables and configure logger
load_dotenv()
configure_logger()

PROMPT_TEMPLATE = """I will give you a tweet.
Generate a comma separated list of 3 keywords for the tweet. Avoid multi-word keywords.
Expand Down Expand Up @@ -63,7 +69,4 @@ def main():


if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv("secrets.env")
main()
6 changes: 5 additions & 1 deletion datafast/examples/mcq_contextual_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
import pandas as pd
from dotenv import load_dotenv

load_dotenv("secrets.env")
load_dotenv()

from datafast.schema.config import MCQDatasetConfig
from datafast.datasets import MCQDataset
from datafast.llms import OpenAIProvider
from datafast.logger_config import configure_logger

# Configure logger
configure_logger()

def main():
# 1. Create a temporary filtered version of the dataset
Expand Down
20 changes: 11 additions & 9 deletions datafast/examples/mcq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import os
from datafast.schema.config import MCQDatasetConfig
from datafast.datasets import MCQDataset
from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider
from datafast.llms import OpenAIProvider, OpenRouterProvider
from datafast.logger_config import configure_logger
from dotenv import load_dotenv

# Load environment variables and configure logger
load_dotenv()
configure_logger()


def main():
Expand All @@ -16,8 +22,8 @@ def main():
# local_file_path="datafast/examples/data/mcq/sample.txt",
#local_file_path="datafast/examples/data/mcq/sample.jsonl",
text_column="text", # Column containing the text to generate questions from
sample_count=2, # Process only 3 samples for testing
num_samples_per_prompt=2,# Generate 2 questions per document
sample_count=20, # Process only 3 samples for testing
num_samples_per_prompt=3,# Generate 2 questions per document
min_document_length=100, # Skip documents shorter than 100 chars
max_document_length=20000,# Skip documents longer than 20000 chars
output_file="mcq_test_dataset.jsonl",
Expand All @@ -26,13 +32,12 @@ def main():
# 2. Initialize LLM providers
providers = [
OpenAIProvider(model_id="gpt-5-mini-2025-08-07"),
# AnthropicProvider(model_id="claude-haiku-4-5-20251001"),
# GeminiProvider(model_id="gemini-2.0-flash"),
OpenRouterProvider(model_id="qwen/qwen3-next-80b-a3b-instruct"),
]

# 3. Generate the dataset
dataset = MCQDataset(config)
num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=2)
num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=20)
print(f"\nExpected number of rows: {num_expected_rows}")
dataset.generate(providers)

Expand All @@ -53,7 +58,4 @@ def main():


if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv()
main()
2 changes: 1 addition & 1 deletion datafast/examples/preference_dataset_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ def main():

if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv("secrets.env")
load_dotenv()
main()
5 changes: 4 additions & 1 deletion datafast/examples/quickstart_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from datafast.datasets import ClassificationDataset
from datafast.schema.config import ClassificationDatasetConfig, PromptExpansionConfig
from dotenv import load_dotenv
load_dotenv("secrets.env")
from datafast.logger_config import configure_logger
load_dotenv()

configure_logger()

config = ClassificationDatasetConfig(
classes=[
Expand Down
21 changes: 12 additions & 9 deletions datafast/examples/raw_text_space_engineering_example.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from datafast.datasets import RawDataset
from datafast.schema.config import RawDatasetConfig, PromptExpansionConfig
from datafast.llms import OpenAIProvider, AnthropicProvider
from datafast.logger_config import configure_logger

configure_logger()


def main():
# 1. Configure the dataset generation
config = RawDatasetConfig(
document_types=[
"space engineering textbook",
"spacecraft design justification document",
# "space engineering textbook",
# "spacecraft design justification document",
"personal blog of a space engineer"
],
topics=[
"Microgravity",
"Vacuum",
"Heavy Ions",
"Thermal Extremes",
"Atomic Oxygen",
"Debris Impact",
# "Microgravity",
# "Vacuum",
# "Heavy Ions",
# "Thermal Extremes",
# "Atomic Oxygen",
# "Debris Impact",
"Electrostatic Charging",
"Propellant Boil-off",
# ... You can pour hundreds of topics here. 8 is enough for this example
Expand Down Expand Up @@ -67,5 +70,5 @@ def main():
if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv("secrets.env")
load_dotenv()
main()
2 changes: 1 addition & 1 deletion datafast/examples/ultrachat_materials_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def main():

if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv("secrets.env")
load_dotenv()
main()
53 changes: 49 additions & 4 deletions datafast/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import traceback
import warnings
from loguru import logger

# Pydantic
from pydantic import BaseModel
Expand Down Expand Up @@ -65,6 +66,9 @@ def __init__(

# Configure environment with API key if needed
self._configure_env()

# Log successful initialization
logger.info(f"Initialized {self.provider_name} | Model: {self.model_id}")

@property
@abstractmethod
Expand All @@ -82,6 +86,9 @@ def _get_api_key(self) -> str:
"""Get API key from environment variables."""
api_key = os.getenv(self.env_key_name)
if not api_key:
logger.error(
f"Missing API key | Set {self.env_key_name} environment variable"
)
raise ValueError(
f"{self.env_key_name} environment variable not set. "
f"Please set it or provide an API key when initializing the provider."
Expand Down Expand Up @@ -112,7 +119,7 @@ def _respect_rate_limit(self) -> None:
# Add a 1s margin to avoid accidental rate limit exceedance
sleep_time = 61 - (current - earliest)
if sleep_time > 0:
print("Waiting for rate limit...")
logger.warning(f"Rate limit reached | Waiting {sleep_time:.1f}s")
time.sleep(sleep_time)

@staticmethod
Expand Down Expand Up @@ -264,8 +271,23 @@ def generate(
if response_format is not None:
# Strip code fences before validation
content = self._strip_code_fences(content)
results.append(
response_format.model_validate_json(content))
try:
results.append(
response_format.model_validate_json(content))
except Exception as validation_error:
# Show the content that failed to parse for debugging
content_preview = content[:200] + "..." if len(content) > 200 else content
logger.warning(
f"JSON parsing failed, skipping response | "
f"Model: {self.model_id} | "
f"Format: {response_format.__name__} | "
f"Content preview: {content_preview}"
)
raise ValueError(
f"Failed to parse JSON response into {response_format.__name__}.\n"
f"Validation error: {validation_error}\n"
f"Content received (first 200 chars):\n{content_preview}"
) from validation_error
else:
# Strip leading/trailing whitespace for text responses
results.append(content.strip() if content else content)
Expand All @@ -277,6 +299,10 @@ def generate(

except Exception as e:
error_trace = traceback.format_exc()
logger.error(
f"Generation failed | Provider: {self.provider_name} | "
f"Model: {self.model_id} | Error: {str(e)}"
)
raise RuntimeError(
f"Error generating batch response with {self.provider_name}:\n{error_trace}"
)
Expand Down Expand Up @@ -471,7 +497,22 @@ def generate(
if response_format is not None:
# Strip code fences before validation
content = self._strip_code_fences(content)
results.append(response_format.model_validate_json(content))
try:
results.append(response_format.model_validate_json(content))
except Exception as validation_error:
# Show the content that failed to parse for debugging
content_preview = content[:200] + "..." if len(content) > 200 else content
logger.warning(
f"JSON parsing failed, skipping response | "
f"Model: {self.model_id} | "
f"Format: {response_format.__name__} | "
f"Content preview: {content_preview}"
)
raise ValueError(
f"Failed to parse JSON response into {response_format.__name__}.\n"
f"Validation error: {validation_error}\n"
f"Content received (first 200 chars):\n{content_preview}"
) from validation_error
else:
# Strip leading/trailing whitespace for text responses
results.append(content.strip() if content else content)
Expand All @@ -483,6 +524,10 @@ def generate(

except Exception as e:
error_trace = traceback.format_exc()
logger.error(
f"Generation failed | Provider: {self.provider_name} | "
f"Model: {self.model_id} | Error: {str(e)}"
)
raise RuntimeError(
f"Error generating response with {self.provider_name}:\n{error_trace}"
)
Expand Down
Loading
Loading