diff --git a/db.py b/db.py index 9738a46..8a23062 100644 --- a/db.py +++ b/db.py @@ -1,14 +1,19 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker +from models import Base # ensure models are imported so metadata knows all tables +import os -DATABASE_URL = "mysql+mysqlconnector://noteflow:NoteFlow123!@localhost/noteflow" -engine = create_engine(DATABASE_URL) +DATABASE_URL = os.getenv("DATABASE_URL", "mysql+mysqlconnector://noteflow:NoteFlow123!@localhost/noteflow") + +engine = create_engine(DATABASE_URL, pool_pre_ping=True) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -# 데이터베이스 의존성 +def init_db(): + Base.metadata.create_all(bind=engine) + def get_db(): db = SessionLocal() try: yield db finally: - db.close() \ No newline at end of file + db.close() diff --git a/main.py b/main.py index a9765fd..01d5288 100644 --- a/main.py +++ b/main.py @@ -1,56 +1,48 @@ -# src/main.py +# Backend/main.py import os from dotenv import load_dotenv from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +from db import init_db from routers.auth import router as auth_router from routers.note import router as note_router -from routers.folder import router as folder_router -from fastapi.staticfiles import StaticFiles -from routers.file import router as file_router -import logging -import uvicorn +from routers.folder import router as folder_router +from routers.checklist import router as checklist_router +from routers.file import router as file_router -# 1) 환경변수 로드 -load_dotenv() +import uvicorn -# 2) 로깅 설정 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +load_dotenv() -# 3) FastAPI 앱 생성 app = FastAPI() -# 4) CORS 설정 -origins = [ - "http://localhost:5174", -] - app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=["*"], # 개발 중 전체 허용 allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) -# 5) 라우터 등록 +# 정적 파일(업로드) 서빙 +os.makedirs(os.path.join(os.path.dirname(__file__), "uploads"), exist_ok=True) +app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "uploads")), name="static") + +# 라우터 등록 app.include_router(auth_router) app.include_router(note_router) -app.include_router(folder_router) +app.include_router(folder_router) app.include_router(file_router) +app.include_router(checklist_router) -# 6) 루트 엔드포인트 @app.get("/") -def read_root(): +def root(): return {"message": "mini"} -# 7) 실행 설정 +# 앱 시작 시(uvicorn main:app) 한 번만 테이블 생성 (개발용) +init_db() + if __name__ == "__main__": - uvicorn.run( - "main:app", - host="0.0.0.0", - port=8080, - reload=True, - env_file=".env" - ) + uvicorn.run("main:app", host="0.0.0.0", port=8080, reload=True) diff --git a/models/__init__.py b/models/__init__.py index b7bb9be..114f8c4 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,8 @@ -from .user import User \ No newline at end of file +from .base import Base +from .user import User +from .folder import Folder +from .note import Note +from .file import File +from .checklist import Checklist + +__all__ = ["Base", "User", "Folder", "Note", "File", "Checklist"] diff --git a/models/base.py b/models/base.py index c1da040..59be703 100644 --- a/models/base.py +++ b/models/base.py @@ -1,2 +1,3 @@ -from sqlalchemy.ext.declarative import declarative_base -Base = declarative_base() \ No newline at end of file +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/models/checklist.py b/models/checklist.py index e7e2399..4872506 100644 --- a/models/checklist.py +++ b/models/checklist.py @@ -1,26 +1,17 @@ -from sqlalchemy import Column, Integer, String, Enum, TIMESTAMP, text +from sqlalchemy import Column, Integer, String, Boolean, TIMESTAMP, ForeignKey, text from sqlalchemy.orm import relationship +from .base import Base -class User(Base): - __tablename__ = "user" +class Checklist(Base): + __tablename__ = "checklist" - u_id = Column(Integer, primary_key=True, autoincrement=True) # PK - id = Column(String(50), unique=True, nullable=False) # 로그인 ID - email = Column(String(150), unique=True, nullable=False) - password = Column(String(255), nullable=False) - provider = Column( - Enum("local", "google", "kakao", "naver", name="provider_enum"), - nullable=False, - server_default="local", - ) - created_at = Column( - TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP") - ) - updated_at = Column( - TIMESTAMP, - nullable=False, - server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), - ) + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + title = Column(String(255), nullable=False) + is_clear = Column(Boolean, nullable=False, server_default=text("0")) + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) + updated_at = Column(TIMESTAMP, nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + onupdate=text("CURRENT_TIMESTAMP")) - # 역참조: checklist 목록 - checklists = relationship("Checklist", back_populates="user", cascade="all, delete-orphan") \ No newline at end of file + user = relationship("User", back_populates="checklists") diff --git a/models/file.py b/models/file.py index 573971b..bef1d99 100644 --- a/models/file.py +++ b/models/file.py @@ -1,14 +1,19 @@ from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class File(Base): - __tablename__ = 'file' + __tablename__ = "file" id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False) - folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True) - note_id = Column(Integer, ForeignKey('note.id', ondelete='SET NULL'), nullable=True) # ✅ 첨부된 노트 ID - original_name = Column(String(255), nullable=False) # 유저가 업로드한 원본 파일 이름 - saved_path = Column(String(512), nullable=False) # 서버에 저장된(실제) 경로 - content_type = Column(String(100), nullable=False) # MIME 타입 - created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) + note_id = Column(Integer, ForeignKey("note.id", ondelete="SET NULL"), nullable=True) + original_name = Column(String(255), nullable=False) + saved_path = Column(String(512), nullable=False) + content_type = Column(String(100), nullable=False) + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) + + # relations + user = relationship("User", back_populates="files") + note = relationship("Note", back_populates="files") diff --git a/models/folder.py b/models/folder.py index 1d6f2eb..3105616 100644 --- a/models/folder.py +++ b/models/folder.py @@ -1,14 +1,20 @@ from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class Folder(Base): - __tablename__ = 'folder' + __tablename__ = "folder" id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) name = Column(String(100), nullable=False) - parent_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True) - created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + parent_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) updated_at = Column(TIMESTAMP, nullable=False, - server_default=text('CURRENT_TIMESTAMP'), - onupdate=text('CURRENT_TIMESTAMP')) + server_default=text("CURRENT_TIMESTAMP"), + onupdate=text("CURRENT_TIMESTAMP")) + + # relations + user = relationship("User") + parent = relationship("Folder", remote_side=[id], backref="children") + notes = relationship("Note", back_populates="folder", cascade="all, delete") diff --git a/models/note.py b/models/note.py index 8ff2489..f52ecca 100644 --- a/models/note.py +++ b/models/note.py @@ -1,17 +1,23 @@ from sqlalchemy import Column, Integer, String, Text, Boolean, ForeignKey, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class Note(Base): - __tablename__ = 'note' + __tablename__ = "note" id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False) - folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True) + user_id = Column(Integer, ForeignKey("user.u_id", ondelete="CASCADE"), nullable=False) + folder_id = Column(Integer, ForeignKey("folder.id", ondelete="SET NULL"), nullable=True) title = Column(String(255), nullable=False) content = Column(Text) - is_favorite = Column(Boolean, nullable=False, server_default=text('FALSE')) + is_favorite = Column(Boolean, nullable=False, server_default=text("FALSE")) last_accessed = Column(TIMESTAMP, nullable=True) - created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) updated_at = Column(TIMESTAMP, nullable=False, - server_default=text('CURRENT_TIMESTAMP'), - onupdate=text('CURRENT_TIMESTAMP')) + server_default=text("CURRENT_TIMESTAMP"), + onupdate=text("CURRENT_TIMESTAMP")) + + # relations + user = relationship("User", back_populates="notes") + folder = relationship("Folder", back_populates="notes") + files = relationship("File", back_populates="note", cascade="all, delete") diff --git a/models/user.py b/models/user.py index 3811b46..96f341f 100644 --- a/models/user.py +++ b/models/user.py @@ -1,19 +1,22 @@ from sqlalchemy import Column, Integer, String, Enum, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class User(Base): - __tablename__ = 'user' + __tablename__ = "user" u_id = Column(Integer, primary_key=True, autoincrement=True) - id = Column(String(50), nullable=False, unique=True) + id = Column(String(50), nullable=False, unique=True) # 로그인 ID 또는 소셜 ID email = Column(String(150), nullable=False, unique=True) password = Column(String(255), nullable=False) - provider = Column( - Enum('local','google','kakao','naver', name='provider_enum'), - nullable=False, - server_default=text("'local'") - ) - created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + provider = Column(Enum("local", "google", "kakao", "naver", name="provider_enum"), + nullable=False, server_default=text("'local'")) + created_at = Column(TIMESTAMP, nullable=False, server_default=text("CURRENT_TIMESTAMP")) updated_at = Column(TIMESTAMP, nullable=False, - server_default=text('CURRENT_TIMESTAMP'), - onupdate=text('CURRENT_TIMESTAMP')) + server_default=text("CURRENT_TIMESTAMP"), + onupdate=text("CURRENT_TIMESTAMP")) + + # relations + notes = relationship("Note", back_populates="user", cascade="all, delete-orphan") + files = relationship("File", back_populates="user", cascade="all, delete-orphan") + checklists = relationship("Checklist", back_populates="user", cascade="all, delete-orphan") diff --git a/routers/__init__.py b/routers/__init__.py index 0e2ebc8..1cdd000 100644 --- a/routers/__init__.py +++ b/routers/__init__.py @@ -1,5 +1,7 @@ from .auth import router as auth_router +from .checklist import router as checklist_router routers = [ auth_router, + checklist_router ] \ No newline at end of file diff --git a/routers/checklist.py b/routers/checklist.py index 26a41b1..376d0a4 100644 --- a/routers/checklist.py +++ b/routers/checklist.py @@ -1,47 +1,26 @@ -# 변경/설명: -# - POST /checklists : 생성 전용 -# - PATCH /checklists/{id}/clear : is_clear 0/1 설정 -# - get_current_user는 user.u_id 제공 가정 +# Backend/routers/checklist.py from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session +from db import get_db +from models.checklist import Checklist +from utils.jwt_utils import get_current_user -from app.schemas.checklist import ChecklistCreate, ChecklistClearUpdate -from app.models import Checklist # Checklist(u_id, checklist_title, is_clear) -from app.dependencies import get_db, get_current_user - -router = APIRouter(prefix="/checklists", tags=["checklists"]) +router = APIRouter(prefix="/api/v1/checklists", tags=["Checklists"]) @router.post("", status_code=status.HTTP_201_CREATED) -def create_checklist( - req: ChecklistCreate, - db: Session = Depends(get_db), - user=Depends(get_current_user), -): - obj = Checklist( - u_id=user.u_id, # ← 프로젝트의 사용자 키에 맞게 - checklist_title=req.checklist_title, - is_clear=0 # 기본 0(미완) - ) +def create_checklist(title: str, db: Session = Depends(get_db), user=Depends(get_current_user)): + obj = Checklist(user_id=user.u_id, title=title, is_clear=False) db.add(obj) db.commit() db.refresh(obj) - return {"id": obj.id, "checklist_title": obj.checklist_title, "is_clear": obj.is_clear} + return {"id": obj.id, "title": obj.title, "is_clear": obj.is_clear} @router.patch("/{checklist_id}/clear") -def set_clear_state( - checklist_id: int, - req: ChecklistClearUpdate, # {"is_clear": 0 | 1} - db: Session = Depends(get_db), - user=Depends(get_current_user), -): - obj = ( - db.query(Checklist) - .filter(Checklist.id == checklist_id, Checklist.u_id == user.u_id) - .first() - ) +def set_clear_state(checklist_id: int, is_clear: bool, db: Session = Depends(get_db), user=Depends(get_current_user)): + obj = db.query(Checklist).filter(Checklist.id == checklist_id, Checklist.user_id == user.u_id).first() if not obj: raise HTTPException(status_code=404, detail="Checklist not found") - obj.is_clear = int(req.is_clear) # 0/1 저장 + obj.is_clear = bool(is_clear) db.commit() db.refresh(obj) - return {"id": obj.id, "is_clear": obj.is_clear} \ No newline at end of file + return {"id": obj.id, "is_clear": obj.is_clear} diff --git a/routers/file.py b/routers/file.py index 602d489..6f37d6c 100644 --- a/routers/file.py +++ b/routers/file.py @@ -1,4 +1,5 @@ import os +import re from datetime import datetime from typing import Optional, List @@ -11,29 +12,24 @@ from models.note import Note as NoteModel from utils.jwt_utils import get_current_user -# 추가/변경: 공통 OCR 파이프라인(thin wrapper) +# 공통 OCR 파이프라인 from utils.ocr import run_pipeline, detect_type from schemas.file import OCRResponse -# 추가: 허용 확장자 상수 (불일치 시 200 + warnings 응답) +# 허용 확장자 ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"} ALLOWED_PDF_EXTS = {".pdf"} ALLOWED_DOC_EXTS = {".doc", ".docx"} ALLOWED_HWP_EXTS = {".hwp"} -ALLOWED_ALL_EXTS = ( - ALLOWED_IMAGE_EXTS | ALLOWED_PDF_EXTS | ALLOWED_DOC_EXTS | ALLOWED_HWP_EXTS -) +ALLOWED_ALL_EXTS = (ALLOWED_IMAGE_EXTS | ALLOWED_PDF_EXTS | ALLOWED_DOC_EXTS | ALLOWED_HWP_EXTS) -# 업로드 디렉토리 설정 -BASE_UPLOAD_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "uploads" -) +# 업로드 디렉토리 +BASE_UPLOAD_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "uploads") os.makedirs(BASE_UPLOAD_DIR, exist_ok=True) router = APIRouter(prefix="/api/v1/files", tags=["Files"]) + @router.get("/ocr/diag", summary="OCR 런타임 의존성 진단") def ocr_dependency_diag(): import shutil, subprocess @@ -67,6 +63,7 @@ def run(cmd: list[str]): "hwp5txt": hwp5txt_ok, } + @router.post( "/upload", summary="폴더/노트에 파일 업로드 (note_id 있으면 노트 본문에도 삽입)", @@ -86,7 +83,7 @@ async def upload_file( user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id)) os.makedirs(user_dir, exist_ok=True) - # 원본 파일명 그대로 저장 (중복 시 _1, _2 붙임) + # 원본 파일명 유지 (중복 방지) saved_filename = orig_filename saved_path = os.path.join(user_dir, saved_filename) if os.path.exists(saved_path): @@ -101,7 +98,7 @@ async def upload_file( break counter += 1 - # 파일 저장 + # 저장 try: with open(saved_path, "wb") as buffer: content = await upload_file.read() @@ -120,7 +117,7 @@ async def upload_file( if not note_obj: raise HTTPException(status_code=404, detail="해당 노트를 찾을 수 없습니다.") - # DB에 메타데이터 기록 + # DB 메타 기록 new_file = FileModel( user_id=current_user.u_id, folder_id=None if note_id else folder_id, @@ -136,7 +133,7 @@ async def upload_file( base_url = os.getenv("BASE_API_URL", "http://localhost:8000") download_url = f"{base_url}/api/v1/files/download/{new_file.id}" - # note_id가 있으면 content에도 삽입 + # note_id가 있으면 노트 본문에 첨부 링크 삽입 if note_obj: if content_type.startswith("image/"): embed = f"\n\n![{new_file.original_name}]({download_url})\n\n" @@ -213,6 +210,36 @@ def download_file( ) +# --------------------------- +# 언어코드 유연 처리 유틸 +# --------------------------- +LANG_ALIAS = { + "ko": "kor", "kr": "kor", "korean": "kor", + "en": "eng", "english": "eng" +} +def normalize_langs(raw: str) -> str: + """ + 입력 예: 'koreng', 'kor,eng', 'ko+en', 'ko,en', 'korean+english' + 출력 예: 'kor+eng' + """ + if not raw: + return "kor+eng" + s = raw.strip().lower().replace(" ", "") + s = re.sub(r"[,_;]+", "+", s) + if "+" not in s: + s = s.replace("koreng", "kor+eng").replace("koen", "ko+en") + parts = [p for p in s.split("+") if p] + norm: list[str] = [] + for p in parts: + norm.append(LANG_ALIAS.get(p, p)) + # 중복 제거(순서 보존) + seen, out = set(), [] + for p in norm: + if p not in seen: + seen.add(p); out.append(p) + return "+".join(out) if out else "kor+eng" + + @router.post( "/ocr", summary="이미지/PDF/DOC/DOCX/HWP OCR → 텍스트 변환 후 노트 생성", @@ -222,18 +249,23 @@ async def ocr_and_create_note( file: Optional[UploadFile] = File(None, description="기본 업로드 필드명"), ocr_file: Optional[UploadFile] = File(None, description="과거 호환 업로드 필드명"), folder_id: Optional[int] = Form(None), - langs: str = Query("kor+eng", description="Tesseract 언어코드(예: kor+eng)"), + langs: str = Query("kor+eng", description="Tesseract 언어코드(유연 입력 허용: koreng, ko+en 등)"), max_pages: int = Query(50, ge=1, le=500, description="최대 처리 페이지 수(기본 50)"), db: Session = Depends(get_db), current_user = Depends(get_current_user) ): + # 422 방지: 파일 필드명 유연 처리 upload = file or ocr_file if upload is None: - raise HTTPException(status_code=400, detail="업로드 파일이 필요합니다. 필드명은 'file' 또는 'ocr_file'을 사용하세요.") + raise HTTPException( + status_code=400, + detail="업로드 파일이 필요합니다. 필드명은 'file' 또는 'ocr_file'을 사용하세요." + ) filename = upload.filename or "uploaded" mime = upload.content_type + # 확장자 검사 _, ext = os.path.splitext(filename) ext = ext.lower() if ext and ext not in ALLOWED_ALL_EXTS: @@ -247,6 +279,7 @@ async def ocr_and_create_note( text=None, ) + # 타입 판별 ftype = detect_type(filename, mime) if ftype == "unknown": return OCRResponse( @@ -261,6 +294,10 @@ async def ocr_and_create_note( data = await upload.read() + # 언어코드 정규화 + langs = normalize_langs(langs) + + # 멀티엔진 앙상블 OCR 파이프라인 실행 pipe = run_pipeline( filename=filename, mime=mime, @@ -269,6 +306,7 @@ async def ocr_and_create_note( max_pages=max_pages, ) + # 페이지 텍스트 합치기 merged_text = "\n\n".join([ item.get("text", "") for item in (pipe.get("results") or []) if item.get("text") ]).strip() @@ -365,6 +403,7 @@ async def upload_audio_and_transcribe( "transcript": transcript } + @router.options("/ocr") def ocr_cors_preflight() -> Response: return Response(status_code=200) diff --git a/utils/jwt_utils.py b/utils/jwt_utils.py index 87e16c2..19fd64c 100644 --- a/utils/jwt_utils.py +++ b/utils/jwt_utils.py @@ -1,5 +1,3 @@ -# utils/jwt_utils.py - import os from datetime import datetime, timedelta, timezone from jose import JWTError, jwt @@ -10,11 +8,12 @@ from models.user import User from db import get_db -# 환경변수에서 불러오기 (없으면 기본 60분) -SECRET_KEY = os.getenv("SECRET_KEY", "change_this_in_production") -ALGORITHM = "HS256" +# 환경 변수 기반 설정 +SECRET_KEY = os.getenv("SECRET_KEY", "change_this_in_production") +ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "60")) +# OAuth2 스킴 (Bearer 토큰) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login") @@ -26,23 +25,15 @@ def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt(rounds=12)).decode() -def create_access_token( - user_id: int, - expires_delta: timedelta | None = None -) -> str: +def create_access_token(user_id: int, expires_delta: timedelta | None = None) -> str: """ - 유저 ID를 sub 페이로드로 담아 JWT 생성. - expires_delta 없으면 ACCESS_TOKEN_EXPIRE_MINUTES 환경변수 사용. + 유저 ID를 sub 페이로드로 담아 JWT 생성 """ if expires_delta is None: expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) expire = datetime.now(timezone.utc) + expires_delta - payload = { - "sub": str(user_id), - "exp": expire - } - token = jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) - return token + payload = {"sub": str(user_id), "exp": expire} + return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM) async def get_current_user( @@ -50,7 +41,7 @@ async def get_current_user( db: Session = Depends(get_db) ) -> User: """ - Authorization: Bearer + Authorization: Bearer → User 반환 """ credentials_exception = HTTPException( status_code=401, diff --git a/utils/ocr/__init__.py b/utils/ocr/__init__.py index 5db369a..fc8447c 100644 --- a/utils/ocr/__init__.py +++ b/utils/ocr/__init__.py @@ -1,15 +1,6 @@ -""" -utils.ocr 패키지 - -추가/변경 요약 -- 공통 OCR 파이프라인 진입점(run_pipeline)을 외부에 노출 -- 이미지/PDF/DOC/DOCX/HWP를 단일 인터페이스로 처리 -""" - from .ocr_core import run_pipeline, detect_type __all__ = [ "run_pipeline", "detect_type", ] - diff --git a/utils/ocr/converters.py b/utils/ocr/converters.py index 7455d16..7b3f0ff 100644 --- a/utils/ocr/converters.py +++ b/utils/ocr/converters.py @@ -1,11 +1,3 @@ -""" -utils/ocr/converters.py - -추가/변경 -- PDF/DOC/DOCX/HWP를 파이프라인에서 재사용할 수 있도록 변환/추출 유틸 제공 -- 외부 의존(soffice, hwp5txt)이 없을 수 있으므로 항상 예외를 던지지 말고 상위에서 warnings에 기록 -""" - from __future__ import annotations import os @@ -18,9 +10,7 @@ def save_bytes_to_temp(data: bytes, suffix: str = "") -> str: - """바이트를 임시 파일로 저장하고 경로를 반환. - 호출자가 삭제를 책임짐. - """ + """바이트를 임시 파일로 저장하고 경로를 반환 (호출자가 삭제 책임).""" fd, path = tempfile.mkstemp(suffix=suffix) with os.fdopen(fd, "wb") as f: f.write(data) @@ -28,8 +18,9 @@ def save_bytes_to_temp(data: bytes, suffix: str = "") -> str: def pdf_to_images(pdf_path: str, dpi: int = 200) -> List[Image.Image]: - """pdf2image.convert_from_path로 PDF를 PIL 이미지 리스트로 변환. - 주: 시스템에 poppler가 필요할 수 있음. + """ + pdf2image.convert_from_path로 PDF를 PIL 이미지 리스트로 변환. + (poppler 설치 필요) """ from pdf2image import convert_from_path # 지연 임포트 images = convert_from_path(pdf_path, dpi=dpi) @@ -37,7 +28,8 @@ def pdf_to_images(pdf_path: str, dpi: int = 200) -> List[Image.Image]: def office_to_pdf(input_path: str, outdir: str) -> str: - """LibreOffice(soffice)를 사용하여 DOC/DOCX를 PDF로 변환. + """ + LibreOffice(soffice)를 사용하여 DOC/DOCX를 PDF로 변환. 반환: 변환된 PDF 경로 실패 시 예외 발생(상위에서 warnings 처리) """ @@ -45,20 +37,10 @@ def office_to_pdf(input_path: str, outdir: str) -> str: if not soffice: raise RuntimeError("LibreOffice(soffice) 실행 파일을 찾을 수 없습니다.") - cmd = [ - soffice, - "--headless", - "--convert-to", - "pdf", - "--outdir", - outdir, - input_path, - ] + cmd = [soffice, "--headless", "--convert-to", "pdf", "--outdir", outdir, input_path] proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if proc.returncode != 0: - raise RuntimeError( - f"LibreOffice 변환 실패: {proc.stderr.decode(errors='ignore')[:300]}" - ) + raise RuntimeError(f"LibreOffice 변환 실패: {proc.stderr.decode(errors='ignore')[:300]}") base = os.path.splitext(os.path.basename(input_path))[0] pdf_path = os.path.join(outdir, f"{base}.pdf") @@ -73,16 +55,14 @@ def office_to_pdf(input_path: str, outdir: str) -> str: def hwp_to_text(input_path: str) -> str: - """hwp5txt(또는 pyhwp)로 HWP 텍스트를 추출. - 주: hwp5txt CLI가 설치되어 있어야 함. 없으면 예외. + """ + hwp5txt(또는 pyhwp)로 HWP 텍스트 추출. + hwp5txt CLI가 없으면 예외. """ hwp5txt = shutil.which("hwp5txt") if not hwp5txt: raise RuntimeError("hwp5txt 실행 파일을 찾을 수 없습니다.") proc = subprocess.run([hwp5txt, input_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE) if proc.returncode != 0: - raise RuntimeError( - f"hwp5txt 추출 실패: {proc.stderr.decode(errors='ignore')[:300]}" - ) + raise RuntimeError(f"hwp5txt 추출 실패: {proc.stderr.decode(errors='ignore')[:300]}") return proc.stdout.decode(errors="ignore") - diff --git a/utils/ocr/ocr_core.py b/utils/ocr/ocr_core.py index c15a315..14fba00 100644 --- a/utils/ocr/ocr_core.py +++ b/utils/ocr/ocr_core.py @@ -1,22 +1,19 @@ """ utils/ocr/ocr_core.py -추가/변경 -- 파일 타입 판별(확장자 우선, MIME 보조) 및 통합 OCR 파이프라인(run_pipeline) 구현 -- 이미지: pytesseract 기본 OCR, (기존) EasyOCR/TrOCR는 가능 시 보조로 시도하여 최적 텍스트 선택 -- PDF: pdf2image(convert_from_path, dpi=200)로 페이지 이미지를 생성하여 페이지별 OCR -- DOC/DOCX: LibreOffice(soffice --headless)로 PDF로 변환 후 PDF 파이프라인 재사용 -- HWP: hwp5txt로 텍스트 추출(성공 시 page=1로 results에 추가), 실패 시 warnings 기록 -- 대용량 제어: MAX_PAGES(기본 50)까지 처리하고 잘린 경우 warnings 기록 -- 예외는 raise하지 않고 results=[], warnings로 사유를 담아 상위가 200으로 응답할 수 있게 함 +- 파일 타입 판별(detect_type) +- 통합 OCR 파이프라인(run_pipeline) +- 이미지: Tesseract + EasyOCR + HuggingFace TrOCR 조합 (긴 텍스트 선택) +- PDF: PyMuPDF → 네이티브 텍스트 추출, 없으면 이미지 렌더링 후 OCR +- DOCX: python-docx, DOC: LibreOffice 변환 후 OCR +- HWP: hwp5txt 사용 """ from __future__ import annotations import io import os -from typing import Dict, List, Optional, Tuple - +from typing import Dict, List, Optional from PIL import Image from .converters import ( @@ -26,70 +23,69 @@ hwp_to_text, ) - -# 지원 확장자 세트 +# 지원 확장자 IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"} PDF_EXTS = {".pdf"} DOC_EXTS = {".doc", ".docx"} HWP_EXTS = {".hwp"} +# ───────────────────────────────────────────── +# 타입 판별 +# ───────────────────────────────────────────── def detect_type(filename: str, content_type: Optional[str]) -> str: - """확장자 기반 타입 판별, MIME은 보조. - 반환: "image" | "pdf" | "docx" | "hwp" | "unknown" - """ + """확장자 기반 타입 판별, MIME은 보조.""" ext = os.path.splitext(filename or "")[1].lower() if ext in IMAGE_EXTS: return "image" if ext in PDF_EXTS: return "pdf" if ext in DOC_EXTS: - return "docx" # 내부적으로 DOC/DOCX를 동일 경로로 처리 + return "docx" if ext in HWP_EXTS: return "hwp" - # MIME 보조 판단(간단히) if content_type: ct = content_type.lower() if ct.startswith("image/"): return "image" if ct == "application/pdf": return "pdf" - if ct in ("application/msword", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"): + if ct in ( + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ): return "docx" if "hwp" in ct: return "hwp" return "unknown" +# ───────────────────────────────────────────── +# OCR Backends +# ───────────────────────────────────────────── def _ocr_image_pytesseract(img: Image.Image, langs: str, warnings: List[str]) -> str: - """pytesseract를 사용하여 이미지에서 텍스트 추출. - 주: 시스템에 tesseract OCR 엔진 및 언어 데이터가 설치되어 있어야 함. - """ + """pytesseract OCR""" try: import pytesseract - text = pytesseract.image_to_string(img, lang=langs) - return text.strip() + return pytesseract.image_to_string(img, lang=langs).strip() except Exception as e: warnings.append(f"pytesseract OCR 실패: {e}") return "" def _ocr_image_legacy(img: Image.Image, warnings: List[str]) -> str: - """기존 이미지 OCR(EasyOCR + TrOCR) 로직 재사용. - - 환경/의존성에 따라 실패할 수 있으므로 예외는 warnings에만 기록. - - 기존 구현과 동일하게 가장 긴 텍스트를 선택. - """ + """EasyOCR + HuggingFace TrOCR""" try: import numpy as np import easyocr from transformers import pipeline except Exception as e: - warnings.append(f"기존 OCR 모듈(EasyOCR/TrOCR) 사용 불가: {e}") + warnings.append(f"EasyOCR/TrOCR import 실패: {e}") return "" + # EasyOCR try: - # EasyOCR reader = easyocr.Reader(["ko", "en"], gpu=False) image_np = np.array(img.convert("RGB")) easy_results = reader.readtext(image_np) @@ -98,44 +94,35 @@ def _ocr_image_legacy(img: Image.Image, warnings: List[str]) -> str: warnings.append(f"EasyOCR 실패: {e}") easy_text = "" + # HuggingFace TrOCR hf_texts: List[str] = [] - try: - for model_name in ( - "microsoft/trocr-base-printed", - "microsoft/trocr-base-handwritten", - "microsoft/trocr-small-printed", - "microsoft/trocr-large-printed", - ): - try: - pipe = pipeline("image-to-text", model=model_name, trust_remote_code=True) - out = pipe(img) - if isinstance(out, list) and out and isinstance(out[0], dict) and "generated_text" in out[0]: - hf_texts.append(out[0]["generated_text"].strip()) - except Exception as e: - warnings.append(f"TrOCR({model_name}) 실패: {e}") - except Exception as e: - warnings.append(f"TrOCR 파이프라인 초기화 실패: {e}") - - candidates = [t for t in [easy_text] + hf_texts if t and t.strip()] - if not candidates: - return "" - return max(candidates, key=len) + for model_name in [ + "microsoft/trocr-base-printed", + "microsoft/trocr-base-handwritten", + ]: + try: + pipe = pipeline("image-to-text", model=model_name, trust_remote_code=True) + out = pipe(img) + if isinstance(out, list) and out and "generated_text" in out[0]: + hf_texts.append(out[0]["generated_text"].strip()) + except Exception as e: + warnings.append(f"TrOCR({model_name}) 실패: {e}") + + candidates = [t for t in [easy_text] + hf_texts if t] + return max(candidates, key=len) if candidates else "" def _ocr_image_best(img: Image.Image, langs: str, warnings: List[str]) -> str: - """ - 변경(모델 우선): 기존(EasyOCR/TrOCR) → pytesseract 순으로 시도하고 더 긴 텍스트 선택. - - 서버에 Tesseract가 없어도 동작하도록 모델 기반 경로를 우선. - """ - legacy_text = _ocr_image_legacy(img, warnings) - tesseract_text = _ocr_image_pytesseract(img, langs, warnings) - - candidates = [t for t in [legacy_text, tesseract_text] if t] - if not candidates: - return "" - return max(candidates, key=len) + """모든 OCR 경로 실행 후 가장 긴 텍스트 선택""" + legacy = _ocr_image_legacy(img, warnings) + tess = _ocr_image_pytesseract(img, langs, warnings) + candidates = [t for t in [legacy, tess] if t] + return max(candidates, key=len) if candidates else "" +# ───────────────────────────────────────────── +# Main pipeline +# ───────────────────────────────────────────── def run_pipeline( filename: str, mime: Optional[str], @@ -143,18 +130,16 @@ def run_pipeline( langs: str = "kor+eng", max_pages: int = 50, ) -> Dict: - """공통 OCR 파이프라인 - - 반환 JSON 스키마: + """ + OCR 실행 파이프라인 + 반환 JSON: { "filename": str, - "mime": str | null, + "mime": str, "page_count": int, "results": [{"page": int, "text": str}], "warnings": [str] } - - 예외는 raise하지 않고 warnings에만 기록 후 results를 비워서 반환. """ warnings: List[str] = [] results: List[Dict] = [] @@ -164,155 +149,102 @@ def run_pipeline( try: if ftype == "image": - # 단일 이미지 → 페이지 1로 간주 try: img = Image.open(io.BytesIO(data)).convert("RGB") except Exception as e: warnings.append(f"이미지 열기 실패: {e}") img = None - - if img is not None: + if img: text = _ocr_image_best(img, langs, warnings) + results.append({"page": 1, "text": text}) page_count = 1 - results.append({"page": 1, "text": text or ""}) elif ftype == "pdf": - # 변경: PyMuPDF(fitz) 우선 사용 → 네이티브 텍스트, 없으면 렌더링 후 모델 OCR - images: List[Image.Image] = [] try: import fitz # PyMuPDF doc = fitz.open(stream=data, filetype="pdf") total = doc.page_count - if total > max_pages: - warnings.append(f"페이지가 {max_pages}장을 초과하여 앞 {max_pages}페이지만 처리합니다.") limit = min(total, max_pages) + if total > max_pages: + warnings.append(f"앞 {max_pages}페이지만 처리") for i in range(limit): page = doc.load_page(i) txt = (page.get_text("text") or "").strip() if txt: results.append({"page": i + 1, "text": txt}) else: - # 이미지 렌더링 후 모델 OCR - try: - mat = fitz.Matrix(2, 2) # ~144 DPI 정도 - pix = page.get_pixmap(matrix=mat) - mode = "RGBA" if pix.alpha else "RGB" - img = Image.frombytes(mode, [pix.width, pix.height], pix.samples) - if mode == "RGBA": - img = img.convert("RGB") - images.append(img) - except Exception as e: - warnings.append(f"PDF 페이지 렌더링 실패(page {i+1}): {e}") + pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) + mode = "RGBA" if pix.alpha else "RGB" + img = Image.frombytes(mode, [pix.width, pix.height], pix.samples) + if mode == "RGBA": + img = img.convert("RGB") + t = _ocr_image_best(img, langs, warnings) + results.append({"page": i + 1, "text": t}) page_count = limit except Exception as e: - warnings.append(f"PyMuPDF 처리 실패: {e}") - # 대체 경로: pdf2image(poppler 필요) + warnings.append(f"PyMuPDF 실패: {e}") + # fallback → pdf2image pdf_path = save_bytes_to_temp(data, suffix=".pdf") try: images = pdf_to_images(pdf_path, dpi=200) - except Exception as ee: - warnings.append(f"PDF를 이미지로 변환 실패: {ee}") - images = [] + if len(images) > max_pages: + warnings.append(f"앞 {max_pages}페이지만 처리") + images = images[:max_pages] + for idx, img in enumerate(images, 1): + t = _ocr_image_best(img, langs, warnings) + results.append({"page": idx, "text": t}) + page_count = len(images) finally: try: os.remove(pdf_path) - except Exception: + except: pass - total = len(images) - if total > max_pages: - warnings.append(f"페이지가 {max_pages}장을 초과하여 앞 {max_pages}페이지만 처리합니다.") - images = images[:max_pages] - page_count = len(images) - - # 이미지에 대해 모델 OCR 수행 (필요한 페이지만) - for idx, img in enumerate(images, start=1): - text = _ocr_image_best(img, langs, warnings) - results.append({"page": idx, "text": text or ""}) elif ftype == "docx": - # 변경: .docx는 python-docx로 네이티브 텍스트 추출 우선, .doc는 LibreOffice 변환 ext = os.path.splitext(filename or "")[1].lower() if ext == ".docx": try: - from docx import Document # python-docx + from docx import Document doc = Document(io.BytesIO(data)) - paras = [] - for p in doc.paragraphs: - if p.text: - paras.append(p.text) + paras = [p.text for p in doc.paragraphs if p.text] text = "\n".join(paras).strip() - if text: - results.append({"page": 1, "text": text}) - page_count = 1 - else: - warnings.append("DOCX에서 추출된 텍스트가 없습니다.") + results.append({"page": 1, "text": text}) + page_count = 1 except Exception as e: - warnings.append(f"python-docx 처리 실패: {e}") + warnings.append(f"python-docx 실패: {e}") else: - # 구형 .doc → LibreOffice로 PDF 변환 후 OCR - in_path = save_bytes_to_temp(data, suffix=ext or ".doc") - outdir = os.path.dirname(in_path) - pdf_path: Optional[str] = None + in_path = save_bytes_to_temp(data, suffix=".doc") try: - pdf_path = office_to_pdf(in_path, outdir) - # PDF 처리 동일 (PyMuPDF 경로 우선) - try: - import fitz - doc = fitz.open(pdf_path) - total = doc.page_count - if total > max_pages: - warnings.append(f"페이지가 {max_pages}장을 초과하여 앞 {max_pages}페이지만 처리합니다.") - limit = min(total, max_pages) - for i in range(limit): - page = doc.load_page(i) - txt = (page.get_text("text") or "").strip() - if txt: - results.append({"page": i + 1, "text": txt}) - else: - mat = fitz.Matrix(2, 2) - pix = page.get_pixmap(matrix=mat) - mode = "RGBA" if pix.alpha else "RGB" - img = Image.frombytes(mode, [pix.width, pix.height], pix.samples) - if mode == "RGBA": - img = img.convert("RGB") - t = _ocr_image_best(img, langs, warnings) - results.append({"page": i + 1, "text": t or ""}) - page_count = limit - except Exception as e: - warnings.append(f"DOC→PDF 처리 후 읽기 실패: {e}") + pdf_path = office_to_pdf(in_path, os.path.dirname(in_path)) + with open(pdf_path, "rb") as f: + pdf_bytes = f.read() + return run_pipeline(os.path.basename(pdf_path), "application/pdf", pdf_bytes, langs, max_pages) except Exception as e: warnings.append(f"DOC 변환 실패: {e}") finally: try: os.remove(in_path) - except Exception: + except: pass - if pdf_path: - try: - os.remove(pdf_path) - except Exception: - pass elif ftype == "hwp": - # HWP → hwp5txt 1차 시도. 성공 시 page=1 in_path = save_bytes_to_temp(data, suffix=".hwp") try: text = hwp_to_text(in_path) - results.append({"page": 1, "text": (text or "").strip()}) + results.append({"page": 1, "text": text.strip()}) page_count = 1 except Exception as e: - warnings.append(f"HWP 텍스트 추출 실패: {e}") + warnings.append(f"HWP 처리 실패: {e}") finally: try: os.remove(in_path) - except Exception: + except: pass else: - warnings.append("지원되지 않는 파일 형식입니다.") + warnings.append("지원되지 않는 파일 형식") except Exception as e: - # 상위에서 200으로 내려줄 수 있도록 전체 예외 흡수 warnings.append(f"파이프라인 실행 오류: {e}") return {