Skip to content

Commit a8e3799

Browse files
Merge pull request #101 from patrickfleith/feature/moving-language-code-in-a-column
Feature/moving language code in a column
2 parents 672d44c + f7c2052 commit a8e3799

File tree

5 files changed

+40
-34
lines changed

5 files changed

+40
-34
lines changed

datafast/datasets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def generate(self, llms: list[LLMProvider]) -> "ClassificationDataset":
416416
label=label["name"],
417417
model_id=llm.model_id,
418418
label_source=LabelSource.SYNTHETIC,
419-
metadata={"language": lang_code},
419+
language=lang_code,
420420
)
421421
self.data_rows.append(row)
422422
new_rows.append(row)
@@ -510,8 +510,8 @@ def generate(self, llms: list[LLMProvider]) -> "RawDataset":
510510
text=example.text,
511511
text_source=TextSource.SYNTHETIC,
512512
model_id=llm.model_id,
513+
language=lang_code,
513514
metadata={
514-
"language": lang_code,
515515
"document_type": document_type,
516516
"topic": topic,
517517
},
@@ -686,8 +686,8 @@ def generate(self, llms: list[LLMProvider]) -> "UltrachatDataset":
686686
opening_question=messages[0]["content"],
687687
messages=messages,
688688
model_id=llm.model_id,
689+
language=lang_code,
689690
metadata={
690-
"language": lang_code,
691691
"domain": self.config.domain,
692692
"topic": topic,
693693
"subtopic": subtopic,
@@ -915,8 +915,8 @@ def generate(self, llms: list[LLMProvider]) -> "MCQDataset":
915915
incorrect_answer_3=incorrect_answers[2],
916916
model_id=llm.model_id,
917917
mcq_source=MCQSource.SYNTHETIC,
918+
language=lang_code,
918919
metadata={
919-
"language": lang_code,
920920
"source_dataset": self._get_source_dataset_name(),
921921
},
922922
)
@@ -1092,8 +1092,8 @@ def generate(self,
10921092
"preference_source": PreferenceSource.SYNTHETIC,
10931093
"chosen_model_id": chosen_model_id,
10941094
"rejected_model_id": rejected_model_id,
1095+
"language": lang_code,
10951096
"metadata": {
1096-
"language": lang_code,
10971097
"instruction_model": question_gen_llm.model_id,
10981098
}
10991099
}

datafast/examples/quickstart_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
],
2626
expansion=PromptExpansionConfig(
2727
placeholders={
28-
"context": ["hike review", "speedboat tour review", "outdoor climbing experience"],
28+
"context": ["hike review", "speedboat tour review"],
2929
},
3030
combinatorial=True
3131
)
@@ -34,7 +34,7 @@
3434
from datafast.llms import OpenAIProvider, AnthropicProvider, GeminiProvider, OllamaProvider
3535

3636
providers = [
37-
OpenAIProvider(model_id="gpt-4.1-nano"),
37+
OpenAIProvider(model_id="gpt-5-nano-2025-08-07"),
3838
# AnthropicProvider(model_id="claude-3-5-haiku-latest"),
3939
# GeminiProvider(model_id="gemini-2.0-flash"),
4040
# OllamaProvider(model_id="gemma3:12b")
@@ -47,9 +47,9 @@
4747
dataset.generate(providers)
4848

4949
# Optional: Push to Hugging Face Hub
50-
USERNAME = "username" # <--- Your hugging face username
51-
DATASET_NAME = "dataset_name" # <--- Your hugging face dataset name
50+
USERNAME = "patrickfleith" # <--- Your hugging face username
51+
DATASET_NAME = "datafast_quickstart_no_train_test_split" # <--- Your hugging face dataset name
5252
dataset.push_to_hub(
5353
repo_id=f"{USERNAME}/{DATASET_NAME}",
54-
train_size=0.6
54+
# train_size=0.6
5555
)

datafast/examples/show_dataset_examples.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
text="The trail is blocked by a fallen tree.",
3939
label="trail_obstruction",
4040
model_id="gpt-4.1-nano",
41-
metadata={"language": "en"},
41+
language="en",
4242
)
4343
classification_dataset = ClassificationDataset(
4444
ClassificationDatasetConfig(classes=[{"name": "trail_obstruction", "description": "Obstruction on the trail."}])
@@ -47,7 +47,7 @@
4747
text="The trail is well maintained and easy to follow.",
4848
label="positive_conditions",
4949
model_id="claude-3-5-haiku-latest",
50-
metadata={"language": "en"},
50+
language="en",
5151
)
5252
classification_dataset.data_rows = [classification_row, classification_row2]
5353

@@ -60,7 +60,7 @@
6060
incorrect_answer_2="Berlin",
6161
incorrect_answer_3="Rome",
6262
model_id="gemini-2.0-flash",
63-
metadata={"language": "en"},
63+
language="en",
6464
)
6565
mcq_config = MCQDatasetConfig(
6666
text_column="source_document",
@@ -75,7 +75,7 @@
7575
incorrect_answer_2="Yangtze River",
7676
incorrect_answer_3="Mississippi River",
7777
model_id="gpt-4.1-nano",
78-
metadata={"language": "en"},
78+
language="en",
7979
)
8080
mcq_dataset.data_rows = [mcq_row, mcq_row2]
8181

@@ -91,7 +91,7 @@
9191
rejected_response_score=3,
9292
chosen_response_assessment="Accurate and detailed.",
9393
rejected_response_assessment="Too generic.",
94-
metadata={"language": "en"},
94+
language="en",
9595
)
9696
preference_dataset = PreferenceDataset(PreferenceDatasetConfig(input_documents=["Describe a recent Mars mission."]))
9797
preference_row2 = PreferenceRow(
@@ -105,7 +105,7 @@
105105
rejected_response_score=2,
106106
chosen_response_assessment="Factually correct and detailed.",
107107
rejected_response_assessment="Incorrect mission.",
108-
metadata={"language": "en"},
108+
language="en",
109109
)
110110
preference_dataset.data_rows = [preference_row, preference_row2]
111111

@@ -134,20 +134,21 @@
134134
persona="space policy expert",
135135
messages=[{"role": "user", "content": "What are current efforts to clean up space debris?"}, {"role": "assistant", "content": "There are several ongoing projects, such as RemoveDEBRIS and ClearSpace-1."}],
136136
model_id="gpt-4.1-nano",
137-
metadata={"language": "en"}
137+
language="en"
138138
)
139139
ultrachat_row2 = ChatRow(
140140
opening_question="What is the importance of the Moon missions?",
141141
persona="lunar geologist",
142142
messages=[{"role": "user", "content": "Why do we keep returning to the Moon?"}, {"role": "assistant", "content": "The Moon offers scientific insights and is a stepping stone for Mars exploration."}],
143143
model_id="gemini-2.0-flash",
144-
metadata={"language": "en"}
144+
language="en"
145145
)
146146
ultrachat_config = UltrachatDatasetConfig()
147147
ultrachat_dataset = UltrachatDataset(ultrachat_config)
148148
ultrachat_dataset.data_rows = [ultrachat_row1, ultrachat_row2]
149149

150150
if __name__ == "__main__":
151+
# pass
151152
# print("Showing ClassificationDataset example...")
152153
# inspect_classification_dataset(classification_dataset)
153154
# print("Showing MCQDataset example...")
@@ -156,5 +157,5 @@
156157
# inspect_preference_dataset(preference_dataset)
157158
# print("Showing RawDataset example...")
158159
# inspect_raw_dataset(raw_dataset)
159-
# print("Showing UltrachatDataset example...")
160-
# inspect_ultrachat_dataset(ultrachat_dataset)
160+
print("Showing UltrachatDataset example...")
161+
inspect_ultrachat_dataset(ultrachat_dataset)

datafast/inspectors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ def format_conversation(self, conversation: List[Dict]) -> str:
353353
def show_example(self, idx: int) -> Tuple[str, str, str, Dict]:
354354
"""Extract data from the example to display."""
355355
row = self.get_example(idx)
356-
conversation = row.get("conversation", [])
357-
formatted_convo = self.format_conversation(conversation)
356+
messages = row.get("messages", [])
357+
formatted_convo = self.format_conversation(messages)
358358
return (
359359
self.get_index_label(idx),
360360
formatted_convo,

datafast/schema/data_rows.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pydantic import BaseModel, Field
22
from uuid import UUID, uuid4
3-
from typing import Union, Optional
3+
from typing import Union
44
from enum import Enum
55

66

@@ -33,7 +33,8 @@ class TextRow(BaseModel):
3333

3434
text: str
3535
text_source: TextSource = TextSource.SYNTHETIC
36-
model_id: Optional[str] = None
36+
model_id: str | None = None
37+
language: str | None = None
3738
uuid: UUID = Field(default_factory=uuid4)
3839
metadata: dict[str, str] = Field(default_factory=dict)
3940

@@ -43,7 +44,8 @@ class ChatRow(BaseModel):
4344

4445
opening_question: str
4546
messages: list[dict[str, str]]
46-
model_id: Optional[str] = None
47+
model_id: str | None = None
48+
language: str | None = None
4749
uuid: UUID = Field(default_factory=uuid4)
4850
metadata: dict[str, str] = Field(default_factory=dict)
4951
persona: str
@@ -52,9 +54,10 @@ class ChatRow(BaseModel):
5254
class TextClassificationRow(BaseModel):
5355
text: str
5456
label: LabelType # Must be either str, list[str], or list[int]
55-
model_id: Optional[str] = None
57+
model_id: str | None = None
5658
label_source: LabelSource = LabelSource.SYNTHETIC
57-
confidence_scores: Optional[dict[str, float]] = Field(default_factory=dict)
59+
confidence_scores: dict[str, float] | None = Field(default_factory=dict)
60+
language: str | None = None
5861

5962
# System and metadata fields
6063
uuid: UUID = Field(default_factory=uuid4)
@@ -68,8 +71,9 @@ class MCQRow(BaseModel):
6871
incorrect_answer_1: str
6972
incorrect_answer_2: str
7073
incorrect_answer_3: str
71-
model_id: Optional[str] = None
74+
model_id: str | None = None
7275
mcq_source: MCQSource = MCQSource.SYNTHETIC
76+
language: str | None = None
7377
uuid: UUID = Field(default_factory=uuid4)
7478
metadata: dict[str, str] = Field(default_factory=dict)
7579

@@ -89,14 +93,15 @@ class PreferenceRow(BaseModel):
8993
chosen_response: str
9094
rejected_response: str
9195
preference_source: PreferenceSource = PreferenceSource.SYNTHETIC
92-
chosen_model_id: Optional[str] = None
93-
rejected_model_id: Optional[str] = None
96+
chosen_model_id: str | None = None
97+
rejected_model_id: str | None = None
98+
language: str | None = None
9499

95100
# Optional judge-related fields
96-
chosen_response_score: Optional[int] = None
97-
rejected_response_score: Optional[int] = None
98-
chosen_response_assessment: Optional[str] = None
99-
rejected_response_assessment: Optional[str] = None
101+
chosen_response_score: int | None = None
102+
rejected_response_score: int | None = None
103+
chosen_response_assessment: str | None = None
104+
rejected_response_assessment: str | None = None
100105

101106
uuid: UUID = Field(default_factory=uuid4)
102107
metadata: dict[str, str] = Field(default_factory=dict)

0 commit comments

Comments
 (0)