Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
*.py[cod]
*$py.class
LAUNCH.md
projects/*

# C extensions
*.so
Expand Down
38 changes: 29 additions & 9 deletions datafast/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,17 +833,37 @@ def generate(self, llms: list[LLMProvider]) -> "MCQDataset":
if len(document.strip()) > self.config.max_document_length: # Skip very long documents
continue

# Check if context_column exists and extract context if available
context = None
if self.config.context_column and self.config.context_column in sample:
context = sample[self.config.context_column]

for lang_code, language_name in languages.items():
# 1. First call: Generate questions and correct answers
question_prompts = self.config.prompts or self._get_default_prompts()
question_prompts = [
prompt.format(
num_samples=self.config.num_samples_per_prompt,
language_name=language_name,
document=document,
)
for prompt in question_prompts
]
if context and isinstance(context, str):
# Use contextualized templates if context is available
from datafast.prompts.mcq_prompts import CONTEXTUALISED_TEMPLATES
question_prompts = self.config.prompts or CONTEXTUALISED_TEMPLATES
question_prompts = [
prompt.format(
num_samples=self.config.num_samples_per_prompt,
language_name=language_name,
document=document,
context=context
)
for prompt in question_prompts
]
else:
# Use default templates if no context is available
question_prompts = self.config.prompts or self._get_default_prompts()
question_prompts = [
prompt.format(
num_samples=self.config.num_samples_per_prompt,
language_name=language_name,
document=document,
)
for prompt in question_prompts
]

# Expand prompts with configured variations
question_expansions = expand_prompts(
Expand Down
85 changes: 85 additions & 0 deletions datafast/examples/mcq_contextual_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Example script for generating MCQ questions from the AR6 dataset using context information.
This script demonstrates the use of the context_column parameter to enhance question generation.
"""

import os
import json
import random
from pathlib import Path
import pandas as pd
from dotenv import load_dotenv

load_dotenv("secrets.env")

from datafast.schema.config import MCQDatasetConfig
from datafast.datasets import MCQDataset
from datafast.llms import OpenAIProvider

def main():
# 1. Create a temporary filtered version of the dataset
ar6_file_path = Path("datafast/examples/data/mcq/ar6.jsonl")
filtered_file_path = Path("datafast/examples/data/mcq/ar6_filtered.jsonl")

# Read the ar6.jsonl file
with open(ar6_file_path, "r") as f:
data = [json.loads(line) for line in f if line.strip()]

# Filter for rows where chunk_grade is "OK" or "GREAT"
filtered_data = [row for row in data if row.get("chunk_grade") in ["OK", "GREAT"]]

# Randomly select 10 examples
selected_data = random.sample(filtered_data, min(10, len(filtered_data)))

# Write the selected data to a temporary file
with open(filtered_file_path, "w") as f:
for row in selected_data:
f.write(json.dumps(row) + "\n")

print(f"Selected {len(selected_data)} examples from AR6 dataset")

# 2. Create MCQ dataset config
config = MCQDatasetConfig(
local_file_path=str(filtered_file_path),
text_column="chunk_text", # Column containing the text to generate questions from
context_column="document_summary", # Column containing context information
num_samples_per_prompt=2, # Generate 2 questions per document
min_document_length=100, # Skip documents shorter than 100 chars
max_document_length=20000, # Skip documents longer than 20000 chars
sample_count=len(selected_data), # Number of samples to process
output_file="mcq_ar6_contextual_dataset.jsonl",
)

# 3. Initialize OpenAI provider with gpt-4.1-mini
providers = [
OpenAIProvider(model_id="gpt-4.1-mini"),
]

# 4. Generate the dataset
dataset = MCQDataset(config)
num_expected_rows = dataset.get_num_expected_rows(providers, source_data_num_rows=len(selected_data))
print(f"\nExpected number of rows: {num_expected_rows}")
dataset.generate(providers)

# 5. Print results summary
print(f"\nGenerated {len(dataset.data_rows)} MCQs")
print(f"Results saved to {config.output_file}")

# 6. Cleanup temporary file
os.remove(filtered_file_path)
print(f"Cleaned up temporary file {filtered_file_path}")
# # 5. Optional: Push to HF hub
# USERNAME = "your_username" # <--- Your hugging face username
# DATASET_NAME = "your_dataset_name" # <--- Your hugging face dataset name
# url = dataset.push_to_hub(
# repo_id=f"{USERNAME}/{DATASET_NAME}",
# train_size=0.7,
# shuffle=True,
# upload_card=True,
# )
# print(f"\nDataset pushed to Hugging Face Hub: {url}")

dataset.inspect()

if __name__ == "__main__":
main()
40 changes: 36 additions & 4 deletions datafast/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Type, TypeVar
from abc import ABC, abstractmethod
import os
import time
import traceback

# Pydantic
Expand Down Expand Up @@ -35,6 +36,7 @@ def __init__(
max_completion_tokens: int | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
rpm_limit: int | None = None,
):
"""Initialize the LLM provider with common parameters.

Expand All @@ -54,6 +56,10 @@ def __init__(
self.max_completion_tokens = max_completion_tokens
self.top_p = top_p
self.frequency_penalty = frequency_penalty

# Rate limiting
self.rpm_limit = rpm_limit
self._request_timestamps: list[float] = []

# Configure environment with API key if needed
self._configure_env()
Expand Down Expand Up @@ -88,6 +94,23 @@ def _configure_env(self) -> None:
def _get_model_string(self) -> str:
"""Get the full model string for LiteLLM."""
return f"{self.provider_name}/{self.model_id}"

def _respect_rate_limit(self) -> None:
"""Block execution to ensure we do not exceed the rpm_limit."""
if self.rpm_limit is None:
return
current = time.monotonic()
# Keep only timestamps within the last minute
self._request_timestamps = [ts for ts in self._request_timestamps if current - ts < 60]
if len(self._request_timestamps) < self.rpm_limit:
return
# Need to wait until the earliest request is outside the 60-second window
earliest = self._request_timestamps[0]
# Add a 1s margin to avoid accidental rate limit exceedance
sleep_time = 61 - (current - earliest)
if sleep_time > 0:
print("Waiting for rate limit...")
time.sleep(sleep_time)

def generate(
self,
Expand Down Expand Up @@ -122,6 +145,8 @@ def generate(
else:
messages_to_send = messages

# Enforce rate limit if set
self._respect_rate_limit()
# Prepare completion parameters
completion_params = {
"model": self._get_model_string(),
Expand All @@ -138,6 +163,9 @@ def generate(

# Call LiteLLM completion
response: ModelResponse = litellm.completion(**completion_params)
# Record timestamp for rate limiting
if self.rpm_limit is not None:
self._request_timestamps.append(time.monotonic())

# Extract content from response
content = response.choices[0].message.content
Expand Down Expand Up @@ -172,7 +200,7 @@ def __init__(
max_completion_tokens: int | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
):
):
"""Initialize the OpenAI provider.

Args:
Expand Down Expand Up @@ -212,7 +240,7 @@ def __init__(
max_completion_tokens: int | None = None,
top_p: float | None = None,
# frequency_penalty: float | None = None, # Not supported by anthropic
):
):
"""Initialize the Anthropic provider.

Args:
Expand Down Expand Up @@ -250,7 +278,8 @@ def __init__(
max_completion_tokens: int | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
):
rpm_limit: int | None = None,
):
"""Initialize the Gemini provider.

Args:
Expand All @@ -268,6 +297,7 @@ def __init__(
max_completion_tokens=max_completion_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
rpm_limit=rpm_limit,
)


Expand Down Expand Up @@ -301,7 +331,8 @@ def __init__(
top_p: float | None = None,
frequency_penalty: float | None = None,
api_base: str | None = None,
):
rpm_limit: int | None = None,
):
"""Initialize the Ollama provider.

Args:
Expand All @@ -323,4 +354,5 @@ def __init__(
max_completion_tokens=max_completion_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
rpm_limit=rpm_limit,
)
25 changes: 22 additions & 3 deletions datafast/prompts/mcq_prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
DEFAULT_TEMPLATES = [
"""You are an expert at creating exam questions. Your task is to come up with {num_samples} \
DEFAULT_TEMPLATES = ["""You are an expert at creating exam questions. Your task is to come up with {num_samples} \
difficult multiple choice questions written in {language_name} in relation to the following document along with the correct answer.
The question should be self-contained, short and answerable.
It is very important to have unique questions. No questions should be like 'what is X and what about Y?' or 'what is X and when did Y happen?'.
Expand All @@ -11,7 +10,27 @@

Now come up with {num_samples} questions in relation to the document.
Make sure the questions are difficult, but answerable with a short answer.
Provide the correct answer for each question."""
Provide the correct answer for each question."""]

# Template for more contextualized questions
CONTEXTUALISED_TEMPLATES = ["""You are an expert at creating exam questions. Your task is to come up with {num_samples} \
multiple choice questions written in {language_name} in relation to the following document along with the correct answer.
The question should be self-contained, short and answerable.
It is very important to have unique questions. No questions should be like 'what is X and what about Y?' or 'what is X and when did Y happen?'.
The answer must be short.
It must relate to the details of the document. However questions should never contain wording as reference to the document like "according to the report, 'in this paper', 'in the document', etc.
Make sure to write the questions to include some very brief context, like if the person asking the questions would be explaining the context in which the question arise very concisely. This is just to remove ambiguity like if the question was provided in an exam.

### Context
{context}

### Document
{document}

Now come up with {num_samples} contextualized questions in relation to the document.
Make sure the questions are difficult, but answerable with a short answer.
Provide the correct answer for each question.
"""
]

DISTRACTOR_TEMPLATE = """
Expand Down
6 changes: 6 additions & 0 deletions datafast/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ class MCQDatasetConfig(BaseModel):
description="Column name containing the text to generate questions from"
)

context_column: str | None = Field(
default=None,
description="Optional column name containing contextual information to enhance question generation. \
When provided, questions will be generated with this contextual information."
)

# MCQ Generation parameters
num_samples_per_prompt: int = Field(
default=3,
Expand Down
3 changes: 3 additions & 0 deletions docs/guides/generating_mcq_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ The `MCQDatasetConfig` class defines all parameters for your MCQ dataset generat
- **`hf_dataset_name`**: (Optional) Name of a Hugging Face dataset to use as source material
- **`local_file_path`**: (Optional) Path to a local file to use as source material
- **`text_column`**: (Required) Column name containing the text to generate questions from
- **`context_column`**: (Optional) Column name containing contextual information to enhance question generation with domain-specific context
- **`num_samples_per_prompt`**: Number of questions to generate for each text document
- **`sample_count`**: (Optional) Number of samples to process from the source text dataset (useful for testing)
- **`min_document_length`**: Minimum text length (in characters) for processing (skips shorter documents)
Expand Down Expand Up @@ -295,6 +296,7 @@ def main():
# train_size=0.7,
# seed=42,
# shuffle=True,
# upload_card=True,
# )
# print(f"\nDataset pushed to Hugging Face Hub: {url}")

Expand Down Expand Up @@ -322,3 +324,4 @@ Each generated question is stored as an `MCQRow` with these properties:
3. **Model Selection**: Larger, more capable models generally produce better questions and answers.
4. **Validation**: Review a sample of the generated questions to ensure quality and accuracy, then edit prompt.
5. **Start Small**: Begin with a small sample_count to test the configuration before scaling up.
6. **Use Context**: When available, use the `context_column` parameter to provide additional domain-specific context that helps generate more self-contained questions. Good contexts include document summaries, topic descriptions.
17 changes: 16 additions & 1 deletion tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,22 @@ def test_gemini_provider():
response = provider.generate(prompt="What is the capital of France? Answer in one word.")
assert "Paris" in response


@pytest.mark.slow
@pytest.mark.integration
def test_gemini_rpm_limit_real():
"""Test GeminiProvider RPM limit (15 requests/minute) is enforced with real waiting."""
import time
prompts_count = 17
rpm = 15
provider = GeminiProvider(model_id="gemini-2.5-flash-lite-preview-06-17", rpm_limit=rpm)
prompts = [f"Test request {i}" for i in range(prompts_count)]
start = time.monotonic()
for prompt in prompts:
provider.generate(prompt=prompt)
elapsed = time.monotonic() - start
# 17 requests, rpm=15, donc on doit attendre au moins ~60s pour les 2 requêtes au-delà de la limite
assert elapsed >= 59, f"Elapsed time too short for RPM limit: {elapsed:.2f}s for {prompts_count} requests with rpm={rpm}"

@pytest.mark.integration
def test_openai_structured_output():
"""Test the OpenAI provider with structured output."""
Expand Down
Loading