diff --git a/.gitignore b/.gitignore index dc908e6..da3db69 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,6 @@ data/test_uploads/ # Training pipeline artifacts training_data/ + +# Superpowers runtime state +.superpowers/ diff --git a/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md b/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md index dd81ea1..c54d09d 100644 --- a/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md +++ b/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md @@ -1,6 +1,6 @@ # Design Spec: Fine-Tuning for Student Explainability -**Date:** 2026-04-02 +**Date:** 2026-04-02 (updated 2026-04-03) **Epic label:** `fine-tuning: student-explainability` **Epic branch:** `fine-tuning/student-explainability` **Status:** Draft @@ -9,7 +9,7 @@ ## 1. Goal -Fine-tune a small language model (Qwen 3.5) on Bishop State domain data to replace GPT-4o-mini for three inference tasks in the dashboard. The primary value is improved explainability: advisors get SHAP-grounded, institution-aware narratives instead of templated rule-engine output. Secondary benefits include FERPA compliance (all inference on-premises), offline deployment, and institutional scalability. +Fine-tune a small language model on Bishop State domain data to replace GPT-4o-mini for three inference tasks in the dashboard. Two candidate models will be evaluated head-to-head: **Qwen 3.5-4B** (proven by D4BL) and **Gemma 4 E4B** (native structured JSON output). The primary value is improved explainability: advisors get SHAP-grounded, institution-aware narratives instead of templated rule-engine output. Secondary benefits include FERPA compliance (all inference on-premises), offline deployment, and institutional scalability. ### Tasks to Fine-Tune @@ -24,6 +24,34 @@ Fine-tune a small language model (Qwen 3.5) on Bishop State domain data to repla - Query Analyzer (NL → SQL) — high risk, deferred to future epic - Model serving infrastructure (RunPod, dedicated GPU hosting) — use local Ollama for now +### Model Selection: Qwen 3.5-4B vs Gemma 4 E4B + +Two candidate models will be trained and evaluated. The winner is selected based on ship criteria metrics. + +| | **Qwen 3.5-4B** | **Gemma 4 E4B** | +|---|---|---| +| Effective params | 4B | 4.5B (8B with embeddings) | +| GGUF size (q4_k_m) | ~2.7 GB | ~5 GB | +| Context window | 32K | 128K | +| Native JSON output | No | Yes — built-in structured function calling | +| Ollama support | Text-only (vision mmproj broken) | Full support | +| Unsloth LoRA | Yes, but QLoRA discouraged — use bf16 LoRA | Yes, bf16 LoRA recommended | +| VRAM for training (bf16) | 10 GB | ~10 GB | +| License | Apache 2.0 | Apache 2.0 | +| D4BL proven | Yes (5 experiments, 98.77% schema validity) | No (new model) | + +**Why these two:** +- Qwen 3.5-4B is the known quantity — D4BL proved the full pipeline (distill → train → GGUF → Ollama) with 5 experiments and achieved 98.77% schema validity on structured output tasks. +- Gemma 4 E4B has native structured JSON output before fine-tuning, 128K context (headroom for SHAP-heavy narrator prompts), and full Ollama GGUF support without the mmproj workaround. + +**Why not others:** +- Qwen 3.5-9B: 22 GB VRAM for training, larger GGUF (~5.5 GB), marginal quality gain over 4B for our task complexity. +- Qwen 3-4B: #1 fine-tuning benchmark, but lacks 3.5's architecture improvements. +- Llama 3.2-3B: Best tunability gain, but Meta's community license (700M MAU limit) is restrictive for educational institutions. +- Phi-4-mini: Strong on math, but less proven for structured output and limited Unsloth support. + +**Training note:** Unsloth explicitly discourages QLoRA (4-bit) on Qwen 3.5 due to quantization artifacts. Both models will use bf16 LoRA on A100 (40 GB VRAM is sufficient for either). + ## 2. Prerequisites Before the epic branch is created: @@ -87,7 +115,7 @@ Before the epic branch is created: | 3 | Build Colab training notebook (Unsloth + LoRA) | Single "Run All" notebook, parameterized config, 3-phase training, GGUF export. Replace `training/finetune.py` (MLX) with Unsloth wrapper. | #1 | `type:feature`, `area:ai` | | 4 | Distill training pairs for summarizer and explainer | Run distillation for both existing tasks (~1,500 pairs each via Claude API). Prepare datasets. | #1 | `type:feature`, `area:ai` | | 5 | Distill training pairs for SHAP narrator | Generate ~1,500 SHAP narrator pairs from student data + SHAP values. Requires SHAP data in DB. | #2 | `type:feature`, `area:ai` | -| 6 | Train and evaluate 4B + 9B models | Run Colab notebook for both model sizes. Evaluate via ship criteria. Compare metrics, pick winner. | #3, #4, #5 | `type:spike`, `area:ai` | +| 6 | Train and evaluate Qwen 3.5-4B + Gemma 4 E4B | Run Colab notebook for both models. Evaluate via ship criteria. Compare metrics, pick winner. | #3, #4, #5 | `type:spike`, `area:ai` | | 7 | Export models and wire into dashboard | GGUF export, Ollama registration, wire `model-client.ts` into consumer routes, update `enrich_with_llm` model string. | #6 | `type:feature`, `area:ai`, `area:frontend` | | 8 | Update documentation and feasibility report | Update feasibility report with actual results, update README and CLAUDE.md. | #6 | `type:documentation` | @@ -110,7 +138,10 @@ Issues #2, #3, and #4 can proceed concurrently after #1. Issue #5 waits only on Cell 1: Configuration (ONLY cell the user edits) ------------------------------------------------- SCHOOL = "bishop-state" -MODEL_SIZES = ["4b", "9b"] +MODELS = [ + {"name": "qwen3.5-4b", "hf_id": "Qwen/Qwen3.5-4B"}, + {"name": "gemma4-e4b", "hf_id": "google/gemma-4-e4b-it"}, +] REPO_URL = "https://github.com/codebenders/datathon.git" REPO_BRANCH = "fine-tuning/student-explainability" HF_TOKEN = "" # or userdata.get('HF_TOKEN') @@ -123,9 +154,9 @@ Cell 2+: Fully autonomous - GPU detection + validation (assert A100/T4/L4) - pip install unsloth, trl, peft - Clone repo, load schools/{SCHOOL}/config.yaml -- For each model size: +- For each model in MODELS: - Phase 1: Domain adaptation - - Load base Qwen model via Unsloth (4-bit NF4) + - Load base model via Unsloth (bf16 LoRA — no QLoRA for Qwen 3.5) - Train on training_data/{school}/domain.jsonl - LoRA rank 16, all modules, 1 epoch, lr 2e-4, effective batch 32 - Save merged checkpoint @@ -139,13 +170,13 @@ Cell 2+: Fully autonomous - Phase 3: GGUF export - Quantize each task adapter to q4_k_m - Upload to Google Drive (or HF Hub if HF_TOKEN provided) -- Print comparison table: 4B vs 9B metrics across all tasks +- Print comparison table: Qwen 3.5-4B vs Gemma 4 E4B metrics across all tasks - Recommend winner based on ship criteria ``` ### Training Hyperparameters -Based on D4BL's proven configurations: +Based on D4BL's proven configurations. Both models use bf16 LoRA (not QLoRA) — Unsloth discourages 4-bit QLoRA on Qwen 3.5 due to quantization artifacts, and Gemma 4 also recommends bf16 LoRA. | Parameter | Phase 1 (Domain) | Phase 2 (Tasks) | |-----------|------------------|-----------------| @@ -158,6 +189,7 @@ Based on D4BL's proven configurations: | Max sequence length | 4096 | 4096-8192 | | Optimizer | AdamW 8-bit | AdamW 8-bit | | Precision | bf16 (A100) | bf16 (A100) | +| Quantization during training | None (bf16 LoRA) | None (bf16 LoRA) | ### What the Notebook Does NOT Do @@ -253,12 +285,12 @@ This is the highest-value task — it transforms per-student SHAP attribution da ### Ollama Model Naming ``` -bishop-state-narrator:{size} # SHAP narrator -bishop-state-summarizer:{size} # Query summary -bishop-state-explainer:{size} # Course pairing +bishop-state-narrator:{model} # SHAP narrator +bishop-state-summarizer:{model} # Query summary +bishop-state-explainer:{model} # Course pairing ``` -Where `{size}` is `4b` or `9b` based on evaluation results. +Where `{model}` is the winning model identifier (e.g., `qwen3.5-4b` or `gemma4-e4b`) based on evaluation results. ### SHAP Narrator Integration Point @@ -268,8 +300,8 @@ Where `{size}` is `4b` or `9b` based on evaluation results. # Before (OpenAI) python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model gpt-4o-mini -# After (fine-tuned) -python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollama/bishop-state-narrator:4b +# After (fine-tuned, winner TBD after evaluation) +python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollama/bishop-state-narrator ``` ### Environment Variables @@ -277,7 +309,7 @@ python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollam ```env MODEL_BACKEND=ollama # or "openai" (fallback) OLLAMA_BASE_URL=http://localhost:11434 -MODEL_SIZE=4b # set after evaluation picks winner +MODEL_TAG=qwen3.5-4b # or gemma4-e4b, set after evaluation picks winner SCHOOL_CODE=bishop-state ``` @@ -290,16 +322,19 @@ The operator sets `MODEL_BACKEND` to either `ollama` or `openai`. There is no au | Item | Cost | |------|------| | Claude API distillation (~4,500 pairs across 3 tasks) | $5-10 | -| Colab A100 compute (~4 hours for 2 model sizes) | $8-16 | -| **Total per training run** | **$13-26** | -| Iteration runs (subsequent) | $8-16 each | +| Colab A100 compute (~4-5 hours for 2 models, bf16 LoRA) | $8-20 | +| **Total per training run** | **$13-30** | +| Iteration runs (subsequent) | $8-20 each | + +Note: bf16 LoRA (required for Qwen 3.5, recommended for Gemma 4) uses more VRAM than QLoRA but fits comfortably on A100 40GB. Training time may be slightly longer than D4BL's QLoRA runs. ## 8. Success Criteria The epic is complete when: -1. All three tasks pass ship criteria on the winning model size +1. All three tasks pass ship criteria on the winning model (Qwen 3.5-4B or Gemma 4 E4B) 2. `MODEL_BACKEND=ollama` serves all three tasks in the dashboard without OpenAI 3. SHAP narrator produces grounded narratives that cite specific feature attributions -4. Feasibility report is updated with actual metrics and model selection rationale +4. Feasibility report is updated with actual metrics, model comparison, and selection rationale 5. Colab notebook is documented and reproducible (clone + Run All) +6. Model selection decision is documented with head-to-head metrics for both candidates diff --git a/tests/training/test_distill.py b/tests/training/test_distill.py index ccaba2f..0dcfeda 100644 --- a/tests/training/test_distill.py +++ b/tests/training/test_distill.py @@ -7,8 +7,7 @@ from training.distill import ( validate_json, call_teacher, - generate_explainer_pairs, - generate_summarizer_pairs, + generate_pairs, ) @@ -80,10 +79,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours }) with patch("training.distill.call_teacher", return_value=mock_response): - pairs = generate_explainer_pairs( + pairs = generate_pairs( config=sample_school_config, seed_data=[sample_course_pairing_data], count=2, + task="explainer", ) assert len(pairs) == 2 @@ -92,10 +92,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours def test_skips_invalid_responses(self, sample_school_config, sample_course_pairing_data): with patch("training.distill.call_teacher", return_value="not json"): - pairs = generate_explainer_pairs( + pairs = generate_pairs( config=sample_school_config, seed_data=[sample_course_pairing_data], count=3, + task="explainer", ) assert len(pairs) == 0 @@ -112,10 +113,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_query }) with patch("training.distill.call_teacher", return_value=mock_response): - pairs = generate_summarizer_pairs( + pairs = generate_pairs( config=sample_school_config, seed_data=[sample_query_result_data], count=2, + task="summarizer", ) assert len(pairs) == 2 diff --git a/tests/training/test_seed.py b/tests/training/test_seed.py index 77b0a3e..f72d240 100644 --- a/tests/training/test_seed.py +++ b/tests/training/test_seed.py @@ -35,7 +35,7 @@ def test_loads_valid_yaml(self, tmp_path): def test_returns_empty_on_missing_file(self, tmp_path): with patch("training.seed.get_school_dir", return_value=tmp_path): result = load_seed_queries("test-school") - assert result == {"explainer": [], "summarizer": []} + assert result == {"narrator": [], "explainer": [], "summarizer": []} class TestGenerateSyntheticCoursePairings: diff --git a/training/distill.py b/training/distill.py index bd8f80c..97ad679 100644 --- a/training/distill.py +++ b/training/distill.py @@ -21,8 +21,10 @@ from training.config import get_training_data_dir, load_school_config, write_jsonl from training.prompts import ( EXPLAINER_STUDENT_SYSTEM, + NARRATOR_STUDENT_SYSTEM, SUMMARIZER_STUDENT_SYSTEM, build_explainer_prompt, + build_narrator_prompt, build_summarizer_prompt, build_system_prompt, ) @@ -30,6 +32,7 @@ format_as_chatml, generate_synthetic_course_pairings, generate_synthetic_query_results, + generate_synthetic_student_profiles, ) # Cost tracking @@ -136,6 +139,11 @@ def call_teacher(system: str, user: str, backend: str, model: str) -> str: _FLUSH_INTERVAL = 25 _TASK_CONFIG = { + "narrator": { + "prompt_builder": build_narrator_prompt, + "student_system": NARRATOR_STUDENT_SYSTEM, + "format_user": lambda config, data: json.dumps(data, ensure_ascii=False, default=str), + }, "explainer": { "prompt_builder": build_explainer_prompt, "student_system": EXPLAINER_STUDENT_SYSTEM, @@ -163,7 +171,7 @@ def generate_pairs( config: Parsed school config. seed_data: List of seed data dicts. count: Number of pairs to generate. - task: "explainer" or "summarizer". + task: "narrator", "explainer", or "summarizer". outfile: If provided, pairs are written incrementally. system_prompt: Pre-built system prompt (avoids recomputation). """ @@ -213,24 +221,6 @@ def generate_pairs( return pairs -def generate_explainer_pairs( - config: dict[str, Any], seed_data: list[dict[str, Any]], - count: int, outfile: Path | None = None, - system_prompt: str | None = None, -) -> list[dict]: - """Generate explainer training pairs via teacher model distillation.""" - return generate_pairs(config, seed_data, count, "explainer", outfile, system_prompt) - - -def generate_summarizer_pairs( - config: dict[str, Any], seed_data: list[dict[str, Any]], - count: int, outfile: Path | None = None, - system_prompt: str | None = None, -) -> list[dict]: - """Generate summarizer training pairs via teacher model distillation.""" - return generate_pairs(config, seed_data, count, "summarizer", outfile, system_prompt) - - def main(school: str, local: bool = False) -> None: """Run distillation for a school.""" config = load_school_config(school) @@ -245,28 +235,43 @@ def main(school: str, local: bool = False) -> None: data_dir = get_training_data_dir(school) pairs_dir = data_dir / "pairs" - synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task) - synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task) - system_prompt = build_system_prompt(config) + all_counts: dict[str, int] = {} + + # Narrator + print(f"\n{'='*60}\nNARRATOR — generating {pairs_per_task} pairs\n{'='*60}") + synthetic_profiles = generate_synthetic_student_profiles(config, count=pairs_per_task) + narrator_pairs = generate_pairs( + config=config, seed_data=synthetic_profiles, + count=pairs_per_task, task="narrator", outfile=pairs_dir / "narrator.jsonl", + system_prompt=system_prompt, + ) + all_counts["narrator"] = len(narrator_pairs) + + # Explainer print(f"\n{'='*60}\nEXPLAINER — generating {pairs_per_task} pairs\n{'='*60}") - explainer_pairs = generate_explainer_pairs( + synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task) + explainer_pairs = generate_pairs( config=config, seed_data=synthetic_pairings, - count=pairs_per_task, outfile=pairs_dir / "explainer.jsonl", + count=pairs_per_task, task="explainer", outfile=pairs_dir / "explainer.jsonl", system_prompt=system_prompt, ) + all_counts["explainer"] = len(explainer_pairs) + # Summarizer print(f"\n{'='*60}\nSUMMARIZER — generating {pairs_per_task} pairs\n{'='*60}") - summarizer_pairs = generate_summarizer_pairs( + synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task) + summarizer_pairs = generate_pairs( config=config, seed_data=synthetic_results, - count=pairs_per_task, outfile=pairs_dir / "summarizer.jsonl", + count=pairs_per_task, task="summarizer", outfile=pairs_dir / "summarizer.jsonl", system_prompt=system_prompt, ) + all_counts["summarizer"] = len(summarizer_pairs) print(f"\n{'='*60}\nDISTILLATION COMPLETE\n{'='*60}") - print(f" Explainer: {len(explainer_pairs)} pairs") - print(f" Summarizer: {len(summarizer_pairs)} pairs") + for task_name, count in all_counts.items(): + print(f" {task_name.capitalize()}: {count} pairs") _print_cost_summary() diff --git a/training/eval.py b/training/eval.py index 23bffb8..5343060 100644 --- a/training/eval.py +++ b/training/eval.py @@ -17,33 +17,27 @@ from typing import Any from training.config import get_message_content, get_training_data_dir, read_jsonl +from training.prompts import EXPLAINER_SCHEMA, NARRATOR_SCHEMA, SUMMARIZER_SCHEMA # --------------------------------------------------------------------------- -# Required keys per task +# Required keys per task — derived from schema definitions in prompts.py # --------------------------------------------------------------------------- -_EXPLAINER_REQUIRED_KEYS: set[str] = { - "explanation", - "structural_factors", - "student_impact", - "advisor_recommendation", - "data_limitations", - "related_intervention", -} - -_SUMMARIZER_REQUIRED_KEYS: set[str] = { - "summary", - "key_insights", - "context", - "action_items", - "caveats", -} +_EXPLAINER_REQUIRED_KEYS: set[str] = set(EXPLAINER_SCHEMA.keys()) +_NARRATOR_REQUIRED_KEYS: set[str] = set(NARRATOR_SCHEMA.keys()) +_SUMMARIZER_REQUIRED_KEYS: set[str] = set(SUMMARIZER_SCHEMA.keys()) # --------------------------------------------------------------------------- # Ship criteria — minimum thresholds per task # --------------------------------------------------------------------------- SHIP_CRITERIA: dict[str, dict[str, float]] = { + "narrator": { + "json_validity": 0.95, + "schema_adherence": 0.90, + "shap_grounding": 0.80, + "caveat_inclusion": 0.85, + }, "explainer": { "json_validity": 0.95, "schema_adherence": 0.90, @@ -120,9 +114,11 @@ def check_schema_adherence(outputs: list[str], task: str) -> float: """Fraction of valid JSON outputs that contain all required keys.""" if not outputs: return 0.0 - required = ( - _EXPLAINER_REQUIRED_KEYS if task == "explainer" else _SUMMARIZER_REQUIRED_KEYS - ) + required = { + "narrator": _NARRATOR_REQUIRED_KEYS, + "explainer": _EXPLAINER_REQUIRED_KEYS, + "summarizer": _SUMMARIZER_REQUIRED_KEYS, + }.get(task, _SUMMARIZER_REQUIRED_KEYS) passing = 0 total = 0 for text in outputs: @@ -147,7 +143,7 @@ def check_caveat_inclusion(outputs: list[str], task: str) -> float: """ if not outputs: return 0.0 - caveat_key = "data_limitations" if task == "explainer" else "caveats" + caveat_key = "caveats" if task == "summarizer" else "data_limitations" passing = 0 total = 0 for text in outputs: @@ -169,6 +165,44 @@ def check_caveat_inclusion(outputs: list[str], task: str) -> float: return passing / total if total else 0.0 +def check_shap_grounding(outputs: list[str], inputs: list[dict[str, Any]], min_features: int = 2) -> float: + """Fraction of narrator outputs that mention at least `min_features` of the top-3 SHAP features. + + Extracts feature names from the input's SHAP data and checks whether the + narrative text references them (case-insensitive, underscore-tolerant). + """ + if not outputs: + return 0.0 + passing = 0 + total = 0 + for output_text, input_data in zip(outputs, inputs): + total += 1 + # Collect top SHAP feature names from all models in the input + shap_data = input_data.get("shap", {}) + top_features: list[str] = [] + for model_attrs in shap_data.values(): + for entry in model_attrs.get("top_positive", [])[:3]: + top_features.append(entry["feature"]) + for entry in model_attrs.get("top_negative", [])[:3]: + top_features.append(entry["feature"]) + top_features = list(dict.fromkeys(top_features))[:6] + + if not top_features: + passing += 1 # no SHAP data to ground against + continue + + # Check how many features appear in the output (case-insensitive, underscores → spaces) + output_lower = output_text.lower().replace("_", " ") + mentioned = sum( + 1 for f in top_features + if f.lower().replace("_", " ") in output_lower + ) + if mentioned >= min_features: + passing += 1 + + return passing / total if total else 0.0 + + def check_factual_grounding(outputs: list[str], inputs: list[dict[str, Any]]) -> float: """Fraction of outputs that contain numeric values referenced in their input. @@ -207,19 +241,24 @@ def check_ship_criteria(metrics: dict[str, float], task: str) -> ShipDecision: blocking_failures: list[CriterionFailure] = [] warnings: list[str] = [] + # Check all required criteria — missing metrics are blocking failures + for metric, threshold in criteria.items(): + value = metrics.get(metric) + if value is None: + blocking_failures.append( + CriterionFailure(metric=metric, threshold=threshold, actual=0.0) + ) + elif value < threshold: + blocking_failures.append( + CriterionFailure(metric=metric, threshold=threshold, actual=value) + ) + + # Check informational metrics (present in metrics but not in criteria) for metric, value in metrics.items(): - threshold = criteria.get(metric) - if threshold is not None: - if value < threshold: - blocking_failures.append( - CriterionFailure(metric=metric, threshold=threshold, actual=value) - ) - else: - # Informational metric — warn if very low - if value < 0.5: - warnings.append( - f"{metric} is low ({value:.3f}) — consider improving before deploying" - ) + if metric not in criteria and value < 0.5: + warnings.append( + f"{metric} is low ({value:.3f}) — consider improving before deploying" + ) if blocking_failures: decision = "no_ship" @@ -314,8 +353,11 @@ def run_eval(school: str, task: str) -> ShipDecision: "json_validity": check_json_validity(outputs), "schema_adherence": check_schema_adherence(outputs, task), "caveat_inclusion": check_caveat_inclusion(outputs, task), - "factual_grounding": check_factual_grounding(outputs, inputs), } + if task == "narrator": + metrics["shap_grounding"] = check_shap_grounding(outputs, inputs) + else: + metrics["factual_grounding"] = check_factual_grounding(outputs, inputs) print(f"\n[eval] Results for {school}/{task}:") for k, v in metrics.items(): @@ -337,13 +379,13 @@ def main() -> None: parser.add_argument("--school", required=True, help="School directory name (e.g. bishop-state)") parser.add_argument( "--task", - choices=["explainer", "summarizer"], + choices=["narrator", "explainer", "summarizer"], default=None, help="Task to evaluate (default: both)", ) args = parser.parse_args() - tasks = [args.task] if args.task else ["explainer", "summarizer"] + tasks = [args.task] if args.task else ["narrator", "explainer", "summarizer"] results: dict[str, ShipDecision] = {} for task in tasks: print(f"\n{'='*60}\nEVAL: {task.upper()}\n{'='*60}") diff --git a/training/prepare.py b/training/prepare.py index 78e1eb9..c1c41af 100644 --- a/training/prepare.py +++ b/training/prepare.py @@ -133,7 +133,7 @@ def process_task(school: str, task: str) -> dict[str, int]: def main(school: str) -> None: """Run preparation for all tasks.""" - for task in ("explainer", "summarizer"): + for task in ("narrator", "explainer", "summarizer"): try: process_task(school, task) except FileNotFoundError as e: diff --git a/training/prompts.py b/training/prompts.py index 47e7716..3b0a786 100644 --- a/training/prompts.py +++ b/training/prompts.py @@ -26,6 +26,13 @@ "caveats": ["data limitations relevant to this specific query"], } +NARRATOR_SCHEMA = { + "narrative": "2-3 sentence explanation grounded in SHAP feature attribution", + "key_drivers": ["ranked list of factors with direction and magnitude"], + "recommended_actions": ["3-5 specific, actionable interventions"], + "data_limitations": ["caveats about the prediction"], +} + EXPLAINER_STUDENT_SYSTEM = ( "You are a student success analyst. Given course pairing data, generate a " "structured JSON explanation. Include: explanation, structural_factors, " @@ -33,6 +40,14 @@ "related_intervention. Respond with ONLY valid JSON." ) +NARRATOR_STUDENT_SYSTEM = ( + "You are a student success analyst. Given a student profile with ML prediction " + "attribution (SHAP values), generate a structured JSON explanation. Include: " + "narrative, key_drivers, recommended_actions, and data_limitations. " + "Ground your narrative in the SHAP values — cite specific features by name " + "and magnitude. Respond with ONLY valid JSON." +) + SUMMARIZER_STUDENT_SYSTEM = ( "You are a student success analyst. Given a query and its results, generate " "a structured JSON summary. Include: summary, key_insights, context, " @@ -195,6 +210,62 @@ def build_system_prompt(config: dict[str, Any]) -> str: return "\n\n".join(sections) +def build_narrator_prompt( + config: dict[str, Any], + student_data: dict[str, Any], +) -> str: + """Build the teacher prompt for generating a SHAP-grounded student narrative.""" + schema_str = json.dumps(NARRATOR_SCHEMA, indent=2) + profile = student_data.get("student_profile", {}) + shap_data = student_data.get("shap", {}) + risk_factors = student_data.get("risk_factors", []) + readiness_score = student_data.get("readiness_score", "N/A") + readiness_level = student_data.get("readiness_level", "unknown") + + # Format SHAP attribution section + shap_lines = [] + for model_name, attrs in shap_data.items(): + shap_lines.append(f"\n {model_name} model (base prediction: {attrs.get('base_value', 'N/A')}):") + for f in attrs.get("top_positive", []): + shap_lines.append(f" + {f['feature']} = {f['value']} (pushes prediction UP by {f['shap_value']})") + for f in attrs.get("top_negative", []): + shap_lines.append(f" - {f['feature']} = {f['value']} (pushes prediction DOWN by {abs(f['shap_value'])})") + + profile_str = json.dumps(profile, indent=2, default=str) + risk_str = "\n".join(f"- {r}" for r in risk_factors) if risk_factors else "None identified" + + interventions = config.get("school", {}).get("interventions", {}).get("active", []) + intervention_lines = [] + for i in interventions: + intervention_lines.append(f"- {i['name']} ({i['type']}): {i.get('effectiveness', 'unknown')}") + interventions_str = "\n".join(intervention_lines) if intervention_lines else "None listed" + + return f"""A student at this institution has a readiness score of {readiness_score} ({readiness_level}). +Analyze their ML prediction factors and write an advisor-facing explanation. + +STUDENT PROFILE: +{profile_str} + +RISK FACTORS (rule-engine identified): +{risk_str} + +ML MODEL FEATURE ATTRIBUTION (SHAP values — what drives each prediction): +{''.join(shap_lines) if shap_lines else 'No SHAP data available'} + +AVAILABLE INTERVENTIONS: +{interventions_str} + +Generate a JSON response with this exact schema: +{schema_str} + +Guidelines: +- Ground the narrative in SHAP values. Cite at least 2 of the top contributing features by name and magnitude. +- Explain in plain language what each factor means for this student's likelihood of success. +- Make recommended actions specific to this institution — reference active interventions by name when relevant. +- Include at least one data limitation or caveat about the prediction. +- Do NOT speculate beyond what the SHAP values and profile data show.""" + + def build_explainer_prompt( config: dict[str, Any], course_data: dict[str, Any], diff --git a/training/seed.py b/training/seed.py index e2d3b66..82f8a03 100644 --- a/training/seed.py +++ b/training/seed.py @@ -52,14 +52,123 @@ _RACES = ["Black", "White", "Hispanic", "Asian", "Two or More", "Unknown"] +_ENROLLMENT_INTENSITIES = ["Full-Time", "Part-Time"] +_MATH_PLACEMENTS = ["C", "R", "N"] +_ALERT_LEVELS = ["LOW", "MODERATE", "HIGH", "URGENT"] +_READINESS_LEVELS = ["high", "medium", "low"] +_FEATURE_NAMES_RETENTION = [ + "GPA_Group_Year_1", "course_completion_rate", "CompletedGatewayMathYear1", + "CompletedGatewayEnglishYear1", "Enrollment_Intensity_First_Term", + "total_credits_attempted", "Math_Placement", "Pell_Status_First_Year", + "Student_Age", "Number_of_Credits_Earned_Year_1", +] + + +def generate_synthetic_student_profiles( + config: dict[str, Any], + count: int, +) -> list[dict[str, Any]]: + """Generate synthetic student profiles with SHAP data for narrator training.""" + if count == 0: + return [] + results = [] + for _ in range(count): + gpa = round(random.uniform(0.5, 4.0), 1) + completion_rate = round(random.uniform(0.3, 1.0), 2) + retention_prob = round(random.uniform(0.1, 0.9), 2) + readiness_score = round(random.uniform(0.15, 0.85), 2) + intensity = random.choice(_ENROLLMENT_INTENSITIES) + math_placement = random.choice(_MATH_PLACEMENTS) + gateway_math = random.choice([True, False]) + gateway_english = random.choice([True, False]) + credits_earned = random.randint(3, 36) + alert = random.choice(_ALERT_LEVELS) + + if readiness_score >= 0.65: + readiness_level = "high" + elif readiness_score >= 0.40: + readiness_level = "medium" + else: + readiness_level = "low" + + # Build risk factors based on profile + risk_factors = [] + if gpa < 2.0: + risk_factors.append(f"Low first-year GPA ({gpa} / 4.0)") + if not gateway_math: + risk_factors.append("Gateway math not completed in Year 1") + if not gateway_english: + risk_factors.append("Gateway English not completed in Year 1") + if intensity == "Part-Time": + risk_factors.append("Part-time enrollment reduces success probability") + if credits_earned < 12: + risk_factors.append(f"Below 12-credit Year 1 milestone ({credits_earned} credits earned)") + if alert in ("URGENT", "HIGH"): + risk_factors.append(f"Retention model flags as {alert.capitalize()} risk") + + # Generate synthetic SHAP values + features = random.sample(_FEATURE_NAMES_RETENTION, min(8, len(_FEATURE_NAMES_RETENTION))) + shap_values = [round(random.uniform(-0.25, 0.25), 4) for _ in features] + feature_values = { + "GPA_Group_Year_1": gpa, + "course_completion_rate": completion_rate, + "CompletedGatewayMathYear1": 1.0 if gateway_math else 0.0, + "CompletedGatewayEnglishYear1": 1.0 if gateway_english else 0.0, + "Enrollment_Intensity_First_Term": 1.0 if intensity == "Full-Time" else 0.0, + "total_credits_attempted": float(credits_earned + random.randint(0, 6)), + "Math_Placement": {"C": 2.0, "R": 1.0, "N": 0.0}[math_placement], + "Pell_Status_First_Year": float(random.randint(0, 1)), + "Student_Age": float(random.randint(18, 45)), + "Number_of_Credits_Earned_Year_1": float(credits_earned), + } + + top_positive = sorted( + [{"feature": f, "shap_value": sv, "value": feature_values.get(f, 0.0)} + for f, sv in zip(features, shap_values) if sv > 0], + key=lambda x: x["shap_value"], reverse=True, + )[:5] + + top_negative = sorted( + [{"feature": f, "shap_value": sv, "value": feature_values.get(f, 0.0)} + for f, sv in zip(features, shap_values) if sv < 0], + key=lambda x: x["shap_value"], + )[:5] + + results.append({ + "student_profile": { + "enrollment_intensity": intensity, + "gpa_year1": gpa, + "math_placement": math_placement, + "course_completion_rate": completion_rate, + "gateway_math_completed": gateway_math, + "gateway_english_completed": gateway_english, + "credits_earned_y1": credits_earned, + "at_risk_alert": alert, + "retention_probability": retention_prob, + }, + "readiness_score": readiness_score, + "readiness_level": readiness_level, + "risk_factors": risk_factors, + "shap": { + "retention": { + "base_value": round(random.uniform(0.4, 0.6), 4), + "top_positive": top_positive, + "top_negative": top_negative, + }, + }, + }) + return results + + def load_seed_queries(school: str) -> dict[str, list[dict]]: """Load seed queries from a school's seed_queries.yaml.""" seed_path = get_school_dir(school) / "seed_queries.yaml" if not seed_path.exists(): - return {"explainer": [], "summarizer": []} + return {"narrator": [], "explainer": [], "summarizer": []} with seed_path.open("r", encoding="utf-8") as fh: data = yaml.safe_load(fh) or {} return { + "narrator": data.get("narrator", []), "explainer": data.get("explainer", []), "summarizer": data.get("summarizer", []), }