Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,6 @@ data/test_uploads/

# Training pipeline artifacts
training_data/

# Superpowers runtime state
.superpowers/
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Design Spec: Fine-Tuning for Student Explainability

**Date:** 2026-04-02
**Date:** 2026-04-02 (updated 2026-04-03)
**Epic label:** `fine-tuning: student-explainability`
**Epic branch:** `fine-tuning/student-explainability`
**Status:** Draft
Expand All @@ -9,7 +9,7 @@

## 1. Goal

Fine-tune a small language model (Qwen 3.5) on Bishop State domain data to replace GPT-4o-mini for three inference tasks in the dashboard. The primary value is improved explainability: advisors get SHAP-grounded, institution-aware narratives instead of templated rule-engine output. Secondary benefits include FERPA compliance (all inference on-premises), offline deployment, and institutional scalability.
Fine-tune a small language model on Bishop State domain data to replace GPT-4o-mini for three inference tasks in the dashboard. Two candidate models will be evaluated head-to-head: **Qwen 3.5-4B** (proven by D4BL) and **Gemma 4 E4B** (native structured JSON output). The primary value is improved explainability: advisors get SHAP-grounded, institution-aware narratives instead of templated rule-engine output. Secondary benefits include FERPA compliance (all inference on-premises), offline deployment, and institutional scalability.

### Tasks to Fine-Tune

Expand All @@ -24,6 +24,34 @@ Fine-tune a small language model (Qwen 3.5) on Bishop State domain data to repla
- Query Analyzer (NL → SQL) — high risk, deferred to future epic
- Model serving infrastructure (RunPod, dedicated GPU hosting) — use local Ollama for now

### Model Selection: Qwen 3.5-4B vs Gemma 4 E4B

Two candidate models will be trained and evaluated. The winner is selected based on ship criteria metrics.

| | **Qwen 3.5-4B** | **Gemma 4 E4B** |
|---|---|---|
| Effective params | 4B | 4.5B (8B with embeddings) |
| GGUF size (q4_k_m) | ~2.7 GB | ~5 GB |
| Context window | 32K | 128K |
| Native JSON output | No | Yes — built-in structured function calling |
| Ollama support | Text-only (vision mmproj broken) | Full support |
| Unsloth LoRA | Yes, but QLoRA discouraged — use bf16 LoRA | Yes, bf16 LoRA recommended |
| VRAM for training (bf16) | 10 GB | ~10 GB |
| License | Apache 2.0 | Apache 2.0 |
| D4BL proven | Yes (5 experiments, 98.77% schema validity) | No (new model) |

**Why these two:**
- Qwen 3.5-4B is the known quantity — D4BL proved the full pipeline (distill → train → GGUF → Ollama) with 5 experiments and achieved 98.77% schema validity on structured output tasks.
- Gemma 4 E4B has native structured JSON output before fine-tuning, 128K context (headroom for SHAP-heavy narrator prompts), and full Ollama GGUF support without the mmproj workaround.

**Why not others:**
- Qwen 3.5-9B: 22 GB VRAM for training, larger GGUF (~5.5 GB), marginal quality gain over 4B for our task complexity.
- Qwen 3-4B: #1 fine-tuning benchmark, but lacks 3.5's architecture improvements.
- Llama 3.2-3B: Best tunability gain, but Meta's community license (700M MAU limit) is restrictive for educational institutions.
- Phi-4-mini: Strong on math, but less proven for structured output and limited Unsloth support.

**Training note:** Unsloth explicitly discourages QLoRA (4-bit) on Qwen 3.5 due to quantization artifacts. Both models will use bf16 LoRA on A100 (40 GB VRAM is sufficient for either).

## 2. Prerequisites

Before the epic branch is created:
Expand Down Expand Up @@ -87,7 +115,7 @@ Before the epic branch is created:
| 3 | Build Colab training notebook (Unsloth + LoRA) | Single "Run All" notebook, parameterized config, 3-phase training, GGUF export. Replace `training/finetune.py` (MLX) with Unsloth wrapper. | #1 | `type:feature`, `area:ai` |
| 4 | Distill training pairs for summarizer and explainer | Run distillation for both existing tasks (~1,500 pairs each via Claude API). Prepare datasets. | #1 | `type:feature`, `area:ai` |
| 5 | Distill training pairs for SHAP narrator | Generate ~1,500 SHAP narrator pairs from student data + SHAP values. Requires SHAP data in DB. | #2 | `type:feature`, `area:ai` |
| 6 | Train and evaluate 4B + 9B models | Run Colab notebook for both model sizes. Evaluate via ship criteria. Compare metrics, pick winner. | #3, #4, #5 | `type:spike`, `area:ai` |
| 6 | Train and evaluate Qwen 3.5-4B + Gemma 4 E4B | Run Colab notebook for both models. Evaluate via ship criteria. Compare metrics, pick winner. | #3, #4, #5 | `type:spike`, `area:ai` |
| 7 | Export models and wire into dashboard | GGUF export, Ollama registration, wire `model-client.ts` into consumer routes, update `enrich_with_llm` model string. | #6 | `type:feature`, `area:ai`, `area:frontend` |
| 8 | Update documentation and feasibility report | Update feasibility report with actual results, update README and CLAUDE.md. | #6 | `type:documentation` |

Expand All @@ -110,7 +138,10 @@ Issues #2, #3, and #4 can proceed concurrently after #1. Issue #5 waits only on
Cell 1: Configuration (ONLY cell the user edits)
-------------------------------------------------
SCHOOL = "bishop-state"
MODEL_SIZES = ["4b", "9b"]
MODELS = [
{"name": "qwen3.5-4b", "hf_id": "Qwen/Qwen3.5-4B"},
{"name": "gemma4-e4b", "hf_id": "google/gemma-4-e4b-it"},
]
REPO_URL = "https://github.com/codebenders/datathon.git"
REPO_BRANCH = "fine-tuning/student-explainability"
HF_TOKEN = "" # or userdata.get('HF_TOKEN')
Expand All @@ -123,9 +154,9 @@ Cell 2+: Fully autonomous
- GPU detection + validation (assert A100/T4/L4)
- pip install unsloth, trl, peft
- Clone repo, load schools/{SCHOOL}/config.yaml
- For each model size:
- For each model in MODELS:
- Phase 1: Domain adaptation
- Load base Qwen model via Unsloth (4-bit NF4)
- Load base model via Unsloth (bf16 LoRA — no QLoRA for Qwen 3.5)
- Train on training_data/{school}/domain.jsonl
- LoRA rank 16, all modules, 1 epoch, lr 2e-4, effective batch 32
- Save merged checkpoint
Expand All @@ -139,13 +170,13 @@ Cell 2+: Fully autonomous
- Phase 3: GGUF export
- Quantize each task adapter to q4_k_m
- Upload to Google Drive (or HF Hub if HF_TOKEN provided)
- Print comparison table: 4B vs 9B metrics across all tasks
- Print comparison table: Qwen 3.5-4B vs Gemma 4 E4B metrics across all tasks
- Recommend winner based on ship criteria
```

### Training Hyperparameters

Based on D4BL's proven configurations:
Based on D4BL's proven configurations. Both models use bf16 LoRA (not QLoRA) — Unsloth discourages 4-bit QLoRA on Qwen 3.5 due to quantization artifacts, and Gemma 4 also recommends bf16 LoRA.

| Parameter | Phase 1 (Domain) | Phase 2 (Tasks) |
|-----------|------------------|-----------------|
Expand All @@ -158,6 +189,7 @@ Based on D4BL's proven configurations:
| Max sequence length | 4096 | 4096-8192 |
| Optimizer | AdamW 8-bit | AdamW 8-bit |
| Precision | bf16 (A100) | bf16 (A100) |
| Quantization during training | None (bf16 LoRA) | None (bf16 LoRA) |

### What the Notebook Does NOT Do

Expand Down Expand Up @@ -253,12 +285,12 @@ This is the highest-value task — it transforms per-student SHAP attribution da
### Ollama Model Naming

```
bishop-state-narrator:{size} # SHAP narrator
bishop-state-summarizer:{size} # Query summary
bishop-state-explainer:{size} # Course pairing
bishop-state-narrator:{model} # SHAP narrator
bishop-state-summarizer:{model} # Query summary
bishop-state-explainer:{model} # Course pairing
```

Where `{size}` is `4b` or `9b` based on evaluation results.
Where `{model}` is the winning model identifier (e.g., `qwen3.5-4b` or `gemma4-e4b`) based on evaluation results.

### SHAP Narrator Integration Point

Expand All @@ -268,16 +300,16 @@ Where `{size}` is `4b` or `9b` based on evaluation results.
# Before (OpenAI)
python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model gpt-4o-mini

# After (fine-tuned)
python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollama/bishop-state-narrator:4b
# After (fine-tuned, winner TBD after evaluation)
python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollama/bishop-state-narrator
```

### Environment Variables

```env
MODEL_BACKEND=ollama # or "openai" (fallback)
OLLAMA_BASE_URL=http://localhost:11434
MODEL_SIZE=4b # set after evaluation picks winner
MODEL_TAG=qwen3.5-4b # or gemma4-e4b, set after evaluation picks winner
SCHOOL_CODE=bishop-state
```

Expand All @@ -290,16 +322,19 @@ The operator sets `MODEL_BACKEND` to either `ollama` or `openai`. There is no au
| Item | Cost |
|------|------|
| Claude API distillation (~4,500 pairs across 3 tasks) | $5-10 |
| Colab A100 compute (~4 hours for 2 model sizes) | $8-16 |
| **Total per training run** | **$13-26** |
| Iteration runs (subsequent) | $8-16 each |
| Colab A100 compute (~4-5 hours for 2 models, bf16 LoRA) | $8-20 |
| **Total per training run** | **$13-30** |
| Iteration runs (subsequent) | $8-20 each |

Note: bf16 LoRA (required for Qwen 3.5, recommended for Gemma 4) uses more VRAM than QLoRA but fits comfortably on A100 40GB. Training time may be slightly longer than D4BL's QLoRA runs.

## 8. Success Criteria

The epic is complete when:

1. All three tasks pass ship criteria on the winning model size
1. All three tasks pass ship criteria on the winning model (Qwen 3.5-4B or Gemma 4 E4B)
2. `MODEL_BACKEND=ollama` serves all three tasks in the dashboard without OpenAI
3. SHAP narrator produces grounded narratives that cite specific feature attributions
4. Feasibility report is updated with actual metrics and model selection rationale
4. Feasibility report is updated with actual metrics, model comparison, and selection rationale
5. Colab notebook is documented and reproducible (clone + Run All)
6. Model selection decision is documented with head-to-head metrics for both candidates
12 changes: 7 additions & 5 deletions tests/training/test_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from training.distill import (
validate_json,
call_teacher,
generate_explainer_pairs,
generate_summarizer_pairs,
generate_pairs,
)


Expand Down Expand Up @@ -80,10 +79,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours
})

with patch("training.distill.call_teacher", return_value=mock_response):
pairs = generate_explainer_pairs(
pairs = generate_pairs(
config=sample_school_config,
seed_data=[sample_course_pairing_data],
count=2,
task="explainer",
)

assert len(pairs) == 2
Expand All @@ -92,10 +92,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours

def test_skips_invalid_responses(self, sample_school_config, sample_course_pairing_data):
with patch("training.distill.call_teacher", return_value="not json"):
pairs = generate_explainer_pairs(
pairs = generate_pairs(
config=sample_school_config,
seed_data=[sample_course_pairing_data],
count=3,
task="explainer",
)

assert len(pairs) == 0
Expand All @@ -112,10 +113,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_query
})

with patch("training.distill.call_teacher", return_value=mock_response):
pairs = generate_summarizer_pairs(
pairs = generate_pairs(
config=sample_school_config,
seed_data=[sample_query_result_data],
count=2,
task="summarizer",
)

assert len(pairs) == 2
Expand Down
2 changes: 1 addition & 1 deletion tests/training/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_loads_valid_yaml(self, tmp_path):
def test_returns_empty_on_missing_file(self, tmp_path):
with patch("training.seed.get_school_dir", return_value=tmp_path):
result = load_seed_queries("test-school")
assert result == {"explainer": [], "summarizer": []}
assert result == {"narrator": [], "explainer": [], "summarizer": []}


class TestGenerateSyntheticCoursePairings:
Expand Down
61 changes: 33 additions & 28 deletions training/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
from training.config import get_training_data_dir, load_school_config, write_jsonl
from training.prompts import (
EXPLAINER_STUDENT_SYSTEM,
NARRATOR_STUDENT_SYSTEM,
SUMMARIZER_STUDENT_SYSTEM,
build_explainer_prompt,
build_narrator_prompt,
build_summarizer_prompt,
build_system_prompt,
)
from training.seed import (
format_as_chatml,
generate_synthetic_course_pairings,
generate_synthetic_query_results,
generate_synthetic_student_profiles,
)

# Cost tracking
Expand Down Expand Up @@ -136,6 +139,11 @@ def call_teacher(system: str, user: str, backend: str, model: str) -> str:
_FLUSH_INTERVAL = 25

_TASK_CONFIG = {
"narrator": {
"prompt_builder": build_narrator_prompt,
"student_system": NARRATOR_STUDENT_SYSTEM,
"format_user": lambda config, data: json.dumps(data, ensure_ascii=False, default=str),
},
"explainer": {
"prompt_builder": build_explainer_prompt,
"student_system": EXPLAINER_STUDENT_SYSTEM,
Expand Down Expand Up @@ -163,7 +171,7 @@ def generate_pairs(
config: Parsed school config.
seed_data: List of seed data dicts.
count: Number of pairs to generate.
task: "explainer" or "summarizer".
task: "narrator", "explainer", or "summarizer".
outfile: If provided, pairs are written incrementally.
system_prompt: Pre-built system prompt (avoids recomputation).
"""
Expand Down Expand Up @@ -213,24 +221,6 @@ def generate_pairs(
return pairs


def generate_explainer_pairs(
config: dict[str, Any], seed_data: list[dict[str, Any]],
count: int, outfile: Path | None = None,
system_prompt: str | None = None,
) -> list[dict]:
"""Generate explainer training pairs via teacher model distillation."""
return generate_pairs(config, seed_data, count, "explainer", outfile, system_prompt)


def generate_summarizer_pairs(
config: dict[str, Any], seed_data: list[dict[str, Any]],
count: int, outfile: Path | None = None,
system_prompt: str | None = None,
) -> list[dict]:
"""Generate summarizer training pairs via teacher model distillation."""
return generate_pairs(config, seed_data, count, "summarizer", outfile, system_prompt)


def main(school: str, local: bool = False) -> None:
"""Run distillation for a school."""
config = load_school_config(school)
Expand All @@ -245,28 +235,43 @@ def main(school: str, local: bool = False) -> None:
data_dir = get_training_data_dir(school)
pairs_dir = data_dir / "pairs"

synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task)
synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task)

system_prompt = build_system_prompt(config)

all_counts: dict[str, int] = {}

# Narrator
print(f"\n{'='*60}\nNARRATOR — generating {pairs_per_task} pairs\n{'='*60}")
synthetic_profiles = generate_synthetic_student_profiles(config, count=pairs_per_task)
narrator_pairs = generate_pairs(
config=config, seed_data=synthetic_profiles,
count=pairs_per_task, task="narrator", outfile=pairs_dir / "narrator.jsonl",
system_prompt=system_prompt,
)
all_counts["narrator"] = len(narrator_pairs)

# Explainer
print(f"\n{'='*60}\nEXPLAINER — generating {pairs_per_task} pairs\n{'='*60}")
explainer_pairs = generate_explainer_pairs(
synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task)
explainer_pairs = generate_pairs(
config=config, seed_data=synthetic_pairings,
count=pairs_per_task, outfile=pairs_dir / "explainer.jsonl",
count=pairs_per_task, task="explainer", outfile=pairs_dir / "explainer.jsonl",
system_prompt=system_prompt,
)
all_counts["explainer"] = len(explainer_pairs)

# Summarizer
print(f"\n{'='*60}\nSUMMARIZER — generating {pairs_per_task} pairs\n{'='*60}")
summarizer_pairs = generate_summarizer_pairs(
synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task)
summarizer_pairs = generate_pairs(
config=config, seed_data=synthetic_results,
count=pairs_per_task, outfile=pairs_dir / "summarizer.jsonl",
count=pairs_per_task, task="summarizer", outfile=pairs_dir / "summarizer.jsonl",
system_prompt=system_prompt,
)
all_counts["summarizer"] = len(summarizer_pairs)

print(f"\n{'='*60}\nDISTILLATION COMPLETE\n{'='*60}")
print(f" Explainer: {len(explainer_pairs)} pairs")
print(f" Summarizer: {len(summarizer_pairs)} pairs")
for task_name, count in all_counts.items():
print(f" {task_name.capitalize()}: {count} pairs")
_print_cost_summary()


Expand Down
Loading