Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions datafast/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def generate(self, llms: list[LLMProvider]) -> "ClassificationDataset":
label=label["name"],
model_id=llm.model_id,
label_source=LabelSource.SYNTHETIC,
metadata={"language": lang_code},
language=lang_code,
)
self.data_rows.append(row)
new_rows.append(row)
Expand Down Expand Up @@ -510,8 +510,8 @@ def generate(self, llms: list[LLMProvider]) -> "RawDataset":
text=example.text,
text_source=TextSource.SYNTHETIC,
model_id=llm.model_id,
language=lang_code,
metadata={
"language": lang_code,
"document_type": document_type,
"topic": topic,
},
Expand Down Expand Up @@ -686,8 +686,8 @@ def generate(self, llms: list[LLMProvider]) -> "UltrachatDataset":
opening_question=messages[0]["content"],
messages=messages,
model_id=llm.model_id,
language=lang_code,
metadata={
"language": lang_code,
"domain": self.config.domain,
"topic": topic,
"subtopic": subtopic,
Expand Down Expand Up @@ -915,8 +915,8 @@ def generate(self, llms: list[LLMProvider]) -> "MCQDataset":
incorrect_answer_3=incorrect_answers[2],
model_id=llm.model_id,
mcq_source=MCQSource.SYNTHETIC,
language=lang_code,
metadata={
"language": lang_code,
"source_dataset": self._get_source_dataset_name(),
},
)
Expand Down Expand Up @@ -1092,8 +1092,8 @@ def generate(self,
"preference_source": PreferenceSource.SYNTHETIC,
"chosen_model_id": chosen_model_id,
"rejected_model_id": rejected_model_id,
"language": lang_code,
"metadata": {
"language": lang_code,
"instruction_model": question_gen_llm.model_id,
}
}
Expand Down
10 changes: 5 additions & 5 deletions datafast/examples/quickstart_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
],
expansion=PromptExpansionConfig(
placeholders={
"context": ["hike review", "speedboat tour review", "outdoor climbing experience"],
"context": ["hike review", "speedboat tour review"],
},
combinatorial=True
)
Expand All @@ -34,7 +34,7 @@
from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider, OllamaProvider

providers = [
OpenAIProvider(model_id="gpt-4.1-nano"),
OpenAIProvider(model_id="gpt-5-nano-2025-08-07"),
# AnthropicProvider(model_id="claude-3-5-haiku-latest"),
# GeminiProvider(model_id="gemini-2.0-flash"),
# OllamaProvider(model_id="gemma3:12b")
Expand All @@ -47,9 +47,9 @@
dataset.generate(providers)

# Optional: Push to Hugging Face Hub
USERNAME = "username" # <--- Your hugging face username
DATASET_NAME = "dataset_name" # <--- Your hugging face dataset name
USERNAME = "patrickfleith" # <--- Your hugging face username
DATASET_NAME = "datafast_quickstart_no_train_test_split" # <--- Your hugging face dataset name
dataset.push_to_hub(
repo_id=f"{USERNAME}/{DATASET_NAME}",
train_size=0.6
# train_size=0.6
)
21 changes: 11 additions & 10 deletions datafast/examples/show_dataset_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
text="The trail is blocked by a fallen tree.",
label="trail_obstruction",
model_id="gpt-4.1-nano",
metadata={"language": "en"},
language="en",
)
classification_dataset = ClassificationDataset(
ClassificationDatasetConfig(classes=[{"name": "trail_obstruction", "description": "Obstruction on the trail."}])
Expand All @@ -47,7 +47,7 @@
text="The trail is well maintained and easy to follow.",
label="positive_conditions",
model_id="claude-3-5-haiku-latest",
metadata={"language": "en"},
language="en",
)
classification_dataset.data_rows = [classification_row, classification_row2]

Expand All @@ -60,7 +60,7 @@
incorrect_answer_2="Berlin",
incorrect_answer_3="Rome",
model_id="gemini-2.0-flash",
metadata={"language": "en"},
language="en",
)
mcq_config = MCQDatasetConfig(
text_column="source_document",
Expand All @@ -75,7 +75,7 @@
incorrect_answer_2="Yangtze River",
incorrect_answer_3="Mississippi River",
model_id="gpt-4.1-nano",
metadata={"language": "en"},
language="en",
)
mcq_dataset.data_rows = [mcq_row, mcq_row2]

Expand All @@ -91,7 +91,7 @@
rejected_response_score=3,
chosen_response_assessment="Accurate and detailed.",
rejected_response_assessment="Too generic.",
metadata={"language": "en"},
language="en",
)
preference_dataset = PreferenceDataset(PreferenceDatasetConfig(input_documents=["Describe a recent Mars mission."]))
preference_row2 = PreferenceRow(
Expand All @@ -105,7 +105,7 @@
rejected_response_score=2,
chosen_response_assessment="Factually correct and detailed.",
rejected_response_assessment="Incorrect mission.",
metadata={"language": "en"},
language="en",
)
preference_dataset.data_rows = [preference_row, preference_row2]

Expand Down Expand Up @@ -134,20 +134,21 @@
persona="space policy expert",
messages=[{"role": "user", "content": "What are current efforts to clean up space debris?"}, {"role": "assistant", "content": "There are several ongoing projects, such as RemoveDEBRIS and ClearSpace-1."}],
model_id="gpt-4.1-nano",
metadata={"language": "en"}
language="en"
)
ultrachat_row2 = ChatRow(
opening_question="What is the importance of the Moon missions?",
persona="lunar geologist",
messages=[{"role": "user", "content": "Why do we keep returning to the Moon?"}, {"role": "assistant", "content": "The Moon offers scientific insights and is a stepping stone for Mars exploration."}],
model_id="gemini-2.0-flash",
metadata={"language": "en"}
language="en"
)
ultrachat_config = UltrachatDatasetConfig()
ultrachat_dataset = UltrachatDataset(ultrachat_config)
ultrachat_dataset.data_rows = [ultrachat_row1, ultrachat_row2]

if __name__ == "__main__":
# pass
# print("Showing ClassificationDataset example...")
# inspect_classification_dataset(classification_dataset)
# print("Showing MCQDataset example...")
Expand All @@ -156,5 +157,5 @@
# inspect_preference_dataset(preference_dataset)
# print("Showing RawDataset example...")
# inspect_raw_dataset(raw_dataset)
# print("Showing UltrachatDataset example...")
# inspect_ultrachat_dataset(ultrachat_dataset)
print("Showing UltrachatDataset example...")
inspect_ultrachat_dataset(ultrachat_dataset)
4 changes: 2 additions & 2 deletions datafast/inspectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,8 @@ def format_conversation(self, conversation: List[Dict]) -> str:
def show_example(self, idx: int) -> Tuple[str, str, str, Dict]:
"""Extract data from the example to display."""
row = self.get_example(idx)
conversation = row.get("conversation", [])
formatted_convo = self.format_conversation(conversation)
messages = row.get("messages", [])
formatted_convo = self.format_conversation(messages)
return (
self.get_index_label(idx),
formatted_convo,
Expand Down
29 changes: 17 additions & 12 deletions datafast/schema/data_rows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from uuid import UUID, uuid4
from typing import Union, Optional
from typing import Union
from enum import Enum


Expand Down Expand Up @@ -33,7 +33,8 @@ class TextRow(BaseModel):

text: str
text_source: TextSource = TextSource.SYNTHETIC
model_id: Optional[str] = None
model_id: str | None = None
language: str | None = None
uuid: UUID = Field(default_factory=uuid4)
metadata: dict[str, str] = Field(default_factory=dict)

Expand All @@ -43,7 +44,8 @@ class ChatRow(BaseModel):

opening_question: str
messages: list[dict[str, str]]
model_id: Optional[str] = None
model_id: str | None = None
language: str | None = None
uuid: UUID = Field(default_factory=uuid4)
metadata: dict[str, str] = Field(default_factory=dict)
persona: str
Expand All @@ -52,9 +54,10 @@ class ChatRow(BaseModel):
class TextClassificationRow(BaseModel):
text: str
label: LabelType # Must be either str, list[str], or list[int]
model_id: Optional[str] = None
model_id: str | None = None
label_source: LabelSource = LabelSource.SYNTHETIC
confidence_scores: Optional[dict[str, float]] = Field(default_factory=dict)
confidence_scores: dict[str, float] | None = Field(default_factory=dict)
language: str | None = None

# System and metadata fields
uuid: UUID = Field(default_factory=uuid4)
Expand All @@ -68,8 +71,9 @@ class MCQRow(BaseModel):
incorrect_answer_1: str
incorrect_answer_2: str
incorrect_answer_3: str
model_id: Optional[str] = None
model_id: str | None = None
mcq_source: MCQSource = MCQSource.SYNTHETIC
language: str | None = None
uuid: UUID = Field(default_factory=uuid4)
metadata: dict[str, str] = Field(default_factory=dict)

Expand All @@ -89,14 +93,15 @@ class PreferenceRow(BaseModel):
chosen_response: str
rejected_response: str
preference_source: PreferenceSource = PreferenceSource.SYNTHETIC
chosen_model_id: Optional[str] = None
rejected_model_id: Optional[str] = None
chosen_model_id: str | None = None
rejected_model_id: str | None = None
language: str | None = None

# Optional judge-related fields
chosen_response_score: Optional[int] = None
rejected_response_score: Optional[int] = None
chosen_response_assessment: Optional[str] = None
rejected_response_assessment: Optional[str] = None
chosen_response_score: int | None = None
rejected_response_score: int | None = None
chosen_response_assessment: str | None = None
rejected_response_assessment: str | None = None

uuid: UUID = Field(default_factory=uuid4)
metadata: dict[str, str] = Field(default_factory=dict)
Loading