diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index f7fbfbd..25f9d2f 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -59,5 +59,12 @@ jobs: nohup uvicorn app.app:app & sleep 10 curl -I http://127.0.0.1:8000 + + # Test batch prediction endpoint + echo "ra,dec,redshift,psfMag_r,u,g,r,i,z" > dummy_test.csv + echo "0.1,0.2,0.3,1.0,2.0,3.0,4.0,5.0,6.0" >> dummy_test.csv + curl -X POST http://127.0.0.1:8000/predict/file -H "accept: application/json" -F "payload=@dummy_test.csv;type=text/csv" + rm dummy_test.csv + pkill -f uvicorn diff --git a/Datasets/sample_generator.py b/Datasets/sample_generator.py new file mode 100644 index 0000000..bc943b3 --- /dev/null +++ b/Datasets/sample_generator.py @@ -0,0 +1,50 @@ +""" +sample_generator.py + +This script is used to generate samples directly from the main dataset. +These samples are used to test the `/predict/file` route. +To run it, enter this in your command line: +``` +python -m Datasets.sample_generator.py +``` +""" + +import pandas as pd +from pathlib import Path + +DATASET_PATH = Path("Datasets","SDSS_DR18.csv") +OUTPUT_PATH = Path("Datasets","samples.csv") + +def preprocess(df: pd.DataFrame) -> pd.DataFrame: + drop_cols = ["objid", "specobjid", "run", "rerun", "camcol", + "field", "plate", "mjd", "fiberid"] + df = df.drop(columns=[c for c in drop_cols if c in df.columns]) + + df["class"] = df["class"].map({"GALAXY": 0, "STAR": 1, "QSO": 2}) + + df = df[["ra", "dec", "redshift","psfMag_r", "u", "g", "r", "i", "z", "class"]].copy() + + return df + + +def stratified_sample(df: pd.DataFrame, total_samples: int, class_col: str = "class", random_state: int = 42) -> pd.DataFrame: + class_counts = df[class_col].value_counts(normalize=True) + class_n = (class_counts * total_samples).round().astype(int) + + diff = total_samples - class_n.sum() + class_n.iloc[0] += diff + + return df.groupby(class_col, group_keys=False).apply( + lambda x: x.sample(n=min(class_n[x.name], len(x)), random_state=random_state) + ) + + +df_raw = pd.read_csv(DATASET_PATH) + +df_processed = preprocess(df_raw) + +sample = stratified_sample(df_processed, total_samples=100) + +sample = sample.drop(columns=["class"]) + +sample.to_csv(OUTPUT_PATH,index=False) \ No newline at end of file diff --git a/app/app.py b/app/app.py index d9569a8..377f312 100644 --- a/app/app.py +++ b/app/app.py @@ -1,17 +1,20 @@ -from fastapi import FastAPI, Depends, Request +from pydantic import BaseModel +from fastapi import FastAPI, Depends, Request, UploadFile from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles -from fastapi.exceptions import RequestValidationError +from fastapi.exceptions import RequestValidationError, HTTPException from typing import Tuple,List from sklearn.pipeline import Pipeline from models.fit import main from .schema.validation import UserInput from pathlib import Path from contextlib import asynccontextmanager -import joblib import numpy as np +import pandas as pd +import joblib import os +# Helper for loading and self-healing the artifacts def load_or_create_models() -> Tuple[Pipeline,np.ndarray]: model_path = Path("models","estimator.pkl") columns_path = Path("models","column_names.pkl") @@ -34,6 +37,36 @@ def load_or_create_models() -> Tuple[Pipeline,np.ndarray]: return pipe,column_names except Exception as e: raise RuntimeError(f"Artifacts could not be loaded: {e}") + +# Helper for performing feature engineering +def preprocess_data(value:BaseModel) -> dict: + # Preprocessing + value:dict = value.model_dump(mode="json") + kick = ["u","g","r","i","z"] + final_value = {key:val for key,val in value.items() if key not in kick} + final_value["u_g_color"] = safe_sub("u","g",value) + final_value["g_r_color"] = safe_sub("g","r",value) + final_value["r_i_color"] = safe_sub("r","i",value) + final_value["i_z_color"] = safe_sub("i","z",value) + + return final_value + +# Helper for validating user-provided csv files +def upload_validator(df:pd.DataFrame,col_names:List[str]) -> pd.DataFrame: + if df.columns.tolist() != col_names: + raise HTTPException( + status_code=422, detail="Uploaded csv file does not match the expected " \ + "columns or their order" + ) + + try: + df = df.astype(float) + except Exception as e: + raise HTTPException( + status_code=422, + detail="All values must be numeric (float-compatible)" + ) + return df @asynccontextmanager async def lifespan(app:FastAPI): @@ -91,15 +124,7 @@ def home(): def prediction_ops(value:UserInput, dep:Tuple[Pipeline,np.ndarray] = Depends(get_model)): pipe, column_names = dep column_names:List[str] = column_names.tolist() - - # Preprocessing - value:dict = value.model_dump(mode="json") - kick = ["u","g","r","i","z"] - final_value = {key:val for key,val in value.items() if key not in kick} - final_value["u_g_color"] = safe_sub("u","g",value) - final_value["g_r_color"] = safe_sub("g","r",value) - final_value["r_i_color"] = safe_sub("r","i",value) - final_value["i_z_color"] = safe_sub("i","z",value) + final_value:dict = preprocess_data(value) # Order Check and running prediction final_res = [] @@ -127,3 +152,49 @@ def prediction_ops(value:UserInput, dep:Tuple[Pipeline,np.ndarray] = Depends(get status_code=201, content=msg ) +@app.post("/predict/file") +async def prediction_via_file_ops(payload:UploadFile, dep: Tuple[Pipeline,np.ndarray] = Depends(get_model)): + pipe, column_names = dep + expected_upload_cols = ['ra', 'dec', 'redshift', 'psfMag_r', 'u', 'g', 'r', 'i', 'z'] + + accepted_exts = [".csv"] + extension = Path(payload.filename).suffix + if extension.lower() not in accepted_exts: + raise HTTPException( + status_code=422, + detail=f"Uploaded data must be in '.csv' format, got {extension} instead" + ) + + try: + df = pd.read_csv(payload.file) + except Exception as e: + raise HTTPException(status_code=422, detail=f"Failed to parse CSV file tracking: {str(e)}") + df = upload_validator(df, expected_upload_cols) + + # Feature Engineering (Vectorized) + df['u_g_color'] = df['u'] - df['g'] + df['g_r_color'] = df['g'] - df['r'] + df['r_i_color'] = df['r'] - df['i'] + df['i_z_color'] = df['i'] - df['z'] + df = df.drop(columns=['u', 'g', 'r', 'i', 'z']) + + # Reorder columns to match the pipeline's expected order (excluding 'class') + model_features = [col for col in column_names.tolist() if col != "class"] + df = df[model_features] + + pred_label:list[float] = pipe.predict(df).tolist() + pred_proba:list[list[float]] = pipe.predict_proba(df).tolist() + + # Postprocessing + label_map = {0: "GALAXY", 1: "STAR", 2: "QSO"} + pred_label:list[str] = [label_map.get(pred) for pred in pred_label] + pred_proba = [[round(r, 3) for r in pred] for pred in pred_proba] + + msg = { + "message": "batch prediction successful", + "prediction": pred_label, + "probabilities": pred_proba + } + return JSONResponse( + status_code=201, content=msg + ) \ No newline at end of file diff --git a/app/static/script.js b/app/static/script.js index ccd2f7a..1f5e1f0 100644 --- a/app/static/script.js +++ b/app/static/script.js @@ -7,6 +7,7 @@ document.addEventListener("DOMContentLoaded", function() { // Initialize components initTabNavigation(); initPredictForm(); + initBatchPredictForm(); initKeyboardShortcuts(); }); @@ -65,6 +66,14 @@ function initPredictForm() { const formData = new FormData(form); const dataObj = Object.fromEntries(formData.entries()); + // Manual validation check to ensure no empty values + const requiredFields = ["ra", "dec", "redshift", "psfMag_r", "u", "g", "r", "i", "z"]; + for (let field of requiredFields) { + if (dataObj[field] === undefined || dataObj[field] === "") { + throw new Error("Please fill in all inputs before analyzing."); + } + } + // Convert numeric strings to numbers for (let key in dataObj) { if (!isNaN(dataObj[key]) && dataObj[key] !== "") { @@ -286,3 +295,179 @@ window.CosmoClassifier = { formatNumber, debounce }; + +/** + * Batch Prediction Form Handler + */ +function initBatchPredictForm() { + const fileInput = document.getElementById("batchFile"); + const dropzone = document.getElementById("fileUploadDropzone"); + const fileInfo = document.getElementById("fileInfo"); + const selectedFileName = document.getElementById("selectedFileName"); + const selectedFileSize = document.getElementById("selectedFileSize"); + const removeFileBtn = document.getElementById("removeFileBtn"); + const submitBtn = document.getElementById("batchPredictBtn"); + const form = document.getElementById("batchPredictForm"); + + if(!fileInput) return; + + const awaitingBatch = document.getElementById("awaiting-batch"); + const batchResult = document.getElementById("batch-result"); + const tbody = document.getElementById("batchTableBody"); + + const MAX_SIZE = 5 * 1024 * 1024; // 5MB + + function handleFile(file) { + if (!file) return; + + if (!file.name.toLowerCase().endsWith('.csv')) { + showToast("Only .csv files are allowed", "error"); + fileInput.value = ""; + return; + } + + if (file.size > MAX_SIZE) { + showToast("File size exceeds 5MB limit", "error"); + fileInput.value = ""; + return; + } + + selectedFileName.textContent = file.name; + selectedFileSize.textContent = (file.size / 1024 / 1024).toFixed(2) + " MB"; + + dropzone.classList.add("hidden"); + fileInfo.classList.remove("hidden"); + submitBtn.disabled = false; + } + + fileInput.addEventListener("change", (e) => { + handleFile(e.target.files[0]); + }); + + ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => { + dropzone.addEventListener(eventName, preventDefaults, false); + }); + + function preventDefaults(e) { + e.preventDefault(); + e.stopPropagation(); + } + + ['dragenter', 'dragover'].forEach(eventName => { + dropzone.addEventListener(eventName, () => { + dropzone.style.borderColor = "var(--accent-secondary)"; + }, false); + }); + + ['dragleave', 'drop'].forEach(eventName => { + dropzone.addEventListener(eventName, () => { + dropzone.style.borderColor = ""; + }, false); + }); + + dropzone.addEventListener('drop', (e) => { + let dt = e.dataTransfer; + let files = dt.files; + if (files.length) { + fileInput.files = files; + handleFile(files[0]); + } + }, false); + + removeFileBtn.addEventListener("click", () => { + fileInput.value = ""; + dropzone.classList.remove("hidden"); + fileInfo.classList.add("hidden"); + submitBtn.disabled = true; + }); + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + + if (!fileInput.files[0]) return; + + setLoadingState(submitBtn, true); + awaitingBatch.style.display = 'none'; + batchResult.classList.add("hidden"); + + try { + const formData = new FormData(); + formData.append("payload", fileInput.files[0]); + + const response = await fetch("/predict/file", { + method: "POST", + body: formData + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || errorData.message || "Batch prediction failed"); + } + + const data = await response.json(); + + renderBatchResults(data); + showToast('Batch classification complete!', 'success'); + + } catch (error) { + console.error(error); + showToast(`Error: ${error.message}`, 'error'); + awaitingBatch.style.display = 'flex'; + } finally { + setLoadingState(submitBtn, false); + } + }); + + let batchChart = null; + + function renderBatchResults(data) { + const preds = data.prediction; + const probs = data.probabilities; + + tbody.innerHTML = ""; + + let counts = { "GALAXY": 0, "STAR": 0, "QSO": 0 }; + + preds.forEach((pred, index) => { + counts[pred] = (counts[pred] || 0) + 1; + + const maxProb = (Math.max(...probs[index]) * 100).toFixed(1); + + const tr = document.createElement("tr"); + tr.innerHTML = ` +