From f770ff807668fedec92b64c5799a7715b602fc75 Mon Sep 17 00:00:00 2001 From: Junseo1026 Date: Mon, 9 Jun 2025 23:03:48 +0900 Subject: [PATCH 1/2] feat/1st_presentation --- routers/file.py | 68 +++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/routers/file.py b/routers/file.py index d530446..9f5c94d 100644 --- a/routers/file.py +++ b/routers/file.py @@ -1,5 +1,4 @@ -# ~/noteflow/Backend/routers/file.py - +# routers/file.py import os import io import whisper @@ -7,7 +6,6 @@ from datetime import datetime import numpy as np from typing import Optional, List -from urllib.parse import quote from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, status from fastapi.responses import FileResponse @@ -19,6 +17,9 @@ from models.note import Note as NoteModel from utils.jwt_utils import get_current_user +# 추가: 파일명 인코딩용 +import urllib.parse + # ------------------------------- # 1) EasyOCR 라이브러리 임포트 (GPU 모드 활성화) # ------------------------------- @@ -55,7 +56,6 @@ trust_remote_code=True ) -# 업로드 디렉토리 설정 BASE_UPLOAD_DIR = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", @@ -80,11 +80,9 @@ async def upload_file( orig_filename: str = upload_file.filename or "unnamed" content_type: str = upload_file.content_type or "application/octet-stream" - # 사용자별 디렉토리 생성 user_dir = os.path.join(BASE_UPLOAD_DIR, str(current_user.u_id)) os.makedirs(user_dir, exist_ok=True) - # 원본 파일명 그대로 저장 (동명이인 방지) saved_filename = orig_filename saved_path = os.path.join(user_dir, saved_filename) if os.path.exists(saved_path): @@ -99,7 +97,6 @@ async def upload_file( break counter += 1 - # 파일 저장 try: with open(saved_path, "wb") as buffer: content = await upload_file.read() @@ -107,7 +104,6 @@ async def upload_file( except Exception as e: raise HTTPException(status_code=500, detail=f"파일 저장 실패: {e}") - # DB에 메타데이터 기록 new_file = FileModel( user_id=current_user.u_id, folder_id=folder_id, @@ -177,9 +173,9 @@ def download_file( if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="서버에 파일이 존재하지 않습니다.") - # original_name 을 percent-encoding 해서 ASCII 만으로 헤더 구성 - filename_quoted = quote(file_obj.original_name) - content_disposition = f"inline; filename*=UTF-8''{filename_quoted}" + # 원본 파일명 UTF-8 URL 인코딩 처리 + quoted_name = urllib.parse.quote(file_obj.original_name, safe='') + content_disposition = f"inline; filename*=UTF-8''{quoted_name}" return FileResponse( path=file_path, @@ -200,52 +196,64 @@ async def ocr_and_create_note( current_user = Depends(get_current_user) ): """ - • EasyOCR + TrOCR 모델로 이미지에서 텍스트 추출 - • 가장 긴 결과를 선택해 새 노트로 저장 + • ocr_file: 이미지 파일(UploadFile) + • 1) EasyOCR로 기본 텍스트 추출 (GPU 모드) + • 2) TrOCR 4개 모델로 OCR 수행 (모두 GPU) + • 3) 가장 긴 결과를 최종 OCR 결과로 선택 + • 4) Note로 저장 및 결과 반환 """ - # 1) 이미지 로드 + + # 1) 이미지 로드 (PIL) contents = await ocr_file.read() try: image = Image.open(io.BytesIO(contents)).convert("RGB") except Exception as e: raise HTTPException(status_code=400, detail=f"이미지 처리 실패: {e}") - # 2) EasyOCR + # 2) EasyOCR로 텍스트 추출 try: image_np = np.array(image) - easy_results = reader.readtext(image_np) + easy_results = reader.readtext(image_np) # GPU 모드 사용 easy_text = " ".join([res[1] for res in easy_results]) except Exception: easy_text = "" - # 3) TrOCR 4개 모델 + # 3) TrOCR 모델 4개로 OCR 수행 (모두 GPU input) hf_texts: List[str] = [] try: - for pipe in ( - hf_trocr_printed, - hf_trocr_handwritten, - hf_trocr_small_printed, - hf_trocr_large_printed - ): - out = pipe(image) - if isinstance(out, list) and "generated_text" in out[0]: - hf_texts.append(out[0]["generated_text"].strip()) + out1 = hf_trocr_printed(image) + if isinstance(out1, list) and "generated_text" in out1[0]: + hf_texts.append(out1[0]["generated_text"].strip()) + + out2 = hf_trocr_handwritten(image) + if isinstance(out2, list) and "generated_text" in out2[0]: + hf_texts.append(out2[0]["generated_text"].strip()) + + out3 = hf_trocr_small_printed(image) + if isinstance(out3, list) and "generated_text" in out3[0]: + hf_texts.append(out3[0]["generated_text"].strip()) + + out4 = hf_trocr_large_printed(image) + if isinstance(out4, list) and "generated_text" in out4[0]: + hf_texts.append(out4[0]["generated_text"].strip()) except Exception: + # TrOCR 중 오류 발생 시 무시하고 계속 진행 pass - # 4) 가장 긴 결과 선택 + # 4) 여러 OCR 결과 병합: 가장 긴 문자열을 최종 ocr_text로 선택 candidates = [t for t in [easy_text] + hf_texts if t and t.strip()] if not candidates: raise HTTPException(status_code=500, detail="텍스트를 인식할 수 없습니다.") - ocr_text = max(candidates, key=len) - # 5) Note 생성 + ocr_text = max(candidates, key=lambda s: len(s)) + + # 5) 새 노트 생성 및 DB에 저장 try: new_note = NoteModel( user_id=current_user.u_id, folder_id=folder_id, title="OCR 결과", - content=ocr_text + content=ocr_text # **원본 OCR 텍스트만 저장** ) db.add(new_note) db.commit() From 57517fa970e99786c56d7d1500eff225d52bc004 Mon Sep 17 00:00:00 2001 From: Junseo1026 Date: Sun, 26 Oct 2025 18:19:32 +0900 Subject: [PATCH 2/2] 1026 18:19 --- main.py | 5 +++-- models/base.py | 3 ++- models/file.py | 14 ++++++++++---- models/folder.py | 7 +++++++ models/note.py | 6 ++++++ models/user.py | 14 ++++++++++---- 6 files changed, 38 insertions(+), 11 deletions(-) diff --git a/main.py b/main.py index a9765fd..638f37e 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,8 @@ # src/main.py import os from dotenv import load_dotenv +# 환경 변수를 최대한 빨리 로드하여 GPU 설정(CUDA_VISIBLE_DEVICES)이 라우터 임포트 전에 적용되도록 함 +load_dotenv() from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from routers.auth import router as auth_router @@ -11,8 +13,7 @@ import logging import uvicorn -# 1) 환경변수 로드 -load_dotenv() +# 1) 환경변수 로드 (상단에서 선 로드됨) # 2) 로깅 설정 logging.basicConfig(level=logging.INFO) diff --git a/models/base.py b/models/base.py index c1da040..860e542 100644 --- a/models/base.py +++ b/models/base.py @@ -1,2 +1,3 @@ from sqlalchemy.ext.declarative import declarative_base -Base = declarative_base() \ No newline at end of file + +Base = declarative_base() diff --git a/models/file.py b/models/file.py index 7bef388..aacd2ac 100644 --- a/models/file.py +++ b/models/file.py @@ -1,4 +1,4 @@ -# Backend/models/file.py +from sqlalchemy.orm import relationship from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text from .base import Base @@ -8,7 +8,13 @@ class File(Base): id = Column(Integer, primary_key=True, autoincrement=True) user_id = Column(Integer, ForeignKey('user.u_id', ondelete='CASCADE'), nullable=False) folder_id = Column(Integer, ForeignKey('folder.id', ondelete='SET NULL'), nullable=True) - original_name = Column(String(255), nullable=False) # 유저가 업로드한 원본 파일 이름 - saved_path = Column(String(512), nullable=False) # 서버에 저장된(실제) 경로 - content_type = Column(String(100), nullable=False) # MIME 타입 + note_id = Column(Integer, ForeignKey('note.id', ondelete='CASCADE'), nullable=True) + original_name = Column(String(255), nullable=False) + saved_path = Column(String(512), nullable=False) + content_type = Column(String(100), nullable=False) created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + user = relationship("User", back_populates="files") + folder = relationship("Folder", back_populates="files") + note = relationship("Note", back_populates="files") diff --git a/models/folder.py b/models/folder.py index 1d6f2eb..92f4842 100644 --- a/models/folder.py +++ b/models/folder.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, String, ForeignKey, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class Folder(Base): @@ -12,3 +13,9 @@ class Folder(Base): updated_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'), onupdate=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + user = relationship("User", back_populates="folders") + parent = relationship("Folder", remote_side=[id], backref="children") + notes = relationship("Note", back_populates="folder", cascade="all, delete") + files = relationship("File", back_populates="folder", cascade="all, delete") diff --git a/models/note.py b/models/note.py index 8ff2489..67a0fb2 100644 --- a/models/note.py +++ b/models/note.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, String, Text, Boolean, ForeignKey, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class Note(Base): @@ -15,3 +16,8 @@ class Note(Base): updated_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'), onupdate=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + user = relationship("User", back_populates="notes") + folder = relationship("Folder", back_populates="notes") + files = relationship("File", back_populates="note", cascade="all, delete") diff --git a/models/user.py b/models/user.py index 3811b46..295df81 100644 --- a/models/user.py +++ b/models/user.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, String, Enum, TIMESTAMP, text +from sqlalchemy.orm import relationship from .base import Base class User(Base): @@ -9,11 +10,16 @@ class User(Base): email = Column(String(150), nullable=False, unique=True) password = Column(String(255), nullable=False) provider = Column( - Enum('local','google','kakao','naver', name='provider_enum'), - nullable=False, - server_default=text("'local'") - ) + Enum('local','google','kakao','naver', name='provider_enum'), + nullable=False, + server_default=text("'local'") + ) created_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP')) updated_at = Column(TIMESTAMP, nullable=False, server_default=text('CURRENT_TIMESTAMP'), onupdate=text('CURRENT_TIMESTAMP')) + + # ✅ 관계 + folders = relationship("Folder", back_populates="user", cascade="all, delete") + notes = relationship("Note", back_populates="user", cascade="all, delete") + files = relationship("File", back_populates="user", cascade="all, delete")