diff --git a/models/checklist.py b/models/checklist.py index 4872506..c22ce52 100644 --- a/models/checklist.py +++ b/models/checklist.py @@ -1,17 +1,16 @@ -from sqlalchemy import Column, Integer, String, Boolean, TIMESTAMP, ForeignKey, text +from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, text from sqlalchemy.orm import relationship from .base import Base + class Checklist(Base): __tablename__ = "checklist" - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) - title = Column(String(255), nullable=False) - is_clear = Column(Boolean, nullable=False, server_default=text("0")) - created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) - updated_at = Column(TIMESTAMP, nullable=False, - server_default=text("CURRENT_TIMESTAMP"), - onupdate=text("CURRENT_TIMESTAMP")) + id = Column(Integer, primary_key=True, autoincrement=True) + # DB 컬럼명은 u_id 이므로 명시적으로 매핑 + user_id = Column("u_id", Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + # DB 컬럼명은 checklist_title 이므로 명시적으로 매핑 + title = Column("checklist_title", String(255), nullable=False) + is_clear = Column(Boolean, nullable=False, server_default=text("0")) user = relationship("User", back_populates="checklists") diff --git a/models/file.py b/models/file.py index 55eb21b..437736b 100644 --- a/models/file.py +++ b/models/file.py @@ -1,37 +1,21 @@ - -from sqlalchemy.orm import relationship - from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text from sqlalchemy.orm import relationship from .base import Base + class File(Base): __tablename__ = "file" - id = Column(Integer, primary_key=True, autoincrement=True) - - user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False) - folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True) - note_id = Column(Integer, ForeignKey('note.id', ondelete='CASCADE'), nullable=True) + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) + note_id = Column(Integer, ForeignKey("note.id", ondelete="CASCADE"), nullable=True) original_name = Column(String(255), nullable=False) - saved_path = Column(String(512), nullable=False) - content_type = Column(String(100), nullable=False) - created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + saved_path = Column(String(512), nullable=False) + content_type = Column(String(100), nullable=False) + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) - # ✅ 관계 - user = relationship("User", back_populates="files") + # relationships + user = relationship("User", back_populates="files") folder = relationship("Folder", back_populates="files") - note = relationship("Note", back_populates="files") - - user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) - folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) - note_id = Column(Integer, ForeignKey("note.id", ondelete="SET NULL"), nullable=True) - original_name = Column(String(255), nullable=False) - saved_path = Column(String(512), nullable=False) - content_type = Column(String(100), nullable=False) - created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) - - # relations - user = relationship("User", back_populates="files") - note = relationship("Note", back_populates="files") - + note = relationship("Note", back_populates="files") diff --git a/models/folder.py b/models/folder.py index 92d6bf0..35949f4 100644 --- a/models/folder.py +++ b/models/folder.py @@ -2,30 +2,24 @@ from sqlalchemy.orm import relationship from .base import Base + class Folder(Base): __tablename__ = "folder" - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) - name = Column(String(100), nullable=False) - parent_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + name = Column(String(100), nullable=False) + parent_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) - updated_at = Column(TIMESTAMP, nullable=False, - - server_default=text('CURRENT_TIMESTAMP'), - onupdate=text('CURRENT_TIMESTAMP')) - - # ✅ 관계 - user = relationship("User", back_populates="folders") - parent = relationship("Folder", remote_side=[id], backref="children") - notes = relationship("Note", back_populates="folder", cascade="all, delete") - files = relationship("File", back_populates="folder", cascade="all, delete") - - server_default=text("CURRENT_TIMESTAMP"), - onupdate=text("CURRENT_TIMESTAMP")) - - # relations - user = relationship("User") - parent = relationship("Folder", remote_side=[id], backref="children") - notes = relationship("Note", back_populates="folder", cascade="all, delete") + updated_at = Column( + TIMESTAMP, + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + onupdate=text("CURRENT_TIMESTAMP"), + ) + # relationships + user = relationship("User", back_populates="folders") + parent = relationship("Folder", remote_side=[id], backref="children") + notes = relationship("Note", back_populates="folder", cascade="all, delete") + files = relationship("File", back_populates="folder", cascade="all, delete") diff --git a/models/note.py b/models/note.py index 680a666..66770c3 100644 --- a/models/note.py +++ b/models/note.py @@ -2,29 +2,26 @@ from sqlalchemy.orm import relationship from .base import Base + class Note(Base): __tablename__ = "note" - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) - folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) - title = Column(String(255), nullable=False) - content = Column(Text) - is_favorite = Column(Boolean, nullable=False, server_default=text("FALSE")) + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) + title = Column(String(255), nullable=False) + content = Column(Text) + is_favorite = Column(Boolean, nullable=False, server_default=text("FALSE")) last_accessed = Column(TIMESTAMP, nullable=True) - created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) - updated_at = Column(TIMESTAMP, nullable=False, - - server_default=text('CURRENT_TIMESTAMP'), - onupdate=text('CURRENT_TIMESTAMP')) - - # ✅ 관계 - - server_default=text("CURRENT_TIMESTAMP"), - onupdate=text("CURRENT_TIMESTAMP")) - - # relations + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) + updated_at = Column( + TIMESTAMP, + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + onupdate=text("CURRENT_TIMESTAMP"), + ) - user = relationship("User", back_populates="notes") + # relationships + user = relationship("User", back_populates="notes") folder = relationship("Folder", back_populates="notes") - files = relationship("File", back_populates="note", cascade="all, delete") + files = relationship("File", back_populates="note", cascade="all, delete") diff --git a/routers/checklist.py b/routers/checklist.py index 376d0a4..8e6f654 100644 --- a/routers/checklist.py +++ b/routers/checklist.py @@ -7,13 +7,27 @@ router = APIRouter(prefix="/api/v1/checklists", tags=["Checklists"]) +@router.get("") +def list_checklists(db: Session = Depends(get_db), user=Depends(get_current_user)): + """현재 사용자 체크리스트 전체 목록 반환(최신순).""" + items = ( + db.query(Checklist) + .filter(Checklist.user_id == user.u_id) + .order_by(Checklist.id.desc()) + .all() + ) + return [ + {"id": it.id, "title": it.title, "is_clear": int(bool(it.is_clear))} + for it in items + ] + @router.post("", status_code=status.HTTP_201_CREATED) def create_checklist(title: str, db: Session = Depends(get_db), user=Depends(get_current_user)): obj = Checklist(user_id=user.u_id, title=title, is_clear=False) db.add(obj) db.commit() db.refresh(obj) - return {"id": obj.id, "title": obj.title, "is_clear": obj.is_clear} + return {"id": obj.id, "title": obj.title, "is_clear": int(bool(obj.is_clear))} @router.patch("/{checklist_id}/clear") def set_clear_state(checklist_id: int, is_clear: bool, db: Session = Depends(get_db), user=Depends(get_current_user)): @@ -23,4 +37,4 @@ def set_clear_state(checklist_id: int, is_clear: bool, db: Session = Depends(get obj.is_clear = bool(is_clear) db.commit() db.refresh(obj) - return {"id": obj.id, "is_clear": obj.is_clear} + return {"id": obj.id, "is_clear": int(bool(obj.is_clear))} diff --git a/routers/file.py b/routers/file.py index 2b9d7ef..0590742 100644 --- a/routers/file.py +++ b/routers/file.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Optional, List -from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, status, Query, Response +from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, status, Query, Response, Request from fastapi.responses import FileResponse from sqlalchemy.orm import Session @@ -11,55 +11,17 @@ from models.file import File as FileModel from models.note import Note as NoteModel from utils.jwt_utils import get_current_user - -cuda_gpu -# 추가: 파일명 인코딩용 import urllib.parse -# ------------------------------- -# 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화) -# ------------------------------- -import easyocr -reader = easyocr.Reader(["ko", "en"], gpu=True) - -# ------------------------------- -# 2) Hugging Face TrOCR 모델용 파이프라인 (GPU 사용) -# ------------------------------- -from transformers import pipeline - -hf_trocr_printed = pipeline( - "image-to-text", - model="microsoft/trocr-base-printed", - device=0, - trust_remote_code=True -) -hf_trocr_handwritten = pipeline( - "image-to-text", - model="microsoft/trocr-base-handwritten", - device=0, - trust_remote_code=True -) -hf_trocr_small_printed = pipeline( - "image-to-text", - model="microsoft/trocr-small-printed", - device=0, - trust_remote_code=True -) -hf_trocr_large_printed = pipeline( - "image-to-text", - model="microsoft/trocr-large-printed", - device=0, - trust_remote_code=True -) - -BASE_UPLOAD_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "uploads" -) - -# 공통 OCR 파이프라인 -from utils.ocr import run_pipeline, detect_type +# 공통 OCR 파이프라인 (내부 유틸로 위임) — 안전 임포트 +try: + from utils.ocr import run_pipeline, detect_type +except ModuleNotFoundError: + import sys as _sys, os as _os + _ROOT = _os.path.abspath(_os.path.join(_os.path.dirname(__file__), "..")) + if _ROOT not in _sys.path: + _sys.path.insert(0, _ROOT) + from utils.ocr import run_pipeline, detect_type from schemas.file import OCRResponse # 허용 확장자 @@ -96,9 +58,9 @@ def run(cmd: list[str]): langs = None tess_ver = None if tesseract_ok: - tess_ver = run(["tesseract", "--version"]).splitlines()[0] if tesseract_ok else None + tess_ver = run(["tesseract", "--version"]).splitlines()[0] langs_out = run(["tesseract", "--list-langs"]) - langs = [l.strip() for l in langs_out.splitlines() if l and not l.lower().startswith("list of available")] if langs_out and not langs_out.startswith("ERR:") else None + langs = [l.strip() for l in langs_out.splitlines() if l and not l.lower().startswith("list of available")] return { "tesseract": tesseract_ok, @@ -120,7 +82,8 @@ async def upload_file( note_id: Optional[int] = Form(None), upload_file: UploadFile = File(...), db: Session = Depends(get_db), - current_user = Depends(get_current_user) + current_user = Depends(get_current_user), + request: Request = None, ): orig_filename: str = upload_file.filename or "unnamed" content_type: str = upload_file.content_type or "application/octet-stream" @@ -128,7 +91,6 @@ async def upload_file( user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id)) os.makedirs(user_dir, exist_ok=True) - saved_filename = orig_filename saved_path = os.path.join(user_dir, saved_filename) if os.path.exists(saved_path): @@ -143,9 +105,7 @@ async def upload_file( break counter += 1 - # 저장 - try: with open(saved_path, "wb") as buffer: content = await upload_file.read() @@ -153,7 +113,6 @@ async def upload_file( except Exception as e: raise HTTPException(status_code=500, detail=f"파일 저장 실패: {e}") - # note_id가 있으면 해당 노트 확인 note_obj = None if note_id is not None: @@ -166,7 +125,6 @@ async def upload_file( raise HTTPException(status_code=404, detail="해당 노트를 찾을 수 없습니다.") # DB 메타 기록 - new_file = FileModel( user_id=current_user.u_id, folder_id=None if note_id else folder_id, @@ -179,8 +137,12 @@ async def upload_file( db.commit() db.refresh(new_file) - base_url = os.getenv("BASE_API_URL", "http://localhost:8000") + # Prefer request base_url if available; fallback to env + base_url = (str(request.base_url).rstrip('/') if request else None) or os.getenv("BASE_API_URL", "http://localhost:8000") download_url = f"{base_url}/api/v1/files/download/{new_file.id}" + public_url = None + if content_type.startswith("image/"): + public_url = f"{base_url}/static/{current_user.u_id}/{urllib.parse.quote(saved_filename)}" # note_id가 있으면 노트 본문에 첨부 링크 삽입 if note_obj: @@ -190,7 +152,6 @@ async def upload_file( embed = f"\n\n[{new_file.original_name}]({download_url}) (PDF 보기)\n\n" else: embed = f"\n\n[{new_file.original_name}]({download_url})\n\n" - note_obj.content = (note_obj.content or "") + embed db.commit() db.refresh(note_obj) @@ -198,6 +159,7 @@ async def upload_file( return { "file_id": new_file.id, "url": download_url, + "public_url": public_url, "original_name": new_file.original_name, "folder_id": new_file.folder_id, "note_id": new_file.note_id, @@ -251,10 +213,6 @@ def download_file( if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="서버에 파일이 존재하지 않습니다.") - # 원본 파일명 UTF-8 URL 인코딩 처리 - quoted_name = urllib.parse.quote(file_obj.original_name, safe='') - content_disposition = f"inline; filename*=UTF-8''{quoted_name}" - return FileResponse( path=file_path, media_type=file_obj.content_type, @@ -271,10 +229,6 @@ def download_file( "en": "eng", "english": "eng" } def normalize_langs(raw: str) -> str: - """ - 입력 예: 'koreng', 'kor,eng', 'ko+en', 'ko,en', 'korean+english' - 출력 예: 'kor+eng' - """ if not raw: return "kor+eng" s = raw.strip().lower().replace(" ", "") @@ -285,7 +239,6 @@ def normalize_langs(raw: str) -> str: norm: list[str] = [] for p in parts: norm.append(LANG_ALIAS.get(p, p)) - # 중복 제거(순서 보존) seen, out = set(), [] for p in norm: if p not in seen: @@ -302,84 +255,17 @@ async def ocr_and_create_note( file: Optional[UploadFile] = File(None, description="기본 업로드 필드명"), ocr_file: Optional[UploadFile] = File(None, description="과거 호환 업로드 필드명"), folder_id: Optional[int] = Form(None), - langs: str = Query("kor+eng", description="Tesseract 언어코드(유연 입력 허용: koreng, ko+en 등)"), - max_pages: int = Query(50, ge=1, le=500, description="최대 처리 페이지 수(기본 50)"), + langs: str = Query("kor+eng", description="언어코드"), + max_pages: int = Query(50, ge=1, le=500, description="최대 페이지 수"), db: Session = Depends(get_db), current_user = Depends(get_current_user) ): - - """ - • ocr_file: 이미지 파일(UploadFile) - • 1) EasyOCR로 기본 텍스트 추출 (GPU 모드) - • 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU) - • 3) 가장 긴 결과를 최종 OCR 결과로 선택 - • 4) Note로 저장 및 결과 반환 - """ - - # 1) 이미지 로드 (PIL) - contents = await ocr_file.read() - try: - image = Image.open(io.BytesIO(contents)).convert("RGB") - except Exception as e: - raise HTTPException(status_code=400, detail=f"이미지 처리 실패: {e}") - - # 2) EasyOCR로 텍스트 추출 - try: - image_np = np.array(image) - easy_results = reader.readtext(image_np) # GPU 모드 사용 - easy_text = " ".join([res[1] for res in easy_results]) - except Exception: - easy_text = "" - - # 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input) - hf_texts: List[str] = [] - try: - out1 = hf_trocr_printed(image) - if isinstance(out1, list) and "generated_text" in out1[0]: - hf_texts.append(out1[0]["generated_text"].strip()) - - out2 = hf_trocr_handwritten(image) - if isinstance(out2, list) and "generated_text" in out2[0]: - hf_texts.append(out2[0]["generated_text"].strip()) - - out3 = hf_trocr_small_printed(image) - if isinstance(out3, list) and "generated_text" in out3[0]: - hf_texts.append(out3[0]["generated_text"].strip()) - - out4 = hf_trocr_large_printed(image) - if isinstance(out4, list) and "generated_text" in out4[0]: - hf_texts.append(out4[0]["generated_text"].strip()) - except Exception: - # TrOCR 중 오류 발생 시 무시하고 계속 진행 - pass - - # 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택 - candidates = [t for t in [easy_text] + hf_texts if t and t.strip()] - if not candidates: - raise HTTPException(status_code=500, detail="텍스트를 인식할 수 없습니다.") - - ocr_text = max(candidates, key=lambda s: len(s)) - - # 5) 새 노트 생성 및 DB에 저장 - try: - new_note = NoteModel( - user_id=current_user.u_id, - folder_id=folder_id, - title="OCR 결과", - content=ocr_text # **원본 OCR 텍스트만 저장** - - # 422 방지: 파일 필드명 유연 처리 upload = file or ocr_file if upload is None: - raise HTTPException( - status_code=400, - detail="업로드 파일이 필요합니다. 필드명은 'file' 또는 'ocr_file'을 사용하세요." - ) + raise HTTPException(status_code=400, detail="파일이 필요합니다.") filename = upload.filename or "uploaded" - mime = upload.content_type - - # 확장자 검사 + mime = upload.content_type or "application/octet-stream" _, ext = os.path.splitext(filename) ext = ext.lower() if ext and ext not in ALLOWED_ALL_EXTS: @@ -388,31 +274,14 @@ async def ocr_and_create_note( mime=mime, page_count=0, results=[], - warnings=[f"허용되지 않는 확장자({ext}). 허용: {sorted(ALLOWED_ALL_EXTS)}"], - note_id=None, - text=None, - - ) - - # 타입 판별 - ftype = detect_type(filename, mime) - if ftype == "unknown": - return OCRResponse( - filename=filename, - mime=mime, - page_count=0, - results=[], - warnings=["지원되지 않는 파일 형식입니다."], + warnings=[f"허용되지 않는 확장자({ext})."], note_id=None, text=None, ) data = await upload.read() - - # 언어코드 정규화 langs = normalize_langs(langs) - # 멀티엔진 앙상블 OCR 파이프라인 실행 pipe = run_pipeline( filename=filename, mime=mime, @@ -421,7 +290,6 @@ async def ocr_and_create_note( max_pages=max_pages, ) - # 페이지 텍스트 합치기 merged_text = "\n\n".join([ item.get("text", "") for item in (pipe.get("results") or []) if item.get("text") ]).strip() @@ -445,7 +313,6 @@ async def ocr_and_create_note( pipe["note_id"] = note_id pipe["text"] = merged_text or None - return pipe @@ -493,15 +360,12 @@ async def upload_audio_and_transcribe( NoteModel.id == note_id, NoteModel.user_id == user.u_id ).first() - if not note: raise HTTPException(status_code=404, detail="해당 노트를 찾을 수 없습니다.") - note.content = (note.content or "") + "\n\n" + transcript note.updated_at = datetime.utcnow() db.commit() db.refresh(note) - else: new_note = NoteModel( user_id=user.u_id, @@ -513,10 +377,7 @@ async def upload_audio_and_transcribe( db.commit() db.refresh(new_note) - return { - "message": "STT 및 노트 저장 완료", - "transcript": transcript - } + return {"message": "STT 및 노트 저장 완료", "transcript": transcript} @router.options("/ocr") diff --git a/routers/folder.py b/routers/folder.py index 960c06f..22ca762 100644 --- a/routers/folder.py +++ b/routers/folder.py @@ -1,11 +1,13 @@ # Backend/routers/folder.py -from fastapi import APIRouter, Depends, HTTPException, status +import os +from fastapi import APIRouter, Depends, HTTPException, status, Request from sqlalchemy.orm import Session from typing import List, Optional from db import get_db from models.folder import Folder from models.note import Note +from models.file import File as FileModel from schemas.folder import FolderCreate, FolderResponse, FolderUpdate from schemas.note import NoteResponse from utils.jwt_utils import get_current_user @@ -40,6 +42,7 @@ def get_all_descendant_folder_ids(db: Session, parent_id: int, user_id: int) -> summary="유저의 모든 폴더(트리 구조) 및 폴더별 노트 리스트 반환" ) def list_folders( + request: Request, db: Session = Depends(get_db), user = Depends(get_current_user) ): @@ -71,7 +74,49 @@ def list_folders( else: roots.append(f) - return roots + BASE_API_URL = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + + def serialize_note(note: Note) -> NoteResponse: + files = ( + db.query(FileModel) + .filter(FileModel.note_id == note.id, FileModel.user_id == note.user_id) + .order_by(FileModel.created_at.desc()) + .all() + ) + file_items = [ + { + "file_id": f.id, + "original_name": f.original_name, + "content_type": f.content_type, + "url": f"{BASE_API_URL}/api/v1/files/download/{f.id}", + "created_at": f.created_at, + } + for f in files + ] + return NoteResponse( + id=note.id, + user_id=note.user_id, + folder_id=note.folder_id, + title=note.title, + content=note.content, + is_favorite=bool(note.is_favorite), + last_accessed=note.last_accessed, + created_at=note.created_at, + updated_at=note.updated_at, + files=file_items, + ) + + def to_folder_response(f: Folder) -> FolderResponse: + return FolderResponse( + id=f.id, + user_id=f.user_id, + name=f.name, + parent_id=f.parent_id, + children=[to_folder_response(c) for c in getattr(f, 'children', [])], + notes=[serialize_note(n) for n in folder_note_map.get(f.id, [])], + ) + + return [to_folder_response(r) for r in roots] @router.post( diff --git a/routers/note.py b/routers/note.py index b2038c3..446e3a6 100644 --- a/routers/note.py +++ b/routers/note.py @@ -1,18 +1,21 @@ import os from dotenv import load_dotenv -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from typing import List from datetime import datetime import traceback +import re +import json from db import get_db, SessionLocal from models.note import Note from models.file import File as FileModel from schemas.note import NoteCreate, NoteUpdate, NoteResponse, FavoriteUpdate, NoteFile from utils.jwt_utils import get_current_user -from utils.llm import stream_summary_with_langchain +from utils.llm import stream_summary_with_langchain, _strip_top_level_h1_outside_code, _hf_generate_once, _system_prompt +from utils.llm import _hf_generate_once, _system_prompt load_dotenv() HF_TOKEN = os.getenv("HF_API_TOKEN") @@ -26,57 +29,124 @@ # ───────────────────────────────────────────── # 공통: Note → NoteResponse 직렬화 + files 채우기 # ───────────────────────────────────────────── -def serialize_note(db: Session, note: Note) -> NoteResponse: +def serialize_note(db: Session, note: Note, base_url: str) -> NoteResponse: """ - Note ORM 객체를 NoteResponse로 변환하면서 - note_id로 연결된 File들을 찾아 files 배열에 채워 넣는다. + Note ORM → NoteResponse 수동 매핑. + 관계(note.files)로 인해 Pydantic가 ORM 객체를 바로 검증하려다 실패하는 문제를 피하기 위해 + 기본 스칼라 필드만 직접 채우고, files는 별도 쿼리로 구성한다. """ - # 파일 목록 조회 (노트 첨부) files = ( db.query(FileModel) .filter(FileModel.note_id == note.id, FileModel.user_id == note.user_id) .order_by(FileModel.created_at.desc()) .all() ) - - file_items: List[NoteFile] = [] - for f in files: - file_items.append( - NoteFile( - file_id=f.id, - original_name=f.original_name, - content_type=f.content_type, - url=f"{BASE_API_URL}/api/v1/files/download/{f.id}", - created_at=f.created_at, - ) + file_items: List[NoteFile] = [ + NoteFile( + file_id=f.id, + original_name=f.original_name, + content_type=f.content_type, + url=f"{base_url}/api/v1/files/download/{f.id}", + created_at=f.created_at, ) + for f in files + ] + + return NoteResponse( + id=note.id, + user_id=note.user_id, + folder_id=note.folder_id, + title=note.title, + content=note.content, + is_favorite=bool(note.is_favorite), + last_accessed=note.last_accessed, + created_at=note.created_at, + updated_at=note.updated_at, + files=file_items, + ) + + +def _fallback_extractive_summary(text: str) -> str: + """Simple extractive fallback: pick leading sentences and format as TL;DR + bullets.""" + if not text: + return "## TL;DR\n요약할 내용이 없습니다." + sents = re.split(r"(?<=[.!?。])\s+|\n+", text) + sents = [s.strip() for s in sents if s.strip()] + if not sents: + return "## TL;DR\n요약할 내용이 없습니다." + tl = sents[0][:400] + bullets = [] + for s in sents[1:6]: + short = s[:200] + bullets.append(f"- {short}") + body = "\n\n## 핵심 요점\n" + "\n".join(bullets) if bullets else "" + return f"## TL;DR\n{tl}{body}" + + +def _is_summary_complete(s: str) -> bool: + """Heuristic: check presence of key sections and reasonable length.""" + if not s or not s.strip(): + return False + low = s.lower() + # require TL;DR or 핵심 요점 and some detail + if ('## tl;dr' in low or '## 핵심' in low or '## 핵심 요점' in low) and len(s) > 300: + return True + # if contains multiple section headers, consider complete + headers = len(re.findall(r"^##\s+", s, flags=re.M)) + if headers >= 2 and len(s) > 200: + return True + # otherwise likely incomplete + return False - # Pydantic 모델 생성 (from_attributes=True로 기본 필드 매핑) - data = NoteResponse.model_validate(note, from_attributes=True) - # files 필드 교체 - data.files = file_items - return data + +async def _ensure_completion(full: str, domain: str | None = None, length: str = 'long') -> str: + """If `full` looks truncated, attempt up to 3 continuation passes to complete it.""" + try: + for i in range(3): + if _is_summary_complete(full) and re.search(r"[\.\!\?]\s*$", full.strip()): + return full + # build continuation prompt + sys_prompt = _system_prompt(domain or 'general', phase='final', output_format='md', length=length) + cont_prompt = "The following summary appears incomplete. Continue and finish the summary without repeating previous text:\n\n" + full + "\n\nContinue:" + try: + cont = await _hf_generate_once(sys_prompt, cont_prompt, max_new_tokens=int(os.getenv('HF_MAX_NEW_TOKENS_LONG', '32000'))) + except Exception: + cont = '' + if cont and cont.strip(): + # append continuation + full = (full + "\n\n" + cont.strip()).strip() + else: + break + except Exception: + pass + return full # 1) 모든 노트 조회 @router.get("/notes", response_model=List[NoteResponse]) def list_notes( + request: Request, + q: str | None = Query(default=None, description="Optional search query (title or content)"), db: Session = Depends(get_db), user = Depends(get_current_user) ): - notes = ( - db.query(Note) - .filter(Note.user_id == user.u_id) - .order_by(Note.created_at.desc()) - .all() - ) + """List notes for the current user. If `q` is provided, filter by title or content (case-insensitive). + """ + query = db.query(Note).filter(Note.user_id == user.u_id) + if q and q.strip(): + like = f"%{q.strip()}%" + query = query.filter((Note.title.ilike(like)) | (Note.content.ilike(like))) + + notes = query.order_by(Note.created_at.desc()).all() # 각 노트의 files도 채워 반환 - return [serialize_note(db, n) for n in notes] + base_url = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + return [serialize_note(db, n, base_url) for n in notes] # 2) 최근 접근한 노트 조회 (상위 10개) @router.get("/notes/recent", response_model=List[NoteResponse]) def recent_notes( + request: Request, db: Session = Depends(get_db), user = Depends(get_current_user) ): @@ -87,12 +157,14 @@ def recent_notes( .limit(10) .all() ) - return [serialize_note(db, n) for n in notes] + base_url = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + return [serialize_note(db, n, base_url) for n in notes] # 3) 노트 생성 @router.post("/notes", response_model=NoteResponse) def create_note( + request: Request, req: NoteCreate, db: Session = Depends(get_db), user = Depends(get_current_user) @@ -106,12 +178,14 @@ def create_note( db.add(note) db.commit() db.refresh(note) - return serialize_note(db, note) + base_url = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + return serialize_note(db, note, base_url) # 4) 노트 수정 (제목/내용/폴더) @router.patch("/notes/{note_id}", response_model=NoteResponse) def update_note( + request: Request, note_id: int, req: NoteUpdate, db: Session = Depends(get_db), @@ -133,12 +207,14 @@ def update_note( note.updated_at = datetime.utcnow() db.commit() db.refresh(note) - return serialize_note(db, note) + base_url = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + return serialize_note(db, note, base_url) # 5) 노트 단일 조회 (마지막 접근 시간 업데이트 포함) @router.get("/notes/{note_id}", response_model=NoteResponse) def get_note( + request: Request, note_id: int, db: Session = Depends(get_db), user = Depends(get_current_user) @@ -150,7 +226,8 @@ def get_note( note.last_accessed = datetime.utcnow() db.commit() db.refresh(note) - return serialize_note(db, note) + base_url = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + return serialize_note(db, note, base_url) # 6) 노트 삭제 @@ -172,6 +249,7 @@ def delete_note( # 7) 즐겨찾기 토글 @router.patch("/notes/{note_id}/favorite", response_model=NoteResponse) def toggle_favorite( + request: Request, note_id: int, req: FavoriteUpdate, db: Session = Depends(get_db), @@ -185,7 +263,8 @@ def toggle_favorite( note.updated_at = datetime.utcnow() db.commit() db.refresh(note) - return serialize_note(db, note) + base_url = os.getenv("BASE_API_URL") or str(request.base_url).rstrip('/') + return serialize_note(db, note, base_url) # ───────────────────────────────────────────── @@ -205,15 +284,85 @@ async def summarize_stream_langchain( if not note or not (note.content or "").strip(): raise HTTPException(status_code=404, detail="요약 대상 없음") + + async def event_gen(): parts = [] - async for sse in stream_summary_with_langchain(note.content, domain=domain, longdoc=longdoc): + # Default to a comprehensive (long) summary when called without explicit options + async for sse in stream_summary_with_langchain(note.content, domain=domain, longdoc=longdoc, length='long', tone='neutral', output_format='md'): parts.append(sse.removeprefix("data: ").strip()) yield sse.encode() full = "".join(parts).strip() + # attempt to complete if truncated + try: + full = await _ensure_completion(full, domain=domain, length='long') + except Exception: + pass + # If streamed output looks incomplete, attempt a single-shot completion pass + try: + if not _is_summary_complete(full): + try: + print('[summarize] partial output detected, performing completion pass') + sys_prompt = _system_prompt(domain or 'general', phase='final', output_format='md', length='long') + cont = await _hf_generate_once(sys_prompt, "Existing partial summary:\n\n" + full + "\n\nPlease expand and complete the summary, preserving facts and following the output format.", max_new_tokens=int(os.getenv('HF_MAX_NEW_TOKENS_LONG', '20000'))) + if cont and cont.strip(): + full = (full + "\n\n" + cont.strip()).strip() + print('[summarize] completion pass appended, new length=', len(full)) + except Exception as e: + print('[summarize] completion pass failed:', e) + except Exception: + pass + # If model produced empty output, fall back to a simple extractive summary + if not (full or "").strip(): + try: + sents = re.split(r"(?<=[.!?。])\s+|\n+", note.content or "") + sents = [p.strip() for p in sents if p.strip()] + head = sents[:6] + tl = head[0] if head else (note.content or "")[:200] + bullets = [f"- {p}" for p in head[1:5]] + fb = "## TL;DR\n" + tl + "\n\n## 핵심 요점\n" + "\n".join(bullets) + full = fb + except Exception: + full = (note.content or "")[:800] + try: + print(f"[summarize-sync] generated full length={len(full)} preview={repr(full[:200])}") + except Exception: + pass + # Remove local temp file paths (e.g. macOS /var/... or file://...) which shouldn't be persisted + try: + # remove explicit file://... patterns + full = re.sub(r"file://\S+", "", full) + # remove absolute tmp paths like /var/... (up to whitespace or closing paren) + full = re.sub(r"/var/[^\s)]+", "", full) + # remove parenthesis-wrapped local paths in markdown images: ![alt](/path/to/file.png) + full = re.sub(r"!\[([^\]]*)\]\([^)]*(/var/[^)\s]+)[)]", r"![\1]()", full) + except Exception: + pass + # Strip any top-level H1 headings that the model may have added (outside code fences) + try: + full = _strip_top_level_h1_outside_code(full) + except Exception: + # fallback: naive removal of a single leading H1 + full = re.sub(r"^\s*#\s.*?\n+", "", full, count=1) + # Ensure non-empty summary; if model produced nothing, use extractive fallback + if not (full or "").strip(): + try: + full = _fallback_extractive_summary(note.content) + print(f"[summarize] fallback summary used length={len(full)}") + except Exception: + full = (note.content or '')[:800] + + # Ensure non-empty summary; if model produced nothing, use extractive fallback + if not (full or "").strip(): + try: + full = _fallback_extractive_summary(note.content) + print(f"[summarize-sync] fallback summary used length={len(full)}") + except Exception: + full = (note.content or '')[:800] + if full: - # Create a new summary note in the same folder with title '요약' - title = (note.title or "").strip() + "요약" + # Create a new summary note in the same folder with title ' — 요약' + title = (note.title or "").strip() + " — 요약" if len(title) > 255: title = title[:255] new_note = Note( @@ -226,8 +375,42 @@ async def event_gen(): db.commit() db.refresh(new_note) try: - # Optional: notify created note id - yield f"data: SUMMARY_NOTE_ID:{new_note.id}\n\n".encode() + # log created summary id and content preview for debugging + print(f"[summarize] created summary note id={new_note.id} for note_id={note_id}") + try: + print("[summarize] saved content length=", len(new_note.content or "")) + print("[summarize] saved content preview=", repr((new_note.content or "")[:400])) + except Exception: + pass + except Exception: + pass + # normal streaming path: notify created note via SSE + try: + # notify created note: include serialized note JSON so client can render immediately + base_url = os.getenv("BASE_API_URL") or BASE_API_URL + note_obj = serialize_note(db, new_note, base_url) + payload = {"summary_note": note_obj.dict()} + yield f"data: {json.dumps(payload, default=str)}\n\n".encode() + except Exception: + # fallback to ID-only message + try: + yield f"data: SUMMARY_NOTE_ID:{new_note.id}\n\n".encode() + except Exception: + pass + else: + # As an extra fallback, aggregate streamed parts (if any) to ensure coverage + try: + agg = "\n\n".join(parts) if parts else (note.content or '')[:4000] + fallback_full = "## Aggregated streamed parts\n\n" + agg + title = (note.title or "").strip() + " — 요약" + new_note2 = Note(user_id=user.u_id, folder_id=note.folder_id, title=title, content=fallback_full) + db.add(new_note2) + db.commit() + db.refresh(new_note2) + try: + yield f"data: SUMMARY_NOTE_ID:{new_note2.id}\n\n".encode() + except Exception: + pass except Exception: pass @@ -236,3 +419,127 @@ async def event_gen(): media_type="text/event-stream", headers={"Cache-Control": "no-cache"} ) + + + +@router.post("/notes/{note_id}/generate-quiz") +def generate_quiz( + note_id: int, + count: int = Query(default=5, ge=1, le=20), + db: Session = Depends(get_db), + user = Depends(get_current_user) +): + """간단한 규칙 기반 퀴즈 생성(대형 모델 없이 동작).""" + note = db.query(Note).filter(Note.id == note_id, Note.user_id == user.u_id).first() + if not note or not (note.content or "").strip(): + raise HTTPException(status_code=404, detail="퀴즈를 생성할 노트가 없습니다") + + text = (note.content or "").strip() + # 문장 단위 분할 + import re, random + sents = re.split(r"(?<=[.!?。])\s+|\n+", text) + sents = [s.strip() for s in sents if len(s.strip()) >= 8] + random.seed(note_id) + random.shuffle(sents) + + quizzes = [] + for s in sents: + if len(quizzes) >= count: + break + # 공백 기준 토큰화 후, 길이 4 이상인 토큰을 빈칸으로 + toks = s.split() + cand = [i for i, t in enumerate(toks) if len(re.sub(r"\W+", "", t)) >= 4] + if not cand: + continue + idx = cand[0] + answer = re.sub(r"^[\W_]+|[\W_]+$", "", toks[idx]) + toks[idx] = "_____" + q = " ".join(toks) + quizzes.append({ + "type": "cloze", + "question": q, + "answer": answer, + "source": s, + }) + + # 보강: 부족하면 참/거짓 생성 + i = 0 + while len(quizzes) < count and i < len(sents): + stmt = sents[i] + i += 1 + if len(stmt) < 12: + continue + false_stmt = stmt.replace("이다", "아니다").replace("다.", "가 아니다.") + quizzes.append({ + "type": "boolean", + "question": stmt, + "answer": True, + }) + if len(quizzes) >= count: + break + quizzes.append({ + "type": "boolean", + "question": false_stmt, + "answer": False, + }) + + return {"note_id": note.id, "count": len(quizzes), "items": quizzes} + + +# Convenience synchronous summarization endpoint (returns created note JSON). +@router.post("/notes/{note_id}/summarize_sync", response_model=NoteResponse) +async def summarize_sync( + note_id: int, + domain: str | None = Query(default=None, description="meeting | code | paper | general | auto(None)"), + longdoc: bool = Query(default=True, description="Enable long-document map→reduce"), + db: Session = Depends(get_db), + user = Depends(get_current_user) +): + note = db.query(Note).filter(Note.id == note_id, Note.user_id == user.u_id).first() + if not note or not (note.content or "").strip(): + raise HTTPException(status_code=404, detail="요약 대상 없음") + + parts = [] + async for sse in stream_summary_with_langchain(note.content, domain=domain, longdoc=longdoc, length='long', tone='neutral', output_format='md'): + parts.append(sse.removeprefix("data: ").strip()) + full = "".join(parts).strip() + + # sanitize local paths and strip top-level H1 + try: + full = re.sub(r"file://\S+", "", full) + full = re.sub(r"/var/[^\s)]+", "", full) + full = _strip_top_level_h1_outside_code(full) + except Exception: + try: + full = re.sub(r"^\s*#\s.*?\n+", "", full, count=1) + except Exception: + pass + + # If model produced empty output, use extractive fallback + if not (full or "").strip(): + try: + full = _fallback_extractive_summary(note.content) + print(f"[summarize_sync] fallback used length={len(full)}") + except Exception: + full = (note.content or '')[:800] + + title = (note.title or "").strip() + " — 요약" + if len(title) > 255: + title = title[:255] + new_note = Note( + user_id=user.u_id, + folder_id=note.folder_id, + title=title, + content=full, + ) + db.add(new_note) + db.commit() + db.refresh(new_note) + try: + print(f"[summarize_sync] created summary note id={new_note.id} for note_id={note_id}") + print("[summarize_sync] saved content length=", len(new_note.content or "")) + print("[summarize_sync] saved content preview=", repr((new_note.content or "")[:400])) + except Exception: + pass + base_url = os.getenv("BASE_API_URL") or "http://localhost:8000" + return serialize_note(db, new_note, base_url) diff --git a/routers/note_sync.py b/routers/note_sync.py new file mode 100644 index 0000000..38416cf --- /dev/null +++ b/routers/note_sync.py @@ -0,0 +1,117 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session +from typing import Optional +import os, re + +from db import get_db +from models.note import Note +from schemas.note import NoteResponse +from utils.jwt_utils import get_current_user +from utils.llm import ( + stream_summary_with_langchain, + _strip_top_level_h1_outside_code, + _hf_generate_once, + _system_prompt, + _chunk_text, + _compose_user_payload, +) + +router = APIRouter(prefix="/api/v1", tags=["Notes"]) + + +@router.post("/notes/{note_id}/summarize_sync", response_model=NoteResponse) +async def summarize_sync( + note_id: int, + domain: Optional[str] = Query(default=None, description="meeting | code | paper | general | auto(None)"), + longdoc: bool = Query(default=True), + db: Session = Depends(get_db), + user = Depends(get_current_user) +): + note = db.query(Note).filter(Note.id == note_id, Note.user_id == user.u_id).first() + if not note or not (note.content or "").strip(): + raise HTTPException(status_code=404, detail="요약 대상 없음") + + # Use map->reduce single-shot generation to avoid streaming truncation. + # 1) chunk text + chunks = _chunk_text(note.content or "", chunk_chars=int(os.getenv('SUMMARY_CHUNK_CHARS','20000')), overlap=int(os.getenv('SUMMARY_CHUNK_OVERLAP','2000'))) + map_sys = _system_prompt(domain or 'general', phase='map', output_format='md', length='long') + partials = [] + for idx, ch in enumerate(chunks, 1): + try: + map_input = _compose_user_payload(ch, "", "md", length='short', tone='neutral') + part = await _hf_generate_once(map_sys, map_input, max_new_tokens=int(os.getenv('HF_MAP_MAX_NEW_TOKENS','12000'))) + except Exception: + part = (ch or '')[:800] + partials.append(f"[Chunk {idx}]\n{part.strip()}") + + reduce_text = "\n\n".join(partials) + reduce_sys = _system_prompt(domain or 'general', phase='reduce', output_format='md', length='long') + reduce_input = _compose_user_payload(reduce_text, "", "md", length='long', tone='neutral') + try: + full = await _hf_generate_once(reduce_sys, reduce_input, max_new_tokens=int(os.getenv('HF_MAX_NEW_TOKENS_LONG','32000'))) + except Exception: + full = (note.content or '')[:4000] + # If partial/short, try a completion pass + try: + # local completeness heuristic + def is_complete(s: str) -> bool: + if not s or not s.strip(): + return False + low = s.lower() + if ('## tl;dr' in low or '## 핵심' in low or '## 핵심 요점' in low) and len(s) > 300: + return True + headers = len(re.findall(r"^##\s+", s, flags=re.M)) + return headers >= 2 and len(s) > 200 + + if not is_complete(full): + sys_prompt = _system_prompt(domain or 'general', phase='final', output_format='md', length='long') + cont = await _hf_generate_once(sys_prompt, "Existing partial summary:\n\n" + (full or "") + "\n\nPlease expand and complete the summary, preserving facts and following the output format.", max_new_tokens=int(os.getenv('HF_MAX_NEW_TOKENS_LONG', '20000'))) + if cont and cont.strip(): + full = (full + "\n\n" + cont.strip()).strip() + except Exception: + pass + # sanitize + try: + full = re.sub(r"file://\S+", "", full) + full = re.sub(r"/var/[^\s)]+", "", full) + except Exception: + pass + try: + full = _strip_top_level_h1_outside_code(full) + except Exception: + full = re.sub(r"^\s*#\s.*?\n+", "", full, count=1) + + # If still incomplete or suspiciously short, fall back to aggregated chunk summaries to ensure coverage + try: + if not _is_summary_complete(full) or len(full) < 300: + agg = "\n\n".join(partials) + full = "## Aggregated chunk summaries\n\n" + agg + print('[summarize_sync] using aggregated chunk summaries, length=', len(full)) + except Exception: + pass + + title = (note.title or "").strip() + " — 요약" + if len(title) > 255: + title = title[:255] + new_note = Note( + user_id=user.u_id, + folder_id=note.folder_id, + title=title, + content=full, + ) + db.add(new_note) + db.commit() + db.refresh(new_note) + base_url = os.getenv("BASE_API_URL") or "http://localhost:8000" + return NoteResponse( + id=new_note.id, + user_id=new_note.user_id, + folder_id=new_note.folder_id, + title=new_note.title, + content=new_note.content, + is_favorite=bool(new_note.is_favorite), + last_accessed=new_note.last_accessed, + created_at=new_note.created_at, + updated_at=new_note.updated_at, + files=[], + ) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..960b82a --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,8 @@ +"""Utility package for backend. + +Exposes submodules like `ocr`, `jwt_utils`, etc. +""" + +# Re-export common names if needed +# from .ocr import run_pipeline, detect_type # optional convenience + diff --git a/utils/llm.py b/utils/llm.py index fb39e40..5406ebc 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -12,7 +12,8 @@ # =============== 필터: 사고과정 유사 문장 =============== _THOUGHT_PAT = re.compile( - r"^\s*(okay|let\s*me|i\s*need\s*to|first[, ]|then[, ]|next[, ]|in summary|먼저|그\s*다음|요약하면)", + # only filter a few clear English meta-intro phrases to avoid removing valid Korean sentences + r"^\s*(okay|let\s+me|i\s+need\s+to|in summary)\b", re.I, ) @@ -71,6 +72,8 @@ async def stream_summary_with_langchain( # [CHANGED] 출력/보강 옵션 output_format: str = "md", # "md" | "html" augment_web: bool = False, # True면 위키 요약 보강 + length: str = "medium", # short | medium | long + tone: str = "neutral", # neutral | formal | casual | concise ): """ 한국어/영어 자동 인지 후 '동일 언어'로 요약을 스트리밍합니다. @@ -97,48 +100,83 @@ async def stream_summary_with_langchain( enable_long = longdoc and len(text or "") > int(os.getenv("SUMMARY_LONGDOC_CHAR_LIMIT", "3500")) backend = os.getenv("SUMMARY_BACKEND", "hf").lower() + + # choose token budget by requested length + # Increase defaults to allow richer, more complete summaries. Can be overridden via env vars. + # Aggressively increase defaults to allow very long, comprehensive summaries. + # These can still be tuned via environment variables if needed. + # Aggressively increase defaults to allow very long, comprehensive summaries. + # These can still be tuned via environment variables if needed. + if length == 'short': + token_budget = int(os.getenv("HF_MAX_NEW_TOKENS_SHORT", "8000")) + elif length == 'medium': + token_budget = int(os.getenv("HF_MAX_NEW_TOKENS_MEDIUM", "16000")) + else: + token_budget = int(os.getenv("HF_MAX_NEW_TOKENS_LONG", "32000")) + if not enable_long: - sys_txt = _system_prompt(dom, phase="final", output_format=output_format) - user_payload = _compose_user_payload(text, extra_context, output_format) # [CHANGED] - if backend == "ollama": - async for s in _stream_with_ollama(user_payload, system_text=sys_txt, output_format=output_format): - yield s - else: - async for s in _stream_with_hf(user_payload, system_text=sys_txt, output_format=output_format): - yield s + sys_txt = _system_prompt(dom, phase="final", output_format=output_format, length=length) + user_payload = _compose_user_payload(text, extra_context, output_format, length=length, tone=tone) # [CHANGED] + # Temporarily set HF token budget env so downstream generator respects it + old_budget = os.environ.get('HF_MAX_NEW_TOKENS') + os.environ['HF_MAX_NEW_TOKENS'] = str(token_budget) + try: + if backend == "ollama": + async for s in _stream_with_ollama(user_payload, system_text=sys_txt, output_format=output_format): + yield s + else: + async for s in _stream_with_hf(user_payload, system_text=sys_txt, output_format=output_format): + yield s + finally: + if old_budget is None: + os.environ.pop('HF_MAX_NEW_TOKENS', None) + else: + os.environ['HF_MAX_NEW_TOKENS'] = old_budget return # Long-doc: Map (chunk summaries) → Reduce (final synthesis streamed) chunks = _chunk_text( text, - chunk_chars=int(os.getenv("SUMMARY_CHUNK_CHARS", "2000")), - overlap=int(os.getenv("SUMMARY_CHUNK_OVERLAP", "200")), + chunk_chars=int(os.getenv("SUMMARY_CHUNK_CHARS", "20000")), + overlap=int(os.getenv("SUMMARY_CHUNK_OVERLAP", "2000")), ) - map_sys = _system_prompt(dom, phase="map", output_format=output_format) + map_sys = _system_prompt(dom, phase="map", output_format=output_format, length=length) partials: list[str] = [] for idx, ch in enumerate(chunks, 1): try: - map_input = _compose_user_payload(ch, "", output_format) # [CHANGED] - part = await _hf_generate_once(map_sys, map_input, max_new_tokens=int(os.getenv("HF_MAP_MAX_NEW_TOKENS", "220"))) + map_input = _compose_user_payload(ch, "", output_format, length=length, tone=tone) # [CHANGED] + part = await _hf_generate_once(map_sys, map_input, max_new_tokens=int(os.getenv("HF_MAP_MAX_NEW_TOKENS", "12000"))) except Exception: part = ch[:500] partials.append(f"[Chunk {idx}]\n{part.strip()}") reduce_text = "\n\n".join(partials) - reduce_sys = _system_prompt(dom, phase="reduce", output_format=output_format) - reduce_input = _compose_user_payload(reduce_text, extra_context, output_format) # [CHANGED] + reduce_sys = _system_prompt(dom, phase="reduce", output_format=output_format, length=length) + reduce_input = _compose_user_payload(reduce_text, extra_context, output_format, length=length, tone=tone) # [CHANGED] - if backend == "ollama": - async for s in _stream_with_ollama(reduce_input, system_text=reduce_sys, output_format=output_format): - yield s - else: - async for s in _stream_with_hf(reduce_input, system_text=reduce_sys, output_format=output_format): - yield s + # For reduce/final stage, also apply token budget + old_budget = os.environ.get('HF_MAX_NEW_TOKENS') + os.environ['HF_MAX_NEW_TOKENS'] = str(token_budget) + try: + if backend == "ollama": + async for s in _stream_with_ollama(reduce_input, system_text=reduce_sys, output_format=output_format): + yield s + else: + async for s in _stream_with_hf(reduce_input, system_text=reduce_sys, output_format=output_format): + yield s + finally: + if old_budget is None: + os.environ.pop('HF_MAX_NEW_TOKENS', None) + else: + os.environ['HF_MAX_NEW_TOKENS'] = old_budget # =============== 도메인/언어 감지 =============== def _detect_domain(t: str) -> str: s = (t or "").lower() + # lecture / slides signals + if re.search(r"\blecture\b|강의|슬라이드|ppt|slide|강의자료|강의록", s): + return "lecture" # code-like signals if re.search(r"\b(def |class |import |#include|public\s+class|function\s|=>|:=)", s) or re.search(r"```|\bdiff --git\b|\bcommit\b", s): return "code" @@ -160,16 +198,18 @@ def _detect_lang(t: str) -> str: # =============== 시스템 프롬프트 =============== # [CHANGED] 출력 포맷(MD/HTML) 지원 + 마크다운 간격 규칙 + 도메인별 포함 요소 힌트 -def _system_prompt(domain: str, phase: str = "final", output_format: str = "md") -> str: +def _system_prompt(domain: str, phase: str = "final", output_format: str = "md", length: str = "medium") -> str: # phase: map | reduce | final fmt = output_format.lower() base_rules = ( - "역할: 너는 사실 보존에 강한 전문 요약가다. 입력 텍스트의 언어를 감지하고, 반드시 동일한 언어로 작성한다. " - "핵심 정보(주요 주장/결론, 인물·기관·수치·날짜·지표·범위, 원인↔결과·조건·한계)를 빠짐없이 담되 군더더기와 반복을 제거한다. " - "추정·가치판단·조언은 덧붙이지 않는다. 사고과정(Chain-of-Thought), 단계 나열, 메타 코멘트는 출력하지 않는다." + "역할: 너는 사실 보존에 강한 전문 요약가다. 입력 텍스트의 언어(Korean/English)를 감지하고, 반드시 동일한 언어로 작성한다. " + "요약의 목적은 읽기 쉬운, GPT 스타일의 명확한 결과물을 만드는 것이다. 핵심 정보(주요 주장/결론, 인물·기관·수치·날짜·지표·범위, 원인↔결과·조건·한계)를 빠짐없이 담되 군더더기와 반복을 제거한다. " + "추정·가치판단·조언은 입력에서 명시적 근거가 없으면 추가하지 말고, 사고과정(Chain-of-Thought)이나 중간 추론은 출력하지 마라." ) if domain == "meeting": include_hint = "결정 사항, 책임자/기한이 명시된 액션, 남은 이슈, 다음 단계" + elif domain == "lecture": + include_hint = "핵심 개념, 주요 정의/공식, 예제/응용, 학습 포인트(요약된 학습 목표), 참고 자료" elif domain == "code": include_hint = "변경 목적/범위, 주요 API/모듈 영향, 호환성/리스크, 마이그레이션 포인트" elif domain == "paper": @@ -179,10 +219,9 @@ def _system_prompt(domain: str, phase: str = "final", output_format: str = "md") if fmt == "md": format_rule = ( - "출력 형식: Markdown. 다음 섹션을 사용하라 — " - "# 제목, ## 개요, ## 핵심 요점(불릿 3–7개), ## 상세 설명(문단 분리), " - "## 용어 정리(필요시), ## 한계/주의, ## 할 일(있다면), ## 참고/추가자료(있다면). " - "각 섹션 제목만 출력하고, 불필요한 프리앰블은 쓰지 마라." + "출력 형식: Markdown. 반드시 다음 섹션으로 구성하라(필요시 일부 생략 가능): " + "## TL;DR, ## 핵심 요점(불릿 3–8개), ## 상세 설명(문단), ## 용어 정리(선택), ## 한계/주의, ## 할 일(액션), ## 참고(선택). " + "절대 H1('# ')로 시작하지 말고, 불필요한 전언/사고과정/추론 과정을 출력하지 마라." ) else: format_rule = ( @@ -191,7 +230,12 @@ def _system_prompt(domain: str, phase: str = "final", output_format: str = "md") "

용어 정리

,

한계/주의

,

할 일

,

참고/추가자료

의 순서." ) - length_rule = "분량: 원문 대비 약 15–30%. 각 문단은 2–5문장." + if length == 'long': + length_rule = "분량: 충분히 상세하게, 원문 전반의 주요 포인트·예시·숫자·결론을 모두 포함하라(원문 대비 30–70% 분량 권장, 또는 토큰 예산 한도 내 최대로)." + elif length == 'short': + length_rule = "분량: 한두 문장 TL;DR 중심(간결)." + else: + length_rule = "분량: 원문 대비 약 15–30%. 각 문단은 2–5문장." # [CHANGED] 마크다운 간격 규칙 추가 md_spacing_rule = ( "마크다운 간격 규칙: 모든 헤더(#, ##, ### 등) 뒤에는 한 칸 공백을 두고, 헤더의 앞뒤에는 빈 줄 1줄을 둔다. " @@ -205,11 +249,53 @@ def _system_prompt(domain: str, phase: str = "final", output_format: str = "md") if phase == "map": scope = "이 청크만 대상으로 섹션 골격을 간략히 채워라. 과도한 요약 금지." elif phase == "reduce": - scope = "아래 청크 요약들을 중복 없이 통합해 일관된 섹션 구성을 완성하라. 흐름(원인→과정→결과)을 유지." + scope = "아래 청크 요약들을 중복 없이 통합해 일관된 섹션 구성을 완성하라. 흐름(원인→과정→결과)을 유지. 최종 요약은 누락이 없도록 모든 청크의 핵심을 포함하라." else: - scope = "전체 텍스트를 위 섹션 구조에 맞춰 응집력 있게 작성하라. 첫 줄부터 본문만 출력." + # Do not force a top-level H1; many clients render H1 differently. + scope = "전체 텍스트를 위 섹션 구조에 맞춰 응집력 있게 작성하라. 출력은 반드시 Markdown만 사용하라(원시 HTML 금지). 최상단 제목(H1)은 생략하거나 필요시만 사용하고, 주요 요약은 '## TL;DR' 또는 '## 핵심 요점'로 시작하라." + + # 명시적 예시 추가: (Korean short example) + example = ( + "\n\n--- 예시 출력 (한국어, medium) ---\n" + "## TL;DR\n" + "프로젝트 A의 기능 X가 2주 지연되어 배포 일정이 11/10로 변경됨. 주요 리스크는 외부 API 응답 지연.\n\n" + "## 핵심 요점\n" + "- 기능 X 구현 지연: 2주\n" + "- 배포 일정: 11/10\n" + "- QA 담당: 민수\n\n" + "## 할 일\n" + "- [개발팀] API 응답 문제 원인 분석 — 11/1\n" + "-------------------------------\n\n" + ) + + # 추가 지침: 출력은 반드시 위 섹션 구조를 따르고(필요시 일부 섹션은 생략 가능), 끝에 JSON 메타데이터 블록을 추가하라. + # 이 블록은 분석용이며, ```json로 fenced 되어야 한다. + # MUST (강제) 요건: + # - 출력은 절대 H1('# ')로 시작하지 말고, 반드시 '## TL;DR'로 시작하라. + # - 포함 필수 항목: Setting(배경), Inciting Incident(발단), Protagonist(주인공), Goal(목적), + # Stakes(위험/중요성), Key Events(핵심 사건), Next Steps/Actions(권장 조치). + # - 사실 기반(Fact preservation): 숫자, 날짜, 고유명사는 원문 그대로 보존. 입력에 없는 정보를 생성하지 마라(허위 생성 금지). + # - 길이 보장: 요청된 length가 'long'일 경우, 충분한 상세(증거·인용·핵심 문장 포함)를 제공하라. + # + # CHECKLIST: The following checklist MUST be present and satisfied in the summary (model must ensure each item is covered): + # 1) Setting/Background — where and in what context the document/event occurs + # 2) Inciting Incident — what triggered the situation or main event + # 3) Protagonist/Actors — who are the main people/agents involved + # 4) Goal/Purpose — what is being attempted or investigated + # 5) Stakes/Importance — why this matters, consequences if unresolved + # 6) Key Events/Findings — sequence of core events or main findings (with key numbers/dates) + # 7) Next Steps/Actions — recommended actions, owners and deadlines if present + # + # END-OF-SUMMARY REQUIREMENT: At the end of the Markdown output, include a fenced JSON block with keys: + # {"tl_dr":"...","tags":[...],"actions":[...],"language":"ko|en","missing":[...]}. + # The "missing" array must list any checklist items that could not be filled from the input (use empty array [] when all present). + # 메타데이터 키: tl_dr (string), tags (array of strings), actions (array of { assignee?, task, due? }), language (ko/en) + meta_hint = ( + "\n\n출력 후 반드시 JSON 메타데이터 블록을 추가하라. " + "형식: ```json\n{ \"tl_dr\": \"...\", \"tags\": [\"t1\",\"t2\"], \"actions\": [{\"assignee\": \"name\", \"task\": \"...\", \"due\": \"YYYY-MM-DD\"}], \"language\": \"ko\" }\n```\n" + ) - return f"{base_rules}\n포함 우선: {include_hint}\n{format_rule}\n{length_rule}\n{md_spacing_rule}\n{web_rule}\n{scope}" + return f"{base_rules}\n포함 우선: {include_hint}\n{format_rule}\n{length_rule}\n{md_spacing_rule}\n{web_rule}\n{scope}{example}{meta_hint}" # =============== HF backend (Transformers) =============== @@ -282,10 +368,16 @@ def load_model(with_token: bool): model, tok = try_load(primary) _HF_MODEL, _HF_TOKENIZER, _HF_NAME = model, tok, primary return _HF_MODEL, _HF_TOKENIZER - except Exception: - model, tok = try_load(fallback) - _HF_MODEL, _HF_TOKENIZER, _HF_NAME = model, tok, fallback - return _HF_MODEL, _HF_TOKENIZER + except Exception as e: + # 디스크 부족/네트워크 이슈 등으로 대형 모델 로딩 실패 시, 환경변수로 폴백 비활성화 가능 + if os.getenv("HF_DISABLE_FALLBACK", "1").lower() in ("1", "true", "yes"): + raise RuntimeError("HF_DISABLED") from e + try: + model, tok = try_load(fallback) + _HF_MODEL, _HF_TOKENIZER, _HF_NAME = model, tok, fallback + return _HF_MODEL, _HF_TOKENIZER + except Exception as e2: + raise RuntimeError("HF_DISABLED") from e2 def _build_prompt(tokenizer, system_text: str, user_text: str) -> str: @@ -305,8 +397,29 @@ def _build_prompt(tokenizer, system_text: str, user_text: str) -> str: ) +def _simple_fallback_summary(text: str, output_format: str = "md") -> list[str]: + """모델 로딩 실패 시 사용할 초경량 요약: 앞부분 일부와 불릿을 구성.""" + s = (text or "").strip() + if not s: + return ["요약할 내용이 없습니다."] + # 문장 단위로 잘라 앞부분 3~6문장을 사용 + parts = re.split(r"(?<=[.!?。])\s+|\n+", s) + parts = [p.strip() for p in parts if p.strip()] + head = parts[:6] + bullets = [f"- {p[:120]}{'…' if len(p) > 120 else ''}" for p in head[1:6]] + if output_format.lower() == "md": + out = [f"# 요약", head[0][:160] + ("…" if len(head[0]) > 160 else ""), "\n" ] + bullets + else: + out = [f"

요약

", f"

{head[0][:160]}{'…' if len(head[0])>160 else ''}

"] + [f"
  • {b[2:]}
  • " for b in bullets] + return out + async def _stream_with_hf(text: str, system_text: str | None = None, output_format: str = "md"): - model, tokenizer = _load_hf_model() + try: + model, tokenizer = _load_hf_model() + except Exception: + for line in _simple_fallback_summary(text, output_format=output_format): + yield f"data: {line}\n\n" + return sys_msg = system_text or ( # [CHANGED] 기본 시스템 프롬프트: Markdown 섹션 + 동일 언어 @@ -326,9 +439,9 @@ async def _stream_with_hf(text: str, system_text: str | None = None, output_form pass gen_kwargs = dict( - max_new_tokens=int(os.getenv("HF_MAX_NEW_TOKENS", "320")), # [CHANGED] 섹션 증가에 맞춰 상향 + max_new_tokens=int(os.getenv("HF_MAX_NEW_TOKENS", "32000")), # very generous default to allow extremely long summaries do_sample=False, - repetition_penalty=1.05, + repetition_penalty=float(os.getenv("HF_REPETITION_PENALTY", "1.02")), eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id or tokenizer.pad_token_id, streamer=streamer, @@ -448,12 +561,13 @@ def make_llm(model_name: str) -> ChatOllama: kwargs["num_thread"] = int(num_thread) except ValueError: pass + temp = float(os.getenv("OLLAMA_TEMPERATURE", "0.2")) return ChatOllama( base_url=base_url, model=model_name, streaming=True, callbacks=[cb], - temperature=0.4, + temperature=temp, **kwargs, ) @@ -531,21 +645,24 @@ def make_llm(model_name: str) -> ChatOllama: # =============== 보조 유틸리티 =============== # [CHANGED] 원문 + (선택) 추가자료를 모델에 전달하기 위한 합본 -def _compose_user_payload(main_text: str, extra_context: str, output_format: str) -> str: +def _compose_user_payload(main_text: str, extra_context: str, output_format: str, length: str = "medium", tone: str = "neutral") -> str: fmt = output_format.lower() + # Include user preferences (length, tone) to guide the summarizer + pref = f"요약 길이: {length}. 톤: {tone}." if fmt == "md": if extra_context: return ( "## 원문\n" f"{main_text}\n\n" "## 추가자료(요약)\n" - f"{extra_context}\n" + f"{extra_context}\n\n" + f"\n" ) - return main_text + return f"{main_text}\n\n" else: if extra_context: - return f"

    원문

    \n{main_text}\n\n

    추가자료(요약)

    \n{extra_context}\n" - return main_text + return f"

    원문

    \n{main_text}\n\n

    추가자료(요약)

    \n{extra_context}\n\n" + return f"{main_text}\n" def _is_augmentation_allowed() -> bool: """환경변수로 보강 ON/OFF 제어. 기본 False.""" @@ -602,3 +719,16 @@ def _fetch_wikipedia_summaries(entities: list[str], lang: str = "ko", max_source except Exception: continue return "\n".join(out) + + +def _strip_top_level_h1_outside_code(s: str) -> str: + """Remove top-level H1 lines (lines starting with '# ') outside of code fences. + Preserves content inside ```code fences```. + """ + if not s: + return s + parts = re.split(r'(```[\s\S]*?```)', s) + for i in range(0, len(parts), 2): + # only operate on non-code parts (even indices) + parts[i] = re.sub(r'(?m)^[ \t]*#\s+.*\n?', '', parts[i]) + return ''.join(parts)