From ac42cfb4fce780b6b536d3baab092bd6dd4ea1ed Mon Sep 17 00:00:00 2001 From: William Hill Date: Thu, 2 Apr 2026 22:23:52 -0400 Subject: [PATCH 1/4] feat(#97): add SHAP narrator task type to training pipeline Add narrator as a third task alongside explainer and summarizer. The narrator takes per-student SHAP values + profile and generates advisor-facing narratives grounded in ML feature attribution. - prompts.py: NARRATOR_SCHEMA, NARRATOR_STUDENT_SYSTEM, build_narrator_prompt() - seed.py: generate_synthetic_student_profiles() with SHAP data - distill.py: narrator in _TASK_CONFIG, included in main() distillation loop - eval.py: _NARRATOR_REQUIRED_KEYS, shap_grounding ship criterion (>= 80%), check_shap_grounding() metric (counts feature name mentions in narrative) - prepare.py: narrator added to task iteration --- training/distill.py | 41 +++++++++++++---- training/eval.py | 77 ++++++++++++++++++++++++++++--- training/prepare.py | 2 +- training/prompts.py | 71 +++++++++++++++++++++++++++++ training/seed.py | 108 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 282 insertions(+), 17 deletions(-) diff --git a/training/distill.py b/training/distill.py index bd8f80c..9cb6564 100644 --- a/training/distill.py +++ b/training/distill.py @@ -21,8 +21,10 @@ from training.config import get_training_data_dir, load_school_config, write_jsonl from training.prompts import ( EXPLAINER_STUDENT_SYSTEM, + NARRATOR_STUDENT_SYSTEM, SUMMARIZER_STUDENT_SYSTEM, build_explainer_prompt, + build_narrator_prompt, build_summarizer_prompt, build_system_prompt, ) @@ -30,6 +32,7 @@ format_as_chatml, generate_synthetic_course_pairings, generate_synthetic_query_results, + generate_synthetic_student_profiles, ) # Cost tracking @@ -136,6 +139,11 @@ def call_teacher(system: str, user: str, backend: str, model: str) -> str: _FLUSH_INTERVAL = 25 _TASK_CONFIG = { + "narrator": { + "prompt_builder": build_narrator_prompt, + "student_system": NARRATOR_STUDENT_SYSTEM, + "format_user": lambda config, data: json.dumps(data, ensure_ascii=False, default=str), + }, "explainer": { "prompt_builder": build_explainer_prompt, "student_system": EXPLAINER_STUDENT_SYSTEM, @@ -245,28 +253,43 @@ def main(school: str, local: bool = False) -> None: data_dir = get_training_data_dir(school) pairs_dir = data_dir / "pairs" - synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task) - synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task) - system_prompt = build_system_prompt(config) + all_counts: dict[str, int] = {} + + # Narrator + print(f"\n{'='*60}\nNARRATOR — generating {pairs_per_task} pairs\n{'='*60}") + synthetic_profiles = generate_synthetic_student_profiles(config, count=pairs_per_task) + narrator_pairs = generate_pairs( + config=config, seed_data=synthetic_profiles, + count=pairs_per_task, task="narrator", outfile=pairs_dir / "narrator.jsonl", + system_prompt=system_prompt, + ) + all_counts["narrator"] = len(narrator_pairs) + + # Explainer print(f"\n{'='*60}\nEXPLAINER — generating {pairs_per_task} pairs\n{'='*60}") - explainer_pairs = generate_explainer_pairs( + synthetic_pairings = generate_synthetic_course_pairings(config, count=pairs_per_task) + explainer_pairs = generate_pairs( config=config, seed_data=synthetic_pairings, - count=pairs_per_task, outfile=pairs_dir / "explainer.jsonl", + count=pairs_per_task, task="explainer", outfile=pairs_dir / "explainer.jsonl", system_prompt=system_prompt, ) + all_counts["explainer"] = len(explainer_pairs) + # Summarizer print(f"\n{'='*60}\nSUMMARIZER — generating {pairs_per_task} pairs\n{'='*60}") - summarizer_pairs = generate_summarizer_pairs( + synthetic_results = generate_synthetic_query_results(config, count=pairs_per_task) + summarizer_pairs = generate_pairs( config=config, seed_data=synthetic_results, - count=pairs_per_task, outfile=pairs_dir / "summarizer.jsonl", + count=pairs_per_task, task="summarizer", outfile=pairs_dir / "summarizer.jsonl", system_prompt=system_prompt, ) + all_counts["summarizer"] = len(summarizer_pairs) print(f"\n{'='*60}\nDISTILLATION COMPLETE\n{'='*60}") - print(f" Explainer: {len(explainer_pairs)} pairs") - print(f" Summarizer: {len(summarizer_pairs)} pairs") + for task_name, count in all_counts.items(): + print(f" {task_name.capitalize()}: {count} pairs") _print_cost_summary() diff --git a/training/eval.py b/training/eval.py index 23bffb8..3739129 100644 --- a/training/eval.py +++ b/training/eval.py @@ -31,6 +31,13 @@ "related_intervention", } +_NARRATOR_REQUIRED_KEYS: set[str] = { + "narrative", + "key_drivers", + "recommended_actions", + "data_limitations", +} + _SUMMARIZER_REQUIRED_KEYS: set[str] = { "summary", "key_insights", @@ -44,6 +51,12 @@ # --------------------------------------------------------------------------- SHIP_CRITERIA: dict[str, dict[str, float]] = { + "narrator": { + "json_validity": 0.95, + "schema_adherence": 0.90, + "shap_grounding": 0.80, + "caveat_inclusion": 0.85, + }, "explainer": { "json_validity": 0.95, "schema_adherence": 0.90, @@ -120,9 +133,11 @@ def check_schema_adherence(outputs: list[str], task: str) -> float: """Fraction of valid JSON outputs that contain all required keys.""" if not outputs: return 0.0 - required = ( - _EXPLAINER_REQUIRED_KEYS if task == "explainer" else _SUMMARIZER_REQUIRED_KEYS - ) + required = { + "narrator": _NARRATOR_REQUIRED_KEYS, + "explainer": _EXPLAINER_REQUIRED_KEYS, + "summarizer": _SUMMARIZER_REQUIRED_KEYS, + }.get(task, _SUMMARIZER_REQUIRED_KEYS) passing = 0 total = 0 for text in outputs: @@ -147,7 +162,7 @@ def check_caveat_inclusion(outputs: list[str], task: str) -> float: """ if not outputs: return 0.0 - caveat_key = "data_limitations" if task == "explainer" else "caveats" + caveat_key = "caveats" if task == "summarizer" else "data_limitations" passing = 0 total = 0 for text in outputs: @@ -169,6 +184,51 @@ def check_caveat_inclusion(outputs: list[str], task: str) -> float: return passing / total if total else 0.0 +def check_shap_grounding(outputs: list[str], inputs: list[dict[str, Any]], min_features: int = 2) -> float: + """Fraction of narrator outputs that mention at least `min_features` of the top-3 SHAP features. + + Extracts feature names from the input's SHAP data and checks whether the + narrative text references them (case-insensitive, underscore-tolerant). + """ + if not outputs: + return 0.0 + passing = 0 + total = 0 + for output_text, input_data in zip(outputs, inputs): + total += 1 + # Collect top SHAP feature names from all models in the input + shap_data = input_data.get("shap", {}) + top_features: list[str] = [] + for model_attrs in shap_data.values(): + for entry in model_attrs.get("top_positive", [])[:3]: + top_features.append(entry["feature"]) + for entry in model_attrs.get("top_negative", [])[:3]: + top_features.append(entry["feature"]) + # Deduplicate while preserving order + seen = set() + unique_features = [] + for f in top_features: + if f not in seen: + seen.add(f) + unique_features.append(f) + top_features = unique_features[:6] # top 3 per direction, deduplicated + + if not top_features: + passing += 1 # no SHAP data to ground against + continue + + # Check how many features appear in the output (case-insensitive, underscores → spaces) + output_lower = output_text.lower().replace("_", " ") + mentioned = sum( + 1 for f in top_features + if f.lower().replace("_", " ") in output_lower + ) + if mentioned >= min_features: + passing += 1 + + return passing / total if total else 0.0 + + def check_factual_grounding(outputs: list[str], inputs: list[dict[str, Any]]) -> float: """Fraction of outputs that contain numeric values referenced in their input. @@ -314,8 +374,11 @@ def run_eval(school: str, task: str) -> ShipDecision: "json_validity": check_json_validity(outputs), "schema_adherence": check_schema_adherence(outputs, task), "caveat_inclusion": check_caveat_inclusion(outputs, task), - "factual_grounding": check_factual_grounding(outputs, inputs), } + if task == "narrator": + metrics["shap_grounding"] = check_shap_grounding(outputs, inputs) + else: + metrics["factual_grounding"] = check_factual_grounding(outputs, inputs) print(f"\n[eval] Results for {school}/{task}:") for k, v in metrics.items(): @@ -337,13 +400,13 @@ def main() -> None: parser.add_argument("--school", required=True, help="School directory name (e.g. bishop-state)") parser.add_argument( "--task", - choices=["explainer", "summarizer"], + choices=["narrator", "explainer", "summarizer"], default=None, help="Task to evaluate (default: both)", ) args = parser.parse_args() - tasks = [args.task] if args.task else ["explainer", "summarizer"] + tasks = [args.task] if args.task else ["narrator", "explainer", "summarizer"] results: dict[str, ShipDecision] = {} for task in tasks: print(f"\n{'='*60}\nEVAL: {task.upper()}\n{'='*60}") diff --git a/training/prepare.py b/training/prepare.py index 78e1eb9..c1c41af 100644 --- a/training/prepare.py +++ b/training/prepare.py @@ -133,7 +133,7 @@ def process_task(school: str, task: str) -> dict[str, int]: def main(school: str) -> None: """Run preparation for all tasks.""" - for task in ("explainer", "summarizer"): + for task in ("narrator", "explainer", "summarizer"): try: process_task(school, task) except FileNotFoundError as e: diff --git a/training/prompts.py b/training/prompts.py index 47e7716..3b0a786 100644 --- a/training/prompts.py +++ b/training/prompts.py @@ -26,6 +26,13 @@ "caveats": ["data limitations relevant to this specific query"], } +NARRATOR_SCHEMA = { + "narrative": "2-3 sentence explanation grounded in SHAP feature attribution", + "key_drivers": ["ranked list of factors with direction and magnitude"], + "recommended_actions": ["3-5 specific, actionable interventions"], + "data_limitations": ["caveats about the prediction"], +} + EXPLAINER_STUDENT_SYSTEM = ( "You are a student success analyst. Given course pairing data, generate a " "structured JSON explanation. Include: explanation, structural_factors, " @@ -33,6 +40,14 @@ "related_intervention. Respond with ONLY valid JSON." ) +NARRATOR_STUDENT_SYSTEM = ( + "You are a student success analyst. Given a student profile with ML prediction " + "attribution (SHAP values), generate a structured JSON explanation. Include: " + "narrative, key_drivers, recommended_actions, and data_limitations. " + "Ground your narrative in the SHAP values — cite specific features by name " + "and magnitude. Respond with ONLY valid JSON." +) + SUMMARIZER_STUDENT_SYSTEM = ( "You are a student success analyst. Given a query and its results, generate " "a structured JSON summary. Include: summary, key_insights, context, " @@ -195,6 +210,62 @@ def build_system_prompt(config: dict[str, Any]) -> str: return "\n\n".join(sections) +def build_narrator_prompt( + config: dict[str, Any], + student_data: dict[str, Any], +) -> str: + """Build the teacher prompt for generating a SHAP-grounded student narrative.""" + schema_str = json.dumps(NARRATOR_SCHEMA, indent=2) + profile = student_data.get("student_profile", {}) + shap_data = student_data.get("shap", {}) + risk_factors = student_data.get("risk_factors", []) + readiness_score = student_data.get("readiness_score", "N/A") + readiness_level = student_data.get("readiness_level", "unknown") + + # Format SHAP attribution section + shap_lines = [] + for model_name, attrs in shap_data.items(): + shap_lines.append(f"\n {model_name} model (base prediction: {attrs.get('base_value', 'N/A')}):") + for f in attrs.get("top_positive", []): + shap_lines.append(f" + {f['feature']} = {f['value']} (pushes prediction UP by {f['shap_value']})") + for f in attrs.get("top_negative", []): + shap_lines.append(f" - {f['feature']} = {f['value']} (pushes prediction DOWN by {abs(f['shap_value'])})") + + profile_str = json.dumps(profile, indent=2, default=str) + risk_str = "\n".join(f"- {r}" for r in risk_factors) if risk_factors else "None identified" + + interventions = config.get("school", {}).get("interventions", {}).get("active", []) + intervention_lines = [] + for i in interventions: + intervention_lines.append(f"- {i['name']} ({i['type']}): {i.get('effectiveness', 'unknown')}") + interventions_str = "\n".join(intervention_lines) if intervention_lines else "None listed" + + return f"""A student at this institution has a readiness score of {readiness_score} ({readiness_level}). +Analyze their ML prediction factors and write an advisor-facing explanation. + +STUDENT PROFILE: +{profile_str} + +RISK FACTORS (rule-engine identified): +{risk_str} + +ML MODEL FEATURE ATTRIBUTION (SHAP values — what drives each prediction): +{''.join(shap_lines) if shap_lines else 'No SHAP data available'} + +AVAILABLE INTERVENTIONS: +{interventions_str} + +Generate a JSON response with this exact schema: +{schema_str} + +Guidelines: +- Ground the narrative in SHAP values. Cite at least 2 of the top contributing features by name and magnitude. +- Explain in plain language what each factor means for this student's likelihood of success. +- Make recommended actions specific to this institution — reference active interventions by name when relevant. +- Include at least one data limitation or caveat about the prediction. +- Do NOT speculate beyond what the SHAP values and profile data show.""" + + def build_explainer_prompt( config: dict[str, Any], course_data: dict[str, Any], diff --git a/training/seed.py b/training/seed.py index e2d3b66..5a6bbba 100644 --- a/training/seed.py +++ b/training/seed.py @@ -52,6 +52,114 @@ _RACES = ["Black", "White", "Hispanic", "Asian", "Two or More", "Unknown"] +_ENROLLMENT_INTENSITIES = ["Full-Time", "Part-Time"] +_MATH_PLACEMENTS = ["C", "R", "N"] +_ALERT_LEVELS = ["LOW", "MODERATE", "HIGH", "URGENT"] +_READINESS_LEVELS = ["high", "medium", "low"] +_FEATURE_NAMES_RETENTION = [ + "GPA_Group_Year_1", "course_completion_rate", "CompletedGatewayMathYear1", + "CompletedGatewayEnglishYear1", "Enrollment_Intensity_First_Term", + "total_credits_attempted", "Math_Placement", "Pell_Status_First_Year", + "Student_Age", "Number_of_Credits_Earned_Year_1", +] + + +def generate_synthetic_student_profiles( + config: dict[str, Any], + count: int, +) -> list[dict[str, Any]]: + """Generate synthetic student profiles with SHAP data for narrator training.""" + if count == 0: + return [] + results = [] + for _ in range(count): + gpa = round(random.uniform(0.5, 4.0), 1) + completion_rate = round(random.uniform(0.3, 1.0), 2) + retention_prob = round(random.uniform(0.1, 0.9), 2) + readiness_score = round(random.uniform(0.15, 0.85), 2) + intensity = random.choice(_ENROLLMENT_INTENSITIES) + math_placement = random.choice(_MATH_PLACEMENTS) + gateway_math = random.choice([True, False]) + gateway_english = random.choice([True, False]) + credits_earned = random.randint(3, 36) + alert = random.choice(_ALERT_LEVELS) + + if readiness_score >= 0.65: + readiness_level = "high" + elif readiness_score >= 0.40: + readiness_level = "medium" + else: + readiness_level = "low" + + # Build risk factors based on profile + risk_factors = [] + if gpa < 2.0: + risk_factors.append(f"Low first-year GPA ({gpa} / 4.0)") + if not gateway_math: + risk_factors.append("Gateway math not completed in Year 1") + if not gateway_english: + risk_factors.append("Gateway English not completed in Year 1") + if intensity == "Part-Time": + risk_factors.append("Part-time enrollment reduces success probability") + if credits_earned < 12: + risk_factors.append(f"Below 12-credit Year 1 milestone ({credits_earned} credits earned)") + if alert in ("URGENT", "HIGH"): + risk_factors.append(f"Retention model flags as {alert.capitalize()} risk") + + # Generate synthetic SHAP values + features = random.sample(_FEATURE_NAMES_RETENTION, min(8, len(_FEATURE_NAMES_RETENTION))) + shap_values = [round(random.uniform(-0.25, 0.25), 4) for _ in features] + feature_values = { + "GPA_Group_Year_1": gpa, + "course_completion_rate": completion_rate, + "CompletedGatewayMathYear1": 1.0 if gateway_math else 0.0, + "CompletedGatewayEnglishYear1": 1.0 if gateway_english else 0.0, + "Enrollment_Intensity_First_Term": 1.0 if intensity == "Full-Time" else 0.0, + "total_credits_attempted": float(credits_earned + random.randint(0, 6)), + "Math_Placement": {"C": 2.0, "R": 1.0, "N": 0.0}[math_placement], + "Pell_Status_First_Year": float(random.randint(0, 1)), + "Student_Age": float(random.randint(18, 45)), + "Number_of_Credits_Earned_Year_1": float(credits_earned), + } + + top_positive = sorted( + [{"feature": f, "shap_value": sv, "value": feature_values.get(f, 0.0)} + for f, sv in zip(features, shap_values) if sv > 0], + key=lambda x: x["shap_value"], reverse=True, + )[:5] + + top_negative = sorted( + [{"feature": f, "shap_value": sv, "value": feature_values.get(f, 0.0)} + for f, sv in zip(features, shap_values) if sv < 0], + key=lambda x: x["shap_value"], + )[:5] + + results.append({ + "student_profile": { + "enrollment_intensity": intensity, + "gpa_year1": gpa, + "math_placement": math_placement, + "course_completion_rate": completion_rate, + "gateway_math_completed": gateway_math, + "gateway_english_completed": gateway_english, + "credits_earned_y1": credits_earned, + "at_risk_alert": alert, + "retention_probability": retention_prob, + }, + "readiness_score": readiness_score, + "readiness_level": readiness_level, + "risk_factors": risk_factors, + "shap": { + "retention": { + "base_value": round(random.uniform(0.4, 0.6), 4), + "top_positive": top_positive, + "top_negative": top_negative, + }, + }, + }) + return results + + def load_seed_queries(school: str) -> dict[str, list[dict]]: """Load seed queries from a school's seed_queries.yaml.""" seed_path = get_school_dir(school) / "seed_queries.yaml" From e2911ca8c5af43dbeeffde90728cee96dfa0b7f0 Mon Sep 17 00:00:00 2001 From: William Hill Date: Thu, 2 Apr 2026 22:30:42 -0400 Subject: [PATCH 2/4] refactor: simplify narrator integration per code review - Remove dead generate_explainer_pairs/generate_summarizer_pairs wrappers - Derive _REQUIRED_KEYS sets from schema dicts in prompts.py - Simplify dedup loop in check_shap_grounding with dict.fromkeys() - Add narrator key to load_seed_queries() return - Update test_distill.py to use generate_pairs() directly - Fix stale docstring in generate_pairs() --- tests/training/test_distill.py | 12 ++++++----- training/distill.py | 20 +----------------- training/eval.py | 38 ++++++---------------------------- training/seed.py | 1 + 4 files changed, 15 insertions(+), 56 deletions(-) diff --git a/tests/training/test_distill.py b/tests/training/test_distill.py index ccaba2f..0dcfeda 100644 --- a/tests/training/test_distill.py +++ b/tests/training/test_distill.py @@ -7,8 +7,7 @@ from training.distill import ( validate_json, call_teacher, - generate_explainer_pairs, - generate_summarizer_pairs, + generate_pairs, ) @@ -80,10 +79,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours }) with patch("training.distill.call_teacher", return_value=mock_response): - pairs = generate_explainer_pairs( + pairs = generate_pairs( config=sample_school_config, seed_data=[sample_course_pairing_data], count=2, + task="explainer", ) assert len(pairs) == 2 @@ -92,10 +92,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_cours def test_skips_invalid_responses(self, sample_school_config, sample_course_pairing_data): with patch("training.distill.call_teacher", return_value="not json"): - pairs = generate_explainer_pairs( + pairs = generate_pairs( config=sample_school_config, seed_data=[sample_course_pairing_data], count=3, + task="explainer", ) assert len(pairs) == 0 @@ -112,10 +113,11 @@ def test_generates_pairs_from_seed_data(self, sample_school_config, sample_query }) with patch("training.distill.call_teacher", return_value=mock_response): - pairs = generate_summarizer_pairs( + pairs = generate_pairs( config=sample_school_config, seed_data=[sample_query_result_data], count=2, + task="summarizer", ) assert len(pairs) == 2 diff --git a/training/distill.py b/training/distill.py index 9cb6564..97ad679 100644 --- a/training/distill.py +++ b/training/distill.py @@ -171,7 +171,7 @@ def generate_pairs( config: Parsed school config. seed_data: List of seed data dicts. count: Number of pairs to generate. - task: "explainer" or "summarizer". + task: "narrator", "explainer", or "summarizer". outfile: If provided, pairs are written incrementally. system_prompt: Pre-built system prompt (avoids recomputation). """ @@ -221,24 +221,6 @@ def generate_pairs( return pairs -def generate_explainer_pairs( - config: dict[str, Any], seed_data: list[dict[str, Any]], - count: int, outfile: Path | None = None, - system_prompt: str | None = None, -) -> list[dict]: - """Generate explainer training pairs via teacher model distillation.""" - return generate_pairs(config, seed_data, count, "explainer", outfile, system_prompt) - - -def generate_summarizer_pairs( - config: dict[str, Any], seed_data: list[dict[str, Any]], - count: int, outfile: Path | None = None, - system_prompt: str | None = None, -) -> list[dict]: - """Generate summarizer training pairs via teacher model distillation.""" - return generate_pairs(config, seed_data, count, "summarizer", outfile, system_prompt) - - def main(school: str, local: bool = False) -> None: """Run distillation for a school.""" config = load_school_config(school) diff --git a/training/eval.py b/training/eval.py index 3739129..05706ad 100644 --- a/training/eval.py +++ b/training/eval.py @@ -17,34 +17,15 @@ from typing import Any from training.config import get_message_content, get_training_data_dir, read_jsonl +from training.prompts import EXPLAINER_SCHEMA, NARRATOR_SCHEMA, SUMMARIZER_SCHEMA # --------------------------------------------------------------------------- -# Required keys per task +# Required keys per task — derived from schema definitions in prompts.py # --------------------------------------------------------------------------- -_EXPLAINER_REQUIRED_KEYS: set[str] = { - "explanation", - "structural_factors", - "student_impact", - "advisor_recommendation", - "data_limitations", - "related_intervention", -} - -_NARRATOR_REQUIRED_KEYS: set[str] = { - "narrative", - "key_drivers", - "recommended_actions", - "data_limitations", -} - -_SUMMARIZER_REQUIRED_KEYS: set[str] = { - "summary", - "key_insights", - "context", - "action_items", - "caveats", -} +_EXPLAINER_REQUIRED_KEYS: set[str] = set(EXPLAINER_SCHEMA.keys()) +_NARRATOR_REQUIRED_KEYS: set[str] = set(NARRATOR_SCHEMA.keys()) +_SUMMARIZER_REQUIRED_KEYS: set[str] = set(SUMMARIZER_SCHEMA.keys()) # --------------------------------------------------------------------------- # Ship criteria — minimum thresholds per task @@ -204,14 +185,7 @@ def check_shap_grounding(outputs: list[str], inputs: list[dict[str, Any]], min_f top_features.append(entry["feature"]) for entry in model_attrs.get("top_negative", [])[:3]: top_features.append(entry["feature"]) - # Deduplicate while preserving order - seen = set() - unique_features = [] - for f in top_features: - if f not in seen: - seen.add(f) - unique_features.append(f) - top_features = unique_features[:6] # top 3 per direction, deduplicated + top_features = list(dict.fromkeys(top_features))[:6] if not top_features: passing += 1 # no SHAP data to ground against diff --git a/training/seed.py b/training/seed.py index 5a6bbba..ae5880c 100644 --- a/training/seed.py +++ b/training/seed.py @@ -168,6 +168,7 @@ def load_seed_queries(school: str) -> dict[str, list[dict]]: with seed_path.open("r", encoding="utf-8") as fh: data = yaml.safe_load(fh) or {} return { + "narrator": data.get("narrator", []), "explainer": data.get("explainer", []), "summarizer": data.get("summarizer", []), } From 88f5bbaa60cb3c80939b7c486079d340e73d9225 Mon Sep 17 00:00:00 2001 From: William Hill Date: Thu, 2 Apr 2026 23:34:18 -0400 Subject: [PATCH 3/4] fix: address CodeRabbit review findings - check_ship_criteria() now iterates required criteria first; missing metrics are blocking failures instead of silently passing - load_seed_queries() missing-file fallback includes narrator key - Add .superpowers/ to .gitignore (runtime state files) - Update test expectation for narrator in seed query fallback --- .gitignore | 3 +++ tests/training/test_seed.py | 2 +- training/eval.py | 29 +++++++++++++++++------------ training/seed.py | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index dc908e6..da3db69 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,6 @@ data/test_uploads/ # Training pipeline artifacts training_data/ + +# Superpowers runtime state +.superpowers/ diff --git a/tests/training/test_seed.py b/tests/training/test_seed.py index 77b0a3e..f72d240 100644 --- a/tests/training/test_seed.py +++ b/tests/training/test_seed.py @@ -35,7 +35,7 @@ def test_loads_valid_yaml(self, tmp_path): def test_returns_empty_on_missing_file(self, tmp_path): with patch("training.seed.get_school_dir", return_value=tmp_path): result = load_seed_queries("test-school") - assert result == {"explainer": [], "summarizer": []} + assert result == {"narrator": [], "explainer": [], "summarizer": []} class TestGenerateSyntheticCoursePairings: diff --git a/training/eval.py b/training/eval.py index 05706ad..5343060 100644 --- a/training/eval.py +++ b/training/eval.py @@ -241,19 +241,24 @@ def check_ship_criteria(metrics: dict[str, float], task: str) -> ShipDecision: blocking_failures: list[CriterionFailure] = [] warnings: list[str] = [] + # Check all required criteria — missing metrics are blocking failures + for metric, threshold in criteria.items(): + value = metrics.get(metric) + if value is None: + blocking_failures.append( + CriterionFailure(metric=metric, threshold=threshold, actual=0.0) + ) + elif value < threshold: + blocking_failures.append( + CriterionFailure(metric=metric, threshold=threshold, actual=value) + ) + + # Check informational metrics (present in metrics but not in criteria) for metric, value in metrics.items(): - threshold = criteria.get(metric) - if threshold is not None: - if value < threshold: - blocking_failures.append( - CriterionFailure(metric=metric, threshold=threshold, actual=value) - ) - else: - # Informational metric — warn if very low - if value < 0.5: - warnings.append( - f"{metric} is low ({value:.3f}) — consider improving before deploying" - ) + if metric not in criteria and value < 0.5: + warnings.append( + f"{metric} is low ({value:.3f}) — consider improving before deploying" + ) if blocking_failures: decision = "no_ship" diff --git a/training/seed.py b/training/seed.py index ae5880c..82f8a03 100644 --- a/training/seed.py +++ b/training/seed.py @@ -164,7 +164,7 @@ def load_seed_queries(school: str) -> dict[str, list[dict]]: """Load seed queries from a school's seed_queries.yaml.""" seed_path = get_school_dir(school) / "seed_queries.yaml" if not seed_path.exists(): - return {"explainer": [], "summarizer": []} + return {"narrator": [], "explainer": [], "summarizer": []} with seed_path.open("r", encoding="utf-8") as fh: data = yaml.safe_load(fh) or {} return { From 42c8f4ac7a0576181024706f396b396bf1b67027 Mon Sep 17 00:00:00 2001 From: William Hill Date: Fri, 3 Apr 2026 10:31:31 -0400 Subject: [PATCH 4/4] =?UTF-8?q?docs:=20update=20design=20spec=20=E2=80=94?= =?UTF-8?q?=20Qwen=203.5-4B=20vs=20Gemma=204=20E4B=20evaluation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace single-family model comparison (Qwen 4B vs 9B) with cross-family comparison (Qwen 3.5-4B vs Gemma 4 E4B). Gemma 4 E4B has native structured JSON output and 128K context; Qwen 3.5-4B is the proven D4BL baseline. Key changes: - Add model selection section with head-to-head comparison table - Update notebook config from MODEL_SIZES to MODELS list - Switch from QLoRA to bf16 LoRA (Unsloth discourages QLoRA on Qwen 3.5) - Update Ollama naming, env vars, cost estimates, success criteria --- ...ne-tuning-student-explainability-design.md | 75 ++++++++++++++----- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md b/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md index dd81ea1..c54d09d 100644 --- a/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md +++ b/docs/superpowers/specs/2026-04-02-fine-tuning-student-explainability-design.md @@ -1,6 +1,6 @@ # Design Spec: Fine-Tuning for Student Explainability -**Date:** 2026-04-02 +**Date:** 2026-04-02 (updated 2026-04-03) **Epic label:** `fine-tuning: student-explainability` **Epic branch:** `fine-tuning/student-explainability` **Status:** Draft @@ -9,7 +9,7 @@ ## 1. Goal -Fine-tune a small language model (Qwen 3.5) on Bishop State domain data to replace GPT-4o-mini for three inference tasks in the dashboard. The primary value is improved explainability: advisors get SHAP-grounded, institution-aware narratives instead of templated rule-engine output. Secondary benefits include FERPA compliance (all inference on-premises), offline deployment, and institutional scalability. +Fine-tune a small language model on Bishop State domain data to replace GPT-4o-mini for three inference tasks in the dashboard. Two candidate models will be evaluated head-to-head: **Qwen 3.5-4B** (proven by D4BL) and **Gemma 4 E4B** (native structured JSON output). The primary value is improved explainability: advisors get SHAP-grounded, institution-aware narratives instead of templated rule-engine output. Secondary benefits include FERPA compliance (all inference on-premises), offline deployment, and institutional scalability. ### Tasks to Fine-Tune @@ -24,6 +24,34 @@ Fine-tune a small language model (Qwen 3.5) on Bishop State domain data to repla - Query Analyzer (NL → SQL) — high risk, deferred to future epic - Model serving infrastructure (RunPod, dedicated GPU hosting) — use local Ollama for now +### Model Selection: Qwen 3.5-4B vs Gemma 4 E4B + +Two candidate models will be trained and evaluated. The winner is selected based on ship criteria metrics. + +| | **Qwen 3.5-4B** | **Gemma 4 E4B** | +|---|---|---| +| Effective params | 4B | 4.5B (8B with embeddings) | +| GGUF size (q4_k_m) | ~2.7 GB | ~5 GB | +| Context window | 32K | 128K | +| Native JSON output | No | Yes — built-in structured function calling | +| Ollama support | Text-only (vision mmproj broken) | Full support | +| Unsloth LoRA | Yes, but QLoRA discouraged — use bf16 LoRA | Yes, bf16 LoRA recommended | +| VRAM for training (bf16) | 10 GB | ~10 GB | +| License | Apache 2.0 | Apache 2.0 | +| D4BL proven | Yes (5 experiments, 98.77% schema validity) | No (new model) | + +**Why these two:** +- Qwen 3.5-4B is the known quantity — D4BL proved the full pipeline (distill → train → GGUF → Ollama) with 5 experiments and achieved 98.77% schema validity on structured output tasks. +- Gemma 4 E4B has native structured JSON output before fine-tuning, 128K context (headroom for SHAP-heavy narrator prompts), and full Ollama GGUF support without the mmproj workaround. + +**Why not others:** +- Qwen 3.5-9B: 22 GB VRAM for training, larger GGUF (~5.5 GB), marginal quality gain over 4B for our task complexity. +- Qwen 3-4B: #1 fine-tuning benchmark, but lacks 3.5's architecture improvements. +- Llama 3.2-3B: Best tunability gain, but Meta's community license (700M MAU limit) is restrictive for educational institutions. +- Phi-4-mini: Strong on math, but less proven for structured output and limited Unsloth support. + +**Training note:** Unsloth explicitly discourages QLoRA (4-bit) on Qwen 3.5 due to quantization artifacts. Both models will use bf16 LoRA on A100 (40 GB VRAM is sufficient for either). + ## 2. Prerequisites Before the epic branch is created: @@ -87,7 +115,7 @@ Before the epic branch is created: | 3 | Build Colab training notebook (Unsloth + LoRA) | Single "Run All" notebook, parameterized config, 3-phase training, GGUF export. Replace `training/finetune.py` (MLX) with Unsloth wrapper. | #1 | `type:feature`, `area:ai` | | 4 | Distill training pairs for summarizer and explainer | Run distillation for both existing tasks (~1,500 pairs each via Claude API). Prepare datasets. | #1 | `type:feature`, `area:ai` | | 5 | Distill training pairs for SHAP narrator | Generate ~1,500 SHAP narrator pairs from student data + SHAP values. Requires SHAP data in DB. | #2 | `type:feature`, `area:ai` | -| 6 | Train and evaluate 4B + 9B models | Run Colab notebook for both model sizes. Evaluate via ship criteria. Compare metrics, pick winner. | #3, #4, #5 | `type:spike`, `area:ai` | +| 6 | Train and evaluate Qwen 3.5-4B + Gemma 4 E4B | Run Colab notebook for both models. Evaluate via ship criteria. Compare metrics, pick winner. | #3, #4, #5 | `type:spike`, `area:ai` | | 7 | Export models and wire into dashboard | GGUF export, Ollama registration, wire `model-client.ts` into consumer routes, update `enrich_with_llm` model string. | #6 | `type:feature`, `area:ai`, `area:frontend` | | 8 | Update documentation and feasibility report | Update feasibility report with actual results, update README and CLAUDE.md. | #6 | `type:documentation` | @@ -110,7 +138,10 @@ Issues #2, #3, and #4 can proceed concurrently after #1. Issue #5 waits only on Cell 1: Configuration (ONLY cell the user edits) ------------------------------------------------- SCHOOL = "bishop-state" -MODEL_SIZES = ["4b", "9b"] +MODELS = [ + {"name": "qwen3.5-4b", "hf_id": "Qwen/Qwen3.5-4B"}, + {"name": "gemma4-e4b", "hf_id": "google/gemma-4-e4b-it"}, +] REPO_URL = "https://github.com/codebenders/datathon.git" REPO_BRANCH = "fine-tuning/student-explainability" HF_TOKEN = "" # or userdata.get('HF_TOKEN') @@ -123,9 +154,9 @@ Cell 2+: Fully autonomous - GPU detection + validation (assert A100/T4/L4) - pip install unsloth, trl, peft - Clone repo, load schools/{SCHOOL}/config.yaml -- For each model size: +- For each model in MODELS: - Phase 1: Domain adaptation - - Load base Qwen model via Unsloth (4-bit NF4) + - Load base model via Unsloth (bf16 LoRA — no QLoRA for Qwen 3.5) - Train on training_data/{school}/domain.jsonl - LoRA rank 16, all modules, 1 epoch, lr 2e-4, effective batch 32 - Save merged checkpoint @@ -139,13 +170,13 @@ Cell 2+: Fully autonomous - Phase 3: GGUF export - Quantize each task adapter to q4_k_m - Upload to Google Drive (or HF Hub if HF_TOKEN provided) -- Print comparison table: 4B vs 9B metrics across all tasks +- Print comparison table: Qwen 3.5-4B vs Gemma 4 E4B metrics across all tasks - Recommend winner based on ship criteria ``` ### Training Hyperparameters -Based on D4BL's proven configurations: +Based on D4BL's proven configurations. Both models use bf16 LoRA (not QLoRA) — Unsloth discourages 4-bit QLoRA on Qwen 3.5 due to quantization artifacts, and Gemma 4 also recommends bf16 LoRA. | Parameter | Phase 1 (Domain) | Phase 2 (Tasks) | |-----------|------------------|-----------------| @@ -158,6 +189,7 @@ Based on D4BL's proven configurations: | Max sequence length | 4096 | 4096-8192 | | Optimizer | AdamW 8-bit | AdamW 8-bit | | Precision | bf16 (A100) | bf16 (A100) | +| Quantization during training | None (bf16 LoRA) | None (bf16 LoRA) | ### What the Notebook Does NOT Do @@ -253,12 +285,12 @@ This is the highest-value task — it transforms per-student SHAP attribution da ### Ollama Model Naming ``` -bishop-state-narrator:{size} # SHAP narrator -bishop-state-summarizer:{size} # Query summary -bishop-state-explainer:{size} # Course pairing +bishop-state-narrator:{model} # SHAP narrator +bishop-state-summarizer:{model} # Query summary +bishop-state-explainer:{model} # Course pairing ``` -Where `{size}` is `4b` or `9b` based on evaluation results. +Where `{model}` is the winning model identifier (e.g., `qwen3.5-4b` or `gemma4-e4b`) based on evaluation results. ### SHAP Narrator Integration Point @@ -268,8 +300,8 @@ Where `{size}` is `4b` or `9b` based on evaluation results. # Before (OpenAI) python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model gpt-4o-mini -# After (fine-tuned) -python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollama/bishop-state-narrator:4b +# After (fine-tuned, winner TBD after evaluation) +python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollama/bishop-state-narrator ``` ### Environment Variables @@ -277,7 +309,7 @@ python ai_model/generate_readiness_scores.py --enrich-with-llm --llm-model ollam ```env MODEL_BACKEND=ollama # or "openai" (fallback) OLLAMA_BASE_URL=http://localhost:11434 -MODEL_SIZE=4b # set after evaluation picks winner +MODEL_TAG=qwen3.5-4b # or gemma4-e4b, set after evaluation picks winner SCHOOL_CODE=bishop-state ``` @@ -290,16 +322,19 @@ The operator sets `MODEL_BACKEND` to either `ollama` or `openai`. There is no au | Item | Cost | |------|------| | Claude API distillation (~4,500 pairs across 3 tasks) | $5-10 | -| Colab A100 compute (~4 hours for 2 model sizes) | $8-16 | -| **Total per training run** | **$13-26** | -| Iteration runs (subsequent) | $8-16 each | +| Colab A100 compute (~4-5 hours for 2 models, bf16 LoRA) | $8-20 | +| **Total per training run** | **$13-30** | +| Iteration runs (subsequent) | $8-20 each | + +Note: bf16 LoRA (required for Qwen 3.5, recommended for Gemma 4) uses more VRAM than QLoRA but fits comfortably on A100 40GB. Training time may be slightly longer than D4BL's QLoRA runs. ## 8. Success Criteria The epic is complete when: -1. All three tasks pass ship criteria on the winning model size +1. All three tasks pass ship criteria on the winning model (Qwen 3.5-4B or Gemma 4 E4B) 2. `MODEL_BACKEND=ollama` serves all three tasks in the dashboard without OpenAI 3. SHAP narrator produces grounded narratives that cite specific feature attributions -4. Feasibility report is updated with actual metrics and model selection rationale +4. Feasibility report is updated with actual metrics, model comparison, and selection rationale 5. Colab notebook is documented and reproducible (clone + Run All) +6. Model selection decision is documented with head-to-head metrics for both candidates