Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 152 additions & 53 deletions datafast/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
UltrachatDatasetConfig,
MCQDatasetConfig,
PreferenceDatasetConfig,
GenericPipelineDatasetConfig,
)
from datafast.schema.data_rows import (
ChatRow,
Expand All @@ -29,10 +30,14 @@
MCQSource,
PreferenceRow,
PreferenceSource,
GenericPipelineRow,
GenericPipelineSource,
)
from datafast.expanders import expand_prompts
import os
from datafast import utils
from loguru import logger


### Model for Raw Text Examples Generation

Expand Down Expand Up @@ -759,60 +764,14 @@ def generate(self, llms: list[LLMProvider]) -> "MCQDataset":
if not llms:
raise ValueError("At least one LLM provider must be supplied")

# Load the dataset from Hugging Face or local file
# Load the dataset using shared utility
try:
if self.config.hf_dataset_name:
# Load from Hugging Face
hf_dataset = load_dataset(self.config.hf_dataset_name)
# Most datasets have a 'train' split, but fallback to first available split
split_names = list(hf_dataset.keys())
if not split_names:
raise ValueError(f"No splits found in dataset {self.config.hf_dataset_name}")

main_split = "train" if "train" in split_names else split_names[0]
dataset = hf_dataset[main_split]

elif self.config.local_file_path:
# Load from local file based on extension
file_path = self.config.local_file_path
file_ext = file_path.lower().split('.')[-1]

if file_ext == 'csv':
# Load CSV file
import pandas as pd
df = pd.read_csv(file_path)
dataset = df.to_dict('records')

elif file_ext == 'txt':
# For TXT files, create a dataset with one record per line
# and use the text_column as the key
with open(file_path, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f if line.strip()]
dataset = [{self.config.text_column: line} for line in lines]

elif file_ext == 'parquet':
# Load Parquet file
import pandas as pd
df = pd.read_parquet(file_path)
dataset = df.to_dict('records')

elif file_ext in ['jsonl', 'json']:
# Load JSONL file (one JSON object per line)
import json
dataset = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
dataset.append(json.loads(line))

else:
raise ValueError(f"Unsupported file extension: {file_ext}. Supported extensions are: csv, txt, parquet, jsonl, json")
else:
raise ValueError("Either hf_dataset_name or local_file_path must be specified")

# Limit the number of samples if specified
if self.config.sample_count is not None:
dataset = dataset[:min(self.config.sample_count, len(dataset))]
dataset = utils.load_dataset_from_source(
hf_dataset_name=self.config.hf_dataset_name,
local_file_path=self.config.local_file_path,
sample_count=self.config.sample_count,
text_column=self.config.text_column
)

except Exception as e:
source = self.config.hf_dataset_name or self.config.local_file_path
Expand Down Expand Up @@ -1300,3 +1259,143 @@ def _get_default_judge_prompt(self) -> str:
from datafast.prompts import preference_prompts
return preference_prompts.JUDGE_PROMPT


class GenericPipelineDataset(DatasetBase):
def __init__(self, config: GenericPipelineDatasetConfig):
super().__init__(config)
self.config = config

def get_num_expected_rows(self, llms: list[LLMProvider]) -> int:
"""Calculate the expected number of rows that will be generated.

Args:
llms: List of LLM providers that will be used for generation.

Returns:
int: The expected number of rows that will be generated.
"""
if not llms:
raise ValueError("At least one LLM provider must be supplied")
return utils._get_generic_pipeline_num_expected_rows(self.config, llms)

def _load_source_dataset(self):
"""Load dataset from Hugging Face or local file using shared utility."""
return utils.load_dataset_from_source(
hf_dataset_name=self.config.hf_dataset_name,
local_file_path=self.config.local_file_path,
sample_count=self.config.sample_count
)

def generate(self, llms: list[LLMProvider]) -> "GenericPipelineDataset":
"""Generate data by processing source dataset through custom prompts.

Args:
llms: List of LLM providers to use for generation.

Returns:
Self for method chaining.
"""
if not llms:
raise ValueError("At least one LLM provider must be supplied")

# Load source dataset
source_dataset = self._load_source_dataset()
print(f"Loaded source dataset with {len(source_dataset)} rows")

# Apply sample limit if specified
if self.config.sample_count:
source_dataset = source_dataset[:min(self.config.sample_count, len(source_dataset))]
print(f"Limited to {len(source_dataset)} rows")

# Get languages from config
languages = self.config.languages or {"en": "English"}

# Process each row in the source dataset
for row_idx, source_row in enumerate(source_dataset):
# Apply skip function if provided
if self.config.skip_function and self.config.skip_function(source_row):
print(f"Skipping row {row_idx} due to skip_function")
continue

# Extract input data based on input_columns
input_data = {col: str(source_row.get(col, "")) for col in self.config.input_columns}

# Extract forward data if specified
forward_data = {}
if self.config.forward_columns:
forward_data = {col: str(source_row.get(col, "")) for col in self.config.forward_columns}

# Process for each language
for lang_code, language_name in languages.items():
# Process each prompt
for prompt_idx, prompt_template in enumerate(self.config.prompts):
# Format prompt with input data and required placeholders
formatted_prompt = prompt_template.format(
num_samples=self.config.num_samples_per_prompt,
language=language_name,
**input_data
)

# Expand prompts with configured variations
expansions = expand_prompts(
prompt_templates=[formatted_prompt],
**self.config.expansion.model_dump()
)

# Process each expanded prompt
for expanded_prompt, meta in expansions:
# Process with each LLM
for llm in llms:
try:
# Create dynamic response model based on output_columns configuration
response_model = utils.create_response_model(self.config)

# Create dynamic row model based on output_columns configuration
row_model = utils.create_generic_pipeline_row_model(self.config)

# Generate response using the LLM with proper response format
response = llm.generate(expanded_prompt, response_format=response_model)

# Create rows for each generated sample
new_rows = []
for entry in response.entries:
# Prepare row data with all columns as separate top-level fields
row_data = {
"model_id": llm.model_id,
"pipeline_source": GenericPipelineSource.SYNTHETIC,
"language": lang_code,
"metadata": {
"prompt_index": str(prompt_idx),
"source_row_index": str(row_idx),
}
}

# Add input data as individual top-level fields
for column, value in input_data.items():
row_data[column] = value

# Add forward data as individual top-level fields
for column, value in forward_data.items():
row_data[column] = value

# Add each output column as a separate field
if self.config.output_columns:
for column in self.config.output_columns:
row_data[column] = getattr(entry, column, "")
else:
row_data["generated_text"] = getattr(entry, "generated_text", "")

# Create the dynamic row
row = row_model(**row_data)
self.data_rows.append(row)
new_rows.append(row)

# Save this batch
self.to_jsonl(self.config.output_file, new_rows, append=True)
logger.success(f"Generated and saved {len(self.data_rows)} examples total")

except Exception as e:
logger.error(f"Error with llm provider {llm.provider_name} on row {row_idx}: {e}")
continue

return self
86 changes: 86 additions & 0 deletions datafast/examples/generic_pipeline_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Example script for generating a dataset using GenericPipelineDataset.
This example uses the patrickfleith/FinePersonas-v0.1-100k-space-filtered dataset to generate tweets and CVs for different personas.
"""

import os
from datafast.schema.config import GenericPipelineDatasetConfig
from datafast.datasets import GenericPipelineDataset
from datafast.llms import OpenAIProvider, GeminiProvider, OllamaProvider


PROMPT_TEMPLATE = """I will give you a persona.
Generate {num_samples} texts in {language} with:
1. A tweet that this person might write (engaging, authentic to their character)
2. A short CV highlighting their background

Make sure the content reflects their personality and background authentically.
The CV should include higher education degree (and school/university they obtained it from), work experience (if any), and relevant skills, and a hobby.\

Here is the persona:
{persona}

Your response should be formatted in valid JSON with {num_samples} entries and all required fields."""

def main():
# 1. Define the configuration
config = GenericPipelineDatasetConfig(
hf_dataset_name="patrickfleith/FinePersonas-v0.1-100k-space-filtered",
input_columns=["persona"], # Input data for generation
forward_columns=["summary_label"], # Data to forward through
output_columns=["tweet", "cv"], # Generated content columns
sample_count=5, # Process only 5 samples for testing
num_samples_per_prompt=2, # Generate 1 set per persona
prompts=[PROMPT_TEMPLATE], # Use the prompt template
output_file="generic_pipeline_test_dataset.jsonl",
languages={"en": "English", "fr": "French"}
)

# 2. Initialize LLM providers
providers = [
OpenAIProvider(
model_id="gpt-5-mini-2025-08-07",
temperature=1
),
# AnthropicProvider(model_id="claude-3-5-haiku-latest"),
# GeminiProvider(model_id="gemini-2.5-flash-lite", rpm_limit=15),
# OllamaProvider(model_id="gemma3:4b"),
]

# 3. Generate the dataset
dataset = GenericPipelineDataset(config)
num_expected_rows = dataset.get_num_expected_rows(providers)
print(f"\nExpected number of rows: {num_expected_rows}")
dataset.generate(providers)

# 4. Print results summary
print(f"\nGenerated {len(dataset.data_rows)} examples")
print(f"Results saved to {config.output_file}")

# # 5. Show a sample of the generated data
# if dataset.data_rows:
# print("\nSample generated row:")
# sample_row = dataset.data_rows[0]
# print(f"UUID: {sample_row.uuid}")
# print(f"Tweet: {getattr(sample_row, 'tweet', 'N/A')}")
# print(f"CV: {getattr(sample_row, 'cv', 'N/A')}")
# print(f"Persona: {getattr(sample_row, 'persona', 'N/A')}")
# print(f"Summary Label: {getattr(sample_row, 'summary_label', 'N/A')}")
# print(f"Model ID: {sample_row.model_id}")

# 6. Optional: Push to HF hub
USERNAME = "username" # <--- Your hugging face username
DATASET_NAME = "generic_pipeline_test_dataset_2" # <--- Your hugging face dataset name
url = dataset.push_to_hub(
repo_id=f"{USERNAME}/{DATASET_NAME}",
seed=20250816,
shuffle=True,
)
print(f"\nDataset pushed to Hugging Face Hub: {url}")


if __name__ == "__main__":
from dotenv import load_dotenv

load_dotenv("secrets.env")
main()
17 changes: 17 additions & 0 deletions datafast/examples/generic_pipeline_response_format_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Simple test for create_response_model function."""

from datafast.schema.config import GenericPipelineDatasetConfig
from datafast.utils import create_response_model

# Test with multiple columns and num_samples_per_prompt = 3
config = GenericPipelineDatasetConfig(
hf_dataset_name="imdb",
input_columns=["text"],
output_columns=["summary", "sentiment"],
prompts=["Analyze: {text}. Language: {language}. Generate {num_samples} responses."],
num_samples_per_prompt=3
)

ResponseModel = create_response_model(config)

print(ResponseModel.model_json_schema())
Loading
Loading