diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index f7fbfbd..25f9d2f 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -59,5 +59,12 @@ jobs: nohup uvicorn app.app:app & sleep 10 curl -I http://127.0.0.1:8000 + + # Test batch prediction endpoint + echo "ra,dec,redshift,psfMag_r,u,g,r,i,z" > dummy_test.csv + echo "0.1,0.2,0.3,1.0,2.0,3.0,4.0,5.0,6.0" >> dummy_test.csv + curl -X POST http://127.0.0.1:8000/predict/file -H "accept: application/json" -F "payload=@dummy_test.csv;type=text/csv" + rm dummy_test.csv + pkill -f uvicorn diff --git a/Datasets/sample_generator.py b/Datasets/sample_generator.py new file mode 100644 index 0000000..bc943b3 --- /dev/null +++ b/Datasets/sample_generator.py @@ -0,0 +1,50 @@ +""" +sample_generator.py + +This script is used to generate samples directly from the main dataset. +These samples are used to test the `/predict/file` route. +To run it, enter this in your command line: +``` +python -m Datasets.sample_generator.py +``` +""" + +import pandas as pd +from pathlib import Path + +DATASET_PATH = Path("Datasets","SDSS_DR18.csv") +OUTPUT_PATH = Path("Datasets","samples.csv") + +def preprocess(df: pd.DataFrame) -> pd.DataFrame: + drop_cols = ["objid", "specobjid", "run", "rerun", "camcol", + "field", "plate", "mjd", "fiberid"] + df = df.drop(columns=[c for c in drop_cols if c in df.columns]) + + df["class"] = df["class"].map({"GALAXY": 0, "STAR": 1, "QSO": 2}) + + df = df[["ra", "dec", "redshift","psfMag_r", "u", "g", "r", "i", "z", "class"]].copy() + + return df + + +def stratified_sample(df: pd.DataFrame, total_samples: int, class_col: str = "class", random_state: int = 42) -> pd.DataFrame: + class_counts = df[class_col].value_counts(normalize=True) + class_n = (class_counts * total_samples).round().astype(int) + + diff = total_samples - class_n.sum() + class_n.iloc[0] += diff + + return df.groupby(class_col, group_keys=False).apply( + lambda x: x.sample(n=min(class_n[x.name], len(x)), random_state=random_state) + ) + + +df_raw = pd.read_csv(DATASET_PATH) + +df_processed = preprocess(df_raw) + +sample = stratified_sample(df_processed, total_samples=100) + +sample = sample.drop(columns=["class"]) + +sample.to_csv(OUTPUT_PATH,index=False) \ No newline at end of file diff --git a/app/app.py b/app/app.py index d9569a8..377f312 100644 --- a/app/app.py +++ b/app/app.py @@ -1,17 +1,20 @@ -from fastapi import FastAPI, Depends, Request +from pydantic import BaseModel +from fastapi import FastAPI, Depends, Request, UploadFile from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, HTTPException from typing import Tuple,List from sklearn.pipeline import Pipeline from models.fit import main from .schema.validation import UserInput from pathlib import Path from contextlib import asynccontextmanager -import joblib import numpy as np +import pandas as pd +import joblib import os +# Helper for loading and self-healing the artifacts def load_or_create_models() -> Tuple[Pipeline,np.ndarray]: model_path = Path("models","estimator.pkl") columns_path = Path("models","column_names.pkl") @@ -34,6 +37,36 @@ def load_or_create_models() -> Tuple[Pipeline,np.ndarray]: return pipe,column_names except Exception as e: raise RuntimeError(f"Artifacts could not be loaded: {e}") + +# Helper for performing feature engineering +def preprocess_data(value:BaseModel) -> dict: + # Preprocessing + value:dict = value.model_dump(mode="json") + kick = ["u","g","r","i","z"] + final_value = {key:val for key,val in value.items() if key not in kick} + final_value["u_g_color"] = safe_sub("u","g",value) + final_value["g_r_color"] = safe_sub("g","r",value) + final_value["r_i_color"] = safe_sub("r","i",value) + final_value["i_z_color"] = safe_sub("i","z",value) + + return final_value + +# Helper for validating user-provided csv files +def upload_validator(df:pd.DataFrame,col_names:List[str]) -> pd.DataFrame: + if df.columns.tolist() != col_names: + raise HTTPException( + status_code=422, detail="Uploaded csv file does not match the expected " \ + "columns or their order" + ) + + try: + df = df.astype(float) + except Exception as e: + raise HTTPException( + status_code=422, + detail="All values must be numeric (float-compatible)" + ) + return df @asynccontextmanager async def lifespan(app:FastAPI): @@ -91,15 +124,7 @@ def home(): def prediction_ops(value:UserInput, dep:Tuple[Pipeline,np.ndarray] = Depends(get_model)): pipe, column_names = dep column_names:List[str] = column_names.tolist() - - # Preprocessing - value:dict = value.model_dump(mode="json") - kick = ["u","g","r","i","z"] - final_value = {key:val for key,val in value.items() if key not in kick} - final_value["u_g_color"] = safe_sub("u","g",value) - final_value["g_r_color"] = safe_sub("g","r",value) - final_value["r_i_color"] = safe_sub("r","i",value) - final_value["i_z_color"] = safe_sub("i","z",value) + final_value:dict = preprocess_data(value) # Order Check and running prediction final_res = [] @@ -127,3 +152,49 @@ def prediction_ops(value:UserInput, dep:Tuple[Pipeline,np.ndarray] = Depends(get status_code=201, content=msg ) +@app.post("/predict/file") +async def prediction_via_file_ops(payload:UploadFile, dep: Tuple[Pipeline,np.ndarray] = Depends(get_model)): + pipe, column_names = dep + expected_upload_cols = ['ra', 'dec', 'redshift', 'psfMag_r', 'u', 'g', 'r', 'i', 'z'] + + accepted_exts = [".csv"] + extension = Path(payload.filename).suffix + if extension.lower() not in accepted_exts: + raise HTTPException( + status_code=422, + detail=f"Uploaded data must be in '.csv' format, got {extension} instead" + ) + + try: + df = pd.read_csv(payload.file) + except Exception as e: + raise HTTPException(status_code=422, detail=f"Failed to parse CSV file tracking: {str(e)}") + df = upload_validator(df, expected_upload_cols) + + # Feature Engineering (Vectorized) + df['u_g_color'] = df['u'] - df['g'] + df['g_r_color'] = df['g'] - df['r'] + df['r_i_color'] = df['r'] - df['i'] + df['i_z_color'] = df['i'] - df['z'] + df = df.drop(columns=['u', 'g', 'r', 'i', 'z']) + + # Reorder columns to match the pipeline's expected order (excluding 'class') + model_features = [col for col in column_names.tolist() if col != "class"] + df = df[model_features] + + pred_label:list[float] = pipe.predict(df).tolist() + pred_proba:list[list[float]] = pipe.predict_proba(df).tolist() + + # Postprocessing + label_map = {0: "GALAXY", 1: "STAR", 2: "QSO"} + pred_label:list[str] = [label_map.get(pred) for pred in pred_label] + pred_proba = [[round(r, 3) for r in pred] for pred in pred_proba] + + msg = { + "message": "batch prediction successful", + "prediction": pred_label, + "probabilities": pred_proba + } + return JSONResponse( + status_code=201, content=msg + ) \ No newline at end of file diff --git a/app/static/script.js b/app/static/script.js index ccd2f7a..1f5e1f0 100644 --- a/app/static/script.js +++ b/app/static/script.js @@ -7,6 +7,7 @@ document.addEventListener("DOMContentLoaded", function() { // Initialize components initTabNavigation(); initPredictForm(); + initBatchPredictForm(); initKeyboardShortcuts(); }); @@ -65,6 +66,14 @@ function initPredictForm() { const formData = new FormData(form); const dataObj = Object.fromEntries(formData.entries()); + // Manual validation check to ensure no empty values + const requiredFields = ["ra", "dec", "redshift", "psfMag_r", "u", "g", "r", "i", "z"]; + for (let field of requiredFields) { + if (dataObj[field] === undefined || dataObj[field] === "") { + throw new Error("Please fill in all inputs before analyzing."); + } + } + // Convert numeric strings to numbers for (let key in dataObj) { if (!isNaN(dataObj[key]) && dataObj[key] !== "") { @@ -286,3 +295,179 @@ window.CosmoClassifier = { formatNumber, debounce }; + +/** + * Batch Prediction Form Handler + */ +function initBatchPredictForm() { + const fileInput = document.getElementById("batchFile"); + const dropzone = document.getElementById("fileUploadDropzone"); + const fileInfo = document.getElementById("fileInfo"); + const selectedFileName = document.getElementById("selectedFileName"); + const selectedFileSize = document.getElementById("selectedFileSize"); + const removeFileBtn = document.getElementById("removeFileBtn"); + const submitBtn = document.getElementById("batchPredictBtn"); + const form = document.getElementById("batchPredictForm"); + + if(!fileInput) return; + + const awaitingBatch = document.getElementById("awaiting-batch"); + const batchResult = document.getElementById("batch-result"); + const tbody = document.getElementById("batchTableBody"); + + const MAX_SIZE = 5 * 1024 * 1024; // 5MB + + function handleFile(file) { + if (!file) return; + + if (!file.name.toLowerCase().endsWith('.csv')) { + showToast("Only .csv files are allowed", "error"); + fileInput.value = ""; + return; + } + + if (file.size > MAX_SIZE) { + showToast("File size exceeds 5MB limit", "error"); + fileInput.value = ""; + return; + } + + selectedFileName.textContent = file.name; + selectedFileSize.textContent = (file.size / 1024 / 1024).toFixed(2) + " MB"; + + dropzone.classList.add("hidden"); + fileInfo.classList.remove("hidden"); + submitBtn.disabled = false; + } + + fileInput.addEventListener("change", (e) => { + handleFile(e.target.files[0]); + }); + + ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => { + dropzone.addEventListener(eventName, preventDefaults, false); + }); + + function preventDefaults(e) { + e.preventDefault(); + e.stopPropagation(); + } + + ['dragenter', 'dragover'].forEach(eventName => { + dropzone.addEventListener(eventName, () => { + dropzone.style.borderColor = "var(--accent-secondary)"; + }, false); + }); + + ['dragleave', 'drop'].forEach(eventName => { + dropzone.addEventListener(eventName, () => { + dropzone.style.borderColor = ""; + }, false); + }); + + dropzone.addEventListener('drop', (e) => { + let dt = e.dataTransfer; + let files = dt.files; + if (files.length) { + fileInput.files = files; + handleFile(files[0]); + } + }, false); + + removeFileBtn.addEventListener("click", () => { + fileInput.value = ""; + dropzone.classList.remove("hidden"); + fileInfo.classList.add("hidden"); + submitBtn.disabled = true; + }); + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + + if (!fileInput.files[0]) return; + + setLoadingState(submitBtn, true); + awaitingBatch.style.display = 'none'; + batchResult.classList.add("hidden"); + + try { + const formData = new FormData(); + formData.append("payload", fileInput.files[0]); + + const response = await fetch("/predict/file", { + method: "POST", + body: formData + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || errorData.message || "Batch prediction failed"); + } + + const data = await response.json(); + + renderBatchResults(data); + showToast('Batch classification complete!', 'success'); + + } catch (error) { + console.error(error); + showToast(`Error: ${error.message}`, 'error'); + awaitingBatch.style.display = 'flex'; + } finally { + setLoadingState(submitBtn, false); + } + }); + + let batchChart = null; + + function renderBatchResults(data) { + const preds = data.prediction; + const probs = data.probabilities; + + tbody.innerHTML = ""; + + let counts = { "GALAXY": 0, "STAR": 0, "QSO": 0 }; + + preds.forEach((pred, index) => { + counts[pred] = (counts[pred] || 0) + 1; + + const maxProb = (Math.max(...probs[index]) * 100).toFixed(1); + + const tr = document.createElement("tr"); + tr.innerHTML = ` + Row ${index + 1} + ${pred} + ${maxProb}% + `; + tbody.appendChild(tr); + }); + + batchResult.classList.remove("hidden"); + + // Render Chart + const ctx = document.getElementById('batchPieChart').getContext('2d'); + if (batchChart) batchChart.destroy(); + + batchChart = new Chart(ctx, { + type: 'doughnut', + data: { + labels: ['GALAXY', 'STAR', 'QSO'], + datasets: [{ + data: [counts['GALAXY'], counts['STAR'], counts['QSO']], + backgroundColor: ['#7000ff', '#00d4ff', '#ff6b6b'], + borderWidth: 0 + }] + }, + options: { + responsive: true, + maintainAspectRatio: false, + plugins: { + legend: { + position: 'bottom', + labels: { color: '#ffffff' } + } + } + } + }); + } +} diff --git a/app/static/style.css b/app/static/style.css index 58ff40c..3447a6f 100644 --- a/app/static/style.css +++ b/app/static/style.css @@ -1034,6 +1034,224 @@ input:focus ~ .input-focus-line { text-shadow: 0 0 10px rgba(0, 212, 255, 0.3); } +/* ======================================== + Batch UI Enhancements + ======================================== */ + +.file-upload-container { + border: 2px dashed rgba(255, 255, 255, 0.2); + border-radius: 12px; + padding: 2rem 1rem; + text-align: center; + transition: all 0.3s ease; + background: rgba(0, 0, 0, 0.2); + position: relative; + cursor: pointer; + margin-bottom: 1.5rem; +} + +.file-upload-container:hover { + border-color: var(--accent-secondary); + background: rgba(0, 212, 255, 0.05); +} + +.file-input-hidden { + position: absolute; + width: 0; + height: 0; + opacity: 0; + visibility: hidden; +} + +.file-upload-label { + display: flex; + flex-direction: column; + align-items: center; + cursor: pointer; + gap: 1rem; +} + +.upload-icon { + width: 48px; + height: 48px; + color: var(--accent-secondary); + background: rgba(0, 212, 255, 0.1); + border-radius: 50%; + padding: 12px; +} + +.upload-title { + font-size: 1.1rem; + font-weight: 600; + color: var(--text-primary); + display: block; +} + +.upload-subtitle { + font-size: 0.85rem; + color: var(--text-muted); +} + +.file-info { + display: flex; + align-items: center; + justify-content: center; + gap: 1rem; + margin-top: 1rem; + padding: 0.5rem 1rem; + background: rgba(0, 212, 255, 0.1); + border-radius: 8px; + border: 1px solid rgba(0, 212, 255, 0.2); +} + +.file-name { + font-family: var(--font-mono); + color: var(--text-primary); + font-size: 0.9rem; + max-width: 200px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.file-size { + color: var(--text-secondary); + font-size: 0.8rem; +} + +.btn-remove-file { + background: none; + border: none; + color: var(--error); + cursor: pointer; + font-size: 1.2rem; + line-height: 1; +} + +.btn-remove-file:hover { + color: #ff8a80; +} + +.batch-table-container { + max-height: 300px; + overflow-y: auto; + border-radius: 8px; + border: 1px solid var(--border-color); + background: rgba(0, 0, 0, 0.2); +} + +.batch-table { + width: 100%; + border-collapse: collapse; + text-align: left; + font-size: 0.85rem; +} + +.batch-table th, .batch-table td { + padding: 0.5rem; + border-bottom: 1px solid rgba(255, 255, 255, 0.05); +} + +.batch-badge { + font-size: 0.85rem; + padding: 2px 8px; + border-radius: 4px; + font-weight: 600; + color: #fff; +} + +.batch-galaxy { background: var(--galaxy); } +.batch-star { background: var(--star); } +.batch-qso { background: var(--qso); } + +.batch-table th { + background: rgba(255, 255, 255, 0.05); + color: var(--text-secondary); + font-weight: 600; + position: sticky; + top: 0; + z-index: 10; +} + +.mt-4 { + margin-top: 1rem; +} + +.chart-container { + width: 100%; + max-width: 200px; + margin: 0 auto; +} + +.batch-guidelines { + margin-top: 1.5rem; + padding: 1rem; + background: rgba(68, 138, 255, 0.05); + border: 1px solid rgba(68, 138, 255, 0.2); + border-radius: 12px; +} + +.guidelines-header { + display: flex; + align-items: center; + gap: 0.5rem; + margin-bottom: 0.75rem; + color: var(--info); +} + +.info-icon { + width: 18px; + height: 18px; +} + +.guidelines-header h4 { + font-size: 0.9rem; + font-weight: 600; +} + +.guideline-list { + list-style: none; + display: flex; + flex-direction: column; + gap: 0.5rem; +} + +.guideline-list li { + font-size: 0.8rem; + color: var(--text-secondary); + line-height: 1.4; + padding-left: 1.2rem; + position: relative; +} + +.guideline-list li::before { + content: '•'; + position: absolute; + left: 0; + color: var(--info); +} + +.guideline-list code { + background: rgba(255, 255, 255, 0.1); + padding: 0.1rem 0.3rem; + border-radius: 4px; + font-family: var(--font-mono); + font-size: 0.75rem; + color: var(--accent-secondary); +} + +.guideline-warning { + color: #ff8a80 !important; + background: rgba(255, 82, 82, 0.05); + padding: 0.5rem 0.5rem 0.5rem 1.2rem !important; + border-radius: 6px; + margin-top: 0.25rem; +} + +.guideline-warning::before { + display: none; +} + /* ======================================== Responsive Design ======================================== */ diff --git a/app/templates/index.html b/app/templates/index.html index b306408..d76cbbf 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -39,6 +39,12 @@ Predict + + + + +
+ +
+ + + +
+
+ + + + + +

Dataset Format Guidelines

+
+ +
+ + + +
+
+
+

Batch Results

+
+
+ +
+
+ + + + + +
+

Awaiting data...

+ Upload a CSV file to see predictions +
+ + + +
+
+
+ + + +
@@ -293,6 +415,8 @@

~100K Training Samples

+ + diff --git a/requirements.txt b/requirements.txt index 609c7d4..b53b9df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,8 @@ Flask==3.1.2 fonttools==4.61.1 gunicorn==23.0.0 h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 idna==3.11 imbalanced-learn==0.14.1 itsdangerous==2.2.0 @@ -34,6 +36,7 @@ pydantic_core==2.41.5 pyparsing==3.3.1 python-dateutil==2.9.0.post0 python-dotenv==1.2.1 +python-multipart==0.0.24 pytz==2025.2 requests==2.32.5 scikit-learn==1.8.0