diff --git a/datafast/datasets.py b/datafast/datasets.py index cedc2fe..6d50fb5 100644 --- a/datafast/datasets.py +++ b/datafast/datasets.py @@ -416,7 +416,7 @@ def generate(self, llms: list[LLMProvider]) -> "ClassificationDataset": label=label["name"], model_id=llm.model_id, label_source=LabelSource.SYNTHETIC, - metadata={"language": lang_code}, + language=lang_code, ) self.data_rows.append(row) new_rows.append(row) @@ -510,8 +510,8 @@ def generate(self, llms: list[LLMProvider]) -> "RawDataset": text=example.text, text_source=TextSource.SYNTHETIC, model_id=llm.model_id, + language=lang_code, metadata={ - "language": lang_code, "document_type": document_type, "topic": topic, }, @@ -686,8 +686,8 @@ def generate(self, llms: list[LLMProvider]) -> "UltrachatDataset": opening_question=messages[0]["content"], messages=messages, model_id=llm.model_id, + language=lang_code, metadata={ - "language": lang_code, "domain": self.config.domain, "topic": topic, "subtopic": subtopic, @@ -915,8 +915,8 @@ def generate(self, llms: list[LLMProvider]) -> "MCQDataset": incorrect_answer_3=incorrect_answers[2], model_id=llm.model_id, mcq_source=MCQSource.SYNTHETIC, + language=lang_code, metadata={ - "language": lang_code, "source_dataset": self._get_source_dataset_name(), }, ) @@ -1092,8 +1092,8 @@ def generate(self, "preference_source": PreferenceSource.SYNTHETIC, "chosen_model_id": chosen_model_id, "rejected_model_id": rejected_model_id, + "language": lang_code, "metadata": { - "language": lang_code, "instruction_model": question_gen_llm.model_id, } } diff --git a/datafast/examples/quickstart_example.py b/datafast/examples/quickstart_example.py index 6cff340..7565adf 100644 --- a/datafast/examples/quickstart_example.py +++ b/datafast/examples/quickstart_example.py @@ -25,7 +25,7 @@ ], expansion=PromptExpansionConfig( placeholders={ - "context": ["hike review", "speedboat tour review", "outdoor climbing experience"], + "context": ["hike review", "speedboat tour review"], }, combinatorial=True ) @@ -34,7 +34,7 @@ from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider, OllamaProvider providers = [ - OpenAIProvider(model_id="gpt-4.1-nano"), + OpenAIProvider(model_id="gpt-5-nano-2025-08-07"), # AnthropicProvider(model_id="claude-3-5-haiku-latest"), # GeminiProvider(model_id="gemini-2.0-flash"), # OllamaProvider(model_id="gemma3:12b") @@ -47,9 +47,9 @@ dataset.generate(providers) # Optional: Push to Hugging Face Hub -USERNAME = "username" # <--- Your hugging face username -DATASET_NAME = "dataset_name" # <--- Your hugging face dataset name +USERNAME = "patrickfleith" # <--- Your hugging face username +DATASET_NAME = "datafast_quickstart_no_train_test_split" # <--- Your hugging face dataset name dataset.push_to_hub( repo_id=f"{USERNAME}/{DATASET_NAME}", - train_size=0.6 + # train_size=0.6 ) diff --git a/datafast/examples/show_dataset_examples.py b/datafast/examples/show_dataset_examples.py index c58bd0b..a8ea3a2 100644 --- a/datafast/examples/show_dataset_examples.py +++ b/datafast/examples/show_dataset_examples.py @@ -38,7 +38,7 @@ text="The trail is blocked by a fallen tree.", label="trail_obstruction", model_id="gpt-4.1-nano", - metadata={"language": "en"}, + language="en", ) classification_dataset = ClassificationDataset( ClassificationDatasetConfig(classes=[{"name": "trail_obstruction", "description": "Obstruction on the trail."}]) @@ -47,7 +47,7 @@ text="The trail is well maintained and easy to follow.", label="positive_conditions", model_id="claude-3-5-haiku-latest", - metadata={"language": "en"}, + language="en", ) classification_dataset.data_rows = [classification_row, classification_row2] @@ -60,7 +60,7 @@ incorrect_answer_2="Berlin", incorrect_answer_3="Rome", model_id="gemini-2.0-flash", - metadata={"language": "en"}, + language="en", ) mcq_config = MCQDatasetConfig( text_column="source_document", @@ -75,7 +75,7 @@ incorrect_answer_2="Yangtze River", incorrect_answer_3="Mississippi River", model_id="gpt-4.1-nano", - metadata={"language": "en"}, + language="en", ) mcq_dataset.data_rows = [mcq_row, mcq_row2] @@ -91,7 +91,7 @@ rejected_response_score=3, chosen_response_assessment="Accurate and detailed.", rejected_response_assessment="Too generic.", - metadata={"language": "en"}, + language="en", ) preference_dataset = PreferenceDataset(PreferenceDatasetConfig(input_documents=["Describe a recent Mars mission."])) preference_row2 = PreferenceRow( @@ -105,7 +105,7 @@ rejected_response_score=2, chosen_response_assessment="Factually correct and detailed.", rejected_response_assessment="Incorrect mission.", - metadata={"language": "en"}, + language="en", ) preference_dataset.data_rows = [preference_row, preference_row2] @@ -134,20 +134,21 @@ persona="space policy expert", messages=[{"role": "user", "content": "What are current efforts to clean up space debris?"}, {"role": "assistant", "content": "There are several ongoing projects, such as RemoveDEBRIS and ClearSpace-1."}], model_id="gpt-4.1-nano", - metadata={"language": "en"} + language="en" ) ultrachat_row2 = ChatRow( opening_question="What is the importance of the Moon missions?", persona="lunar geologist", messages=[{"role": "user", "content": "Why do we keep returning to the Moon?"}, {"role": "assistant", "content": "The Moon offers scientific insights and is a stepping stone for Mars exploration."}], model_id="gemini-2.0-flash", - metadata={"language": "en"} + language="en" ) ultrachat_config = UltrachatDatasetConfig() ultrachat_dataset = UltrachatDataset(ultrachat_config) ultrachat_dataset.data_rows = [ultrachat_row1, ultrachat_row2] if __name__ == "__main__": + # pass # print("Showing ClassificationDataset example...") # inspect_classification_dataset(classification_dataset) # print("Showing MCQDataset example...") @@ -156,5 +157,5 @@ # inspect_preference_dataset(preference_dataset) # print("Showing RawDataset example...") # inspect_raw_dataset(raw_dataset) - # print("Showing UltrachatDataset example...") - # inspect_ultrachat_dataset(ultrachat_dataset) + print("Showing UltrachatDataset example...") + inspect_ultrachat_dataset(ultrachat_dataset) \ No newline at end of file diff --git a/datafast/inspectors.py b/datafast/inspectors.py index f2bd81f..517f957 100644 --- a/datafast/inspectors.py +++ b/datafast/inspectors.py @@ -353,8 +353,8 @@ def format_conversation(self, conversation: List[Dict]) -> str: def show_example(self, idx: int) -> Tuple[str, str, str, Dict]: """Extract data from the example to display.""" row = self.get_example(idx) - conversation = row.get("conversation", []) - formatted_convo = self.format_conversation(conversation) + messages = row.get("messages", []) + formatted_convo = self.format_conversation(messages) return ( self.get_index_label(idx), formatted_convo, diff --git a/datafast/schema/data_rows.py b/datafast/schema/data_rows.py index b5fdd40..922a1e0 100644 --- a/datafast/schema/data_rows.py +++ b/datafast/schema/data_rows.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field from uuid import UUID, uuid4 -from typing import Union, Optional +from typing import Union from enum import Enum @@ -33,7 +33,8 @@ class TextRow(BaseModel): text: str text_source: TextSource = TextSource.SYNTHETIC - model_id: Optional[str] = None + model_id: str | None = None + language: str | None = None uuid: UUID = Field(default_factory=uuid4) metadata: dict[str, str] = Field(default_factory=dict) @@ -43,7 +44,8 @@ class ChatRow(BaseModel): opening_question: str messages: list[dict[str, str]] - model_id: Optional[str] = None + model_id: str | None = None + language: str | None = None uuid: UUID = Field(default_factory=uuid4) metadata: dict[str, str] = Field(default_factory=dict) persona: str @@ -52,9 +54,10 @@ class ChatRow(BaseModel): class TextClassificationRow(BaseModel): text: str label: LabelType # Must be either str, list[str], or list[int] - model_id: Optional[str] = None + model_id: str | None = None label_source: LabelSource = LabelSource.SYNTHETIC - confidence_scores: Optional[dict[str, float]] = Field(default_factory=dict) + confidence_scores: dict[str, float] | None = Field(default_factory=dict) + language: str | None = None # System and metadata fields uuid: UUID = Field(default_factory=uuid4) @@ -68,8 +71,9 @@ class MCQRow(BaseModel): incorrect_answer_1: str incorrect_answer_2: str incorrect_answer_3: str - model_id: Optional[str] = None + model_id: str | None = None mcq_source: MCQSource = MCQSource.SYNTHETIC + language: str | None = None uuid: UUID = Field(default_factory=uuid4) metadata: dict[str, str] = Field(default_factory=dict) @@ -89,14 +93,15 @@ class PreferenceRow(BaseModel): chosen_response: str rejected_response: str preference_source: PreferenceSource = PreferenceSource.SYNTHETIC - chosen_model_id: Optional[str] = None - rejected_model_id: Optional[str] = None + chosen_model_id: str | None = None + rejected_model_id: str | None = None + language: str | None = None # Optional judge-related fields - chosen_response_score: Optional[int] = None - rejected_response_score: Optional[int] = None - chosen_response_assessment: Optional[str] = None - rejected_response_assessment: Optional[str] = None + chosen_response_score: int | None = None + rejected_response_score: int | None = None + chosen_response_assessment: str | None = None + rejected_response_assessment: str | None = None uuid: UUID = Field(default_factory=uuid4) metadata: dict[str, str] = Field(default_factory=dict) \ No newline at end of file