From 5ffa075dad2bd2de3c11eafdb3debeb2ce753141 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Thu, 12 Feb 2026 15:50:05 +0800 Subject: [PATCH 01/13] feat: extend UniversalDocLoader to support additional file formats including Excel and HTML --- .../module/shared/common/document_loaders.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/runtime/datamate-python/app/module/shared/common/document_loaders.py b/runtime/datamate-python/app/module/shared/common/document_loaders.py index b6f045bae..cb0a843f9 100644 --- a/runtime/datamate-python/app/module/shared/common/document_loaders.py +++ b/runtime/datamate-python/app/module/shared/common/document_loaders.py @@ -1,15 +1,17 @@ -from typing import List, Union, Optional from pathlib import Path +from typing import List, Union, Optional -from langchain_core.documents import Document from langchain_community.document_loaders import ( TextLoader, JSONLoader, CSVLoader, UnstructuredMarkdownLoader, + UnstructuredFileLoader, + UnstructuredExcelLoader, PyPDFLoader, - Docx2txtLoader + Docx2txtLoader, BSHTMLLoader ) +from langchain_core.documents import Document from app.core.logging import get_logger @@ -18,7 +20,7 @@ class UniversalDocLoader: """ 通用泛文本文档加载类 - 支持格式:TXT/JSON/CSV/Markdown/Word(.docx)/PPT(.pptx)/PDF + 支持格式:TXT/JSON/CSV/Markdown/Word(.docx)/PPT(.pptx)/PDF/Excel(.xlsx,.xls)/各种文本文件 """ # 格式-加载器映射(轻量优先) SUPPORTED_FORMATS = { @@ -27,9 +29,20 @@ class UniversalDocLoader: ".json": JSONLoader, ".csv": CSVLoader, ".md": UnstructuredMarkdownLoader, + ".markdown": UnstructuredMarkdownLoader, + ".log": TextLoader, + ".xml": TextLoader, + ".html": BSHTMLLoader, + ".htm": BSHTMLLoader, + ".yaml": TextLoader, + ".yml": TextLoader, # 办公文档类 ".docx": Docx2txtLoader, ".doc": Docx2txtLoader, + ".pptx": UnstructuredFileLoader, + ".ppt": UnstructuredFileLoader, + ".xlsx": UnstructuredExcelLoader, + ".xls": UnstructuredExcelLoader, # PDF 类 ".pdf": PyPDFLoader } @@ -79,6 +92,10 @@ def _set_default_kwargs(loader_cls, kwargs: dict) -> dict: kwargs.setdefault("text_content", False) if loader_cls == CSVLoader and "csv_args" not in kwargs: kwargs["csv_args"] = {"delimiter": ","} + if loader_cls == UnstructuredExcelLoader and "mode" not in kwargs: + kwargs.setdefault("mode", "elements") + if loader_cls == UnstructuredFileLoader and "mode" not in kwargs: + kwargs.setdefault("mode", "elements") return kwargs From f2f762663a4a9b6b592da937437f0f00c9037a43 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Wed, 25 Feb 2026 09:43:25 +0800 Subject: [PATCH 02/13] feat: implement RAG module with document loading, splitting, and processing capabilities --- .claude/skills/backend-architect/SKILL.md | 581 +++++++++++++++++- .claude/skills/fastapi-templates/SKILL.md | 567 +++++++++++++++++ .../Detail/KnowledgeBaseDetail.tsx | 6 +- .../KnowledgeBase/Home/KnowledgeBasePage.tsx | 6 +- .../pages/KnowledgeBase/knowledge-base.api.ts | 29 +- frontend/vite.config.ts | 59 +- runtime/datamate-python/app/core/config.py | 7 + .../app/core/exception/codes.py | 13 +- .../datamate-python/app/db/models/__init__.py | 4 + .../app/db/models/base_entity.py | 2 +- .../app/db/models/knowledge_gen.py | 85 ++- .../datamate-python/app/module/__init__.py | 3 + .../app/module/rag/infra/__init__.py | 23 + .../module/rag/infra/embeddings/__init__.py | 63 ++ .../app/module/rag/infra/milvus/__init__.py | 5 + .../app/module/rag/infra/milvus/factory.py | 56 ++ .../module/rag/infra/milvus/vectorstore.py | 295 +++++++++ .../app/module/rag/infra/options.py | 40 ++ .../app/module/rag/infra/parser/__init__.py | 13 + .../app/module/rag/infra/parser/base.py | 168 +++++ .../app/module/rag/infra/pipeline.py | 116 ++++ .../app/module/rag/infra/splitter/__init__.py | 5 + .../app/module/rag/infra/splitter/base.py | 74 +++ .../app/module/rag/infra/splitter/factory.py | 57 ++ .../rag/infra/splitter/langchain_impl.py | 107 ++++ .../app/module/rag/infra/task/__init__.py | 5 + .../app/module/rag/infra/task/worker_pool.py | 105 ++++ .../app/module/rag/interface/__init__.py | 12 + .../module/rag/interface/knowledge_base.py | 256 ++++++++ .../app/module/rag/interface/rag_interface.py | 22 +- .../app/module/rag/repository/__init__.py | 12 + .../module/rag/repository/file_repository.py | 325 ++++++++++ .../repository/knowledge_base_repository.py | 203 ++++++ .../app/module/rag/schema/__init__.py | 58 ++ .../app/module/rag/schema/entity.py | 58 ++ .../app/module/rag/schema/enums.py | 21 + .../app/module/rag/schema/rag_schema.py | 8 - .../app/module/rag/schema/request.py | 326 ++++++++++ .../app/module/rag/schema/response.py | 194 ++++++ .../app/module/rag/service/etl_service.py | 264 ++++++++ .../app/module/rag/service/file_service.py | 181 ++++++ .../rag/service/knowledge_base_service.py | 539 ++++++++++++++++ .../app/module/rag/service/rag_service.py | 64 +- .../app/module/shared/llm/factory.py | 4 + runtime/datamate-python/poetry.lock | 360 ++++++++++- runtime/datamate-python/pyproject.toml | 10 + 46 files changed, 5301 insertions(+), 110 deletions(-) create mode 100644 .claude/skills/fastapi-templates/SKILL.md create mode 100644 runtime/datamate-python/app/module/rag/infra/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/embeddings/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/milvus/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/milvus/factory.py create mode 100644 runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py create mode 100644 runtime/datamate-python/app/module/rag/infra/options.py create mode 100644 runtime/datamate-python/app/module/rag/infra/parser/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/parser/base.py create mode 100644 runtime/datamate-python/app/module/rag/infra/pipeline.py create mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/base.py create mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/factory.py create mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py create mode 100644 runtime/datamate-python/app/module/rag/infra/task/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/task/worker_pool.py create mode 100644 runtime/datamate-python/app/module/rag/interface/knowledge_base.py create mode 100644 runtime/datamate-python/app/module/rag/repository/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/repository/file_repository.py create mode 100644 runtime/datamate-python/app/module/rag/repository/knowledge_base_repository.py create mode 100644 runtime/datamate-python/app/module/rag/schema/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/schema/entity.py create mode 100644 runtime/datamate-python/app/module/rag/schema/enums.py delete mode 100644 runtime/datamate-python/app/module/rag/schema/rag_schema.py create mode 100644 runtime/datamate-python/app/module/rag/schema/request.py create mode 100644 runtime/datamate-python/app/module/rag/schema/response.py create mode 100644 runtime/datamate-python/app/module/rag/service/etl_service.py create mode 100644 runtime/datamate-python/app/module/rag/service/file_service.py create mode 100644 runtime/datamate-python/app/module/rag/service/knowledge_base_service.py diff --git a/.claude/skills/backend-architect/SKILL.md b/.claude/skills/backend-architect/SKILL.md index a68acde40..49fc60af0 100644 --- a/.claude/skills/backend-architect/SKILL.md +++ b/.claude/skills/backend-architect/SKILL.md @@ -1,36 +1,567 @@ --- -name: Python Web Backend Architect -description: As an elite Backend Architect, you specialize in designing and implementing scalable, asynchronous, and high-performance web systems. You transform complex business visions into modular, production-ready code using the FastAPI and SQLAlchemy (Async) stack, adhering to industry-best "Clean Architecture" principles. +name: fastapi-templates +description: Create production-ready FastAPI projects with async patterns, dependency injection, and comprehensive error handling. Use when building new FastAPI applications or setting up backend API projects. --- -### Architecture Blueprint +# FastAPI Project Templates -### Workflow +Production-ready FastAPI project structures with async patterns, dependency injection, middleware, and best practices for building high-performance APIs. -1. **Requirement Distillation:** Deconstruct high-level features into granular data models and business logic flows. -2. **Schema-First Design:** Define Pydantic V2 schemas for I/O validation and SQLAlchemy 2.0 models for the persistent domain layer. -3. **Dependency Injection (DI) Orchestration:** Implement `Depends` for modular service provision, focusing on asynchronous database session management. -4. **Service Layer Implementation:** Encapsulate business rules in standalone services, ensuring the API layer (Routes) remains a thin orchestration shell. -5. **Robust Error Handling:** Deploy global exception middleware to maintain a consistent API response contract ( lookup for error codes). +## When to Use This Skill -### Constraints & Standards +- Starting new FastAPI projects from scratch +- Implementing async REST APIs with Python +- Building high-performance web services and microservices +- Creating async applications with PostgreSQL, MongoDB +- Setting up API projects with proper structure and testing -* **Full Async Chain:** Every I/O operation must be non-blocking. Use `await` for DB queries and external API calls. -* **Atomic Transactions:** Ensure data integrity via the "Unit of Work" pattern. Use context managers for session commits and rollbacks. -* **Zero N+1 Leakage:** Explicitly use `selectinload` or `joinedload` for relationship loading to optimize database roundtrips. -* **Security & Auth:** Implement JWT-based authentication with OAuth2PasswordBearer. Enforce strict Pydantic `response_model` to prevent PII (Personally Identifiable Information) leakage. -* **Code Quality:** Adhere to PEP 8, utilize Type Hinting for all parameters, and maintain an or better complexity for data processing logic. +## Core Concepts -### Technical Specification Template +### 1. Project Structure -* **Database:** SQLAlchemy 2.0 (Declarative Mapping + Async Engine). -* **Migration:** Mandatory Alembic versioning for all schema changes. -* **Validation:** Pydantic V2 with strict type coercion. -* **API Documentation:** Auto-generated OpenAPI (Swagger) with comprehensive status code definitions (200, 201, 400, 401, 403, 404, 500). +**Recommended Layout:** -### Self-Reflective Audit +``` +app/ +├── api/ # API routes +│ ├── v1/ +│ │ ├── endpoints/ +│ │ │ ├── users.py +│ │ │ ├── auth.py +│ │ │ └── items.py +│ │ └── router.py +│ └── dependencies.py # Shared dependencies +├── core/ # Core configuration +│ ├── config.py +│ ├── security.py +│ └── database.py +├── models/ # Database models +│ ├── user.py +│ └── item.py +├── schemas/ # Pydantic schemas +│ ├── user.py +│ └── item.py +├── services/ # Business logic +│ ├── user_service.py +│ └── auth_service.py +├── repositories/ # Data access +│ ├── user_repository.py +│ └── item_repository.py +└── main.py # Application entry +``` -* Before finalizing any module, verify: -1. Is the business logic strictly decoupled from the FastAPI router? -2. Are the database queries optimized for the expected scale? -3. Does the error handling prevent stack trace exposure to the end-user? +### 2. Dependency Injection + +FastAPI's built-in DI system using `Depends`: + +- Database session management +- Authentication/authorization +- Shared business logic +- Configuration injection + +### 3. Async Patterns + +Proper async/await usage: + +- Async route handlers +- Async database operations +- Async background tasks +- Async middleware + +## Implementation Patterns + +### Pattern 1: Complete FastAPI Application + +```python +# main.py +from fastapi import FastAPI, Depends +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan events.""" + # Startup + await database.connect() + yield + # Shutdown + await database.disconnect() + +app = FastAPI( + title="API Template", + version="1.0.0", + lifespan=lifespan +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +from app.api.v1.router import api_router +app.include_router(api_router, prefix="/api/v1") + +# core/config.py +from pydantic_settings import BaseSettings +from functools import lru_cache + +class Settings(BaseSettings): + """Application settings.""" + DATABASE_URL: str + SECRET_KEY: str + ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + API_V1_STR: str = "/api/v1" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# core/database.py +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from app.core.config import get_settings + +settings = get_settings() + +engine = create_async_engine( + settings.DATABASE_URL, + echo=True, + future=True +) + +AsyncSessionLocal = sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False +) + +Base = declarative_base() + +async def get_db() -> AsyncSession: + """Dependency for database session.""" + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() +``` + +### Pattern 2: CRUD Repository Pattern + +```python +# repositories/base_repository.py +from typing import Generic, TypeVar, Type, Optional, List +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +from pydantic import BaseModel + +ModelType = TypeVar("ModelType") +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + +class BaseRepository(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + """Base repository for CRUD operations.""" + + def __init__(self, model: Type[ModelType]): + self.model = model + + async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]: + """Get by ID.""" + result = await db.execute( + select(self.model).where(self.model.id == id) + ) + return result.scalars().first() + + async def get_multi( + self, + db: AsyncSession, + skip: int = 0, + limit: int = 100 + ) -> List[ModelType]: + """Get multiple records.""" + result = await db.execute( + select(self.model).offset(skip).limit(limit) + ) + return result.scalars().all() + + async def create( + self, + db: AsyncSession, + obj_in: CreateSchemaType + ) -> ModelType: + """Create new record.""" + db_obj = self.model(**obj_in.dict()) + db.add(db_obj) + await db.flush() + await db.refresh(db_obj) + return db_obj + + async def update( + self, + db: AsyncSession, + db_obj: ModelType, + obj_in: UpdateSchemaType + ) -> ModelType: + """Update record.""" + update_data = obj_in.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(db_obj, field, value) + await db.flush() + await db.refresh(db_obj) + return db_obj + + async def delete(self, db: AsyncSession, id: int) -> bool: + """Delete record.""" + obj = await self.get(db, id) + if obj: + await db.delete(obj) + return True + return False + +# repositories/user_repository.py +from app.repositories.base_repository import BaseRepository +from app.models.user import User +from app.schemas.user import UserCreate, UserUpdate + +class UserRepository(BaseRepository[User, UserCreate, UserUpdate]): + """User-specific repository.""" + + async def get_by_email(self, db: AsyncSession, email: str) -> Optional[User]: + """Get user by email.""" + result = await db.execute( + select(User).where(User.email == email) + ) + return result.scalars().first() + + async def is_active(self, db: AsyncSession, user_id: int) -> bool: + """Check if user is active.""" + user = await self.get(db, user_id) + return user.is_active if user else False + +user_repository = UserRepository(User) +``` + +### Pattern 3: Service Layer + +```python +# services/user_service.py +from typing import Optional +from sqlalchemy.ext.asyncio import AsyncSession +from app.repositories.user_repository import user_repository +from app.schemas.user import UserCreate, UserUpdate, User +from app.core.security import get_password_hash, verify_password + +class UserService: + """Business logic for users.""" + + def __init__(self): + self.repository = user_repository + + async def create_user( + self, + db: AsyncSession, + user_in: UserCreate + ) -> User: + """Create new user with hashed password.""" + # Check if email exists + existing = await self.repository.get_by_email(db, user_in.email) + if existing: + raise ValueError("Email already registered") + + # Hash password + user_in_dict = user_in.dict() + user_in_dict["hashed_password"] = get_password_hash(user_in_dict.pop("password")) + + # Create user + user = await self.repository.create(db, UserCreate(**user_in_dict)) + return user + + async def authenticate( + self, + db: AsyncSession, + email: str, + password: str + ) -> Optional[User]: + """Authenticate user.""" + user = await self.repository.get_by_email(db, email) + if not user: + return None + if not verify_password(password, user.hashed_password): + return None + return user + + async def update_user( + self, + db: AsyncSession, + user_id: int, + user_in: UserUpdate + ) -> Optional[User]: + """Update user.""" + user = await self.repository.get(db, user_id) + if not user: + return None + + if user_in.password: + user_in_dict = user_in.dict(exclude_unset=True) + user_in_dict["hashed_password"] = get_password_hash( + user_in_dict.pop("password") + ) + user_in = UserUpdate(**user_in_dict) + + return await self.repository.update(db, user, user_in) + +user_service = UserService() +``` + +### Pattern 4: API Endpoints with Dependencies + +```python +# api/v1/endpoints/users.py +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List + +from app.core.database import get_db +from app.schemas.user import User, UserCreate, UserUpdate +from app.services.user_service import user_service +from app.api.dependencies import get_current_user + +router = APIRouter() + +@router.post("/", response_model=User, status_code=status.HTTP_201_CREATED) +async def create_user( + user_in: UserCreate, + db: AsyncSession = Depends(get_db) +): + """Create new user.""" + try: + user = await user_service.create_user(db, user_in) + return user + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.get("/me", response_model=User) +async def read_current_user( + current_user: User = Depends(get_current_user) +): + """Get current user.""" + return current_user + +@router.get("/{user_id}", response_model=User) +async def read_user( + user_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get user by ID.""" + user = await user_service.repository.get(db, user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + +@router.patch("/{user_id}", response_model=User) +async def update_user( + user_id: int, + user_in: UserUpdate, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Update user.""" + if current_user.id != user_id: + raise HTTPException(status_code=403, detail="Not authorized") + + user = await user_service.update_user(db, user_id, user_in) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + +@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_user( + user_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Delete user.""" + if current_user.id != user_id: + raise HTTPException(status_code=403, detail="Not authorized") + + deleted = await user_service.repository.delete(db, user_id) + if not deleted: + raise HTTPException(status_code=404, detail="User not found") +``` + +### Pattern 5: Authentication & Authorization + +```python +# core/security.py +from datetime import datetime, timedelta +from typing import Optional +from jose import JWTError, jwt +from passlib.context import CryptContext +from app.core.config import get_settings + +settings = get_settings() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +ALGORITHM = "HS256" + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + """Create JWT access token.""" + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify password against hash.""" + return pwd_context.verify(plain_password, hashed_password) + +def get_password_hash(password: str) -> str: + """Hash password.""" + return pwd_context.hash(password) + +# api/dependencies.py +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.security import ALGORITHM +from app.core.config import get_settings +from app.repositories.user_repository import user_repository + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") + +async def get_current_user( + db: AsyncSession = Depends(get_db), + token: str = Depends(oauth2_scheme) +): + """Get current authenticated user.""" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: int = payload.get("sub") + if user_id is None: + raise credentials_exception + except JWTError: + raise credentials_exception + + user = await user_repository.get(db, user_id) + if user is None: + raise credentials_exception + + return user +``` + +## Testing + +```python +# tests/conftest.py +import pytest +import asyncio +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker + +from app.main import app +from app.core.database import get_db, Base + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + +@pytest.fixture +async def db_session(): + engine = create_async_engine(TEST_DATABASE_URL, echo=True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + AsyncSessionLocal = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + + async with AsyncSessionLocal() as session: + yield session + +@pytest.fixture +async def client(db_session): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient(app=app, base_url="http://test") as client: + yield client + +# tests/test_users.py +import pytest + +@pytest.mark.asyncio +async def test_create_user(client): + response = await client.post( + "/api/v1/users/", + json={ + "email": "test@example.com", + "password": "testpass123", + "name": "Test User" + } + ) + assert response.status_code == 201 + data = response.json() + assert data["email"] == "test@example.com" + assert "id" in data +``` + +## Resources + +- **references/fastapi-architecture.md**: Detailed architecture guide +- **references/async-best-practices.md**: Async/await patterns +- **references/testing-strategies.md**: Comprehensive testing guide +- **assets/project-template/**: Complete FastAPI project +- **assets/docker-compose.yml**: Development environment setup + +## Best Practices + +1. **Async All The Way**: Use async for database, external APIs +2. **Dependency Injection**: Leverage FastAPI's DI system +3. **Repository Pattern**: Separate data access from business logic +4. **Service Layer**: Keep business logic out of routes +5. **Pydantic Schemas**: Strong typing for request/response +6. **Error Handling**: Consistent error responses +7. **Testing**: Test all layers independently + +## Common Pitfalls + +- **Blocking Code in Async**: Using synchronous database drivers +- **No Service Layer**: Business logic in route handlers +- **Missing Type Hints**: Loses FastAPI's benefits +- **Ignoring Sessions**: Not properly managing database sessions +- **No Testing**: Skipping integration tests +- **Tight Coupling**: Direct database access in routes diff --git a/.claude/skills/fastapi-templates/SKILL.md b/.claude/skills/fastapi-templates/SKILL.md new file mode 100644 index 000000000..05c492e37 --- /dev/null +++ b/.claude/skills/fastapi-templates/SKILL.md @@ -0,0 +1,567 @@ +--- +name: fastapi-templates +description: Create production-ready FastAPI projects with async patterns, dependency injection, and comprehensive error handling. Use when building new FastAPI applications or setting up backend API projects. +--- + +# FastAPI Project Templates + +Production-ready FastAPI project structures with async patterns, dependency injection, middleware, and best practices for building high-performance APIs. + +## When to Use This Skill + +- Starting new FastAPI projects from scratch +- Implementing async REST APIs with Python +- Building high-performance web services and microservices +- Creating async applications with PostgreSQL, MongoDB +- Setting up API projects with proper structure and testing + +## Core Concepts + +### 1. Project Structure + +**Recommended Layout:** + +``` +app/ +├── api/ # API routes +│ ├── v1/ +│ │ ├── endpoints/ +│ │ │ ├── users.py +│ │ │ ├── auth.py +│ │ │ └── items.py +│ │ └── router.py +│ └── dependencies.py # Shared dependencies +├── core/ # Core configuration +│ ├── config.py +│ ├── security.py +│ └── database.py +├── models/ # Database models +│ ├── user.py +│ └── item.py +├── schemas/ # Pydantic schemas +│ ├── user.py +│ └── item.py +├── services/ # Business logic +│ ├── user_service.py +│ └── auth_service.py +├── repositories/ # Data access +│ ├── user_repository.py +│ └── item_repository.py +└── main.py # Application entry +``` + +### 2. Dependency Injection + +FastAPI's built-in DI system using `Depends`: + +- Database session management +- Authentication/authorization +- Shared business logic +- Configuration injection + +### 3. Async Patterns + +Proper async/await usage: + +- Async route handlers +- Async database operations +- Async background tasks +- Async middleware + +## Implementation Patterns + +### Pattern 1: Complete FastAPI Application + +```python +# main.py +from fastapi import FastAPI, Depends +from fastapi.middleware.cors import CORSMiddleware +from contextlib import asynccontextmanager + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan events.""" + # Startup + await database.connect() + yield + # Shutdown + await database.disconnect() + +app = FastAPI( + title="API Template", + version="1.0.0", + lifespan=lifespan +) + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +from app.api.v1.router import api_router +app.include_router(api_router, prefix="/api/v1") + +# core/config.py +from pydantic_settings import BaseSettings +from functools import lru_cache + +class Settings(BaseSettings): + """Application settings.""" + DATABASE_URL: str + SECRET_KEY: str + ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + API_V1_STR: str = "/api/v1" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# core/database.py +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from app.core.config import get_settings + +settings = get_settings() + +engine = create_async_engine( + settings.DATABASE_URL, + echo=True, + future=True +) + +AsyncSessionLocal = sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False +) + +Base = declarative_base() + +async def get_db() -> AsyncGenerator[Any, Any]: + """Dependency for database session.""" + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() +``` + +### Pattern 2: CRUD Repository Pattern + +```python +# repositories/base_repository.py +from typing import Generic, TypeVar, Type, Optional, List +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +from pydantic import BaseModel + +ModelType = TypeVar("ModelType") +CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) +UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) + +class BaseRepository(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): + """Base repository for CRUD operations.""" + + def __init__(self, model: Type[ModelType]): + self.model = model + + async def get(self, db: AsyncSession, id: int) -> Optional[ModelType]: + """Get by ID.""" + result = await db.execute( + select(self.model).where(self.model.id == id) + ) + return result.scalars().first() + + async def get_multi( + self, + db: AsyncSession, + skip: int = 0, + limit: int = 100 + ) -> List[ModelType]: + """Get multiple records.""" + result = await db.execute( + select(self.model).offset(skip).limit(limit) + ) + return result.scalars().all() + + async def create( + self, + db: AsyncSession, + obj_in: CreateSchemaType + ) -> ModelType: + """Create new record.""" + db_obj = self.model(**obj_in.dict()) + db.add(db_obj) + await db.flush() + await db.refresh(db_obj) + return db_obj + + async def update( + self, + db: AsyncSession, + db_obj: ModelType, + obj_in: UpdateSchemaType + ) -> ModelType: + """Update record.""" + update_data = obj_in.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(db_obj, field, value) + await db.flush() + await db.refresh(db_obj) + return db_obj + + async def delete(self, db: AsyncSession, id: int) -> bool: + """Delete record.""" + obj = await self.get(db, id) + if obj: + await db.delete(obj) + return True + return False + +# repositories/user_repository.py +from app.repositories.base_repository import BaseRepository +from app.models.user import User +from app.schemas.user import UserCreate, UserUpdate + +class UserRepository(BaseRepository[User, UserCreate, UserUpdate]): + """User-specific repository.""" + + async def get_by_email(self, db: AsyncSession, email: str) -> Optional[User]: + """Get user by email.""" + result = await db.execute( + select(User).where(User.email == email) + ) + return result.scalars().first() + + async def is_active(self, db: AsyncSession, user_id: int) -> bool: + """Check if user is active.""" + user = await self.get(db, user_id) + return user.is_active if user else False + +user_repository = UserRepository(User) +``` + +### Pattern 3: Service Layer + +```python +# services/user_service.py +from typing import Optional +from sqlalchemy.ext.asyncio import AsyncSession +from app.repositories.user_repository import user_repository +from app.schemas.user import UserCreate, UserUpdate, User +from app.core.security import get_password_hash, verify_password + +class UserService: + """Business logic for users.""" + + def __init__(self): + self.repository = user_repository + + async def create_user( + self, + db: AsyncSession, + user_in: UserCreate + ) -> User: + """Create new user with hashed password.""" + # Check if email exists + existing = await self.repository.get_by_email(db, user_in.email) + if existing: + raise ValueError("Email already registered") + + # Hash password + user_in_dict = user_in.dict() + user_in_dict["hashed_password"] = get_password_hash(user_in_dict.pop("password")) + + # Create user + user = await self.repository.create(db, UserCreate(**user_in_dict)) + return user + + async def authenticate( + self, + db: AsyncSession, + email: str, + password: str + ) -> Optional[User]: + """Authenticate user.""" + user = await self.repository.get_by_email(db, email) + if not user: + return None + if not verify_password(password, user.hashed_password): + return None + return user + + async def update_user( + self, + db: AsyncSession, + user_id: int, + user_in: UserUpdate + ) -> Optional[User]: + """Update user.""" + user = await self.repository.get(db, user_id) + if not user: + return None + + if user_in.password: + user_in_dict = user_in.dict(exclude_unset=True) + user_in_dict["hashed_password"] = get_password_hash( + user_in_dict.pop("password") + ) + user_in = UserUpdate(**user_in_dict) + + return await self.repository.update(db, user, user_in) + +user_service = UserService() +``` + +### Pattern 4: API Endpoints with Dependencies + +```python +# api/v1/endpoints/users.py +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List + +from app.core.database import get_db +from app.schemas.user import User, UserCreate, UserUpdate +from app.services.user_service import user_service +from app.api.dependencies import get_current_user + +router = APIRouter() + +@router.post("/", response_model=User, status_code=status.HTTP_201_CREATED) +async def create_user( + user_in: UserCreate, + db: AsyncSession = Depends(get_db) +): + """Create new user.""" + try: + user = await user_service.create_user(db, user_in) + return user + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +@router.get("/me", response_model=User) +async def read_current_user( + current_user: User = Depends(get_current_user) +): + """Get current user.""" + return current_user + +@router.get("/{user_id}", response_model=User) +async def read_user( + user_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get user by ID.""" + user = await user_service.repository.get(db, user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + +@router.patch("/{user_id}", response_model=User) +async def update_user( + user_id: int, + user_in: UserUpdate, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Update user.""" + if current_user.id != user_id: + raise HTTPException(status_code=403, detail="Not authorized") + + user = await user_service.update_user(db, user_id, user_in) + if not user: + raise HTTPException(status_code=404, detail="User not found") + return user + +@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_user( + user_id: int, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Delete user.""" + if current_user.id != user_id: + raise HTTPException(status_code=403, detail="Not authorized") + + deleted = await user_service.repository.delete(db, user_id) + if not deleted: + raise HTTPException(status_code=404, detail="User not found") +``` + +### Pattern 5: Authentication & Authorization + +```python +# core/security.py +from datetime import datetime, timedelta +from typing import Optional +from jose import JWTError, jwt +from passlib.context import CryptContext +from app.core.config import get_settings + +settings = get_settings() +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +ALGORITHM = "HS256" + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + """Create JWT access token.""" + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify password against hash.""" + return pwd_context.verify(plain_password, hashed_password) + +def get_password_hash(password: str) -> str: + """Hash password.""" + return pwd_context.hash(password) + +# api/dependencies.py +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.security import ALGORITHM +from app.core.config import get_settings +from app.repositories.user_repository import user_repository + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") + +async def get_current_user( + db: AsyncSession = Depends(get_db), + token: str = Depends(oauth2_scheme) +): + """Get current authenticated user.""" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: int = payload.get("sub") + if user_id is None: + raise credentials_exception + except JWTError: + raise credentials_exception + + user = await user_repository.get(db, user_id) + if user is None: + raise credentials_exception + + return user +``` + +## Testing + +```python +# tests/conftest.py +import pytest +import asyncio +from httpx import AsyncClient +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker + +from app.main import app +from app.core.database import get_db, Base + +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + +@pytest.fixture +async def db_session(): + engine = create_async_engine(TEST_DATABASE_URL, echo=True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + AsyncSessionLocal = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False + ) + + async with AsyncSessionLocal() as session: + yield session + +@pytest.fixture +async def client(db_session): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + + async with AsyncClient(app=app, base_url="http://test") as client: + yield client + +# tests/test_users.py +import pytest + +@pytest.mark.asyncio +async def test_create_user(client): + response = await client.post( + "/api/v1/users/", + json={ + "email": "test@example.com", + "password": "testpass123", + "name": "Test User" + } + ) + assert response.status_code == 201 + data = response.json() + assert data["email"] == "test@example.com" + assert "id" in data +``` + +## Resources + +- **references/fastapi-architecture.md**: Detailed architecture guide +- **references/async-best-practices.md**: Async/await patterns +- **references/testing-strategies.md**: Comprehensive testing guide +- **assets/project-template/**: Complete FastAPI project +- **assets/docker-compose.yml**: Development environment setup + +## Best Practices + +1. **Async All The Way**: Use async for database, external APIs +2. **Dependency Injection**: Leverage FastAPI's DI system +3. **Repository Pattern**: Separate data access from business logic +4. **Service Layer**: Keep business logic out of routes +5. **Pydantic Schemas**: Strong typing for request/response +6. **Error Handling**: Consistent error responses +7. **Testing**: Test all layers independently + +## Common Pitfalls + +- **Blocking Code in Async**: Using synchronous database drivers +- **No Service Layer**: Business logic in route handlers +- **Missing Type Hints**: Loses FastAPI's benefits +- **Ignoring Sessions**: Not properly managing database sessions +- **No Testing**: Skipping integration tests +- **Tight Coupling**: Direct database access in routes diff --git a/frontend/src/pages/KnowledgeBase/Detail/KnowledgeBaseDetail.tsx b/frontend/src/pages/KnowledgeBase/Detail/KnowledgeBaseDetail.tsx index ab99318a7..8c52f1508 100644 --- a/frontend/src/pages/KnowledgeBase/Detail/KnowledgeBaseDetail.tsx +++ b/frontend/src/pages/KnowledgeBase/Detail/KnowledgeBaseDetail.tsx @@ -97,7 +97,11 @@ const KnowledgeBaseDetailPage: React.FC = () => { handleKeywordChange, } = useFetchData( (params) => id ? queryKnowledgeBaseFilesUsingGet(id, params) : Promise.resolve({ data: [] }), - (file) => mapFileData(file, t) + (file) => mapFileData(file, t), + 30000, // 30秒轮询间隔 + false, // 不自动轮询 + [], // 额外的轮询函数 + 0 // pageOffset: Python 后端期望 page 从 1 开始,前端 current=1 时传 page=1 ); // File table logic diff --git a/frontend/src/pages/KnowledgeBase/Home/KnowledgeBasePage.tsx b/frontend/src/pages/KnowledgeBase/Home/KnowledgeBasePage.tsx index fc39f28c7..7ad7b10a9 100644 --- a/frontend/src/pages/KnowledgeBase/Home/KnowledgeBasePage.tsx +++ b/frontend/src/pages/KnowledgeBase/Home/KnowledgeBasePage.tsx @@ -31,7 +31,11 @@ export default function KnowledgeBasePage() { handleKeywordChange, } = useFetchData( queryKnowledgeBasesUsingPost, - (kb) => mapKnowledgeBase(kb, false, t) // 在首页不显示索引模型和文本理解模型字段 + (kb) => mapKnowledgeBase(kb, false, t), // 在首页不显示索引模型和文本理解模型字段 + 30000, // 30秒轮询间隔 + false, // 不自动轮询 + [], // 额外的轮询函数 + 0 // pageOffset: Python 后端期望 page 从 1 开始,前端 current=1 时传 page=1 ); useEffect(() => { diff --git a/frontend/src/pages/KnowledgeBase/knowledge-base.api.ts b/frontend/src/pages/KnowledgeBase/knowledge-base.api.ts index 5741b5d36..e5083b0fd 100644 --- a/frontend/src/pages/KnowledgeBase/knowledge-base.api.ts +++ b/frontend/src/pages/KnowledgeBase/knowledge-base.api.ts @@ -1,8 +1,13 @@ import { get, post, put, del } from "@/utils/request"; // 获取知识库列表 -export function queryKnowledgeBasesUsingPost(params: object) { - return post("/api/knowledge-base/list", params); +export function queryKnowledgeBasesUsingPost(params: any) { + // 将前端的 size 参数映射为后端的 page_size + const { size, ...rest } = params; + return post("/api/knowledge-base/list", { + ...rest, + page_size: size + }); } // 创建知识库 @@ -26,8 +31,22 @@ export function deleteKnowledgeBaseByIdUsingDelete(baseId: string) { } // 获取知识生成文件列表 -export function queryKnowledgeBaseFilesUsingGet(baseId: string, params?: Record) { - return get(`/api/knowledge-base/${baseId}/files${params ? `?${new URLSearchParams(params).toString()}` : ""}`); +export function queryKnowledgeBaseFilesUsingGet(baseId: string, params?: Record) { + if (!params) { + return get(`/api/knowledge-base/${baseId}/files`); + } + // 将前端的 size 参数映射为后端的 page_size + const { size, page, ...rest } = params; + const queryParams = { + page: page || 1, + page_size: size || 10, + ...rest + }; + return get(`/api/knowledge-base/${baseId}/files?${new URLSearchParams( + Object.entries(queryParams) + .filter(([_, v]) => v !== undefined && v !== null) + .reduce((acc, [k, v]) => ({ ...acc, [k]: String(v) }), {}) + ).toString()}`); } // 添加文件到知识库 @@ -62,5 +81,5 @@ export function queryKnowledgeBaseFileDetailUsingGet( ) { const page = params.page ?? 1; const size = params.size ?? 20; - return get(`/api/knowledge-base/${knowledgeBaseId}/files/${ragFileId}?page=${page}&size=${size}`); + return get(`/api/knowledge-base/${knowledgeBaseId}/files/${ragFileId}?page=${page}&page_size=${size}`); } diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index ebd6935d9..dea74fdf8 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -13,25 +13,54 @@ export default defineConfig({ }, server: { host: "0.0.0.0", - proxy: { - "^/api": { - target: "http://localhost:8080", // 本地后端服务地址 + proxy: (() => { + const pythonProxyConfig = { + target: "http://localhost:18000", changeOrigin: true, secure: false, - rewrite: (path) => path.replace(/^\/api/, "/api"), - configure: (proxy, options) => { - // proxy 是 'http-proxy' 的实例 - proxy.on("proxyReq", (proxyReq, req, res) => { - // 可以在这里修改请求头 - proxyReq.removeHeader("referer"); - proxyReq.removeHeader("origin"); + configure: (proxy: { on: (event: string, handler: (arg: unknown) => void) => void }) => { + proxy.on("proxyReq", (proxyReq: unknown) => { + (proxyReq as { removeHeader: (name: string) => void }).removeHeader("referer"); + (proxyReq as { removeHeader: (name: string) => void }).removeHeader("origin"); }); - proxy.on("proxyRes", (proxyRes, req, res) => { - delete proxyRes.headers["set-cookie"]; - proxyRes.headers["cookies"] = ""; // 清除 cookies 头 + proxy.on("proxyRes", (proxyRes: unknown) => { + const res = proxyRes as { headers: Record }; + delete res.headers["set-cookie"]; + res.headers["cookies"] = ""; }); }, - }, - }, + }; + + const javaProxyConfig = { + target: "http://localhost:8080", + changeOrigin: true, + secure: false, + configure: (proxy: { on: (event: string, handler: (arg: unknown) => void) => void }) => { + proxy.on("proxyReq", (proxyReq: unknown) => { + (proxyReq as { removeHeader: (name: string) => void }).removeHeader("referer"); + (proxyReq as { removeHeader: (name: string) => void }).removeHeader("origin"); + }); + proxy.on("proxyRes", (proxyRes: unknown) => { + const res = proxyRes as { headers: Record }; + delete res.headers["set-cookie"]; + res.headers["cookies"] = ""; + }); + }, + }; + + // Python 服务: rag, synthesis, annotation, evaluation, models + const pythonPaths = ["rag", "synthesis", "annotation", "knowledge-base", "data-collection", "evaluation", "models"]; + // Java 服务: data-management, knowledge-base + const javaPaths = ["data-management", "operators", "cleansing"]; + + const proxy: Record = {}; + for (const p of pythonPaths) { + proxy[`/api/${p}`] = pythonProxyConfig; + } + for (const p of javaPaths) { + proxy[`/api/${p}`] = javaProxyConfig; + } + return proxy; + })(), }, }); diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index edaae3b74..580cb97d1 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -77,5 +77,12 @@ def build_database_url(self): datamate_jwt_enable: bool = False + # Milvus 配置 + milvus_uri: str = "http://milvus-standalone:19530" + milvus_token: str = "" + + # 文件存储配置(共享文件系统) + file_storage_path: str = "/data/files" + # 全局设置实例 settings = Settings() diff --git a/runtime/datamate-python/app/core/exception/codes.py b/runtime/datamate-python/app/core/exception/codes.py index 294e6d56b..d23539ddd 100644 --- a/runtime/datamate-python/app/core/exception/codes.py +++ b/runtime/datamate-python/app/core/exception/codes.py @@ -77,8 +77,17 @@ def __init__(self): # ========== RAG 模块 ========== RAG_CONFIG_ERROR: Final = ErrorCode("rag.0001", "RAG configuration error", 400) RAG_KNOWLEDGE_BASE_NOT_FOUND: Final = ErrorCode("rag.0002", "Knowledge base not found", 404) - RAG_MODEL_NOT_FOUND: Final = ErrorCode("rag.0003", "RAG model not found", 404) - RAG_QUERY_FAILED: Final = ErrorCode("rag.0004", "RAG query failed", 500) + RAG_KNOWLEDGE_BASE_ALREADY_EXISTS: Final = ErrorCode("rag.0003", "Knowledge base already exists", 400) + RAG_KNOWLEDGE_BASE_NAME_INVALID: Final = ErrorCode("rag.0004", "Knowledge base name is invalid", 400) + RAG_FILE_NOT_FOUND: Final = ErrorCode("rag.0005", "RAG file not found", 404) + RAG_FILE_PROCESS_FAILED: Final = ErrorCode("rag.0006", "File processing failed", 500) + RAG_FILE_PARSE_FAILED: Final = ErrorCode("rag.0007", "File parsing failed", 500) + RAG_CHUNK_NOT_FOUND: Final = ErrorCode("rag.0008", "Chunk not found", 404) + RAG_MODEL_NOT_FOUND: Final = ErrorCode("rag.0009", "RAG model not found", 404) + RAG_QUERY_FAILED: Final = ErrorCode("rag.0010", "RAG query failed", 500) + RAG_MILVUS_ERROR: Final = ErrorCode("rag.0011", "Milvus operation failed", 500) + RAG_COLLECTION_NOT_FOUND: Final = ErrorCode("rag.0012", "Milvus collection not found", 404) + RAG_EMBEDDING_FAILED: Final = ErrorCode("rag.0013", "Embedding generation failed", 500) # ========== 配比模块 ========== RATIO_TASK_NOT_FOUND: Final = ErrorCode("ratio.0001", "Ratio task not found", 404) diff --git a/runtime/datamate-python/app/db/models/__init__.py b/runtime/datamate-python/app/db/models/__init__.py index 060e4b646..ddc80dc57 100644 --- a/runtime/datamate-python/app/db/models/__init__.py +++ b/runtime/datamate-python/app/db/models/__init__.py @@ -32,6 +32,8 @@ ChunkUploadPreRequest ) +from .knowledge_gen import KnowledgeBase, RagFile + __all__ = [ "Dataset", "DatasetTag", @@ -48,4 +50,6 @@ "CategoryRelation", "OperatorRelease", "ChunkUploadPreRequest", + "KnowledgeBase", + "RagFile", ] diff --git a/runtime/datamate-python/app/db/models/base_entity.py b/runtime/datamate-python/app/db/models/base_entity.py index 56a6aaea8..65f47e5d0 100644 --- a/runtime/datamate-python/app/db/models/base_entity.py +++ b/runtime/datamate-python/app/db/models/base_entity.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String, TIMESTAMP, Text, JSON +from sqlalchemy import Column, String, TIMESTAMP, Text from sqlalchemy.orm import declarative_base from sqlalchemy.sql import func diff --git a/runtime/datamate-python/app/db/models/knowledge_gen.py b/runtime/datamate-python/app/db/models/knowledge_gen.py index daa553630..48806bde2 100644 --- a/runtime/datamate-python/app/db/models/knowledge_gen.py +++ b/runtime/datamate-python/app/db/models/knowledge_gen.py @@ -1,38 +1,81 @@ """ -Tables of RAG Management Module +知识库(RAG)相关 ORM 模型 + +表: t_rag_knowledge_base, t_rag_file +与 Java 实体保持一致。 """ -import uuid -from sqlalchemy import Column, String, TIMESTAMP, Text, Integer, JSON -from sqlalchemy.sql import func +from enum import Enum +from sqlalchemy import Column, String, Integer, JSON, Enum as SQLEnum from app.db.models.base_entity import BaseEntity -class RagKnowledgeBase(BaseEntity): - """知识库模型""" +class RagType(str, Enum): + """RAG 类型枚举 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.RagType + """ + DOCUMENT = "DOCUMENT" # 文档型 RAG(向量检索) + GRAPH = "GRAPH" # 知识图谱型 RAG(LightRAG) + + +class FileStatus(str, Enum): + """文件状态枚举 + + 对应 Java: com.datamate.rag.indexer.domain.model.FileStatus + """ + UNPROCESSED = "UNPROCESSED" # 未处理 + PROCESSING = "PROCESSING" # 处理中 + PROCESSED = "PROCESSED" # 已处理 + PROCESS_FAILED = "PROCESS_FAILED" # 处理失败 + + +class KnowledgeBase(BaseEntity): + """知识库实体 + + 对应 Java: com.datamate.rag.indexer.domain.model.KnowledgeBase + 表名: t_rag_knowledge_base + """ __tablename__ = "t_rag_knowledge_base" + __ignore_data_scope__ = True - id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID") - name = Column(String(255), nullable=False, comment="知识库名称") - type = Column(String(50), nullable=False, comment="知识库类型") + id = Column(String(36), primary_key=True, comment="知识库ID") + name = Column(String(255), nullable=False, unique=True, comment="知识库名称") description = Column(String(512), nullable=True, comment="知识库描述") - embedding_model = Column(String(255), nullable=False, comment="嵌入模型") - chat_model = Column(String(255), nullable=True, comment="聊天模型") + type = Column( + SQLEnum(RagType), + nullable=False, + default=RagType.DOCUMENT, + comment="RAG类型", + ) + embedding_model = Column(String(255), nullable=False, comment="嵌入模型ID") + chat_model = Column(String(255), nullable=True, comment="聊天模型ID") def __repr__(self): - return f"" + return f"" class RagFile(BaseEntity): - """知识库文件模型""" + """RAG 文件实体 + + 对应 Java: com.datamate.rag.indexer.domain.model.RagFile + 表名: t_rag_file + """ __tablename__ = "t_rag_file" __ignore_data_scope__ = True - id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="UUID") - knowledge_base_id = Column(String(36), nullable=False, comment="知识库ID") - file_name = Column(String(255), nullable=False, comment="文件名") - file_id = Column(String(255), nullable=False, comment="文件ID") - chunk_count = Column(Integer, nullable=True, comment="切片数") - file_metadata = Column("metadata", JSON, nullable=True, comment="元数据") - status = Column(String(50), nullable=True, comment="文件状态") - err_msg = Column(Text, nullable=True, comment="错误信息") + id = Column(String(36), primary_key=True, comment="RAG文件ID") + knowledge_base_id = Column(String(36), nullable=False, index=True, comment="知识库ID") + file_name = Column(String(512), nullable=False, comment="文件名") + file_id = Column(String(36), nullable=False, comment="原始文件ID") + chunk_count = Column(Integer, nullable=True, comment="分块数量") + file_metadata = Column("metadata", JSON, nullable=True, comment="元数据(JSON格式)") + status = Column( + SQLEnum(FileStatus), + nullable=False, + default=FileStatus.UNPROCESSED, + comment="处理状态", + ) + err_msg = Column(String(2048), nullable=True, comment="错误信息") + def __repr__(self): + return f"" diff --git a/runtime/datamate-python/app/module/__init__.py b/runtime/datamate-python/app/module/__init__.py index edf8f5479..4f8d6da5f 100644 --- a/runtime/datamate-python/app/module/__init__.py +++ b/runtime/datamate-python/app/module/__init__.py @@ -10,6 +10,7 @@ from .operator.interface import operator_router from .operator.interface import category_router from .cleaning.interface import router as cleaning_router +from .rag.interface.knowledge_base import router as knowledge_base_router router = APIRouter( prefix="/api" @@ -26,4 +27,6 @@ router.include_router(category_router) router.include_router(cleaning_router) +router.include_router(knowledge_base_router) + __all__ = ["router"] diff --git a/runtime/datamate-python/app/module/rag/infra/__init__.py b/runtime/datamate-python/app/module/rag/infra/__init__.py new file mode 100644 index 000000000..8a680d773 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/__init__.py @@ -0,0 +1,23 @@ +""" +RAG 基础设施层:文档加载、分片、管道 + +使用示例: + from app.module.rag.infra import load_and_split, SplitOptions + + chunks = await load_and_split( + "/path/to/doc.pdf", + split_options=SplitOptions( + process_type=ProcessType.PARAGRAPH_CHUNK, + chunk_size=300, + ) + ) +""" +from app.module.rag.infra.pipeline import ingest_file_to_chunks, load_and_split +from app.module.rag.infra.options import SplitOptions, default_split_options + +__all__ = [ + "load_and_split", + "ingest_file_to_chunks", + "SplitOptions", + "default_split_options", +] diff --git a/runtime/datamate-python/app/module/rag/infra/embeddings/__init__.py b/runtime/datamate-python/app/module/rag/infra/embeddings/__init__.py new file mode 100644 index 000000000..f23f33627 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/embeddings/__init__.py @@ -0,0 +1,63 @@ +""" +LangChain Embeddings 封装 + +直接使用 LangChain 的 embeddings 功能,支持多种提供商: +- OpenAI: langchain-openai +- Ollama: langchain-community +- 其他: 通过 LangChain 生态 +""" +from typing import Optional, Any + +from langchain_core.embeddings import Embeddings +from langchain_openai import OpenAIEmbeddings + + +class EmbeddingFactory: + """LangChain Embeddings 工厂类""" + + @staticmethod + def create_embeddings( + model_name: str, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs: Any, + ) -> Embeddings: + """ + 创建 LangChain Embeddings 实例 + + Args: + model_name: 模型名称(如 text-embedding-3-small) + base_url: API 基础 URL + api_key: API 密钥 + **kwargs: 其他参数 + + Returns: + LangChain Embeddings 实例 + """ + # OpenAI / OpenAI 兼容接口 + if "openai" in model_name.lower() or model_name.startswith("text-embedding"): + return OpenAIEmbeddings( + model=model_name, + base_url=base_url, + api_key=api_key, + **kwargs, + ) + # Ollama + if base_url and "ollama" in base_url.lower(): + from langchain_community.embeddings.ollama import OllamaEmbeddings + + return OllamaEmbeddings( + model=model_name, + base_url=base_url, + **kwargs, + ) + # 默认使用 OpenAI 兼容 + return OpenAIEmbeddings( + model=model_name, + base_url=base_url, + api_key=api_key, + **kwargs, + ) + + +__all__ = ["EmbeddingFactory", "Embeddings"] diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/__init__.py b/runtime/datamate-python/app/module/rag/infra/milvus/__init__.py new file mode 100644 index 000000000..b5d6e33ff --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/milvus/__init__.py @@ -0,0 +1,5 @@ +""" +Milvus 向量存储相关模块 + +提供与 Milvus 集成的向量存储功能。 +""" diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/factory.py b/runtime/datamate-python/app/module/rag/infra/milvus/factory.py new file mode 100644 index 000000000..a74fa93e1 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/milvus/factory.py @@ -0,0 +1,56 @@ +""" +向量存储工厂 + +使用 LangChain Milvus 创建向量存储实例,支持混合检索(向量 + BM25) +""" +from __future__ import annotations + +from typing import Any + +from langchain_core.embeddings import Embeddings + +from app.core.config import settings + + +class VectorStoreFactory: + """LangChain Milvus 向量存储工厂""" + + @staticmethod + def get_connection_args() -> dict: + """获取 Milvus 连接参数""" + args: dict = {"uri": settings.milvus_uri} + if getattr(settings, "milvus_token", None): + args["token"] = settings.milvus_token + return args + + @staticmethod + def create( + collection_name: str, + embedding: Embeddings, + *, + drop_old: bool = False, + consistency_level: str = "Strong", + ) -> Any: + """ + 创建 Milvus 向量存储实例(支持混合检索) + + Args: + collection_name: 集合名称(知识库名称) + embedding: LangChain Embeddings 实例 + drop_old: 是否删除已存在同名集合 + consistency_level: 一致性级别 + + Returns: + langchain_milvus.Milvus 实例 + """ + from langchain_milvus import BM25BuiltInFunction, Milvus + + return Milvus( + embedding_function=embedding, + collection_name=collection_name, + connection_args=VectorStoreFactory.get_connection_args(), + builtin_function=BM25BuiltInFunction(), + vector_field=["dense", "sparse"], + consistency_level=consistency_level, + drop_old=drop_old, + ) diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py b/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py new file mode 100644 index 000000000..28c7ae325 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py @@ -0,0 +1,295 @@ +""" +基于 LangChain Milvus 的向量存储封装 + +使用 langchain-milvus.Milvus + BM25BuiltInFunction 实现密集向量 + 全文检索, +Milvus 2.6.x 自动处理 BM25 稀疏向量,无需手动生成。 + +同时提供集合管理辅助函数:drop_collection、rename_collection,供知识库删除/重命名使用。 +""" +from __future__ import annotations + +import logging +from typing import Any, List, Optional + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings + +from app.core.config import settings +from app.module.rag.infra.embeddings import EmbeddingFactory + +logger = logging.getLogger(__name__) + + +def _connection_args() -> dict: + args: dict = {"uri": settings.milvus_uri} + if getattr(settings, "milvus_token", None): + args["token"] = settings.milvus_token + return args + + +def _ensure_connection() -> None: + """确保 Milvus 默认连接已建立(供 utility 使用)。""" + from pymilvus import connections + + conn_args = _connection_args() + connections.connect(alias="default", uri=conn_args["uri"], token=conn_args.get("token") or "") + + +def drop_collection(collection_name: str) -> None: + """删除 Milvus 集合。用于知识库删除等场景。""" + from pymilvus import utility + + from app.core.exception import BusinessError, ErrorCodes + + try: + _ensure_connection() + if utility.has_collection(collection_name, using="default"): + utility.drop_collection(collection_name, using="default") + logger.info("成功删除集合: %s", collection_name) + except Exception as e: + logger.error("删除集合失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除集合失败: {str(e)}") from e + + +def rename_collection(old_name: str, new_name: str) -> None: + """重命名 Milvus 集合。用于知识库重命名。""" + from pymilvus import utility + + from app.core.exception import BusinessError, ErrorCodes + + try: + _ensure_connection() + if utility.has_collection(old_name, using="default"): + utility.rename_collection(old_name, new_name, using="default") + logger.info("成功重命名集合: %s -> %s", old_name, new_name) + except Exception as e: + logger.error("重命名集合失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"重命名集合失败: {str(e)}") from e + + +def create_java_compatible_collection( + collection_name: str, + dimension: int, + consistency_level: str = "Strong" +) -> None: + """创建与 Java 服务兼容的 Milvus 集合 + + 使用 Java 服务相同的字段命名和结构: + - id (VarChar, 主键) + - text (VarChar, with analyzer for BM25) + - metadata (JSON) + - vector (FloatVector, 密集向量) + - sparse (SparseFloatVector, BM25 稀疏向量) + + Args: + collection_name: 集合名称 + dimension: 向量维度 + consistency_level: 一致性级别 + """ + from pymilvus import MilvusClient, DataType, FunctionType + + from app.core.exception import BusinessError, ErrorCodes + + try: + conn_args = _connection_args() + token = conn_args.get("token") if conn_args.get("token") else "" + client = MilvusClient(uri=conn_args["uri"], token=token) + + # 检查集合是否已存在 + if client.has_collection(collection_name): + logger.info("集合 %s 已存在,跳过创建", collection_name) + return + + # 定义 schema + schema = MilvusClient.create_schema() + + # 1. 主键字段 id + schema.add_field( + field_name="id", + datatype=DataType.VARCHAR, + max_length=36, + is_primary=True, + auto_id=False + ) + + # 2. 文本字段 text(启用 analyzer 用于 BM25) + schema.add_field( + field_name="text", + datatype=DataType.VARCHAR, + max_length=65535, + enable_analyzer=True, + enable_match=True + ) + + # 3. 元数据字段 metadata + schema.add_field( + field_name="metadata", + datatype=DataType.JSON + ) + + # 4. 密集向量字段 vector + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=dimension + ) + + # 5. 稀疏向量字段 sparse(BM25) + schema.add_field( + field_name="sparse", + datatype=DataType.SPARSE_FLOAT_VECTOR + ) + + # 创建集合(BM25 将在首次添加文档时自动配置) + client.create_collection( + collection_name=collection_name, + schema=schema, + consistency_level=consistency_level + ) + + # 创建向量索引 + client.create_index( + collection_name=collection_name, + field_name="vector", + index_params={ + "index_type": "HNSW", + "metric_type": "COSINE", + "params": { + "M": 16, + "efConstruction": 256 + } + } + ) + + logger.info("成功创建 Java 兼容的集合: %s (维度: %d)", collection_name, dimension) + + except Exception as e: + logger.error("创建集合失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"创建集合失败: {str(e)}") from e + + +def get_vector_dimension(embedding_model: str, base_url: Optional[str] = None, api_key: Optional[str] = None) -> int: + """获取嵌入模型的向量维度 + + Args: + embedding_model: 模型名称 + base_url: API 基础 URL + api_key: API 密钥 + + Returns: + 向量维度 + + Raises: + BusinessError: 无法获取维度 + """ + from langchain_core.embeddings import Embeddings + + from app.core.exception import BusinessError, ErrorCodes + + try: + import asyncio + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_model, + base_url=base_url, + api_key=api_key, + ) + + test_text = "test" + embedding_vector = asyncio.run(asyncio.to_thread(embedding.embed_query, test_text)) + dimension = len(embedding_vector) + + logger.info("获取模型 %s 的向量维度: %d", embedding_model, dimension) + return dimension + + except Exception as e: + logger.error("获取模型维度失败: %s", e) + raise BusinessError(ErrorCodes.RAG_EMBEDDING_FAILED, f"获取模型维度失败: {str(e)}") from e + + +def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) -> None: + """按 RAG 文件 ID 列表删除 Milvus 中的分块。用于文件删除时清理向量数据。""" + if not rag_file_ids: + return + from pymilvus import MilvusClient + + from app.core.exception import BusinessError, ErrorCodes + + try: + conn_args = _connection_args() + client = MilvusClient(uri=conn_args["uri"], token=conn_args.get("token") or None) + # metadata 为 JSON 字段,按 rag_file_id 过滤 + for rid in rag_file_ids: + filter_expr = f'metadata["rag_file_id"] == "{rid}"' + try: + client.delete(collection_name=collection_name, filter=filter_expr) + except Exception as del_err: + logger.warning("删除分块时部分失败 collection=%s rag_file_id=%s: %s", collection_name, rid, del_err) + logger.info("已按 rag_file_id 删除集合 %s 中的分块: %s", collection_name, rag_file_ids) + except Exception as e: + logger.error("按 rag_file_id 删除 Milvus 分块失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除分块失败: {str(e)}") from e + + +def get_milvus_vectorstore( + collection_name: str, + embedding: Embeddings, + *, + drop_old: bool = False, + consistency_level: str = "Strong", +) -> Any: + """创建带全文检索(BM25)的 Milvus 向量存储实例. + + 使用 langchain-milvus.Milvus + BM25BuiltInFunction,支持混合检索。 + + Args: + collection_name: 集合名称(通常为知识库名称) + embedding: LangChain Embeddings 实例 + drop_old: 是否在创建时删除已存在同名集合 + consistency_level: 一致性级别 + + Returns: + Milvus 向量存储实例,支持 add_documents / similarity_search / as_retriever 等 + """ + from langchain_milvus import BM25BuiltInFunction, Milvus + + return Milvus( + embedding_function=embedding, + collection_name=collection_name, + connection_args=_connection_args(), + builtin_function=BM25BuiltInFunction(), + vector_field=["dense", "sparse"], + consistency_level=consistency_level, + drop_old=drop_old, + ) + + +def chunks_to_langchain_documents( + chunks: List[Any], + *, + ids: Optional[List[str]] = None, + id_key: str = "chunk_id", +) -> tuple[List[Document], List[str]]: + """将领域 DocumentChunk 列表转为 LangChain Document 列表及 id 列表. + + Args: + chunks: 分块列表,每项有 .text 与 .metadata + ids: 若提供则作为文档 id,否则从 metadata[id_key] 取或生成 + id_key: metadata 中作为 id 的键名 + + Returns: + (documents, ids) + """ + from uuid import uuid4 + + documents: List[Document] = [] + out_ids: List[str] = [] + for i, ch in enumerate(chunks): + text = getattr(ch, "text", str(ch)) + meta = getattr(ch, "metadata", {}) or {} + if ids and i < len(ids): + doc_id = ids[i] + else: + doc_id = meta.get(id_key) or str(uuid4()) + documents.append(Document(page_content=text, metadata=dict(meta))) + out_ids.append(doc_id) + return documents, out_ids diff --git a/runtime/datamate-python/app/module/rag/infra/options.py b/runtime/datamate-python/app/module/rag/infra/options.py new file mode 100644 index 000000000..7d7def7bf --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/options.py @@ -0,0 +1,40 @@ +""" +文档加载与分片选项 + +保留必要的配置项,简化使用。 +""" +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from app.module.rag.schema.enums import ProcessType + + +@dataclass +class SplitOptions: + """文档分片选项 + + Args: + process_type: 分片策略 + chunk_size: 块大小(字符) + overlap_size: 块间重叠 + delimiter: 仅 CUSTOM_SEPARATOR_CHUNK 时有效 + """ + + process_type: ProcessType = ProcessType.DEFAULT_CHUNK + chunk_size: int = 500 + overlap_size: int = 50 + delimiter: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典,用于传递给 load_and_split""" + return { + "process_type": self.process_type, + "chunk_size": self.chunk_size, + "overlap_size": self.overlap_size, + "delimiter": self.delimiter, + } + + +def default_split_options() -> SplitOptions: + """默认分片选项:递归分块 500/50""" + return SplitOptions() diff --git a/runtime/datamate-python/app/module/rag/infra/parser/__init__.py b/runtime/datamate-python/app/module/rag/infra/parser/__init__.py new file mode 100644 index 000000000..ebf914b5d --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/parser/__init__.py @@ -0,0 +1,13 @@ +"""保留 ParsedDocument 与 DocumentParser 基类,供 loader 层转换使用.""" + +from app.module.rag.infra.parser.base import ( + ParsedDocument, + DocumentParser, + langchain_documents_to_parsed, +) + +__all__ = [ + "ParsedDocument", + "DocumentParser", + "langchain_documents_to_parsed", +] diff --git a/runtime/datamate-python/app/module/rag/infra/parser/base.py b/runtime/datamate-python/app/module/rag/infra/parser/base.py new file mode 100644 index 000000000..d2e9fa6d6 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/parser/base.py @@ -0,0 +1,168 @@ +""" +文档解析器基类 + +定义文档解析器的抽象接口 +使用策略模式支持多种文档格式 +""" +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional +from pathlib import Path + +from langchain_core.documents import Document + + +class ParsedDocument: + """解析后的文档 + + 包含文档的文本内容和元数据 + """ + + def __init__( + self, + text: str, + metadata: Dict[str, Any], + file_name: str + ): + """初始化解析后的文档 + + Args: + text: 文档文本内容 + metadata: 文档元数据(如作者、创建时间等) + file_name: 文件名 + """ + self.text = text + self.metadata = metadata + self.file_name = file_name + + def __repr__(self): + return f"" + + +def langchain_documents_to_parsed( + documents: List[Document], + file_path: str, + file_name: Optional[str] = None, + **extra_metadata: Any, +) -> ParsedDocument: + """将 LangChain Document 列表转换为 ParsedDocument + + 多页/多段结果合并为一个文档,用于 pipeline。 + + Args: + documents: LangChain 加载器返回的 Document 列表 + file_path: 源文件路径 + file_name: 文件名,若提供则优先使用 + **extra_metadata: 额外的元数据字段(会合并到返回的 metadata 中) + + Returns: + ParsedDocument: 合并后的领域文档对象 + """ + path = Path(file_path) + name = file_name or path.name + + if not documents: + base_metadata = { + "file_name": name, + "file_extension": path.suffix.lower(), + "file_size": path.stat().st_size if path.exists() else 0, + } + base_metadata.update(extra_metadata) + return ParsedDocument( + text="", + metadata=base_metadata, + file_name=name, + ) + + texts = [d.page_content for d in documents if d.page_content] + merged_text = "\n\n".join(texts) + + meta: Dict[str, Any] = { + "file_name": name, + "file_extension": path.suffix.lower(), + "file_size": path.stat().st_size if path.exists() else 0, + # 添加路径信息 + "absolute_directory_path": str(path.parent), + "file_path": str(path), + } + + # 合并额外的元数据 + meta.update(extra_metadata) + + # 合并第一个文档的元数据 + if documents and isinstance(documents[0].metadata, dict): + first_meta = documents[0].metadata + for k, v in first_meta.items(): + if k not in meta and v is not None: + meta[k] = v + + return ParsedDocument(text=merged_text, metadata=meta, file_name=name) + + +class DocumentParser(ABC): + """文档解析器基类(抽象类) + + 对应 Java 的文档解析接口 + + 所有具体的解析器都需要继承此类并实现 parse 方法 + """ + + @abstractmethod + async def parse(self, file_path: str) -> ParsedDocument: + """解析文档 + + Args: + file_path: 文件路径(绝对路径) + + Returns: + ParsedDocument: 解析后的文档对象 + + Raises: + FileNotFoundError: 文件不存在 + ValueError: 文件格式不支持或解析失败 + """ + pass + + def _get_file_name(self, file_path: str) -> str: + """从文件路径中提取文件名 + + Args: + file_path: 文件路径 + + Returns: + 文件名 + """ + return Path(file_path).name + + def _get_file_extension(self, file_path: str) -> str: + """从文件路径中提取文件扩展名 + + Args: + file_path: 文件路径 + + Returns: + 文件扩展名(包含点号,如 ".pdf") + """ + return Path(file_path).suffix.lower() + + def _build_metadata( + self, + file_path: str, + **extra_fields + ) -> Dict[str, Any]: + """构建文档元数据 + + Args: + file_path: 文件路径 + **extra_fields: 额外的元数据字段 + + Returns: + 元数据字典 + """ + path = Path(file_path) + metadata = { + "file_name": path.name, + "file_extension": self._get_file_extension(file_path), + "file_size": path.stat().st_size if path.exists() else 0, + } + metadata.update(extra_fields) + return metadata diff --git a/runtime/datamate-python/app/module/rag/infra/pipeline.py b/runtime/datamate-python/app/module/rag/infra/pipeline.py new file mode 100644 index 000000000..7b41c91fe --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/pipeline.py @@ -0,0 +1,116 @@ +""" +RAG 文档加载与分片管道 + +使用全局 UniversalDocLoader 加载文档,分片后返回 DocumentChunk 列表。 +""" +import asyncio +from typing import Any, List, Optional + +from app.module.shared.common.document_loaders import UniversalDocLoader +from app.module.rag.infra.parser import langchain_documents_to_parsed +from app.module.rag.infra.splitter.base import DocumentChunk +from app.module.rag.infra.splitter.factory import DocumentSplitterFactory +from app.module.rag.schema.enums import ProcessType + + +async def load_and_split( + file_path: str, + split_options: Optional[dict] = None, + **chunk_metadata: Any, +) -> List[DocumentChunk]: + """加载文档并分块 + + 使用 UniversalDocLoader 加载文档,然后按指定策略分块。 + + Args: + file_path: 文件绝对路径 + split_options: 分片选项,None 表示使用默认(递归分块 500/50) + - process_type: ProcessType 枚举,默认 DEFAULT_CHUNK + - chunk_size: 块大小,默认 500 + - overlap_size: 重叠大小,默认 50 + - delimiter: 自定义分隔符 + **chunk_metadata: 写入每个 chunk.metadata 的额外字段 + + Returns: + List[DocumentChunk]: 分块列表 + """ + # 1. 加载文档(使用同步加载器并在异步上下文中运行) + loader = UniversalDocLoader(file_path) + documents = await asyncio.to_thread(loader.load) + + # 2. 准备 parser metadata + parser_metadata = {} + for key in ["original_file_id", "rag_file_id", "file_name"]: + if key in chunk_metadata: + parser_metadata[key] = chunk_metadata[key] + + # 3. 转换为 ParsedDocument(传递额外的 metadata) + parsed = langchain_documents_to_parsed(documents, file_path, **parser_metadata) + + # 4. 获取分片选项 + options = split_options or {} + process_type = options.get("process_type", ProcessType.DEFAULT_CHUNK) + chunk_size = options.get("chunk_size", 500) + overlap_size = options.get("overlap_size", 50) + delimiter = options.get("delimiter") + + # 5. 合并 metadata 用于 chunk + base_chunk_metadata = { + "file_name": parsed.metadata.get("file_name", ""), + "file_extension": parsed.metadata.get("file_extension", ""), + "absolute_directory_path": parsed.metadata.get("absolute_directory_path", ""), + "original_file_id": parsed.metadata.get("original_file_id", ""), + "rag_file_id": parsed.metadata.get("rag_file_id", ""), + } + base_chunk_metadata.update(chunk_metadata) + + # 6. 分片 + splitter = DocumentSplitterFactory.create_splitter( + process_type, + chunk_size=chunk_size, + overlap_size=overlap_size, + delimiter=delimiter, + ) + chunks = await splitter.split( + parsed.text, + file_name=parsed.file_name, + **base_chunk_metadata, + ) + + return chunks + + +async def ingest_file_to_chunks( + file_path: str, + process_type: ProcessType = ProcessType.DEFAULT_CHUNK, + chunk_size: int = 500, + overlap_size: int = 50, + delimiter: Optional[str] = None, + **chunk_metadata: Any, +) -> List[DocumentChunk]: + """从本地文件加载文档并分块(便捷入口) + + 可被 ETL、URL 抓取、S3 等场景复用。 + + Args: + file_path: 文件绝对路径 + process_type: 分块策略 + chunk_size: 块大小 + overlap_size: 重叠大小 + delimiter: 自定义分隔符 + **chunk_metadata: 写入每个 chunk.metadata 的额外字段 + + Returns: + List[DocumentChunk]: 分块列表 + """ + split_options = { + "process_type": process_type, + "chunk_size": chunk_size, + "overlap_size": overlap_size, + "delimiter": delimiter, + } + return await load_and_split( + file_path, + split_options=split_options, + **chunk_metadata, + ) diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/__init__.py b/runtime/datamate-python/app/module/rag/infra/splitter/__init__.py new file mode 100644 index 000000000..c1e09be1a --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/splitter/__init__.py @@ -0,0 +1,5 @@ +""" +文档分块器模块 + +提供各种文档分块策略的实现。 +""" diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/base.py b/runtime/datamate-python/app/module/rag/infra/splitter/base.py new file mode 100644 index 000000000..4a9292018 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/splitter/base.py @@ -0,0 +1,74 @@ +""" +文档分块器基类 + +定义文档分块器的抽象接口 +使用策略模式支持多种分块策略 +""" +from abc import ABC, abstractmethod +from typing import List +from dataclasses import dataclass + + +@dataclass +class DocumentChunk: + """文档分块 + + 包含分块的文本和元数据 + """ + text: str + metadata: dict + + def __repr__(self): + return f"" + + +class DocumentSplitter(ABC): + """文档分块器基类(抽象类) + + 所有具体的分块器都需要继承此类并实现 split 方法 + """ + + def __init__(self, chunk_size: int = 500, overlap_size: int = 50): + """初始化分块器 + + Args: + chunk_size: 分块大小 + overlap_size: 重叠大小 + """ + self.chunk_size = chunk_size + self.overlap_size = overlap_size + + @abstractmethod + async def split(self, text: str, **metadata) -> List[DocumentChunk]: + """分割文档 + + Args: + text: 文档文本 + **metadata: 额外的元数据 + + Returns: + List[DocumentChunk]: 分块列表 + """ + pass + + def _create_chunk( + self, + text: str, + chunk_index: int, + **metadata + ) -> DocumentChunk: + """创建文档分块 + + Args: + text: 分块文本 + chunk_index: 分块索引 + **metadata: 额外的元数据 + + Returns: + DocumentChunk: 文档分块 + """ + chunk_metadata = { + "chunk_index": chunk_index, + **metadata + } + return DocumentChunk(text=text, metadata=chunk_metadata) diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/factory.py b/runtime/datamate-python/app/module/rag/infra/splitter/factory.py new file mode 100644 index 000000000..5d2d833eb --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/splitter/factory.py @@ -0,0 +1,57 @@ +""" +文档分块器工厂 + +根据处理类型创建基于 LangChain 的分块器实例, +对应 Java 的 ProcessType 枚举,供 ETL 与 ingest pipeline 复用。 +""" +from typing import Optional + +from app.module.rag.infra.splitter.base import DocumentSplitter +from app.module.rag.infra.splitter.langchain_impl import LangChainDocumentSplitter +from app.module.rag.schema.enums import ProcessType + + +class DocumentSplitterFactory: + """文档分块器工厂 + + 基于 LangChain RecursiveCharacterTextSplitter / CharacterTextSplitter: + - PARAGRAPH_CHUNK: 段落分块 + - SENTENCE_CHUNK: 句子分块 + - LENGTH_CHUNK: 字符长度分块 + - DEFAULT_CHUNK: 默认递归分块(推荐) + - CUSTOM_SEPARATOR_CHUNK: 自定义分隔符分块 + + 使用示例: + splitter = DocumentSplitterFactory.create_splitter( + ProcessType.DEFAULT_CHUNK, + chunk_size=500, + overlap_size=50 + ) + chunks = await splitter.split(document_text) + """ + + @classmethod + def create_splitter( + cls, + process_type: ProcessType, + chunk_size: int = 500, + overlap_size: int = 50, + delimiter: Optional[str] = None, + ) -> DocumentSplitter: + """根据处理类型创建对应的分块器(LangChain 实现). + + Args: + process_type: 处理类型 + chunk_size: 分块大小 + overlap_size: 重叠大小 + delimiter: 自定义分隔符(仅用于 CUSTOM_SEPARATOR_CHUNK) + + Returns: + DocumentSplitter 实例 + """ + return LangChainDocumentSplitter( + process_type=process_type, + chunk_size=chunk_size, + overlap_size=overlap_size, + delimiter=delimiter, + ) diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py b/runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py new file mode 100644 index 000000000..c2aca2535 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py @@ -0,0 +1,107 @@ +""" +基于 LangChain 的文档分片实现 + +将 ProcessType 映射到 LangChain 的 RecursiveCharacterTextSplitter / CharacterTextSplitter, +在 asyncio.to_thread 中执行同步 split,并转换为领域模型 DocumentChunk。 +""" +from __future__ import annotations + +import asyncio +from typing import Any, List, Optional + +from langchain_text_splitters import ( + CharacterTextSplitter, + RecursiveCharacterTextSplitter, +) +from app.module.rag.schema.enums import ProcessType + +from app.module.rag.infra.splitter.base import DocumentChunk, DocumentSplitter + + +# 各 ProcessType 对应的 RecursiveCharacterTextSplitter 分隔符(优先保持较大语义块) +SEPARATORS_BY_PROCESS_TYPE = { + ProcessType.PARAGRAPH_CHUNK: ["\n\n", "\n", " ", ""], + ProcessType.SENTENCE_CHUNK: ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], + ProcessType.DEFAULT_CHUNK: ["\n\n", "\n", " ", ""], # 推荐默认,递归按段/行/词 + ProcessType.CUSTOM_SEPARATOR_CHUNK: None, # 由调用方传入 delimiter,动态构造 +} + + +def _build_recursive_splitter( + chunk_size: int, + chunk_overlap: int, + separators: Optional[List[str]] = None, +) -> RecursiveCharacterTextSplitter: + if separators is None: + separators = ["\n\n", "\n", " ", ""] + return RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separators=separators, + length_function=len, + ) + + +def _build_character_splitter( + chunk_size: int, + chunk_overlap: int, +) -> CharacterTextSplitter: + return CharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + length_function=len, + ) + + +def _texts_to_chunks(texts: List[str], **base_metadata: Any) -> List[DocumentChunk]: + """将切分后的字符串列表转为 DocumentChunk 列表,保留 chunk_index 等.""" + return [ + DocumentChunk( + text=t, + metadata={**base_metadata, "chunk_index": i}, + ) + for i, t in enumerate(texts) + ] + + +class LangChainDocumentSplitter(DocumentSplitter): + """基于 LangChain 的 DocumentSplitter 实现. + + 根据 ProcessType 选择 RecursiveCharacterTextSplitter 或 CharacterTextSplitter, + async split() 内部用 asyncio.to_thread 调用同步 split_text,再转为 DocumentChunk。 + """ + + def __init__( + self, + process_type: ProcessType, + chunk_size: int = 500, + overlap_size: int = 50, + delimiter: Optional[str] = None, + ): + super().__init__(chunk_size=chunk_size, overlap_size=overlap_size) + self._process_type = process_type + self._delimiter = delimiter or "\n\n" + self._splitter = self._create_splitter() + + def _create_splitter(self) -> RecursiveCharacterTextSplitter | CharacterTextSplitter: + if self._process_type == ProcessType.LENGTH_CHUNK: + return _build_character_splitter( + self.chunk_size, + self.overlap_size, + ) + separators = SEPARATORS_BY_PROCESS_TYPE.get(self._process_type) + if self._process_type == ProcessType.CUSTOM_SEPARATOR_CHUNK: + separators = [self._delimiter, "\n", " ", ""] + if separators is None: + separators = ["\n\n", "\n", " ", ""] + return _build_recursive_splitter( + self.chunk_size, + self.overlap_size, + separators=separators, + ) + + async def split(self, text: str, **metadata: Any) -> List[DocumentChunk]: + if not text or not text.strip(): + return [] + texts = await asyncio.to_thread(self._splitter.split_text, text) + return _texts_to_chunks(texts, **metadata) diff --git a/runtime/datamate-python/app/module/rag/infra/task/__init__.py b/runtime/datamate-python/app/module/rag/infra/task/__init__.py new file mode 100644 index 000000000..28b580ba8 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/task/__init__.py @@ -0,0 +1,5 @@ +""" +异步任务处理模块 + +提供工作池和异步任务处理功能。 +""" diff --git a/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py b/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py new file mode 100644 index 000000000..832dd06ed --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py @@ -0,0 +1,105 @@ +""" +工作协程池 + +使用 asyncio.Semaphore 控制并发数,替代 Java 的虚拟线程 + 信号量 +""" +import asyncio +from typing import Callable, Any, Coroutine +import logging + +logger = logging.getLogger(__name__) + + +class WorkerPool: + """工作协程池 + + 对应 Java 的虚拟线程 + 信号量方案 + + 使用 asyncio.Semaphore 控制并发数,避免资源耗尽 + + 使用示例: + pool = WorkerPool(max_workers=10) + + async def task_func(item): + # 处理任务 + return result + + # 并发执行多个任务 + tasks = [pool.submit(task_func, item) for item in items] + results = await asyncio.gather(*tasks) + """ + + def __init__(self, max_workers: int = 10): + """初始化工作协程池 + + Args: + max_workers: 最大并发数(默认 10) + """ + self.semaphore = asyncio.Semaphore(max_workers) + self.max_workers = max_workers + + async def submit( + self, + coro: Callable[..., Coroutine], + *args: Any, + **kwargs: Any + ) -> Any: + """提交异步任务并等待完成 + + Args: + coro: 异步协程函数 + *args: 位置参数 + **kwargs: 关键字参数 + + Returns: + 协程的返回值 + """ + async with self.semaphore: + try: + result = await coro(*args, **kwargs) + return result + except Exception as e: + logger.error(f"任务执行失败: {e}") + raise + + async def submit_batch( + self, + coro: Callable[..., Coroutine], + items: list[Any], + *args: Any, + **kwargs: Any + ) -> list[Any]: + """批量提交异步任务 + + Args: + coro: 异步协程函数 + items: 要处理的项目列表 + *args: 额外的位置参数 + **kwargs: 额外的关键字参数 + + Returns: + 结果列表 + """ + tasks = [ + self.submit(coro, item, *args, **kwargs) + for item in items + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 检查是否有任务失败 + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"批次任务 {i} 失败: {result}") + + return [r for r in results if not isinstance(r, Exception)] + + def get_available_workers(self) -> int: + """获取可用的工作协程数 + + Returns: + 可用的工作协程数 + """ + # 注意:Semaphore 的值在内部维护,无法直接获取 + # 这里返回最大值作为近似 + return self.max_workers diff --git a/runtime/datamate-python/app/module/rag/interface/__init__.py b/runtime/datamate-python/app/module/rag/interface/__init__.py index e69de29bb..6663ac50d 100644 --- a/runtime/datamate-python/app/module/rag/interface/__init__.py +++ b/runtime/datamate-python/app/module/rag/interface/__init__.py @@ -0,0 +1,12 @@ +""" +RAG 模块 API 路由导出 + +集中导出所有 API 路由 +""" +from .knowledge_base import router as knowledge_base_router +from .rag_interface import router as graph_rag_router + +__all__ = [ + "knowledge_base_router", + "graph_rag_router", +] diff --git a/runtime/datamate-python/app/module/rag/interface/knowledge_base.py b/runtime/datamate-python/app/module/rag/interface/knowledge_base.py new file mode 100644 index 000000000..26c46a8b8 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/interface/knowledge_base.py @@ -0,0 +1,256 @@ +""" +知识库 API 接口 + +实现知识库相关的 REST API 接口 +对应 Java: com.datamate.rag.indexer.interfaces.KnowledgeBaseController + +接口路径调整: +- Java: /knowledge-base/* +- Python: /rag/knowledge-base/* +""" +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exception import SuccessResponse +from app.db.session import get_db +from app.module.rag.schema.request import ( + KnowledgeBaseCreateReq, + KnowledgeBaseUpdateReq, + KnowledgeBaseQueryReq, + AddFilesReq, + DeleteFilesReq, + RagFileReq, + RetrieveReq, + PagingQuery, +) +from app.module.rag.service.knowledge_base_service import KnowledgeBaseService + +router = APIRouter(prefix="/knowledge-base", tags=["知识库管理"]) + + +@router.post("/create", response_model=SuccessResponse) +async def create_knowledge_base( + request: KnowledgeBaseCreateReq, + db: AsyncSession = Depends(get_db) +): + """创建知识库 + + 对应 Java: POST /knowledge-base/create + + Args: + request: 知识库创建请求 + db: 数据库 session + + Returns: + 知识库 ID + """ + service = KnowledgeBaseService(db) + knowledge_base_id = await service.create(request) + return SuccessResponse(data=knowledge_base_id) + + +@router.put("/{knowledge_base_id}", response_model=SuccessResponse) +async def update_knowledge_base( + knowledge_base_id: str, + request: KnowledgeBaseUpdateReq, + db: AsyncSession = Depends(get_db) +): + """更新知识库 + + 对应 Java: PUT /knowledge-base/{id} + + Args: + knowledge_base_id: 知识库 ID + request: 知识库更新请求 + db: 数据库 session + """ + service = KnowledgeBaseService(db) + await service.update(knowledge_base_id, request) + return SuccessResponse(message="知识库更新成功") + + +@router.delete("/{knowledge_base_id}", response_model=SuccessResponse) +async def delete_knowledge_base( + knowledge_base_id: str, + db: AsyncSession = Depends(get_db) +): + """删除知识库 + + 对应 Java: DELETE /knowledge-base/{id} + + Args: + knowledge_base_id: 知识库 ID + db: 数据库 session + """ + service = KnowledgeBaseService(db) + await service.delete(knowledge_base_id) + return SuccessResponse(message="知识库删除成功") + + +@router.get("/{knowledge_base_id}", response_model=SuccessResponse) +async def get_knowledge_base( + knowledge_base_id: str, + db: AsyncSession = Depends(get_db) +): + """获取知识库详情 + + 对应 Java: GET /knowledge-base/{id} + + Args: + knowledge_base_id: 知识库 ID + db: 数据库 session + + Returns: + 知识库详情 + """ + service = KnowledgeBaseService(db) + knowledge_base = await service.get_by_id(knowledge_base_id) + return SuccessResponse(data=knowledge_base) + + +@router.post("/list", response_model=SuccessResponse) +async def list_knowledge_bases( + request: KnowledgeBaseQueryReq, + db: AsyncSession = Depends(get_db) +): + """分页查询知识库列表 + + 对应 Java: POST /knowledge-base/list + + Args: + request: 查询请求 + db: 数据库 session + + Returns: + 知识库列表(分页) + """ + service = KnowledgeBaseService(db) + result = await service.list(request) + return SuccessResponse(data=result) + + +@router.post("/{knowledge_base_id}/files", response_model=SuccessResponse) +async def add_files_to_knowledge_base( + knowledge_base_id: str, + request: AddFilesReq, + db: AsyncSession = Depends(get_db) +): + """添加文件到知识库 + + 对应 Java: POST /knowledge-base/{id}/files + + Args: + knowledge_base_id: 知识库 ID + request: 添加文件请求 + db: 数据库 session + + Returns: + 包含成功添加数量和跳过文件数量的响应 + """ + # 确保 knowledge_base_id 与 request 中的一致 + request.knowledge_base_id = knowledge_base_id + + service = KnowledgeBaseService(db) + result = await service.add_files(request) + + message = f"文件添加成功,正在后台处理" + if result["skipped_count"] > 0: + message = f"成功添加 {result['success_count']} 个文件,跳过 {result['skipped_count']} 个不存在的文件" + + return SuccessResponse( + message=message, + data={ + "successCount": result["success_count"], + "skippedCount": result["skipped_count"], + "skippedFileIds": result["skipped_file_ids"] + } + ) + + +@router.get("/{knowledge_base_id}/files", response_model=SuccessResponse) +async def list_knowledge_base_files( + knowledge_base_id: str, + request: RagFileReq = Depends(), + db: AsyncSession = Depends(get_db) +): + """获取知识库文件列表 + + 对应 Java: GET /knowledge-base/{id}/files + + Args: + knowledge_base_id: 知识库 ID + request: 查询请求 + db: 数据库 session + + Returns: + 文件列表(分页) + """ + service = KnowledgeBaseService(db) + result = await service.list_files(knowledge_base_id, request) + return SuccessResponse(data=result) + + +@router.delete("/{knowledge_base_id}/files", response_model=SuccessResponse) +async def delete_knowledge_base_files( + knowledge_base_id: str, + request: DeleteFilesReq, + db: AsyncSession = Depends(get_db) +): + """删除知识库文件 + + 对应 Java: DELETE /knowledge-base/{id}/files + + Args: + knowledge_base_id: 知识库 ID + request: 删除文件请求 + db: 数据库 session + """ + service = KnowledgeBaseService(db) + await service.delete_files(knowledge_base_id, request) + return SuccessResponse(message="文件删除成功") + + +@router.get("/{knowledge_base_id}/files/{rag_file_id}", response_model=SuccessResponse) +async def get_file_chunks( + knowledge_base_id: str, + rag_file_id: str, + paging_query: PagingQuery = Depends(), + db: AsyncSession = Depends(get_db) +): + """获取指定 RAG 文件的分块列表 + + 对应 Java: GET /knowledge-base/{id}/files/{ragFileId} + + Args: + knowledge_base_id: 知识库 ID + rag_file_id: RAG 文件 ID + paging_query: 分页参数 + db: 数据库 session + + Returns: + 文件分块列表(分页) + """ + service = KnowledgeBaseService(db) + result = await service.get_chunks(knowledge_base_id, rag_file_id, paging_query) + return SuccessResponse(data=result) + + +@router.post("/retrieve", response_model=SuccessResponse) +async def retrieve_knowledge_base( + request: RetrieveReq, + db: AsyncSession = Depends(get_db) +): + """检索知识库内容(向量 + BM25 混合检索) + + 对应 Java: POST /knowledge-base/retrieve + + Args: + request: 检索请求 + db: 数据库 session + + Returns: + 检索结果列表 + """ + service = KnowledgeBaseService(db) + results = await service.retrieve(request) + return SuccessResponse(data=results) diff --git a/runtime/datamate-python/app/module/rag/interface/rag_interface.py b/runtime/datamate-python/app/module/rag/interface/rag_interface.py index 40265bf76..910954a35 100644 --- a/runtime/datamate-python/app/module/rag/interface/rag_interface.py +++ b/runtime/datamate-python/app/module/rag/interface/rag_interface.py @@ -1,17 +1,19 @@ from fastapi import APIRouter, Depends -from app.core.exception import ErrorCodes, BusinessError, SuccessResponse -from app.db.session import get_db +from app.core.exception import SuccessResponse from app.module.rag.service.rag_service import RAGService -from app.module.shared.schema import StandardResponse -from ..schema.rag_schema import QueryRequest +from ..schema.request import QueryRequest -router = APIRouter(prefix="/rag", tags=["rag"]) +router = APIRouter(prefix="/rag", tags=["知识图谱 RAG"]) -@router.post("/process/{knowledge_base_id}") +@router.post("/{knowledge_base_id}/process") async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()): """ - 处理知识库中所有未处理的文件 + 处理知识库中所有未处理的文件(LightRAG) + + 接口路径调整: + - 旧路径: /rag/process/{id} + - 新路径: /rag/graph/{id}/process """ await rag_service.init_graph_rag(knowledge_base_id) return SuccessResponse( @@ -22,7 +24,11 @@ async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService @router.post("/query") async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()): """ - 使用给定的查询文本和知识库 ID 查询知识图谱 + 使用给定的查询文本和知识库 ID 查询知识图谱(LightRAG) + + 接口路径调整: + - 旧路径: /rag/query + - 新路径: /rag/graph/query """ result = await rag_service.query_rag(payload.query, payload.knowledge_base_id) return SuccessResponse(data=result) diff --git a/runtime/datamate-python/app/module/rag/repository/__init__.py b/runtime/datamate-python/app/module/rag/repository/__init__.py new file mode 100644 index 000000000..e80ecdad4 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/repository/__init__.py @@ -0,0 +1,12 @@ +""" +RAG 仓储层导出 + +集中导出所有仓储类 +""" +from .knowledge_base_repository import KnowledgeBaseRepository +from .file_repository import RagFileRepository + +__all__ = [ + "KnowledgeBaseRepository", + "RagFileRepository", +] diff --git a/runtime/datamate-python/app/module/rag/repository/file_repository.py b/runtime/datamate-python/app/module/rag/repository/file_repository.py new file mode 100644 index 000000000..94092de75 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/repository/file_repository.py @@ -0,0 +1,325 @@ +""" +RAG 文件仓储层 + +提供 RAG 文件数据访问操作 +使用 SQLAlchemy 异步 session 进行数据库操作 +""" +from typing import List, Optional, Tuple +from sqlalchemy import select, func, and_, or_ +from sqlalchemy.ext.asyncio import AsyncSession +from app.db.models.knowledge_gen import RagFile, FileStatus +from app.core.exception import BusinessError, ErrorCodes + + +class RagFileRepository: + """RAG 文件仓储类 + + 对应 Java: com.datamate.rag.indexer.domain.repository.RagFileRepository + 提供 RAG 文件的 CRUD 操作和查询功能 + """ + + def __init__(self, db: AsyncSession): + """初始化仓储 + + Args: + db: SQLAlchemy 异步 session + """ + self.db = db + + async def create(self, rag_file: RagFile) -> RagFile: + """创建 RAG 文件 + + Args: + rag_file: RAG 文件实体 + + Returns: + 创建的 RAG 文件实体 + """ + self.db.add(rag_file) + await self.db.flush() + return rag_file + + async def batch_create(self, rag_files: List[RagFile]) -> List[RagFile]: + """批量创建 RAG 文件 + + Args: + rag_files: RAG 文件实体列表 + + Returns: + 创建的 RAG 文件实体列表 + """ + self.db.add_all(rag_files) + await self.db.flush() + return rag_files + + async def update(self, rag_file: RagFile) -> RagFile: + """更新 RAG 文件 + + Args: + rag_file: RAG 文件实体(必须包含 id) + + Returns: + 更新后的 RAG 文件实体 + + Raises: + BusinessError: 文件不存在 + """ + existing = await self.get_by_id(rag_file.id) + if not existing: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + # 更新字段 + if rag_file.chunk_count is not None: + existing.chunk_count = rag_file.chunk_count + if rag_file.metadata is not None: + existing.metadata = rag_file.metadata + if rag_file.status is not None: + existing.status = rag_file.status + if rag_file.err_msg is not None: + existing.err_msg = rag_file.err_msg + + await self.db.flush() + return existing + + async def delete(self, rag_file_id: str) -> None: + """删除 RAG 文件 + + Args: + rag_file_id: RAG 文件 ID + + Raises: + BusinessError: 文件不存在 + """ + rag_file = await self.get_by_id(rag_file_id) + if not rag_file: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + await self.db.delete(rag_file) + await self.db.flush() + + async def batch_delete(self, rag_file_ids: List[str]) -> None: + """批量删除 RAG 文件 + + Args: + rag_file_ids: RAG 文件 ID 列表 + """ + if not rag_file_ids: + return + + await self.db.execute( + select(RagFile).where(RagFile.id.in_(rag_file_ids)) + ) + # 注意:实际删除需要在查询后进行,这里简化处理 + # 在实际使用时,应该先查询再删除 + + async def delete_by_knowledge_base( + self, + knowledge_base_id: str + ) -> int: + """删除知识库的所有文件 + + Args: + knowledge_base_id: 知识库 ID + + Returns: + 删除的文件数量 + """ + result = await self.db.execute( + select(RagFile).where( + RagFile.knowledge_base_id == knowledge_base_id + ) + ) + files = result.scalars().all() + + count = len(files) + for file in files: + await self.db.delete(file) + + await self.db.flush() + return count + + async def get_by_id(self, rag_file_id: str) -> Optional[RagFile]: + """根据 ID 获取 RAG 文件 + + Args: + rag_file_id: RAG 文件 ID + + Returns: + RAG 文件实体,不存在则返回 None + """ + result = await self.db.execute( + select(RagFile).where(RagFile.id == rag_file_id) + ) + return result.scalars().first() + + async def get_by_file_id(self, file_id: str) -> Optional[RagFile]: + """根据原始文件 ID 获取 RAG 文件 + + Args: + file_id: 原始文件 ID + + Returns: + RAG 文件实体,不存在则返回 None + """ + result = await self.db.execute( + select(RagFile).where(RagFile.file_id == file_id) + ) + return result.scalars().first() + + async def list_by_knowledge_base( + self, + knowledge_base_id: str, + keyword: Optional[str] = None, + status: Optional[FileStatus] = None, + page: int = 1, + page_size: int = 10 + ) -> Tuple[List[RagFile], int]: + """分页查询知识库的文件列表 + + Args: + knowledge_base_id: 知识库 ID + keyword: 搜索关键词(模糊匹配文件名) + status: 文件状态筛选 + page: 页码(从 1 开始) + page_size: 每页数量 + + Returns: + (RAG 文件列表, 总记录数) + """ + # 构建查询条件 + conditions = [RagFile.knowledge_base_id == knowledge_base_id] + + if keyword: + conditions.append(RagFile.file_name.like(f"%{keyword}%")) + + if status: + conditions.append(RagFile.status == status) + + # 查询总数 + count_query = select(func.count()).select_from(RagFile).where( + and_(*conditions) + ) + total_result = await self.db.execute(count_query) + total = total_result.scalar() or 0 + + # 分页查询 + query = select(RagFile).where(and_(*conditions)) + query = query.order_by(RagFile.created_at.desc()) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await self.db.execute(query) + items = result.scalars().all() + + return list(items), total + + async def get_unprocessed_files( + self, + knowledge_base_id: str, + limit: int = 100 + ) -> List[RagFile]: + """获取待处理的文件 + + Args: + knowledge_base_id: 知识库 ID + limit: 最大返回数量 + + Returns: + 待处理的 RAG 文件列表 + """ + result = await self.db.execute( + select(RagFile).where( + and_( + RagFile.knowledge_base_id == knowledge_base_id, + RagFile.status == FileStatus.UNPROCESSED + ) + ).limit(limit) + ) + return list(result.scalars().all()) + + async def update_status( + self, + rag_file_id: str, + status: FileStatus, + err_msg: Optional[str] = None + ) -> None: + """更新文件状态 + + Args: + rag_file_id: RAG 文件 ID + status: 新状态 + err_msg: 错误信息(可选) + + Raises: + BusinessError: 文件不存在 + """ + rag_file = await self.get_by_id(rag_file_id) + if not rag_file: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + rag_file.status = status + if err_msg is not None: + rag_file.err_msg = err_msg + + await self.db.flush() + + async def update_chunk_count( + self, + rag_file_id: str, + chunk_count: int + ) -> None: + """更新文件分块数量 + + Args: + rag_file_id: RAG 文件 ID + chunk_count: 分块数量 + + Raises: + BusinessError: 文件不存在 + """ + rag_file = await self.get_by_id(rag_file_id) + if not rag_file: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + rag_file.chunk_count = chunk_count + await self.db.flush() + + async def count_by_knowledge_base( + self, + knowledge_base_id: str + ) -> int: + """统计知识库的文件数量 + + Args: + knowledge_base_id: 知识库 ID + + Returns: + 文件数量 + """ + result = await self.db.execute( + select(func.count()).select_from(RagFile).where( + RagFile.knowledge_base_id == knowledge_base_id + ) + ) + return result.scalar() or 0 + + async def count_chunks_by_knowledge_base( + self, + knowledge_base_id: str + ) -> int: + """统计知识库的总分块数量 + + Args: + knowledge_base_id: 知识库 ID + + Returns: + 总分块数量 + """ + result = await self.db.execute( + select(func.sum(RagFile.chunk_count)).where( + and_( + RagFile.knowledge_base_id == knowledge_base_id, + RagFile.chunk_count.isnot(None) + ) + ) + ) + return result.scalar() or 0 diff --git a/runtime/datamate-python/app/module/rag/repository/knowledge_base_repository.py b/runtime/datamate-python/app/module/rag/repository/knowledge_base_repository.py new file mode 100644 index 000000000..c53e6ae1e --- /dev/null +++ b/runtime/datamate-python/app/module/rag/repository/knowledge_base_repository.py @@ -0,0 +1,203 @@ +""" +知识库仓储层 + +提供知识库数据访问操作 +使用 SQLAlchemy 异步 session 进行数据库操作 +""" +from typing import List, Optional +from sqlalchemy import select, func, and_, or_ +from sqlalchemy.ext.asyncio import AsyncSession +from app.db.models.knowledge_gen import KnowledgeBase, RagType +from app.core.exception import BusinessError, ErrorCodes + + +class KnowledgeBaseRepository: + """知识库仓储类 + + 对应 Java: com.datamate.rag.indexer.domain.repository.KnowledgeBaseRepository + 提供知识库的 CRUD 操作和查询功能 + """ + + def __init__(self, db: AsyncSession): + """初始化仓储 + + Args: + db: SQLAlchemy 异步 session + """ + self.db = db + + async def create(self, knowledge_base: KnowledgeBase) -> KnowledgeBase: + """创建知识库 + + Args: + knowledge_base: 知识库实体 + + Returns: + 创建的知识库实体 + + Raises: + BusinessError: 知识库名称已存在 + """ + # 检查名称是否已存在 + existing = await self.get_by_name(knowledge_base.name) + if existing: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_ALREADY_EXISTS) + + self.db.add(knowledge_base) + await self.db.flush() + return knowledge_base + + async def update(self, knowledge_base: KnowledgeBase) -> KnowledgeBase: + """更新知识库 + + Args: + knowledge_base: 知识库实体(必须包含 id) + + Returns: + 更新后的知识库实体 + + Raises: + BusinessError: 知识库不存在 + """ + existing = await self.get_by_id(knowledge_base.id) + if not existing: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 如果名称变更,检查新名称是否已存在 + if existing.name != knowledge_base.name: + name_exists = await self.get_by_name(knowledge_base.name) + if name_exists: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_ALREADY_EXISTS) + + # 更新字段 + existing.name = knowledge_base.name + existing.description = knowledge_base.description + + await self.db.flush() + return existing + + async def delete(self, knowledge_base_id: str) -> None: + """删除知识库 + + Args: + knowledge_base_id: 知识库ID + + Raises: + BusinessError: 知识库不存在 + """ + knowledge_base = await self.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + await self.db.delete(knowledge_base) + await self.db.flush() + + async def get_by_id(self, knowledge_base_id: str) -> Optional[KnowledgeBase]: + """根据 ID 获取知识库 + + Args: + knowledge_base_id: 知识库ID + + Returns: + 知识库实体,不存在则返回 None + """ + result = await self.db.execute( + select(KnowledgeBase).where(KnowledgeBase.id == knowledge_base_id) + ) + return result.scalars().first() + + async def get_by_name(self, name: str) -> Optional[KnowledgeBase]: + """根据名称获取知识库 + + Args: + name: 知识库名称 + + Returns: + 知识库实体,不存在则返回 None + """ + result = await self.db.execute( + select(KnowledgeBase).where(KnowledgeBase.name == name) + ) + return result.scalars().first() + + async def list( + self, + keyword: Optional[str] = None, + rag_type: Optional[RagType] = None, + page: int = 1, + page_size: int = 10 + ) -> tuple[List[KnowledgeBase], int]: + """分页查询知识库列表 + + Args: + keyword: 搜索关键词(模糊匹配名称或描述) + rag_type: RAG 类型筛选 + page: 页码(从 1 开始) + page_size: 每页数量 + + Returns: + (知识库列表, 总记录数) + """ + # 构建查询条件 + conditions = [] + + if keyword: + conditions.append( + or_( + KnowledgeBase.name.like(f"%{keyword}%"), + KnowledgeBase.description.like(f"%{keyword}%") + ) + ) + + if rag_type: + conditions.append(KnowledgeBase.type == rag_type) + + # 构建查询 + query = select(KnowledgeBase) + if conditions: + query = query.where(and_(*conditions)) + + # 查询总数 + count_query = select(func.count()).select_from(KnowledgeBase) + if conditions: + count_query = count_query.where(and_(*conditions)) + + total_result = await self.db.execute(count_query) + total = total_result.scalar() or 0 + + # 分页查询 + query = query.order_by(KnowledgeBase.created_at.desc()) + query = query.offset((page - 1) * page_size).limit(page_size) + + result = await self.db.execute(query) + items = result.scalars().all() + + return list(items), total + + async def exists_by_name(self, name: str, exclude_id: Optional[str] = None) -> bool: + """检查知识库名称是否存在 + + Args: + name: 知识库名称 + exclude_id: 排除的知识库ID(用于更新时检查) + + Returns: + True 表示名称已存在,False 表示名称可用 + """ + query = select(KnowledgeBase).where(KnowledgeBase.name == name) + if exclude_id: + query = query.where(KnowledgeBase.id != exclude_id) + + result = await self.db.execute(query) + return result.scalars().first() is not None + + async def get_all_ids(self) -> List[str]: + """获取所有知识库 ID + + Returns: + 知识库 ID 列表 + """ + result = await self.db.execute( + select(KnowledgeBase.id).order_by(KnowledgeBase.created_at.desc()) + ) + return list(result.scalars().all()) diff --git a/runtime/datamate-python/app/module/rag/schema/__init__.py b/runtime/datamate-python/app/module/rag/schema/__init__.py new file mode 100644 index 000000000..362bab907 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/schema/__init__.py @@ -0,0 +1,58 @@ +""" +RAG 模块 Schema 导出 + +集中导出所有数据模型、请求和响应 DTO +""" +from .enums import ProcessType +from app.db.models.knowledge_gen import RagType, FileStatus +from .entity import RagChunk +from app.db.models.knowledge_gen import KnowledgeBase, RagFile +from .request import ( + KnowledgeBaseCreateReq, + KnowledgeBaseUpdateReq, + KnowledgeBaseQueryReq, + AddFilesReq, + DeleteFilesReq, + RagFileReq, + RetrieveReq, + FileInfo, + PagingQuery, + QueryRequest, +) +from .response import ( + ModelConfig, + KnowledgeBaseResp, + RagFileResp, + RagChunkResp, + SearchResult, + PagedResponse, +) + +__all__ = [ + # Enums + "RagType", + "ProcessType", + "FileStatus", + # Entities + "KnowledgeBase", + "RagFile", + "RagChunk", + # Requests + "KnowledgeBaseCreateReq", + "KnowledgeBaseUpdateReq", + "KnowledgeBaseQueryReq", + "AddFilesReq", + "DeleteFilesReq", + "RagFileReq", + "RetrieveReq", + "FileInfo", + "PagingQuery", + "QueryRequest", + # Responses + "ModelConfig", + "KnowledgeBaseResp", + "RagFileResp", + "RagChunkResp", + "SearchResult", + "PagedResponse", +] diff --git a/runtime/datamate-python/app/module/rag/schema/entity.py b/runtime/datamate-python/app/module/rag/schema/entity.py new file mode 100644 index 000000000..ad257f72c --- /dev/null +++ b/runtime/datamate-python/app/module/rag/schema/entity.py @@ -0,0 +1,58 @@ +""" +RAG 模块非 ORM 实体 + +RagChunk 为 Milvus 中存储的文档分块 DTO,非数据库表。 +ORM 实体 KnowledgeBase、RagFile 已移至 app.db.models.knowledge_gen。 +""" +# noqa: D104 (RagChunk 是领域模型,不是 DB 表) + + +class RagChunk: + """RAG 分块模型 + + 对应 Java: com.datamate.rag.indexer.domain.model.RagChunk + 注意:这不是数据库实体,而是 Milvus 中存储的文档分块 + """ + + def __init__( + self, + chunk_id: str, + rag_file_id: str, + text: str, + metadata: dict, + vector: list[float] = None, + sparse_vector: dict[int, float] = None + ): + """初始化文档分块 + + Args: + chunk_id: 分块ID + rag_file_id: 关联的 RAG 文件 ID + text: 分块文本内容 + metadata: 元数据(包含文件信息、分块索引等) + vector: 密集向量(嵌入向量) + sparse_vector: 稀疏向量(BM25 向量) + """ + self.chunk_id = chunk_id + self.rag_file_id = rag_file_id + self.text = text + self.metadata = metadata + self.vector = vector + self.sparse_vector = sparse_vector + + def to_dict(self) -> dict: + """转换为字典格式(用于 Milvus 插入) + + Returns: + 包含所有字段的字典 + """ + return { + "id": self.chunk_id, + "text": self.text, + "metadata": self.metadata, + "vector": self.vector, + "sparse": self.sparse_vector or {} + } + + def __repr__(self): + return f"" diff --git a/runtime/datamate-python/app/module/rag/schema/enums.py b/runtime/datamate-python/app/module/rag/schema/enums.py new file mode 100644 index 000000000..7a0e9081b --- /dev/null +++ b/runtime/datamate-python/app/module/rag/schema/enums.py @@ -0,0 +1,21 @@ +""" +RAG 模块枚举定义 + +包含所有 RAG 相关的枚举类型,与 Java 枚举保持一致 +从 app.db.models.knowledge_gen 导入以避免循环依赖 +""" +from enum import Enum + +# 从模型导入以避免循环依赖 + + +class ProcessType(str, Enum): + """分块处理类型枚举 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.ProcessType + """ + PARAGRAPH_CHUNK = "PARAGRAPH_CHUNK" # 段落分块 + SENTENCE_CHUNK = "SENTENCE_CHUNK" # 按句子分块 + LENGTH_CHUNK = "LENGTH_CHUNK" # 按长度分块(字符) + DEFAULT_CHUNK = "DEFAULT_CHUNK" # 默认分块(单词) + CUSTOM_SEPARATOR_CHUNK = "CUSTOM_SEPARATOR_CHUNK" # 自定义分隔符分块 diff --git a/runtime/datamate-python/app/module/rag/schema/rag_schema.py b/runtime/datamate-python/app/module/rag/schema/rag_schema.py deleted file mode 100644 index 00046194d..000000000 --- a/runtime/datamate-python/app/module/rag/schema/rag_schema.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - -class ProcessRequest(BaseModel): - knowledge_base_id: str - -class QueryRequest(BaseModel): - knowledge_base_id: str - query: str diff --git a/runtime/datamate-python/app/module/rag/schema/request.py b/runtime/datamate-python/app/module/rag/schema/request.py new file mode 100644 index 000000000..1fff90f68 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/schema/request.py @@ -0,0 +1,326 @@ +""" +RAG 模块请求 DTO + +使用 Pydantic 定义所有请求参数验证 +与 Java DTO 保持字段一致和验证规则一致 +""" +from pydantic import BaseModel, Field, field_validator +from typing import List, Optional +from .enums import ProcessType +from app.db.models.knowledge_gen import RagType, FileStatus + + +class KnowledgeBaseCreateReq(BaseModel): + """知识库创建请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseCreateReq + """ + name: str = Field( + ..., + min_length=1, + max_length=255, + pattern=r"^[a-zA-Z][a-zA-Z0-9_]*$", + description="知识库名称(必须以字母开头,只能包含字母、数字和下划线)" + ) + description: Optional[str] = Field( + None, + max_length=512, + description="知识库描述" + ) + type: RagType = Field( + default=RagType.DOCUMENT, + description="RAG 类型" + ) + embedding_model: str = Field( + ..., + min_length=1, + description="嵌入模型ID" + ) + chat_model: Optional[str] = Field( + None, + description="聊天模型ID" + ) + + class Config: + json_schema_extra = { + "example": { + "name": "my_knowledge_base", + "description": "我的知识库", + "type": "DOCUMENT", + "embedding_model": "text-embedding-ada-002", + "chat_model": "gpt-4" + } + } + + +class KnowledgeBaseUpdateReq(BaseModel): + """知识库更新请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseUpdateReq + """ + name: str = Field( + ..., + min_length=1, + max_length=255, + pattern=r"^[a-zA-Z][a-zA-Z0-9_]*$", + description="知识库名称" + ) + description: Optional[str] = Field( + None, + max_length=512, + description="知识库描述" + ) + + class Config: + json_schema_extra = { + "example": { + "name": "updated_knowledge_base", + "description": "更新后的描述" + } + } + + +class KnowledgeBaseQueryReq(BaseModel): + """知识库查询请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseQueryReq + """ + page: int = Field( + default=1, + ge=1, + description="页码(从 1 开始)" + ) + page_size: int = Field( + default=10, + ge=1, + le=100, + description="每页数量" + ) + keyword: Optional[str] = Field( + None, + max_length=255, + description="搜索关键词(模糊匹配知识库名称或描述)" + ) + type: Optional[RagType] = Field( + None, + description="按 RAG 类型筛选" + ) + + class Config: + json_schema_extra = { + "example": { + "page": 1, + "page_size": 10, + "keyword": "测试", + "type": "DOCUMENT" + } + } + + +class FileInfo(BaseModel): + """文件信息 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.AddFilesReq.FileInfo + """ + id: str = Field(..., description="文件ID (对应 t_dm_dataset_files.id)") + dataset_id: str = Field(..., description="数据集ID") + file_name: str = Field(..., description="文件名") + + class Config: + json_schema_extra = { + "example": { + "id": "file-uuid-123", + "dataset_id": "dataset-uuid-456", + "file_name": "document.pdf" + } + } + + +class AddFilesReq(BaseModel): + """添加文件请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.AddFilesReq + """ + knowledge_base_id: str = Field(..., description="知识库ID(从路径参数获取,这里保留用于兼容)") + process_type: ProcessType = Field( + default=ProcessType.DEFAULT_CHUNK, + description="分块处理类型" + ) + chunk_size: int = Field( + default=500, + ge=50, + le=2000, + description="分块大小" + ) + overlap_size: int = Field( + default=50, + ge=0, + le=500, + description="重叠大小" + ) + delimiter: Optional[str] = Field( + None, + description="自定义分隔符(仅用于 CUSTOM_SEPARATOR_CHUNK)" + ) + files: List[FileInfo] = Field( + ..., + min_length=1, + description="文件列表" + ) + + @field_validator("delimiter") + @classmethod + def validate_delimiter(cls, v, info): + """验证自定义分隔符""" + if info.data.get("process_type") == ProcessType.CUSTOM_SEPARATOR_CHUNK and not v: + raise ValueError("使用自定义分隔符分块时,delimiter 不能为空") + return v + + class Config: + json_schema_extra = { + "example": { + "knowledge_base_id": "kb-uuid-123", + "process_type": "DEFAULT_CHUNK", + "chunk_size": 500, + "overlap_size": 50, + "files": [ + {"id": "file-1", "dataset_id": "dataset-uuid-456", "file_name": "doc1.pdf"}, + {"id": "file-2", "dataset_id": "dataset-uuid-456", "file_name": "doc2.pdf"} + ] + } + } + + +class DeleteFilesReq(BaseModel): + """删除文件请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.DeleteFilesReq + """ + file_ids: List[str] = Field( + ..., + min_length=1, + description="要删除的文件ID列表" + ) + + class Config: + json_schema_extra = { + "example": { + "file_ids": ["file-1", "file-2", "file-3"] + } + } + + +class RagFileReq(BaseModel): + """RAG 文件查询请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.RagFileReq + """ + page: int = Field( + default=1, + ge=1, + description="页码" + ) + page_size: int = Field( + default=10, + ge=1, + le=100, + description="每页数量" + ) + keyword: Optional[str] = Field( + None, + max_length=255, + description="搜索关键词(模糊匹配文件名)" + ) + status: Optional[FileStatus] = Field( + None, + description="按状态筛选" + ) + + class Config: + json_schema_extra = { + "example": { + "page": 1, + "page_size": 10, + "keyword": "测试", + "status": "PROCESSED" + } + } + + +class RetrieveReq(BaseModel): + """检索请求 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.RetrieveReq + """ + query: str = Field( + ..., + min_length=1, + description="检索查询文本" + ) + top_k: int = Field( + default=5, + ge=1, + le=20, + description="返回前 K 个结果" + ) + threshold: Optional[float] = Field( + None, + ge=0.0, + le=1.0, + description="相似度阈值(仅返回分数大于等于该值的结果)" + ) + knowledge_base_ids: List[str] = Field( + ..., + min_length=1, + description="要检索的知识库ID列表" + ) + + class Config: + json_schema_extra = { + "example": { + "query": "什么是机器学习?", + "top_k": 5, + "threshold": 0.7, + "knowledge_base_ids": ["kb-1", "kb-2"] + } + } + + +class PagingQuery(BaseModel): + """分页查询请求 + + 对应 Java: com.datamate.common.interfaces.PagingQuery + """ + page: int = Field( + default=1, + ge=1, + description="页码(从 1 开始)" + ) + size: int = Field( + default=10, + ge=1, + le=100, + description="每页数量" + ) + + class Config: + json_schema_extra = { + "example": { + "page": 1, + "size": 10 + } + } + + +class QueryRequest(BaseModel): + """知识图谱查询请求""" + knowledge_base_id: str = Field(..., description="知识库ID") + query: str = Field(..., description="查询文本") + + class Config: + json_schema_extra = { + "example": { + "knowledge_base_id": "kb-uuid-123", + "query": "什么是机器学习?" + } + } diff --git a/runtime/datamate-python/app/module/rag/schema/response.py b/runtime/datamate-python/app/module/rag/schema/response.py new file mode 100644 index 000000000..f946b992b --- /dev/null +++ b/runtime/datamate-python/app/module/rag/schema/response.py @@ -0,0 +1,194 @@ +""" +RAG 模块响应 DTO + +定义所有 API 响应的数据结构 +与 Java 响应 DTO 保持字段一致 +""" +from pydantic import BaseModel, Field +from typing import List, Optional, Any +from datetime import datetime +from app.db.models.knowledge_gen import RagType, FileStatus + + +class ModelConfig(BaseModel): + """模型配置信息 + + 对应 Java: com.datamate.common.setting.domain.entity.ModelConfig + """ + id: str = Field(..., description="模型ID") + name: str = Field(..., description="模型名称") + provider: str = Field(..., description="模型提供商") + + class Config: + json_schema_extra = { + "example": { + "id": "model-uuid-123", + "name": "text-embedding-ada-002", + "provider": "openai" + } + } + + +class KnowledgeBaseResp(BaseModel): + """知识库响应 + + 对应 Java: com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseResp + """ + id: str = Field(..., description="知识库ID") + name: str = Field(..., description="知识库名称") + description: Optional[str] = Field(None, description="知识库描述") + type: RagType = Field(..., description="RAG类型") + embedding_model: str = Field(alias="embeddingModel", description="嵌入模型ID") + chat_model: Optional[str] = Field(None, alias="chatModel", description="聊天模型ID") + file_count: Optional[int] = Field(None, alias="fileCount", description="文件数量") + chunk_count: Optional[int] = Field(None, alias="chunkCount", description="分块数量") + embedding: Optional[ModelConfig] = Field(None, description="嵌入模型配置") + chat: Optional[ModelConfig] = Field(None, description="聊天模型配置") + created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") + updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") + + class Config: + populate_by_name = True # 允许使用 snake_case 或 camelCase + json_schema_extra = { + "example": { + "id": "kb-uuid-123", + "name": "my_knowledge_base", + "description": "我的知识库", + "type": "DOCUMENT", + "embeddingModel": "text-embedding-ada-002", + "chatModel": "gpt-4", + "fileCount": 10, + "chunkCount": 150, + "embedding": { + "id": "model-1", + "name": "text-embedding-ada-002", + "provider": "openai" + } + } + } + + +class RagFileResp(BaseModel): + """RAG 文件响应 + + 对应 Java: com.datamate.rag.indexer.domain.model.RagFile + """ + id: str = Field(..., description="RAG文件ID") + knowledge_base_id: str = Field(alias="knowledgeBaseId", description="知识库ID") + file_name: str = Field(alias="fileName", description="文件名") + file_id: str = Field(alias="fileId", description="原始文件ID") + chunk_count: Optional[int] = Field(None, alias="chunkCount", description="分块数量") + metadata: Optional[dict] = Field(None, description="元数据") + status: FileStatus = Field(..., description="处理状态") + err_msg: Optional[str] = Field(None, alias="errMsg", description="错误信息") + created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") + updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") + + class Config: + populate_by_name = True # 允许使用 snake_case 或 camelCase + json_schema_extra = { + "example": { + "id": "rag-file-uuid-123", + "knowledgeBaseId": "kb-uuid-123", + "fileName": "document.pdf", + "fileId": "file-uuid-456", + "chunkCount": 15, + "metadata": {"size": 1024, "format": "pdf"}, + "status": "PROCESSED", + "createdAt": "2025-01-01T00:00:00" + } + } + + +class RagChunkResp(BaseModel): + """RAG 分块响应 + + 对应 Milvus 查询结果 + """ + id: str = Field(..., description="分块ID") + text: str = Field(..., description="分块文本") + metadata: dict = Field(..., description="元数据") + score: float = Field(..., description="相似度分数") + distance: Optional[float] = Field(None, description="距离(可选)") + + class Config: + populate_by_name = True # 允许使用 snake_case 或 camelCase + json_schema_extra = { + "example": { + "id": "chunk-uuid-123", + "text": "这是文档的一个分块内容...", + "metadata": { + "fileName": "document.pdf", + "chunkIndex": 0 + }, + "score": 0.95 + } + } + + +class SearchResult(BaseModel): + """检索结果 + + 对应 Java: io.milvus.v2.service.vector.response.SearchResp.SearchResult + """ + id: str = Field(..., description="结果ID") + score: float = Field(..., description="相似度分数") + text: str = Field(..., description="文本内容") + metadata: dict = Field(default_factory=dict, description="元数据") + + class Config: + json_schema_extra = { + "example": { + "id": "chunk-uuid-123", + "score": 0.95, + "text": "相关文档内容...", + "metadata": {"file_name": "doc.pdf"} + } + } + + +class PagedResponse(BaseModel): + """分页响应 + + 对应 Java: com.datamate.common.interfaces.PagedResponse + """ + items: List[Any] = Field(..., description="数据列表") + total: int = Field(..., description="总记录数") + page: int = Field(..., description="当前页码") + page_size: int = Field(alias="pageSize", description="每页数量") + total_pages: int = Field(alias="totalPages", description="总页数") + + @classmethod + def create(cls, items: List[Any], total: int, page: int, page_size: int): + """创建分页响应 + + Args: + items: 数据列表 + total: 总记录数 + page: 当前页码 + page_size: 每页数量 + + Returns: + PagedResponse 实例 + """ + total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0 + # 使用内部字段名(snake_case) + return cls( + items=items, + total=total, + page=page, + page_size=page_size, + total_pages=total_pages + ) + + class Config: + populate_by_name = True # 允许使用 snake_case 或 camelCase + json_schema_extra = { + "example": { + "items": [], + "total": 100, + "page": 1, + "pageSize": 10, + "totalPages": 10 + } + } diff --git a/runtime/datamate-python/app/module/rag/service/etl_service.py b/runtime/datamate-python/app/module/rag/service/etl_service.py new file mode 100644 index 000000000..bd29e347f --- /dev/null +++ b/runtime/datamate-python/app/module/rag/service/etl_service.py @@ -0,0 +1,264 @@ +""" +ETL 服务 + +实现文件的异步 ETL 处理流程,使用 LangChain Milvus 向量存储(密集向量 + BM25 全文检索)。 +对应 Java: com.datamate.rag.indexer.infra.event.RagEtlService +""" +import uuid +from pathlib import Path +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus +from app.module.rag.schema.request import AddFilesReq +from app.module.rag.repository import RagFileRepository, KnowledgeBaseRepository +from app.module.rag.infra.pipeline import ingest_file_to_chunks +from app.module.rag.infra.embeddings import EmbeddingFactory +from app.module.rag.infra.milvus.factory import VectorStoreFactory +from app.module.rag.infra.milvus.vectorstore import ( + chunks_to_langchain_documents, + create_java_compatible_collection, + get_vector_dimension, +) +from app.module.system.service.common_service import get_model_by_id +from app.module.rag.infra.task.worker_pool import WorkerPool +from app.core.config import settings +from app.core.exception import BusinessError, ErrorCodes + +import logging + +logger = logging.getLogger(__name__) + + +class ETLService: + """RAG ETL 服务类 + + 对应 Java: com.datamate.rag.indexer.infra.event.RagEtlService + + 替代 Java 方案: + - Java: @TransactionalEventListener + 虚拟线程 + 信号量 + - Python: asyncio + WorkerPool(信号量控制) + + 功能: + 1. 解析文档(从共享文件系统读取) + 2. 分块 + 3. 生成嵌入向量 + 4. 存储到 Milvus + 5. 更新文件状态 + """ + + def __init__(self, db: AsyncSession): + """初始化服务 + + Args: + db: 数据库异步 session + """ + self.db = db + self.file_repo = RagFileRepository(db) + self.kb_repo = KnowledgeBaseRepository(db) + self.worker_pool = WorkerPool(max_workers=10) + + async def process_files( + self, + knowledge_base: KnowledgeBase, + request: AddFilesReq + ) -> None: + """处理文件的入口方法(在事务提交后调用) + + 对应 Java 的 @TransactionalEventListener(phase = AFTER_COMMIT) + + Args: + knowledge_base: 知识库实体 + request: 添加文件请求 + """ + # 获取待处理的文件 + files = await self.file_repo.get_unprocessed_files(knowledge_base.id) + + if not files: + logger.info(f"知识库 {knowledge_base.name} 没有待处理的文件") + return + + logger.info(f"开始处理 {len(files)} 个文件,知识库: {knowledge_base.name}") + + # 并发处理所有文件(信号量控制并发数) + import asyncio + tasks = [ + self.worker_pool.submit( + self._process_single_file, + file, knowledge_base, request + ) + for file in files + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 统计处理结果 + success_count = sum(1 for r in results if not isinstance(r, Exception)) + failed_count = len(results) - success_count + + logger.info( + f"文件处理完成,成功: {success_count}, 失败: {failed_count}" + ) + + async def _process_single_file( + self, + rag_file: RagFile, + knowledge_base: KnowledgeBase, + request: AddFilesReq + ) -> None: + """处理单个文件的 ETL 流程 + + 步骤: + 1. 解析文档(从共享文件系统读取) + 2. 分块 + 3. 生成嵌入向量 + 4. 存储到 Milvus + 5. 更新文件状态 + + Args: + rag_file: RAG 文件实体 + knowledge_base: 知识库实体 + request: 添加文件请求 + """ + try: + # 1. 更新状态为处理中 + await self.file_repo.update_status(rag_file.id, FileStatus.PROCESSING) + + # 2. 从 metadata 中获取文件路径和原始文件ID + file_path = rag_file.file_metadata.get("file_path") if rag_file.file_metadata else None + original_file_id = rag_file.file_id # t_dm_dataset_files.id + dataset_id = rag_file.file_metadata.get("dataset_id") if rag_file.file_metadata else None + + # 2.1 验证文件路径 + if not file_path: + error_msg = f"文件路径未设置,file_metadata={rag_file.file_metadata}" + logger.error(error_msg) + await self.file_repo.update_status( + rag_file.id, + FileStatus.PROCESS_FAILED, + err_msg=error_msg + ) + return + + # 2.2 确保使用绝对路径 + import os + file_path = os.path.abspath(file_path) + + # 2.3 验证文件存在 + if not Path(file_path).exists(): + error_msg = f"文件不存在: {file_path}" + logger.error(error_msg) + await self.file_repo.update_status( + rag_file.id, + FileStatus.PROCESS_FAILED, + err_msg=error_msg + ) + return + + # 3. 准备完整的 metadata + file_extension = Path(file_path).suffix + base_metadata = { + "rag_file_id": rag_file.id, + "original_file_id": original_file_id, + "dataset_id": dataset_id, + "file_name": rag_file.file_name, + "file_extension": file_extension, + "knowledge_base_id": knowledge_base.id, + "file_path": file_path, + } + + # 4. 加载并分块(复用 ingest pipeline),传递完整的 metadata + try: + chunks = await ingest_file_to_chunks( + file_path, + process_type=request.process_type, + chunk_size=request.chunk_size, + overlap_size=request.overlap_size, + delimiter=request.delimiter, + **base_metadata + ) + except Exception as e: + error_msg = f"文档解析或分块失败: {str(e)}" + logger.exception(f"文件 {rag_file.file_name} 解析失败: {e}") + await self.file_repo.update_status( + rag_file.id, + FileStatus.PROCESS_FAILED, + err_msg=error_msg + ) + return + + if not chunks: + logger.warning(f"文件 {rag_file.file_name} 未生成任何分块") + await self.file_repo.update_status( + rag_file.id, + FileStatus.PROCESS_FAILED, + err_msg="文档解析后未生成任何分块" + ) + return + + logger.info(f"文件 {rag_file.file_name} 分块完成,共 {len(chunks)} 个分块") + + # 5. 写入 LangChain Milvus 向量存储(自动嵌入 + BM25 全文检索) + try: + embedding_entity = await get_model_by_id(self.db, knowledge_base.embedding_model) + if not embedding_entity: + raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") + + # 5.1 获取向量维度并创建 Java 兼容的集合 + try: + dimension = get_vector_dimension( + embedding_model=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + create_java_compatible_collection( + collection_name=knowledge_base.name, + dimension=dimension + ) + except BusinessError as e: + logger.warning("创建或检查集合失败: %s", e) + # 如果集合已存在,继续处理 + if "已存在" not in str(e): + raise + + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + vectorstore = VectorStoreFactory.create( + collection_name=knowledge_base.name, + embedding=embedding, + ) + for c in chunks: + # 确保 metadata 包含所有必需字段 + for key, value in base_metadata.items(): + if key not in c.metadata: + c.metadata[key] = value + ids = [str(uuid.uuid4()) for _ in chunks] + documents, doc_ids = chunks_to_langchain_documents(chunks, ids=ids) + vectorstore.add_documents(documents=documents, ids=doc_ids) + + except Exception as e: + error_msg = f"向量化或存储到 Milvus 失败: {str(e)}" + logger.exception(f"文件 {rag_file.file_name} 向量化失败: {e}") + await self.file_repo.update_status( + rag_file.id, + FileStatus.PROCESS_FAILED, + err_msg=error_msg + ) + return + + # 6. 更新文件状态为成功 + await self.file_repo.update_chunk_count(rag_file.id, len(chunks)) + await self.file_repo.update_status(rag_file.id, FileStatus.PROCESSED) + + logger.info(f"文件 {rag_file.file_name} ETL 处理完成") + + except Exception as e: + logger.exception(f"文件 {rag_file.file_name} 处理失败: {e}") + await self.file_repo.update_status( + rag_file.id, + FileStatus.PROCESS_FAILED, + err_msg=str(e) + ) + # 不抛出异常,避免影响其他文件的处理 diff --git a/runtime/datamate-python/app/module/rag/service/file_service.py b/runtime/datamate-python/app/module/rag/service/file_service.py new file mode 100644 index 000000000..7f8e8de4c --- /dev/null +++ b/runtime/datamate-python/app/module/rag/service/file_service.py @@ -0,0 +1,181 @@ +""" +文件管理服务 + +实现文件相关的业务逻辑 +""" +import uuid +from typing import List, Tuple +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models.knowledge_gen import RagFile, FileStatus +from app.module.rag.schema.request import AddFilesReq +from app.module.rag.repository import RagFileRepository, KnowledgeBaseRepository +from app.module.rag.infra.milvus.vectorstore import delete_chunks_by_rag_file_ids +from app.core.exception import BusinessError, ErrorCodes + +import logging + +logger = logging.getLogger(__name__) + + +class FileService: + """文件管理服务类 + + 功能: + 1. 添加文件到知识库 + 2. 删除文件 + 3. 查询文件 + """ + + def __init__(self, db: AsyncSession): + """初始化服务 + + Args: + db: 数据库异步 session + """ + self.db = db + self.file_repo = RagFileRepository(db) + self.kb_repo = KnowledgeBaseRepository(db) + + async def add_files(self, request: AddFilesReq) -> Tuple[List[RagFile], List[str]]: + """添加文件到知识库 + + Args: + request: 添加文件请求 + + Returns: + (创建的 RAG 文件列表, 跳过的文件ID列表) + + Raises: + BusinessError: 知识库不存在 + """ + # 验证知识库存在 + knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 验证文件列表不为空 + if not request.files or len(request.files) == 0: + raise BusinessError(ErrorCodes.BAD_REQUEST, "文件列表不能为空") + + # 导入 dataset 服务 + from app.module.dataset.service.service import Service as DatasetService + + dataset_service = DatasetService(self.db) + + # 验证文件存在并创建 RAG 文件记录 + rag_files = [] + skipped_file_ids = [] + + for file_info in request.files: + try: + # 通过 dataset 服务验证文件是否存在 + file_path = await dataset_service.get_file_download_url( + dataset_id=file_info.dataset_id, + file_id=file_info.id + ) + + # 跳过不存在的文件 + if not file_path: + logger.warning( + f"文件不存在,跳过处理: dataset_id={file_info.dataset_id}, " + f"file_id={file_info.id}, file_name={file_info.file_name}" + ) + skipped_file_ids.append(file_info.id) + continue + + # 创建 RAG 文件记录,存储 dataset_id 和 file_path 到 metadata + rag_file = RagFile( + id=str(uuid.uuid4()), + knowledge_base_id=request.knowledge_base_id, + file_name=file_info.file_name, + file_id=file_info.id, + chunk_count=None, + file_metadata={ + "process_type": request.process_type.value, + "dataset_id": file_info.dataset_id, + "file_path": file_path + }, + status=FileStatus.UNPROCESSED, + err_msg=None + ) + rag_files.append(rag_file) + + except Exception as e: + logger.error( + f"处理文件信息失败: dataset_id={file_info.dataset_id}, " + f"file_id={file_info.id}, error={e}" + ) + skipped_file_ids.append(file_info.id) + continue + + # 批量保存 + if rag_files: + await self.file_repo.batch_create(rag_files) + logger.info(f"成功添加 {len(rag_files)} 个文件到知识库: {knowledge_base.name}") + + if skipped_file_ids: + logger.warning(f"跳过 {len(skipped_file_ids)} 个文件: {skipped_file_ids}") + + return rag_files, skipped_file_ids + + async def delete_files( + self, + knowledge_base_id: str, + file_ids: List[str] + ) -> None: + """删除文件 + + Args: + knowledge_base_id: 知识库 ID + file_ids: 文件 ID 列表 + + Raises: + BusinessError: 知识库不存在 + """ + # 验证文件列表不为空 + if not file_ids or len(file_ids) == 0: + raise BusinessError(ErrorCodes.BAD_REQUEST, "文件ID列表不能为空") + + # 获取知识库 + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 获取文件列表(需要删除 Milvus 数据) + rag_files = [] + for file_id in file_ids: + try: + rag_file = await self.file_repo.get_by_id(file_id) + if rag_file: + rag_files.append(rag_file) + else: + logger.warning(f"文件不存在,跳过删除: {file_id}") + except Exception as e: + logger.error(f"查询文件失败: {file_id}, error={e}") + continue + + # 删除 Milvus 中该文件对应的分块数据 + if rag_files: + try: + delete_chunks_by_rag_file_ids( + knowledge_base.name, + [r.id for r in rag_files], + ) + except Exception as e: + logger.error("删除 Milvus 数据失败: %s", e) + # 继续删除数据库记录 + else: + logger.warning("没有找到有效的文件,跳过 Milvus 数据删除") + + # 删除数据库记录 + deleted_count = 0 + for file_id in file_ids: + try: + await self.file_repo.delete(file_id) + deleted_count += 1 + except Exception as e: + logger.error(f"删除数据库记录失败: {file_id}, error={e}") + continue + + logger.info(f"成功删除 {deleted_count} 个文件") diff --git a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py new file mode 100644 index 000000000..efcd8c814 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py @@ -0,0 +1,539 @@ +""" +知识库业务服务 + +实现知识库的 CRUD 操作和业务逻辑 +对应 Java: com.datamate.rag.indexer.application.KnowledgeBaseService +""" +import logging +import uuid +from typing import List + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exception import BusinessError, ErrorCodes +from app.core.config import settings +from app.module.rag.infra.milvus.vectorstore import drop_collection, rename_collection +from app.module.rag.infra.embeddings import EmbeddingFactory +from app.db.models.knowledge_gen import KnowledgeBase +from app.module.rag.repository import KnowledgeBaseRepository, RagFileRepository +from app.module.rag.schema.request import ( + KnowledgeBaseCreateReq, + KnowledgeBaseUpdateReq, + KnowledgeBaseQueryReq, + AddFilesReq, + DeleteFilesReq, + RagFileReq, + RetrieveReq, + PagingQuery, +) +from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagChunkResp +from app.module.rag.service.etl_service import ETLService +from app.module.rag.service.file_service import FileService + +logger = logging.getLogger(__name__) + + +class KnowledgeBaseService: + """知识库业务服务类 + + 对应 Java: com.datamate.rag.indexer.application.KnowledgeBaseService + + 功能: + 1. 知识库 CRUD 操作 + 2. 文件管理 + 3. 检索功能 + """ + + def __init__(self, db: AsyncSession): + """初始化服务 + + Args: + db: 数据库异步 session + """ + self.db = db + self.kb_repo = KnowledgeBaseRepository(db) + self.file_repo = RagFileRepository(db) + self.file_service = FileService(db) + self.etl_service = ETLService(db) + + async def create(self, request: KnowledgeBaseCreateReq) -> str: + """创建知识库 + + 对应 Java: create 方法 + + Args: + request: 创建请求 + + Returns: + 知识库 ID + + Raises: + BusinessError: 知识库名称已存在 + """ + # 创建知识库实体 + knowledge_base = KnowledgeBase( + id=str(uuid.uuid4()), + name=request.name, + description=request.description, + type=request.type, + embedding_model=request.embedding_model, + chat_model=request.chat_model + ) + + # 保存到数据库 + knowledge_base = await self.kb_repo.create(knowledge_base) + + # Milvus 集合由 LangChain Milvus 在首次 ETL add_documents 时自动创建(含 BM25 全文检索) + logger.info(f"成功创建知识库: {request.name}") + + # 提交事务 + await self.db.commit() + + return knowledge_base.id + + async def update(self, knowledge_base_id: str, request: KnowledgeBaseUpdateReq) -> None: + """更新知识库 + + 对应 Java: update 方法 + + Args: + knowledge_base_id: 知识库 ID + request: 更新请求 + + Raises: + BusinessError: 知识库不存在 + """ + # 获取现有知识库 + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + old_name = knowledge_base.name + + # 更新字段 + knowledge_base.name = request.name + knowledge_base.description = request.description + + # 更新数据库 + await self.kb_repo.update(knowledge_base) + + # 如果名称变更,重命名 Milvus 集合 + if old_name != request.name: + try: + rename_collection(old_name, request.name) + except BusinessError: + await self.db.rollback() + raise + + await self.db.commit() + + async def delete(self, knowledge_base_id: str) -> None: + """删除知识库 + + 对应 Java: delete 方法 + + Args: + knowledge_base_id: 知识库 ID + + Raises: + BusinessError: 知识库不存在 + """ + # 获取知识库 + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 删除所有文件 + await self.file_repo.delete_by_knowledge_base(knowledge_base_id) + + # 删除知识库 + await self.kb_repo.delete(knowledge_base_id) + + # 删除 Milvus 集合 + try: + drop_collection(knowledge_base.name) + except Exception as e: + logger.error("删除 Milvus 集合失败: %s", e) + + await self.db.commit() + + async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: + """获取知识库详情 + + 对应 Java: getById 方法 + + Args: + knowledge_base_id: 知识库 ID + + Returns: + 知识库响应对象 + + Raises: + BusinessError: 知识库不存在 + """ + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 统计文件数量 + file_count = await self.file_repo.count_by_knowledge_base(knowledge_base_id) + chunk_count = await self.file_repo.count_chunks_by_knowledge_base(knowledge_base_id) + + # 构建响应 + response = KnowledgeBaseResp( + id=knowledge_base.id, + name=knowledge_base.name, + description=knowledge_base.description, + type=knowledge_base.type, + embedding_model=knowledge_base.embedding_model, + chat_model=knowledge_base.chat_model, + file_count=file_count, + chunk_count=chunk_count, + created_at=knowledge_base.created_at, + updated_at=knowledge_base.updated_at + ) + + return response + + async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: + """分页查询知识库列表 + + 对应 Java: list 方法 + + Args: + request: 查询请求 + + Returns: + 分页响应 + """ + items, total = await self.kb_repo.list( + keyword=request.keyword, + rag_type=request.type, + page=request.page, + page_size=request.page_size + ) + + # 转换为响应对象 + responses = [] + for item in items: + file_count = await self.file_repo.count_by_knowledge_base(item.id) + chunk_count = await self.file_repo.count_chunks_by_knowledge_base(item.id) + + response = KnowledgeBaseResp( + id=item.id, + name=item.name, + description=item.description, + type=item.type, + embedding_model=item.embedding_model, + chat_model=item.chat_model, + file_count=file_count, + chunk_count=chunk_count, + created_at=item.created_at, + updated_at=item.updated_at + ) + responses.append(response) + + return PagedResponse.create( + items=responses, + total=total, + page=request.page, + page_size=request.page_size + ) + + async def add_files(self, request: AddFilesReq) -> dict: + """添加文件到知识库 + + 对应 Java: addFiles 方法 + + Args: + request: 添加文件请求 + + Returns: + 包含成功和跳过文件数量的字典 + + Raises: + BusinessError: 知识库不存在 + """ + # 验证知识库存在 + knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 添加文件记录 + rag_files, skipped_file_ids = await self.file_service.add_files(request) + + # 提交事务后触发 ETL 处理 + await self.db.commit() + + # 异步处理文件(在事务提交后) + if rag_files: + await self.etl_service.process_files(knowledge_base, request) + + return { + "success_count": len(rag_files), + "skipped_count": len(skipped_file_ids), + "skipped_file_ids": skipped_file_ids + } + + async def list_files( + self, + knowledge_base_id: str, + request: RagFileReq + ) -> PagedResponse: + """获取知识库文件列表 + + 对应 Java: listFiles 方法 + + Args: + knowledge_base_id: 知识库 ID + request: 查询请求 + + Returns: + 分页响应 + + Raises: + BusinessError: 知识库不存在 + """ + # 验证知识库存在 + if not await self.kb_repo.get_by_id(knowledge_base_id): + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + items, total = await self.file_repo.list_by_knowledge_base( + knowledge_base_id=knowledge_base_id, + keyword=request.keyword, + status=request.status, + page=request.page, + page_size=request.page_size + ) + + return PagedResponse.create( + items=items, + total=total, + page=request.page, + page_size=request.page_size + ) + + async def delete_files(self, knowledge_base_id: str, request: DeleteFilesReq) -> None: + """删除知识库文件 + + 对应 Java: deleteFiles 方法 + + Args: + knowledge_base_id: 知识库 ID + request: 删除文件请求 + + Raises: + BusinessError: 知识库不存在 + """ + # 验证知识库存在 + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 删除文件(包括 Milvus 数据) + await self.file_service.delete_files(knowledge_base_id, request.file_ids) + + await self.db.commit() + + async def retrieve(self, request: RetrieveReq) -> List[dict]: + """检索知识库内容 + + 对应 Java: retrieve 方法 + + 使用混合检索(向量 + BM25) + + Args: + request: 检索请求 + + Returns: + 检索结果列表 + + Raises: + BusinessError: 知识库不存在 + """ + import asyncio + + # 1. 验证所有知识库存在 + knowledge_bases = [] + for kb_id in request.knowledge_base_ids: + kb = await self.kb_repo.get_by_id(kb_id) + if not kb: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + knowledge_bases.append(kb) + + # 2. 获取嵌入模型(使用第一个知识库的配置) + embedding_entity = await get_model_by_id(self.db, knowledge_bases[0].embedding_model) + if not embedding_entity: + raise BusinessError(ErrorCodes.RAG_MODEL_NOT_FOUND) + + # 3. 创建嵌入模型实例 + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + + # 4. 生成查询向量 + try: + query_vector = await asyncio.to_thread(embedding.embed_query, request.query) + except Exception as e: + logger.error(f"Failed to embed query: {e}") + raise BusinessError(ErrorCodes.RAG_EMBEDDING_FAILED, f"查询向量化失败: {str(e)}") from e + + # 5. 执行混合检索(向量 + BM25) + from pymilvus import MilvusClient + + all_results = [] + + try: + client = MilvusClient(uri=settings.milvus_uri) + + for kb in knowledge_bases: + try: + # 检查集合是否存在 + if not client.has_collection(kb.name): + logger.warning(f"Collection {kb.name} does not exist, skipping") + continue + + # 混合检索:密集向量 + 稀疏向量(BM25) + search_results = client.hybrid_search( + collection_name=kb.name, + data=[ + { + "vector": query_vector, + "sparse": request.query + } + ], + anns_field=["vector", "sparse"], + limit=request.top_k, + ranker={ + "type": "weighted", + "weights": [0.1, 0.9] # 10% 向量相似度,90% BM25 关键词匹配 + } + ) + + # 提取结果 + if search_results and len(search_results) > 0: + for result in search_results[0]: + result["knowledge_base_id"] = kb.id + result["knowledge_base_name"] = kb.name + all_results.append(result) + + except Exception as e: + logger.error(f"Hybrid search failed for kb {kb.name}: {e}") + # 继续处理其他知识库 + continue + + except Exception as e: + logger.error(f"Milvus client initialization or search failed: {e}") + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"检索失败: {str(e)}") from e + + # 6. 按分数降序排序 + all_results.sort(key=lambda x: x.get("distance", 0), reverse=True) + + # 7. 应用阈值过滤 + if request.threshold is not None: + all_results = [r for r in all_results if r.get("distance", 0) >= request.threshold] + + # 8. 格式化返回结果 + formatted_results = [] + for r in all_results: + entity = r.get("entity", {}) + formatted_results.append({ + "id": entity.get("id", ""), + "text": entity.get("text", ""), + "metadata": entity.get("metadata", {}), + "score": r.get("distance", 0), + "knowledgeBaseId": r.get("knowledge_base_id", ""), + "knowledgeBaseName": r.get("knowledge_base_name", "") + }) + + logger.info(f"Retrieve completed: query='{request.query}' results={len(formatted_results)}") + return formatted_results + + async def get_chunks( + self, + knowledge_base_id: str, + rag_file_id: str, + paging_query: PagingQuery + ) -> PagedResponse: + """获取指定 RAG 文件的分块列表 + + 对应 Java: getChunks 方法 + + 从 Milvus 查询指定 rag_file_id 的分块,支持分页。 + + Args: + knowledge_base_id: 知识库 ID + rag_file_id: RAG 文件 ID + paging_query: 分页参数 + + Returns: + 分块列表(分页) + """ + # 验证知识库存在 + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + # 验证文件存在 + rag_file = await self.file_repo.get_by_id(rag_file_id) + if not rag_file: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + # 使用 MilvusClient 查询指定文件的分块 + from pymilvus import MilvusClient + + from app.core.exception import BusinessError as BE, ErrorCodes as EC + + try: + conn_args = settings.milvus_uri + token = getattr(settings, "milvus_token", None) + client = MilvusClient(uri=conn_args, token=token) + + # 查询总数 + count_filter_expr = f'metadata["rag_file_id"] == "{rag_file_id}"' + count_res = client.query( + collection_name=knowledge_base.name, + filter=count_filter_expr, + output_fields=["id"] + ) + total = len(count_res) + + # 查询分页数据 + offset = (paging_query.page - 1) * paging_query.size + filter_expr = f'metadata["rag_file_id"] == "{rag_file_id}"' + results = client.query( + collection_name=knowledge_base.name, + filter=filter_expr, + output_fields=["id", "text", "metadata"], + limit=paging_query.size, + offset=offset + ) + + # 转换为 RagChunkResp + chunks = [] + for item in results: + chunks.append(RagChunkResp( + id=item.get("id", ""), + text=item.get("text", ""), + metadata=item.get("metadata", {}), + score=0.0 # 非相似度查询,默认分数为 0 + )) + + logger.info( + "查询文件分块成功: kb=%s file=%s total=%d page=%d size=%d", + knowledge_base_id, rag_file_id, total, paging_query.page, paging_query.size + ) + + return PagedResponse.create( + items=chunks, + total=total, + page=paging_query.page, + page_size=paging_query.size + ) + + except Exception as e: + logger.error("查询文件分块失败: kb=%s file=%s error=%s", knowledge_base_id, rag_file_id, e) + raise BE(EC.RAG_MILVUS_ERROR, f"查询文件分块失败: {str(e)}") from e + diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 1af9e49ba..9584e406a 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -3,13 +3,17 @@ from typing import Optional, Sequence from fastapi import Depends +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models.dataset_management import DatasetFiles -from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase +from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus from app.db.session import get_db, AsyncSessionLocal +from app.module.rag.infra.embeddings import EmbeddingFactory +from app.module.rag.infra.milvus.factory import VectorStoreFactory from app.module.shared.common.document_loaders import load_documents from .graph_rag import ( DEFAULT_WORKING_DIR, @@ -22,6 +26,12 @@ logger = get_logger(__name__) +# DOCUMENT 类型 RAG 使用 LangChain 检索链 +RAG_DOCUMENT_PROMPT = ChatPromptTemplate.from_messages([ + ("system", "根据以下上下文回答问题。如果上下文中没有相关信息,请说明。\n\n上下文:\n{context}"), + ("human", "{input}"), +]) + class RAGService: def __init__( @@ -37,7 +47,7 @@ async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFil result = await self.db.execute( select(RagFile).where( RagFile.knowledge_base_id == knowledge_base_id, - RagFile.status != "PROCESSED", + RagFile.status != FileStatus.PROCESSED, ) ) return result.scalars().all() @@ -90,7 +100,7 @@ async def _process_pending_files(self, knowledge_base_id: str): async def _process_single_file(self, rag_file: RagFile): try: - await self._mark_file_status(rag_file, "PROCESSING") + await self._mark_file_status(rag_file, FileStatus.PROCESSING) dataset_file = await self._get_dataset_file(rag_file.file_id) documents = load_documents(dataset_file.file_path) for doc in documents: @@ -98,9 +108,9 @@ async def _process_single_file(self, rag_file: RagFile): await self.rag.ainsert(input=doc.page_content, file_paths=[dataset_file.file_path]) except Exception: # noqa: BLE001 logger.exception("Failed to process rag file %s", rag_file.id) - await self._mark_file_status(rag_file, "PROCESS_FAILED") + await self._mark_file_status(rag_file, FileStatus.PROCESS_FAILED) return - await self._mark_file_status(rag_file, "PROCESSED") + await self._mark_file_status(rag_file, FileStatus.PROCESSED) async def _get_dataset_file(self, file_id: str) -> DatasetFiles: result = await self.db.execute( @@ -111,7 +121,7 @@ async def _get_dataset_file(self, file_id: str) -> DatasetFiles: raise ValueError(f"Dataset file with ID {file_id} not found.") return dataset_file - async def _mark_file_status(self, rag_file: RagFile, status: str): + async def _mark_file_status(self, rag_file: RagFile, status: FileStatus): rag_file.status = status self.db.add(rag_file) await self.db.commit() @@ -119,7 +129,7 @@ async def _mark_file_status(self, rag_file: RagFile, status: str): async def _get_knowledge_base(self, knowledge_base_id: str): result = await self.db.execute( - select(RagKnowledgeBase).where(RagKnowledgeBase.id == knowledge_base_id) + select(KnowledgeBase).where(KnowledgeBase.id == knowledge_base_id) ) knowledge_base = result.scalars().first() if not knowledge_base: @@ -135,6 +145,40 @@ async def _get_models(self, model_id: Optional[str]): return models async def query_rag(self, query: str, knowledge_base_id: str) -> str: - if not self.rag: - await self.init_graph_rag(knowledge_base_id) - return await self.rag.get_knowledge_graph(query) + kb = await self._get_knowledge_base(knowledge_base_id) + if kb.type and str(kb.type).upper() == "GRAPH": + if not self.rag: + await self.init_graph_rag(knowledge_base_id) + return await self.rag.get_knowledge_graph(query) + # DOCUMENT 类型:LangChain Milvus 检索 + RAG + return await self._query_document_rag(query, kb) + + async def _query_document_rag(self, query: str, kb: KnowledgeBase) -> str: + """基于 Milvus 向量存储的检索与生成(混合检索 + LLM)。""" + from langchain_classic.chains.combine_documents import create_stuff_documents_chain + from langchain_classic.chains.retrieval import create_retrieval_chain + + embedding_entity = await self._get_models(kb.embedding_model) + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + vectorstore = VectorStoreFactory.create( + collection_name=kb.name, + embedding=embedding, + ) + retriever = vectorstore.as_retriever( + search_type="hybrid", + search_kwargs={"k": 5}, + ) + chat_model_entity = await self._get_models(kb.chat_model) + llm = ChatOpenAI( + model=chat_model_entity.model_name, + base_url=getattr(chat_model_entity, "base_url", None) or None, + api_key=getattr(chat_model_entity, "api_key", None) or None, + ) + combine_chain = create_stuff_documents_chain(llm, RAG_DOCUMENT_PROMPT) + chain = create_retrieval_chain(retriever, combine_chain) + result = await chain.ainvoke({"input": query}) + return result.get("answer", "") diff --git a/runtime/datamate-python/app/module/shared/llm/factory.py b/runtime/datamate-python/app/module/shared/llm/factory.py index 4713600df..adeb428da 100644 --- a/runtime/datamate-python/app/module/shared/llm/factory.py +++ b/runtime/datamate-python/app/module/shared/llm/factory.py @@ -5,6 +5,7 @@ """ from typing import Literal +import httpx from langchain_core.language_models import BaseChatModel from langchain_core.embeddings import Embeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings @@ -14,6 +15,8 @@ class LLMFactory: """基于 LangChain 的 Chat / Embedding 工厂,面向 OpenAI 兼容 API。""" + custom_http_client = httpx.Client(verify=False) + @staticmethod def create_chat( model_name: str, @@ -25,6 +28,7 @@ def create_chat( model=model_name, base_url=base_url or None, api_key=SecretStr(api_key or ""), + http_client=LLMFactory.custom_http_client, ) @staticmethod diff --git a/runtime/datamate-python/poetry.lock b/runtime/datamate-python/poetry.lock index 5c4db4a1a..30b468394 100644 --- a/runtime/datamate-python/poetry.lock +++ b/runtime/datamate-python/poetry.lock @@ -426,6 +426,18 @@ charset-normalizer = ["charset-normalizer"] html5lib = ["html5lib"] lxml = ["lxml"] +[[package]] +name = "cachetools" +version = "7.0.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "cachetools-7.0.0-py3-none-any.whl", hash = "sha256:d52fef60e6e964a1969cfb61ccf6242a801b432790fe520d78720d757c81cbd2"}, + {file = "cachetools-7.0.0.tar.gz", hash = "sha256:a9abf18ff3b86c7d05b27ead412e235e16ae045925e531fae38d5fada5ed5b08"}, +] + [[package]] name = "certifi" version = "2026.1.4" @@ -839,6 +851,18 @@ files = [ [package.extras] dev = ["coverage", "pytest (>=7.4.4)"] +[[package]] +name = "et-xmlfile" +version = "2.0.0" +description = "An implementation of lxml.xmlfile for the standard library" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, + {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, +] + [[package]] name = "fastapi" version = "0.124.4" @@ -1168,6 +1192,83 @@ files = [ docs = ["Sphinx", "furo"] test = ["objgraph", "psutil", "setuptools"] +[[package]] +name = "grpcio" +version = "1.78.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "grpcio-1.78.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:7cc47943d524ee0096f973e1081cb8f4f17a4615f2116882a5f1416e4cfe92b5"}, + {file = "grpcio-1.78.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:c3f293fdc675ccba4db5a561048cca627b5e7bd1c8a6973ffedabe7d116e22e2"}, + {file = "grpcio-1.78.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:10a9a644b5dd5aec3b82b5b0b90d41c0fa94c85ef42cb42cf78a23291ddb5e7d"}, + {file = "grpcio-1.78.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4c5533d03a6cbd7f56acfc9cfb44ea64f63d29091e40e44010d34178d392d7eb"}, + {file = "grpcio-1.78.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ff870aebe9a93a85283837801d35cd5f8814fe2ad01e606861a7fb47c762a2b7"}, + {file = "grpcio-1.78.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:391e93548644e6b2726f1bb84ed60048d4bcc424ce5e4af0843d28ca0b754fec"}, + {file = "grpcio-1.78.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:df2c8f3141f7cbd112a6ebbd760290b5849cda01884554f7c67acc14e7b1758a"}, + {file = "grpcio-1.78.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd8cb8026e5f5b50498a3c4f196f57f9db344dad829ffae16b82e4fdbaea2813"}, + {file = "grpcio-1.78.0-cp310-cp310-win32.whl", hash = "sha256:f8dff3d9777e5d2703a962ee5c286c239bf0ba173877cc68dc02c17d042e29de"}, + {file = "grpcio-1.78.0-cp310-cp310-win_amd64.whl", hash = "sha256:94f95cf5d532d0e717eed4fc1810e8e6eded04621342ec54c89a7c2f14b581bf"}, + {file = "grpcio-1.78.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2777b783f6c13b92bd7b716667452c329eefd646bfb3f2e9dabea2e05dbd34f6"}, + {file = "grpcio-1.78.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:9dca934f24c732750389ce49d638069c3892ad065df86cb465b3fa3012b70c9e"}, + {file = "grpcio-1.78.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:459ab414b35f4496138d0ecd735fed26f1318af5e52cb1efbc82a09f0d5aa911"}, + {file = "grpcio-1.78.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:082653eecbdf290e6e3e2c276ab2c54b9e7c299e07f4221872380312d8cf395e"}, + {file = "grpcio-1.78.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:85f93781028ec63f383f6bc90db785a016319c561cc11151fbb7b34e0d012303"}, + {file = "grpcio-1.78.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f12857d24d98441af6a1d5c87442d624411db486f7ba12550b07788f74b67b04"}, + {file = "grpcio-1.78.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5397fff416b79e4b284959642a4e95ac4b0f1ece82c9993658e0e477d40551ec"}, + {file = "grpcio-1.78.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fbe6e89c7ffb48518384068321621b2a69cab509f58e40e4399fdd378fa6d074"}, + {file = "grpcio-1.78.0-cp311-cp311-win32.whl", hash = "sha256:6092beabe1966a3229f599d7088b38dfc8ffa1608b5b5cdda31e591e6500f856"}, + {file = "grpcio-1.78.0-cp311-cp311-win_amd64.whl", hash = "sha256:1afa62af6e23f88629f2b29ec9e52ec7c65a7176c1e0a83292b93c76ca882558"}, + {file = "grpcio-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:f9ab915a267fc47c7e88c387a3a28325b58c898e23d4995f765728f4e3dedb97"}, + {file = "grpcio-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3f8904a8165ab21e07e58bf3e30a73f4dffc7a1e0dbc32d51c61b5360d26f43e"}, + {file = "grpcio-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:859b13906ce098c0b493af92142ad051bf64c7870fa58a123911c88606714996"}, + {file = "grpcio-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b2342d87af32790f934a79c3112641e7b27d63c261b8b4395350dad43eff1dc7"}, + {file = "grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:12a771591ae40bc65ba67048fa52ef4f0e6db8279e595fd349f9dfddeef571f9"}, + {file = "grpcio-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:185dea0d5260cbb2d224c507bf2a5444d5abbb1fa3594c1ed7e4c709d5eb8383"}, + {file = "grpcio-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51b13f9aed9d59ee389ad666b8c2214cc87b5de258fa712f9ab05f922e3896c6"}, + {file = "grpcio-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fd5f135b1bd58ab088930b3c613455796dfa0393626a6972663ccdda5b4ac6ce"}, + {file = "grpcio-1.78.0-cp312-cp312-win32.whl", hash = "sha256:94309f498bcc07e5a7d16089ab984d42ad96af1d94b5a4eb966a266d9fcabf68"}, + {file = "grpcio-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:9566fe4ababbb2610c39190791e5b829869351d14369603702e890ef3ad2d06e"}, + {file = "grpcio-1.78.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:ce3a90455492bf8bfa38e56fbbe1dbd4f872a3d8eeaf7337dc3b1c8aa28c271b"}, + {file = "grpcio-1.78.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:2bf5e2e163b356978b23652c4818ce4759d40f4712ee9ec5a83c4be6f8c23a3a"}, + {file = "grpcio-1.78.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8f2ac84905d12918e4e55a16da17939eb63e433dc11b677267c35568aa63fc84"}, + {file = "grpcio-1.78.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b58f37edab4a3881bc6c9bca52670610e0c9ca14e2ea3cf9debf185b870457fb"}, + {file = "grpcio-1.78.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:735e38e176a88ce41840c21bb49098ab66177c64c82426e24e0082500cc68af5"}, + {file = "grpcio-1.78.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2045397e63a7a0ee7957c25f7dbb36ddc110e0cfb418403d110c0a7a68a844e9"}, + {file = "grpcio-1.78.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9f136fbafe7ccf4ac7e8e0c28b31066e810be52d6e344ef954a3a70234e1702"}, + {file = "grpcio-1.78.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:748b6138585379c737adc08aeffd21222abbda1a86a0dca2a39682feb9196c20"}, + {file = "grpcio-1.78.0-cp313-cp313-win32.whl", hash = "sha256:271c73e6e5676afe4fc52907686670c7cea22ab2310b76a59b678403ed40d670"}, + {file = "grpcio-1.78.0-cp313-cp313-win_amd64.whl", hash = "sha256:f2d4e43ee362adfc05994ed479334d5a451ab7bc3f3fee1b796b8ca66895acb4"}, + {file = "grpcio-1.78.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:e87cbc002b6f440482b3519e36e1313eb5443e9e9e73d6a52d43bd2004fcfd8e"}, + {file = "grpcio-1.78.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:c41bc64626db62e72afec66b0c8a0da76491510015417c127bfc53b2fe6d7f7f"}, + {file = "grpcio-1.78.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8dfffba826efcf366b1e3ccc37e67afe676f290e13a3b48d31a46739f80a8724"}, + {file = "grpcio-1.78.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:74be1268d1439eaaf552c698cdb11cd594f0c49295ae6bb72c34ee31abbe611b"}, + {file = "grpcio-1.78.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be63c88b32e6c0f1429f1398ca5c09bc64b0d80950c8bb7807d7d7fb36fb84c7"}, + {file = "grpcio-1.78.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3c586ac70e855c721bda8f548d38c3ca66ac791dc49b66a8281a1f99db85e452"}, + {file = "grpcio-1.78.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:35eb275bf1751d2ffbd8f57cdbc46058e857cf3971041521b78b7db94bdaf127"}, + {file = "grpcio-1.78.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:207db540302c884b8848036b80db352a832b99dfdf41db1eb554c2c2c7800f65"}, + {file = "grpcio-1.78.0-cp314-cp314-win32.whl", hash = "sha256:57bab6deef2f4f1ca76cc04565df38dc5713ae6c17de690721bdf30cb1e0545c"}, + {file = "grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb"}, + {file = "grpcio-1.78.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:86f85dd7c947baa707078a236288a289044836d4b640962018ceb9cd1f899af5"}, + {file = "grpcio-1.78.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:de8cb00d1483a412a06394b8303feec5dcb3b55f81d83aa216dbb6a0b86a94f5"}, + {file = "grpcio-1.78.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e888474dee2f59ff68130f8a397792d8cb8e17e6b3434339657ba4ee90845a8c"}, + {file = "grpcio-1.78.0-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:86ce2371bfd7f212cf60d8517e5e854475c2c43ce14aa910e136ace72c6db6c1"}, + {file = "grpcio-1.78.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b0c689c02947d636bc7fab3e30cc3a3445cca99c834dfb77cd4a6cabfc1c5597"}, + {file = "grpcio-1.78.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ce7599575eeb25c0f4dc1be59cada6219f3b56176f799627f44088b21381a28a"}, + {file = "grpcio-1.78.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:684083fd383e9dc04c794adb838d4faea08b291ce81f64ecd08e4577c7398adf"}, + {file = "grpcio-1.78.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ab399ef5e3cd2a721b1038a0f3021001f19c5ab279f145e1146bb0b9f1b2b12c"}, + {file = "grpcio-1.78.0-cp39-cp39-win32.whl", hash = "sha256:f3d6379493e18ad4d39537a82371c5281e153e963cecb13f953ebac155756525"}, + {file = "grpcio-1.78.0-cp39-cp39-win_amd64.whl", hash = "sha256:5361a0630a7fdb58a6a97638ab70e1dae2893c4d08d7aba64ded28bb9e7a29df"}, + {file = "grpcio-1.78.0.tar.gz", hash = "sha256:7382b95189546f375c174f53a5fa873cef91c4b8005faa05cc5b3beea9c4f1c5"}, +] + +[package.dependencies] +typing-extensions = ">=4.12,<5.0" + +[package.extras] +protobuf = ["grpcio-tools (>=1.78.0)"] + [[package]] name = "h11" version = "0.16.0" @@ -1750,6 +1851,22 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7.0,<5.0.0" uuid-utils = ">=0.12.0,<1.0" +[[package]] +name = "langchain-milvus" +version = "0.3.3" +description = "An integration package connecting Milvus and LangChain" +optional = false +python-versions = "<4.0,>=3.10" +groups = ["main"] +files = [ + {file = "langchain_milvus-0.3.3-py3-none-any.whl", hash = "sha256:6e12f15453372dd48836978faa4a149de79c721df3322229ad732a5e628e8e97"}, + {file = "langchain_milvus-0.3.3.tar.gz", hash = "sha256:406c2d88da133741f5cc3e2fea4b36386182b35500205c70d003382ded210e41"}, +] + +[package.dependencies] +langchain-core = ">=1.0.0" +pymilvus = ">=2.6.0,<3.0" + [[package]] name = "langchain-openai" version = "1.1.7" @@ -2196,6 +2313,22 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "msoffcrypto-tool" +version = "6.0.0" +description = "Python tool and library for decrypting and encrypting MS Office files using a password or other keys" +optional = false +python-versions = "<4.0,>=3.10" +groups = ["main"] +files = [ + {file = "msoffcrypto_tool-6.0.0-py3-none-any.whl", hash = "sha256:46c394ed5d9641e802fc79bf3fb0666a53748b23fa8c4aa634ae9d30d46fe397"}, + {file = "msoffcrypto_tool-6.0.0.tar.gz", hash = "sha256:9a5ebc4c0096b42e5d7ebc2350afdc92dc511061e935ca188468094fdd032bbe"}, +] + +[package.dependencies] +cryptography = ">=39.0" +olefile = ">=0.46" + [[package]] name = "multidict" version = "6.7.0" @@ -2578,6 +2711,21 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] realtime = ["websockets (>=13,<16)"] voice-helpers = ["numpy (>=2.0.2)", "sounddevice (>=0.5.1)"] +[[package]] +name = "openpyxl" +version = "3.1.5" +description = "A Python library to read/write Excel 2010 xlsx/xlsm files" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"}, + {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, +] + +[package.dependencies] +et-xmlfile = "*" + [[package]] name = "orjson" version = "3.11.5" @@ -2841,6 +2989,115 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pillow" +version = "12.1.0" +description = "Python Imaging Library (fork)" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"}, + {file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"}, + {file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"}, + {file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"}, + {file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"}, + {file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"}, + {file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"}, + {file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"}, + {file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"}, + {file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"}, + {file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"}, + {file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"}, + {file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"}, + {file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"}, + {file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"}, + {file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"}, + {file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"}, + {file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"}, + {file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"}, + {file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"}, + {file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"}, + {file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"}, + {file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"}, + {file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"}, + {file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"}, + {file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"}, + {file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"}, + {file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"}, + {file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"}, + {file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"}, + {file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"}, + {file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"}, + {file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"}, + {file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"}, + {file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"}, + {file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"}, + {file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"}, + {file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"}, + {file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"}, + {file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"}, + {file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"}, + {file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"}, + {file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"}, + {file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"}, + {file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"}, + {file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"}, + {file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"}, + {file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"}, + {file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"}, + {file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"}, + {file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"}, + {file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"}, + {file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"}, + {file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"}, + {file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"}, + {file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=8.2)", "sphinx-autobuild", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"] +fpx = ["olefile"] +mic = ["olefile"] +test-arrow = ["arro3-compute", "arro3-core", "nanoarrow", "pyarrow"] +tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma (>=5)", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "trove-classifiers (>=2024.10.12)"] +xmp = ["defusedxml"] + [[package]] name = "pipmaster" version = "1.1.0" @@ -3010,6 +3267,26 @@ files = [ {file = "propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d"}, ] +[[package]] +name = "protobuf" +version = "6.33.5" +description = "" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "protobuf-6.33.5-cp310-abi3-win32.whl", hash = "sha256:d71b040839446bac0f4d162e758bea99c8251161dae9d0983a3b88dee345153b"}, + {file = "protobuf-6.33.5-cp310-abi3-win_amd64.whl", hash = "sha256:3093804752167bcab3998bec9f1048baae6e29505adaf1afd14a37bddede533c"}, + {file = "protobuf-6.33.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:a5cb85982d95d906df1e2210e58f8e4f1e3cdc088e52c921a041f9c9a0386de5"}, + {file = "protobuf-6.33.5-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:9b71e0281f36f179d00cbcb119cb19dec4d14a81393e5ea220f64b286173e190"}, + {file = "protobuf-6.33.5-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8afa18e1d6d20af15b417e728e9f60f3aa108ee76f23c3b2c07a2c3b546d3afd"}, + {file = "protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:cbf16ba3350fb7b889fca858fb215967792dc125b35c7976ca4818bee3521cf0"}, + {file = "protobuf-6.33.5-cp39-cp39-win32.whl", hash = "sha256:a3157e62729aafb8df6da2c03aa5c0937c7266c626ce11a278b6eb7963c4e37c"}, + {file = "protobuf-6.33.5-cp39-cp39-win_amd64.whl", hash = "sha256:8f04fa32763dcdb4973d537d6b54e615cc61108c7cb38fe59310c3192d29510a"}, + {file = "protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02"}, + {file = "protobuf-6.33.5.tar.gz", hash = "sha256:6ddcac2a081f8b7b9642c09406bc6a4290128fce5f471cddd165960bb9119e5c"}, +] + [[package]] name = "psutil" version = "7.2.1" @@ -3301,6 +3578,33 @@ dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pyte docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] +[[package]] +name = "pymilvus" +version = "2.6.8" +description = "Python Sdk for Milvus" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "pymilvus-2.6.8-py3-none-any.whl", hash = "sha256:c4c413ffdef2599064301fd831de6f9839a753abe27c68c6148707629711d069"}, + {file = "pymilvus-2.6.8.tar.gz", hash = "sha256:15232f5f66805bf2f50b30bbad59637b62f5258d9343f7615353ce1221fab6b5"}, +] + +[package.dependencies] +cachetools = ">=5.0.0" +grpcio = ">=1.66.2,<1.68.0 || >1.68.0,<1.68.1 || >1.68.1,<1.69.0 || >1.69.0,<1.70.0 || >1.70.0,<1.70.1 || >1.70.1,<1.71.0 || >1.71.0,<1.72.1 || >1.72.1,<1.73.0 || >1.73.0" +orjson = ">=3.10.15" +pandas = ">=1.2.4" +protobuf = ">=5.27.2" +python-dotenv = ">=1.0.1,<2.0.0" +setuptools = ">69" + +[package.extras] +bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "requests", "urllib3"] +dev = ["Cython (>=3.0.0)", "azure-storage-blob", "black", "grpcio (==1.66.2)", "grpcio-tools (==1.66.2)", "memray ; sys_platform != \"win32\"", "minio (>=7.0.0)", "ml_dtypes", "py-spy", "pyarrow (>=12.0.0)", "pytest (>=5.3.4)", "pytest-asyncio", "pytest-benchmark[histogram]", "pytest-cov (>=5.0.0)", "pytest-timeout (>=1.3.4)", "requests", "ruff (>=0.12.9,<1)", "scipy", "urllib3"] +milvus-lite = ["milvus-lite (>=2.4.0) ; sys_platform != \"win32\""] +model = ["pymilvus.model (>=0.3.0)"] + [[package]] name = "pymysql" version = "1.1.2" @@ -3319,14 +3623,14 @@ rsa = ["cryptography"] [[package]] name = "pypdf" -version = "6.6.0" +version = "5.9.0" description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files" optional = false -python-versions = ">=3.9" +python-versions = ">=3.8" groups = ["main"] files = [ - {file = "pypdf-6.6.0-py3-none-any.whl", hash = "sha256:bca9091ef6de36c7b1a81e09327c554b7ce51e88dad68f5890c2b4a4417f1fd7"}, - {file = "pypdf-6.6.0.tar.gz", hash = "sha256:4c887ef2ea38d86faded61141995a3c7d068c9d6ae8477be7ae5de8a8e16592f"}, + {file = "pypdf-5.9.0-py3-none-any.whl", hash = "sha256:be10a4c54202f46d9daceaa8788be07aa8cd5ea8c25c529c50dd509206382c35"}, + {file = "pypdf-5.9.0.tar.gz", hash = "sha256:30f67a614d558e495e1fbb157ba58c1de91ffc1718f5e0dfeb82a029233890a1"}, ] [package.extras] @@ -3386,6 +3690,22 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-docx" +version = "1.2.0" +description = "Create, read, and update Microsoft Word .docx files." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "python_docx-1.2.0-py3-none-any.whl", hash = "sha256:3fd478f3250fbbbfd3b94fe1e985955737c145627498896a8a6bf81f4baf66c7"}, + {file = "python_docx-1.2.0.tar.gz", hash = "sha256:7bc9d7b7d8a69c9c02ca09216118c86552704edc23bac179283f2e38f86220ce"}, +] + +[package.dependencies] +lxml = ">=3.1.0" +typing_extensions = ">=4.9.0" + [[package]] name = "python-dotenv" version = "1.2.1" @@ -3457,6 +3777,24 @@ click = "*" olefile = "*" typing_extensions = ">=4.9.0" +[[package]] +name = "python-pptx" +version = "1.0.2" +description = "Create, read, and update PowerPoint 2007+ (.pptx) files." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba"}, + {file = "python_pptx-1.0.2.tar.gz", hash = "sha256:479a8af0eaf0f0d76b6f00b0887732874ad2e3188230315290cd1f9dd9cc7095"}, +] + +[package.dependencies] +lxml = ">=3.1.0" +Pillow = ">=3.3.2" +typing-extensions = ">=4.9.0" +XlsxWriter = ">=0.5.7" + [[package]] name = "pytz" version = "2025.2" @@ -4110,14 +4448,12 @@ optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "sqlalchemy-2.0.45-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c64772786d9eee72d4d3784c28f0a636af5b0a29f3fe26ff11f55efe90c0bd85"}, {file = "sqlalchemy-2.0.45-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7ae64ebf7657395824a19bca98ab10eb9a3ecb026bf09524014f1bb81cb598d4"}, {file = "sqlalchemy-2.0.45-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f02325709d1b1a1489f23a39b318e175a171497374149eae74d612634b234c0"}, {file = "sqlalchemy-2.0.45-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d2c3684fca8a05f0ac1d9a21c1f4a266983a7ea9180efb80ffeb03861ecd01a0"}, {file = "sqlalchemy-2.0.45-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:040f6f0545b3b7da6b9317fc3e922c9a98fc7243b2a1b39f78390fc0942f7826"}, {file = "sqlalchemy-2.0.45-cp310-cp310-win32.whl", hash = "sha256:830d434d609fe7bfa47c425c445a8b37929f140a7a44cdaf77f6d34df3a7296a"}, {file = "sqlalchemy-2.0.45-cp310-cp310-win_amd64.whl", hash = "sha256:0209d9753671b0da74da2cfbb9ecf9c02f72a759e4b018b3ab35f244c91842c7"}, - {file = "sqlalchemy-2.0.45-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2e90a344c644a4fa871eb01809c32096487928bd2038bf10f3e4515cb688cc56"}, {file = "sqlalchemy-2.0.45-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b8c8b41b97fba5f62349aa285654230296829672fc9939cd7f35aab246d1c08b"}, {file = "sqlalchemy-2.0.45-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12c694ed6468333a090d2f60950e4250b928f457e4962389553d6ba5fe9951ac"}, {file = "sqlalchemy-2.0.45-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f7d27a1d977a1cfef38a0e2e1ca86f09c4212666ce34e6ae542f3ed0a33bc606"}, @@ -4146,14 +4482,12 @@ files = [ {file = "sqlalchemy-2.0.45-cp314-cp314-win_amd64.whl", hash = "sha256:4748601c8ea959e37e03d13dcda4a44837afcd1b21338e637f7c935b8da06177"}, {file = "sqlalchemy-2.0.45-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cd337d3526ec5298f67d6a30bbbe4ed7e5e68862f0bf6dd21d289f8d37b7d60b"}, {file = "sqlalchemy-2.0.45-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9a62b446b7d86a3909abbcd1cd3cc550a832f99c2bc37c5b22e1925438b9367b"}, - {file = "sqlalchemy-2.0.45-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5964f832431b7cdfaaa22a660b4c7eb1dfcd6ed41375f67fd3e3440fd95cb3cc"}, {file = "sqlalchemy-2.0.45-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee580ab50e748208754ae8980cec79ec205983d8cf8b3f7c39067f3d9f2c8e22"}, {file = "sqlalchemy-2.0.45-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13e27397a7810163440c6bfed6b3fe46f1bfb2486eb540315a819abd2c004128"}, {file = "sqlalchemy-2.0.45-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ed3635353e55d28e7f4a95c8eda98a5cdc0a0b40b528433fbd41a9ae88f55b3d"}, {file = "sqlalchemy-2.0.45-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:db6834900338fb13a9123307f0c2cbb1f890a8656fcd5e5448ae3ad5bbe8d312"}, {file = "sqlalchemy-2.0.45-cp38-cp38-win32.whl", hash = "sha256:1d8b4a7a8c9b537509d56d5cd10ecdcfbb95912d72480c8861524efecc6a3fff"}, {file = "sqlalchemy-2.0.45-cp38-cp38-win_amd64.whl", hash = "sha256:ebd300afd2b62679203435f596b2601adafe546cb7282d5a0cd3ed99e423720f"}, - {file = "sqlalchemy-2.0.45-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d29b2b99d527dbc66dd87c3c3248a5dd789d974a507f4653c969999fc7c1191b"}, {file = "sqlalchemy-2.0.45-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:59a8b8bd9c6bedf81ad07c8bd5543eedca55fe9b8780b2b628d495ba55f8db1e"}, {file = "sqlalchemy-2.0.45-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fd93c6f5d65f254ceabe97548c709e073d6da9883343adaa51bf1a913ce93f8e"}, {file = "sqlalchemy-2.0.45-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:6d0beadc2535157070c9c17ecf25ecec31e13c229a8f69196d7590bde8082bf1"}, @@ -4552,14 +4886,14 @@ xlsx = ["msoffcrypto-tool", "networkx", "openpyxl", "pandas", "xlrd"] [[package]] name = "unstructured-client" -version = "0.42.8" +version = "0.42.3" description = "Python Client SDK for Unstructured API" optional = false python-versions = ">=3.9.2" groups = ["main"] files = [ - {file = "unstructured_client-0.42.8-py3-none-any.whl", hash = "sha256:6dbdb62d36554a5cbe61dc1b6ef0c8b11a46cc61e2602c2dc22975ba78028214"}, - {file = "unstructured_client-0.42.8.tar.gz", hash = "sha256:663655548ed5c205efb48b7f38ca0906998b33571512f7c53c60aa811e514464"}, + {file = "unstructured_client-0.42.3-py3-none-any.whl", hash = "sha256:14e9a6a44ed58c64bacd32c62d71db19bf9c2f2b46a2401830a8dfff48249d39"}, + {file = "unstructured_client-0.42.3.tar.gz", hash = "sha256:a568d8b281fafdf452647d874060cd0647e33e4a19e811b4db821eb1f3051163"}, ] [package.dependencies] @@ -4568,7 +4902,7 @@ cryptography = ">=3.1" httpcore = ">=1.0.9" httpx = ">=0.27.0" pydantic = ">=2.11.2" -pypdf = ">=6.2.0" +pypdf = ">=4.0" requests-toolbelt = ">=1.0.0" [[package]] @@ -5479,4 +5813,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0.0" -content-hash = "3a84e7e5ec3874279a429a3b609f5d9c32794c8d0e422b2f93e1dbf3f24fc38e" +content-hash = "5818bbf701080e9eda4c920a96404616130ca997ce7999a5c74a4da0d7c8cc7c" diff --git a/runtime/datamate-python/pyproject.toml b/runtime/datamate-python/pyproject.toml index c29a95a74..0dcb086d4 100644 --- a/runtime/datamate-python/pyproject.toml +++ b/runtime/datamate-python/pyproject.toml @@ -22,6 +22,15 @@ dependencies = [ "unstructured (>=0.18.21,<0.19.0)", "markdown (>=3.10,<4.0)", "langchain-community (>=0.4.1,<0.5.0)", + "langchain-core (>=1.1.0,<2.0.0)", + "langchain-text-splitters (>=1.0.0,<2.0.0)", + "pymilvus (>=2.6.0,<3.0.0)", + "langchain-milvus (>=0.3.0,<0.4.0)", + "pypdf (>=5.0.0,<6.0.0)", + "python-docx (>=1.1.0,<2.0.0)", + "openpyxl (>=3.1.0,<4.0.0)", + "python-pptx (>=1.0.0,<2.0.0)", + "beautifulsoup4 (>=4.12.0,<5.0.0)", "jsonschema (>=4.25.1,<5.0.0)", "greenlet (>=3.3.0,<4.0.0)", "docx2txt (>=0.9,<0.10)", @@ -37,6 +46,7 @@ dependencies = [ "lightrag-hku (==1.4.9.8)", "pytest (>=9.0.2,<10.0.0)", "apscheduler (>=3.11.2,<4.0.0)", + "msoffcrypto-tool (>=6.0.0,<7.0.0)", ] From 01349c06e8c872d76e50fc6c841e6766cae7e2bb Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Wed, 25 Feb 2026 10:18:16 +0800 Subject: [PATCH 03/13] feat: implement RAG module with document loading, splitting, and processing capabilities --- .../app/module/rag/schema/response.py | 33 +++++++------- .../rag/service/knowledge_base_service.py | 45 +++++++++++++------ 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/runtime/datamate-python/app/module/rag/schema/response.py b/runtime/datamate-python/app/module/rag/schema/response.py index f946b992b..68f4545e6 100644 --- a/runtime/datamate-python/app/module/rag/schema/response.py +++ b/runtime/datamate-python/app/module/rag/schema/response.py @@ -46,6 +46,8 @@ class KnowledgeBaseResp(BaseModel): chat: Optional[ModelConfig] = Field(None, description="聊天模型配置") created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") + created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") + updated_by: Optional[str] = Field(None, alias="updatedBy", description="更新人") class Config: populate_by_name = True # 允许使用 snake_case 或 camelCase @@ -83,6 +85,8 @@ class RagFileResp(BaseModel): err_msg: Optional[str] = Field(None, alias="errMsg", description="错误信息") created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") + created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") + updated_by: Optional[str] = Field(None, alias="updatedBy", description="更新人") class Config: populate_by_name = True # 允许使用 snake_case 或 camelCase @@ -152,32 +156,31 @@ class PagedResponse(BaseModel): 对应 Java: com.datamate.common.interfaces.PagedResponse """ - items: List[Any] = Field(..., description="数据列表") - total: int = Field(..., description="总记录数") + content: List[Any] = Field(..., description="数据列表") + total_elements: int = Field(alias="totalElements", description="总记录数") page: int = Field(..., description="当前页码") - page_size: int = Field(alias="pageSize", description="每页数量") + size: int = Field(..., description="每页数量") total_pages: int = Field(alias="totalPages", description="总页数") @classmethod - def create(cls, items: List[Any], total: int, page: int, page_size: int): + def create(cls, content: List[Any], total_elements: int, page: int, size: int): """创建分页响应 Args: - items: 数据列表 - total: 总记录数 + content: 数据列表 + total_elements: 总记录数 page: 当前页码 - page_size: 每页数量 + size: 每页数量 Returns: PagedResponse 实例 """ - total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0 - # 使用内部字段名(snake_case) + total_pages = (total_elements + size - 1) // size if size > 0 else 0 return cls( - items=items, - total=total, + content=content, + total_elements=total_elements, page=page, - page_size=page_size, + size=size, total_pages=total_pages ) @@ -185,10 +188,10 @@ class Config: populate_by_name = True # 允许使用 snake_case 或 camelCase json_schema_extra = { "example": { - "items": [], - "total": 100, + "content": [], + "totalElements": 100, "page": 1, - "pageSize": 10, + "size": 10, "totalPages": 10 } } diff --git a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py index efcd8c814..60d936490 100644 --- a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py +++ b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py @@ -26,7 +26,7 @@ RetrieveReq, PagingQuery, ) -from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagChunkResp +from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagChunkResp, RagFileResp from app.module.rag.service.etl_service import ETLService from app.module.rag.service.file_service import FileService @@ -179,7 +179,6 @@ async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: file_count = await self.file_repo.count_by_knowledge_base(knowledge_base_id) chunk_count = await self.file_repo.count_chunks_by_knowledge_base(knowledge_base_id) - # 构建响应 response = KnowledgeBaseResp( id=knowledge_base.id, name=knowledge_base.name, @@ -190,7 +189,9 @@ async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: file_count=file_count, chunk_count=chunk_count, created_at=knowledge_base.created_at, - updated_at=knowledge_base.updated_at + updated_at=knowledge_base.updated_at, + created_by=knowledge_base.created_by, + updated_by=knowledge_base.updated_by ) return response @@ -229,15 +230,17 @@ async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: file_count=file_count, chunk_count=chunk_count, created_at=item.created_at, - updated_at=item.updated_at + updated_at=item.updated_at, + created_by=item.created_by, + updated_by=item.updated_by ) responses.append(response) return PagedResponse.create( - items=responses, - total=total, + content=responses, + total_elements=total, page=request.page, - page_size=request.page_size + size=request.page_size ) async def add_files(self, request: AddFilesReq) -> dict: @@ -306,11 +309,27 @@ async def list_files( page_size=request.page_size ) + # 转换为响应对象 + responses = [RagFileResp( + id=item.id, + knowledge_base_id=item.knowledge_base_id, + file_name=item.file_name, + file_id=item.file_id, + chunk_count=item.chunk_count, + metadata=item.file_metadata, + status=item.status, + err_msg=item.err_msg, + created_at=item.created_at, + updated_at=item.updated_at, + created_by=item.created_by, + updated_by=item.updated_by + ) for item in items] + return PagedResponse.create( - items=items, - total=total, + content=responses, + total_elements=total, page=request.page, - page_size=request.page_size + size=request.page_size ) async def delete_files(self, knowledge_base_id: str, request: DeleteFilesReq) -> None: @@ -527,10 +546,10 @@ async def get_chunks( ) return PagedResponse.create( - items=chunks, - total=total, + content=chunks, + total_elements=total, page=paging_query.page, - page_size=paging_query.size + size=paging_query.size ) except Exception as e: From 77de5510247842c3d8aa7a0896673cd74943dc88 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Thu, 26 Feb 2026 11:32:35 +0800 Subject: [PATCH 04/13] feat: update Milvus configuration and enhance file processing logic in RAG service --- runtime/datamate-python/app/core/config.py | 2 +- .../app/db/models/knowledge_gen.py | 6 +- .../app/module/rag/infra/milvus/factory.py | 28 ++- .../module/rag/infra/milvus/vectorstore.py | 195 +++++++---------- .../app/module/rag/infra/pipeline.py | 1 - .../app/module/rag/schema/request.py | 58 +++-- .../app/module/rag/service/etl_service.py | 206 ++++++++++-------- .../app/module/rag/service/file_service.py | 31 +-- .../app/module/rag/service/rag_service.py | 18 ++ 9 files changed, 289 insertions(+), 256 deletions(-) diff --git a/runtime/datamate-python/app/core/config.py b/runtime/datamate-python/app/core/config.py index 580cb97d1..49dd3320a 100644 --- a/runtime/datamate-python/app/core/config.py +++ b/runtime/datamate-python/app/core/config.py @@ -78,7 +78,7 @@ def build_database_url(self): datamate_jwt_enable: bool = False # Milvus 配置 - milvus_uri: str = "http://milvus-standalone:19530" + milvus_uri: str = "http://localhost:19530" milvus_token: str = "" # 文件存储配置(共享文件系统) diff --git a/runtime/datamate-python/app/db/models/knowledge_gen.py b/runtime/datamate-python/app/db/models/knowledge_gen.py index 48806bde2..c0bf37d34 100644 --- a/runtime/datamate-python/app/db/models/knowledge_gen.py +++ b/runtime/datamate-python/app/db/models/knowledge_gen.py @@ -5,7 +5,7 @@ 与 Java 实体保持一致。 """ from enum import Enum -from sqlalchemy import Column, String, Integer, JSON, Enum as SQLEnum +from sqlalchemy import Column, String, Integer, JSON from app.db.models.base_entity import BaseEntity @@ -42,7 +42,7 @@ class KnowledgeBase(BaseEntity): name = Column(String(255), nullable=False, unique=True, comment="知识库名称") description = Column(String(512), nullable=True, comment="知识库描述") type = Column( - SQLEnum(RagType), + String(50), nullable=False, default=RagType.DOCUMENT, comment="RAG类型", @@ -70,7 +70,7 @@ class RagFile(BaseEntity): chunk_count = Column(Integer, nullable=True, comment="分块数量") file_metadata = Column("metadata", JSON, nullable=True, comment="元数据(JSON格式)") status = Column( - SQLEnum(FileStatus), + String(50), nullable=False, default=FileStatus.UNPROCESSED, comment="处理状态", diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/factory.py b/runtime/datamate-python/app/module/rag/infra/milvus/factory.py index a74fa93e1..2cef4ff7a 100644 --- a/runtime/datamate-python/app/module/rag/infra/milvus/factory.py +++ b/runtime/datamate-python/app/module/rag/infra/milvus/factory.py @@ -28,7 +28,7 @@ def create( collection_name: str, embedding: Embeddings, *, - drop_old: bool = False, + drop_old: bool = True, consistency_level: str = "Strong", ) -> Any: """ @@ -44,13 +44,35 @@ def create( langchain_milvus.Milvus 实例 """ from langchain_milvus import BM25BuiltInFunction, Milvus + from app.module.rag.infra.milvus.vectorstore import ( + drop_collection, + create_java_compatible_collection, + ) + + # 获取向量维度 + test_text = "test" + dimension = len(embedding.embed_query(test_text)) + + # 删除旧集合(如果存在) + if drop_old: + drop_collection(collection_name) + + # 创建与 Java 兼容的 schema(只有5个字段:id、text、metadata、vector、sparse) + create_java_compatible_collection( + collection_name=collection_name, + dimension=dimension, + consistency_level=consistency_level + ) + # 创建 Milvus 实例(不自动创建集合,使用已有的 schema) return Milvus( embedding_function=embedding, collection_name=collection_name, connection_args=VectorStoreFactory.get_connection_args(), builtin_function=BM25BuiltInFunction(), - vector_field=["dense", "sparse"], + text_field="text", + vector_field=["vector"], + drop_old=False, consistency_level=consistency_level, - drop_old=drop_old, + auto_id=False ) diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py b/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py index 28c7ae325..1e4b4dd5f 100644 --- a/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py +++ b/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py @@ -9,10 +9,9 @@ from __future__ import annotations import logging -from typing import Any, List, Optional +from typing import List, Optional from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings from app.core.config import settings from app.module.rag.infra.embeddings import EmbeddingFactory @@ -86,81 +85,81 @@ def create_java_compatible_collection( dimension: 向量维度 consistency_level: 一致性级别 """ - from pymilvus import MilvusClient, DataType, FunctionType + from pymilvus import MilvusClient, DataType, FunctionType, CollectionSchema, FieldSchema, Function from app.core.exception import BusinessError, ErrorCodes try: conn_args = _connection_args() - token = conn_args.get("token") if conn_args.get("token") else "" - client = MilvusClient(uri=conn_args["uri"], token=token) + + # 创建 Milvus 客户端 + client = MilvusClient(uri=conn_args["uri"], token="") # 检查集合是否已存在 if client.has_collection(collection_name): logger.info("集合 %s 已存在,跳过创建", collection_name) return - # 定义 schema - schema = MilvusClient.create_schema() - - # 1. 主键字段 id - schema.add_field( - field_name="id", - datatype=DataType.VARCHAR, - max_length=36, - is_primary=True, - auto_id=False - ) - - # 2. 文本字段 text(启用 analyzer 用于 BM25) - schema.add_field( - field_name="text", - datatype=DataType.VARCHAR, - max_length=65535, - enable_analyzer=True, - enable_match=True - ) - - # 3. 元数据字段 metadata - schema.add_field( - field_name="metadata", - datatype=DataType.JSON + # 创建字段 + fields = [ + # 1. 主键字段 id + FieldSchema( + name="id", + dtype=DataType.VARCHAR, + max_length=36, + is_primary=True, + auto_id=False + ), + # 2. 文本字段 text(启用 analyzer 用于 BM25) + FieldSchema( + name="text", + dtype=DataType.VARCHAR, + max_length=65535, + enable_analyzer=True, + enable_match=True + ), + # 3. 元数据字段 metadata + FieldSchema( + name="metadata", + dtype=DataType.JSON + ), + # 4. 密集向量字段 vector + FieldSchema( + name="vector", + dtype=DataType.FLOAT_VECTOR, + dim=dimension + ), + # 5. 稀疏向量字段 sparse(BM25) + FieldSchema( + name="sparse", + dtype=DataType.SPARSE_FLOAT_VECTOR + ) + ] + + # 创建 BM25 函数(不使用 params,避免 Milvus 参数错误) + function = Function( + name="text_bm25_emb", + function_type=FunctionType.BM25, + input_field_names=["text"], + output_field_names=["sparse"] ) - # 4. 密集向量字段 vector - schema.add_field( - field_name="vector", - datatype=DataType.FLOAT_VECTOR, - dim=dimension + # 创建 schema + schema = CollectionSchema( + fields=fields, + functions=[function], + description="Knowledge base collection", + enable_dynamic_field=True ) - # 5. 稀疏向量字段 sparse(BM25) - schema.add_field( - field_name="sparse", - datatype=DataType.SPARSE_FLOAT_VECTOR - ) - - # 创建集合(BM25 将在首次添加文档时自动配置) + # 创建集合(不包含索引) + # 索引会在首次插入数据时由 Milvus/LangChain 自动创建 client.create_collection( collection_name=collection_name, schema=schema, consistency_level=consistency_level ) - # 创建向量索引 - client.create_index( - collection_name=collection_name, - field_name="vector", - index_params={ - "index_type": "HNSW", - "metric_type": "COSINE", - "params": { - "M": 16, - "efConstruction": 256 - } - } - ) - logger.info("成功创建 Java 兼容的集合: %s (维度: %d)", collection_name, dimension) except Exception as e: @@ -182,12 +181,10 @@ def get_vector_dimension(embedding_model: str, base_url: Optional[str] = None, a Raises: BusinessError: 无法获取维度 """ - from langchain_core.embeddings import Embeddings from app.core.exception import BusinessError, ErrorCodes try: - import asyncio embedding = EmbeddingFactory.create_embeddings( model_name=embedding_model, base_url=base_url, @@ -195,7 +192,8 @@ def get_vector_dimension(embedding_model: str, base_url: Optional[str] = None, a ) test_text = "test" - embedding_vector = asyncio.run(asyncio.to_thread(embedding.embed_query, test_text)) + # 直接调用同步的 embed_query 方法 + embedding_vector = embedding.embed_query(test_text) dimension = len(embedding_vector) logger.info("获取模型 %s 的向量维度: %d", embedding_model, dimension) @@ -208,6 +206,7 @@ def get_vector_dimension(embedding_model: str, base_url: Optional[str] = None, a def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) -> None: """按 RAG 文件 ID 列表删除 Milvus 中的分块。用于文件删除时清理向量数据。""" + import json if not rag_file_ids: return from pymilvus import MilvusClient @@ -216,10 +215,13 @@ def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) try: conn_args = _connection_args() - client = MilvusClient(uri=conn_args["uri"], token=conn_args.get("token") or None) + client = MilvusClient(uri=conn_args["uri"], token="") + # metadata 为 JSON 字段,按 rag_file_id 过滤 + # 使用 JSON_CONTAINS 的正确语法 for rid in rag_file_ids: - filter_expr = f'metadata["rag_file_id"] == "{rid}"' + json_value = json.dumps({"rag_file_id": rid}) + filter_expr = f'JSON_CONTAINS(metadata, \'{json_value}\')' try: client.delete(collection_name=collection_name, filter=filter_expr) except Exception as del_err: @@ -230,66 +232,25 @@ def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除分块失败: {str(e)}") from e -def get_milvus_vectorstore( - collection_name: str, - embedding: Embeddings, - *, - drop_old: bool = False, - consistency_level: str = "Strong", -) -> Any: - """创建带全文检索(BM25)的 Milvus 向量存储实例. - - 使用 langchain-milvus.Milvus + BM25BuiltInFunction,支持混合检索。 +def chunks_to_langchain_documents( + chunks: list, + ids: List[str] = None +) -> tuple[list, List[str]]: + """将 DocumentChunk 转换为 LangChain Document 格式 Args: - collection_name: 集合名称(通常为知识库名称) - embedding: LangChain Embeddings 实例 - drop_old: 是否在创建时删除已存在同名集合 - consistency_level: 一致性级别 + chunks: DocumentChunk 列表 + ids: 可选的 ID 列表 Returns: - Milvus 向量存储实例,支持 add_documents / similarity_search / as_retriever 等 + (documents, ids): LangChain Document 列表和对应的 ID 列表 """ - from langchain_milvus import BM25BuiltInFunction, Milvus - - return Milvus( - embedding_function=embedding, - collection_name=collection_name, - connection_args=_connection_args(), - builtin_function=BM25BuiltInFunction(), - vector_field=["dense", "sparse"], - consistency_level=consistency_level, - drop_old=drop_old, - ) + if ids is None: + ids = [str(i) for i in range(len(chunks))] + documents = [] + for chunk, chunk_id in zip(chunks, ids): + doc = Document(page_content=chunk.text, metadata=chunk.metadata) + documents.append(doc) -def chunks_to_langchain_documents( - chunks: List[Any], - *, - ids: Optional[List[str]] = None, - id_key: str = "chunk_id", -) -> tuple[List[Document], List[str]]: - """将领域 DocumentChunk 列表转为 LangChain Document 列表及 id 列表. - - Args: - chunks: 分块列表,每项有 .text 与 .metadata - ids: 若提供则作为文档 id,否则从 metadata[id_key] 取或生成 - id_key: metadata 中作为 id 的键名 - - Returns: - (documents, ids) - """ - from uuid import uuid4 - - documents: List[Document] = [] - out_ids: List[str] = [] - for i, ch in enumerate(chunks): - text = getattr(ch, "text", str(ch)) - meta = getattr(ch, "metadata", {}) or {} - if ids and i < len(ids): - doc_id = ids[i] - else: - doc_id = meta.get(id_key) or str(uuid4()) - documents.append(Document(page_content=text, metadata=dict(meta))) - out_ids.append(doc_id) - return documents, out_ids + return documents, ids diff --git a/runtime/datamate-python/app/module/rag/infra/pipeline.py b/runtime/datamate-python/app/module/rag/infra/pipeline.py index 7b41c91fe..213923860 100644 --- a/runtime/datamate-python/app/module/rag/infra/pipeline.py +++ b/runtime/datamate-python/app/module/rag/infra/pipeline.py @@ -73,7 +73,6 @@ async def load_and_split( ) chunks = await splitter.split( parsed.text, - file_name=parsed.file_name, **base_chunk_metadata, ) diff --git a/runtime/datamate-python/app/module/rag/schema/request.py b/runtime/datamate-python/app/module/rag/schema/request.py index 1fff90f68..262fb785c 100644 --- a/runtime/datamate-python/app/module/rag/schema/request.py +++ b/runtime/datamate-python/app/module/rag/schema/request.py @@ -34,21 +34,24 @@ class KnowledgeBaseCreateReq(BaseModel): embedding_model: str = Field( ..., min_length=1, + alias="embeddingModel", description="嵌入模型ID" ) chat_model: Optional[str] = Field( None, + alias="chatModel", description="聊天模型ID" ) class Config: + populate_by_name = True json_schema_extra = { "example": { "name": "my_knowledge_base", "description": "我的知识库", "type": "DOCUMENT", - "embedding_model": "text-embedding-ada-002", - "chat_model": "gpt-4" + "embeddingModel": "text-embedding-ada-002", + "chatModel": "gpt-4" } } @@ -72,6 +75,7 @@ class KnowledgeBaseUpdateReq(BaseModel): ) class Config: + populate_by_name = True json_schema_extra = { "example": { "name": "updated_knowledge_base", @@ -94,11 +98,13 @@ class KnowledgeBaseQueryReq(BaseModel): default=10, ge=1, le=100, + alias="size", description="每页数量" ) keyword: Optional[str] = Field( None, max_length=255, + alias="name", description="搜索关键词(模糊匹配知识库名称或描述)" ) type: Optional[RagType] = Field( @@ -107,11 +113,12 @@ class KnowledgeBaseQueryReq(BaseModel): ) class Config: + populate_by_name = True json_schema_extra = { "example": { "page": 1, - "page_size": 10, - "keyword": "测试", + "size": 10, + "name": "测试", "type": "DOCUMENT" } } @@ -123,15 +130,14 @@ class FileInfo(BaseModel): 对应 Java: com.datamate.rag.indexer.interfaces.dto.AddFilesReq.FileInfo """ id: str = Field(..., description="文件ID (对应 t_dm_dataset_files.id)") - dataset_id: str = Field(..., description="数据集ID") - file_name: str = Field(..., description="文件名") + file_name: str = Field(alias="fileName", description="文件名") class Config: + populate_by_name = True json_schema_extra = { "example": { "id": "file-uuid-123", - "dataset_id": "dataset-uuid-456", - "file_name": "document.pdf" + "fileName": "document.pdf" } } @@ -141,21 +147,24 @@ class AddFilesReq(BaseModel): 对应 Java: com.datamate.rag.indexer.interfaces.dto.AddFilesReq """ - knowledge_base_id: str = Field(..., description="知识库ID(从路径参数获取,这里保留用于兼容)") + knowledge_base_id: Optional[str] = Field(None, alias="knowledgeBaseId", description="知识库ID(从路径参数获取,这里保留用于兼容)") process_type: ProcessType = Field( default=ProcessType.DEFAULT_CHUNK, + alias="processType", description="分块处理类型" ) chunk_size: int = Field( default=500, ge=50, le=2000, + alias="chunkSize", description="分块大小" ) overlap_size: int = Field( default=50, ge=0, le=500, + alias="overlapSize", description="重叠大小" ) delimiter: Optional[str] = Field( @@ -177,15 +186,16 @@ def validate_delimiter(cls, v, info): return v class Config: + populate_by_name = True json_schema_extra = { "example": { - "knowledge_base_id": "kb-uuid-123", - "process_type": "DEFAULT_CHUNK", - "chunk_size": 500, - "overlap_size": 50, + "knowledgeBaseId": "kb-uuid-123", + "processType": "DEFAULT_CHUNK", + "chunkSize": 500, + "overlapSize": 50, "files": [ - {"id": "file-1", "dataset_id": "dataset-uuid-456", "file_name": "doc1.pdf"}, - {"id": "file-2", "dataset_id": "dataset-uuid-456", "file_name": "doc2.pdf"} + {"id": "file-1", "fileName": "doc1.pdf"}, + {"id": "file-2", "fileName": "doc2.pdf"} ] } } @@ -199,13 +209,15 @@ class DeleteFilesReq(BaseModel): file_ids: List[str] = Field( ..., min_length=1, + alias="ids", description="要删除的文件ID列表" ) class Config: + populate_by_name = True json_schema_extra = { "example": { - "file_ids": ["file-1", "file-2", "file-3"] + "ids": ["file-1", "file-2", "file-3"] } } @@ -224,11 +236,13 @@ class RagFileReq(BaseModel): default=10, ge=1, le=100, + alias="size", description="每页数量" ) keyword: Optional[str] = Field( None, max_length=255, + alias="fileName", description="搜索关键词(模糊匹配文件名)" ) status: Optional[FileStatus] = Field( @@ -237,11 +251,12 @@ class RagFileReq(BaseModel): ) class Config: + populate_by_name = True json_schema_extra = { "example": { "page": 1, - "page_size": 10, - "keyword": "测试", + "size": 10, + "fileName": "测试", "status": "PROCESSED" } } @@ -261,6 +276,7 @@ class RetrieveReq(BaseModel): default=5, ge=1, le=20, + alias="topK", description="返回前 K 个结果" ) threshold: Optional[float] = Field( @@ -272,16 +288,18 @@ class RetrieveReq(BaseModel): knowledge_base_ids: List[str] = Field( ..., min_length=1, + alias="knowledgeBaseIds", description="要检索的知识库ID列表" ) class Config: + populate_by_name = True json_schema_extra = { "example": { "query": "什么是机器学习?", - "top_k": 5, + "topK": 5, "threshold": 0.7, - "knowledge_base_ids": ["kb-1", "kb-2"] + "knowledgeBaseIds": ["kb-1", "kb-2"] } } diff --git a/runtime/datamate-python/app/module/rag/service/etl_service.py b/runtime/datamate-python/app/module/rag/service/etl_service.py index bd29e347f..18703dbf2 100644 --- a/runtime/datamate-python/app/module/rag/service/etl_service.py +++ b/runtime/datamate-python/app/module/rag/service/etl_service.py @@ -23,8 +23,10 @@ from app.module.rag.infra.task.worker_pool import WorkerPool from app.core.config import settings from app.core.exception import BusinessError, ErrorCodes +from app.db.session import AsyncSessionLocal import logging +import asyncio logger = logging.getLogger(__name__) @@ -46,97 +48,104 @@ class ETLService: 5. 更新文件状态 """ - def __init__(self, db: AsyncSession): + def __init__(self, db: AsyncSession = None): """初始化服务 Args: - db: 数据库异步 session + db: 数据库异步 session(可选,后台任务时会创建新的) """ self.db = db - self.file_repo = RagFileRepository(db) - self.kb_repo = KnowledgeBaseRepository(db) self.worker_pool = WorkerPool(max_workers=10) - async def process_files( + async def process_files_background( self, - knowledge_base: KnowledgeBase, - request: AddFilesReq + knowledge_base_id: str, + knowledge_base_name: str, + request_data: dict ) -> None: - """处理文件的入口方法(在事务提交后调用) + """后台处理文件的入口方法(使用新的数据库 session) - 对应 Java 的 @TransactionalEventListener(phase = AFTER_COMMIT) + 对应 Java 的 @TransactionalEventListener(phase = AFTER_COMMIT) + @Async Args: - knowledge_base: 知识库实体 - request: 添加文件请求 + knowledge_base_id: 知识库 ID + knowledge_base_name: 知识库名称 + request_data: 添加文件请求数据(dict 格式) """ - # 获取待处理的文件 - files = await self.file_repo.get_unprocessed_files(knowledge_base.id) + # 创建新的数据库 session + async with AsyncSessionLocal() as db: + try: + file_repo = RagFileRepository(db) + kb_repo = KnowledgeBaseRepository(db) - if not files: - logger.info(f"知识库 {knowledge_base.name} 没有待处理的文件") - return + # 获取知识库实体 + knowledge_base = await kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + logger.error(f"知识库不存在: {knowledge_base_id}") + return - logger.info(f"开始处理 {len(files)} 个文件,知识库: {knowledge_base.name}") + # 重建请求对象 + request = AddFilesReq.model_validate(request_data) + + # 获取待处理的文件 + files = await file_repo.get_unprocessed_files(knowledge_base_id) - # 并发处理所有文件(信号量控制并发数) - import asyncio - tasks = [ - self.worker_pool.submit( - self._process_single_file, - file, knowledge_base, request - ) - for file in files - ] + if not files: + logger.info(f"知识库 {knowledge_base_name} 没有待处理的文件") + return - results = await asyncio.gather(*tasks, return_exceptions=True) + logger.info(f"开始处理 {len(files)} 个文件,知识库: {knowledge_base_name}") - # 统计处理结果 - success_count = sum(1 for r in results if not isinstance(r, Exception)) - failed_count = len(results) - success_count + # 顺序处理文件(避免并发问题) + for file in files: + try: + await self._process_single_file_with_session( + db, file, knowledge_base, request + ) + except Exception as e: + logger.exception(f"文件 {file.file_name} 处理失败: {e}") + # 继续处理下一个文件 - logger.info( - f"文件处理完成,成功: {success_count}, 失败: {failed_count}" - ) + logger.info(f"知识库 {knowledge_base_name} 文件处理完成") - async def _process_single_file( + except Exception as e: + logger.exception(f"后台处理文件失败: {e}") + finally: + await db.close() + + async def _process_single_file_with_session( self, + db: AsyncSession, rag_file: RagFile, knowledge_base: KnowledgeBase, request: AddFilesReq ) -> None: - """处理单个文件的 ETL 流程 - - 步骤: - 1. 解析文档(从共享文件系统读取) - 2. 分块 - 3. 生成嵌入向量 - 4. 存储到 Milvus - 5. 更新文件状态 + """处理单个文件的 ETL 流程(使用提供的 session) Args: + db: 数据库 session rag_file: RAG 文件实体 knowledge_base: 知识库实体 request: 添加文件请求 """ + file_repo = RagFileRepository(db) + try: # 1. 更新状态为处理中 - await self.file_repo.update_status(rag_file.id, FileStatus.PROCESSING) + await file_repo.update_status(rag_file.id, FileStatus.PROCESSING) + await db.commit() # 2. 从 metadata 中获取文件路径和原始文件ID file_path = rag_file.file_metadata.get("file_path") if rag_file.file_metadata else None - original_file_id = rag_file.file_id # t_dm_dataset_files.id + original_file_id = rag_file.file_id dataset_id = rag_file.file_metadata.get("dataset_id") if rag_file.file_metadata else None # 2.1 验证文件路径 if not file_path: error_msg = f"文件路径未设置,file_metadata={rag_file.file_metadata}" logger.error(error_msg) - await self.file_repo.update_status( - rag_file.id, - FileStatus.PROCESS_FAILED, - err_msg=error_msg - ) + await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) + await db.commit() return # 2.2 确保使用绝对路径 @@ -147,14 +156,11 @@ async def _process_single_file( if not Path(file_path).exists(): error_msg = f"文件不存在: {file_path}" logger.error(error_msg) - await self.file_repo.update_status( - rag_file.id, - FileStatus.PROCESS_FAILED, - err_msg=error_msg - ) + await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) + await db.commit() return - # 3. 准备完整的 metadata + # 3. 准备完整的 metadata(不包含 file_path,避免与函数参数冲突) file_extension = Path(file_path).suffix base_metadata = { "rag_file_id": rag_file.id, @@ -163,10 +169,10 @@ async def _process_single_file( "file_name": rag_file.file_name, "file_extension": file_extension, "knowledge_base_id": knowledge_base.id, - "file_path": file_path, + # file_path 不包含在此处,因为它作为位置参数传递 } - # 4. 加载并分块(复用 ingest pipeline),传递完整的 metadata + # 4. 加载并分块 try: chunks = await ingest_file_to_chunks( file_path, @@ -179,46 +185,35 @@ async def _process_single_file( except Exception as e: error_msg = f"文档解析或分块失败: {str(e)}" logger.exception(f"文件 {rag_file.file_name} 解析失败: {e}") - await self.file_repo.update_status( - rag_file.id, - FileStatus.PROCESS_FAILED, - err_msg=error_msg - ) + await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) + await db.commit() return if not chunks: logger.warning(f"文件 {rag_file.file_name} 未生成任何分块") - await self.file_repo.update_status( - rag_file.id, - FileStatus.PROCESS_FAILED, - err_msg="文档解析后未生成任何分块" - ) + await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg="文档解析后未生成任何分块") + await db.commit() return logger.info(f"文件 {rag_file.file_name} 分块完成,共 {len(chunks)} 个分块") - # 5. 写入 LangChain Milvus 向量存储(自动嵌入 + BM25 全文检索) + # 5. 写入 Milvus 向量存储 try: - embedding_entity = await get_model_by_id(self.db, knowledge_base.embedding_model) + embedding_entity = await get_model_by_id(db, knowledge_base.embedding_model) if not embedding_entity: raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") - # 5.1 获取向量维度并创建 Java 兼容的集合 + # 5. 获取向量维度并创建集合 try: dimension = get_vector_dimension( embedding_model=embedding_entity.model_name, base_url=getattr(embedding_entity, "base_url", None), api_key=getattr(embedding_entity, "api_key", None), ) - create_java_compatible_collection( - collection_name=knowledge_base.name, - dimension=dimension - ) + # 集合将由 VectorStoreFactory.create() 自动创建(如果已存在则删除) except BusinessError as e: - logger.warning("创建或检查集合失败: %s", e) - # 如果集合已存在,继续处理 - if "已存在" not in str(e): - raise + logger.warning("获取向量维度失败: %s", e) + raise embedding = EmbeddingFactory.create_embeddings( model_name=embedding_entity.model_name, @@ -230,7 +225,6 @@ async def _process_single_file( embedding=embedding, ) for c in chunks: - # 确保 metadata 包含所有必需字段 for key, value in base_metadata.items(): if key not in c.metadata: c.metadata[key] = value @@ -241,24 +235,52 @@ async def _process_single_file( except Exception as e: error_msg = f"向量化或存储到 Milvus 失败: {str(e)}" logger.exception(f"文件 {rag_file.file_name} 向量化失败: {e}") - await self.file_repo.update_status( - rag_file.id, - FileStatus.PROCESS_FAILED, - err_msg=error_msg - ) + await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) + await db.commit() return # 6. 更新文件状态为成功 - await self.file_repo.update_chunk_count(rag_file.id, len(chunks)) - await self.file_repo.update_status(rag_file.id, FileStatus.PROCESSED) + await file_repo.update_chunk_count(rag_file.id, len(chunks)) + await file_repo.update_status(rag_file.id, FileStatus.PROCESSED) + await db.commit() logger.info(f"文件 {rag_file.file_name} ETL 处理完成") except Exception as e: logger.exception(f"文件 {rag_file.file_name} 处理失败: {e}") - await self.file_repo.update_status( - rag_file.id, - FileStatus.PROCESS_FAILED, - err_msg=str(e) - ) - # 不抛出异常,避免影响其他文件的处理 + await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=str(e)) + await db.commit() + + async def process_files( + self, + knowledge_base: KnowledgeBase, + request: AddFilesReq + ) -> None: + """处理文件的入口方法(在事务提交后调用)- 已废弃,使用 process_files_background + + 对应 Java 的 @TransactionalEventListener(phase = AFTER_COMMIT) + + Args: + knowledge_base: 知识库实体 + request: 添加文件请求 + """ + logger.warning("process_files is deprecated, use process_files_background instead") + # 这个方法保留用于兼容,但不推荐使用 + if not self.db: + logger.error("No database session available") + return + + file_repo = RagFileRepository(self.db) + files = await file_repo.get_unprocessed_files(knowledge_base.id) + + if not files: + logger.info(f"知识库 {knowledge_base.name} 没有待处理的文件") + return + + logger.info(f"开始处理 {len(files)} 个文件,知识库: {knowledge_base.name}") + + for file in files: + try: + await self._process_single_file_with_session(self.db, file, knowledge_base, request) + except Exception as e: + logger.exception(f"文件 {file.file_name} 处理失败: {e}") diff --git a/runtime/datamate-python/app/module/rag/service/file_service.py b/runtime/datamate-python/app/module/rag/service/file_service.py index 7f8e8de4c..a2c50584c 100644 --- a/runtime/datamate-python/app/module/rag/service/file_service.py +++ b/runtime/datamate-python/app/module/rag/service/file_service.py @@ -6,8 +6,10 @@ import uuid from typing import List, Tuple from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from app.db.models.knowledge_gen import RagFile, FileStatus +from app.db.models.dataset_management import DatasetFiles from app.module.rag.schema.request import AddFilesReq from app.module.rag.repository import RagFileRepository, KnowledgeBaseRepository from app.module.rag.infra.milvus.vectorstore import delete_chunks_by_rag_file_ids @@ -58,28 +60,22 @@ async def add_files(self, request: AddFilesReq) -> Tuple[List[RagFile], List[str if not request.files or len(request.files) == 0: raise BusinessError(ErrorCodes.BAD_REQUEST, "文件列表不能为空") - # 导入 dataset 服务 - from app.module.dataset.service.service import Service as DatasetService - - dataset_service = DatasetService(self.db) - # 验证文件存在并创建 RAG 文件记录 rag_files = [] skipped_file_ids = [] for file_info in request.files: try: - # 通过 dataset 服务验证文件是否存在 - file_path = await dataset_service.get_file_download_url( - dataset_id=file_info.dataset_id, - file_id=file_info.id + # 根据 file_info.id (DatasetFile ID) 查询文件信息 + result = await self.db.execute( + select(DatasetFiles).where(DatasetFiles.id == file_info.id) ) + dataset_file = result.scalar_one_or_none() # 跳过不存在的文件 - if not file_path: + if not dataset_file: logger.warning( - f"文件不存在,跳过处理: dataset_id={file_info.dataset_id}, " - f"file_id={file_info.id}, file_name={file_info.file_name}" + f"文件不存在,跳过处理: file_id={file_info.id}" ) skipped_file_ids.append(file_info.id) continue @@ -88,23 +84,20 @@ async def add_files(self, request: AddFilesReq) -> Tuple[List[RagFile], List[str rag_file = RagFile( id=str(uuid.uuid4()), knowledge_base_id=request.knowledge_base_id, - file_name=file_info.file_name, + file_name=dataset_file.file_name, file_id=file_info.id, - chunk_count=None, file_metadata={ "process_type": request.process_type.value, - "dataset_id": file_info.dataset_id, - "file_path": file_path + "dataset_id": dataset_file.dataset_id, + "file_path": dataset_file.file_path }, status=FileStatus.UNPROCESSED, - err_msg=None ) rag_files.append(rag_file) except Exception as e: logger.error( - f"处理文件信息失败: dataset_id={file_info.dataset_id}, " - f"file_id={file_info.id}, error={e}" + f"处理文件信息失败: file_id={file_info.id}, error={e}" ) skipped_file_ids.append(file_info.id) continue diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 9584e406a..75c11402a 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -182,3 +182,21 @@ async def _query_document_rag(self, query: str, kb: KnowledgeBase) -> str: chain = create_retrieval_chain(retriever, combine_chain) result = await chain.ainvoke({"input": query}) return result.get("answer", "") + + async def index(self, documents: list[dict], knowledge_base_id: int) -> dict: + kb = await self._get_knowledge_base(str(knowledge_base_id)) + embedding_entity = await self._get_models(kb.embedding_model) + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + vectorstore = VectorStoreFactory.create( + collection_name=kb.name, + embedding=embedding, + ) + from langchain_core.documents import Document + docs = [Document(page_content=doc.get("content", ""), metadata=doc.get("metadata", {})) for doc in documents] + await vectorstore.aadd_documents(docs) + logger.info(f"Indexed {len(documents)} documents into knowledge base {knowledge_base_id}") + return {"indexed_count": len(documents), "collection_name": kb.name} From 50f48b081e64d8732dc1e2925a0417920e8ac17d Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Thu, 26 Feb 2026 14:40:38 +0800 Subject: [PATCH 05/13] feat: enhance RAG infrastructure with document processing, vector storage, and retrieval capabilities --- .../app/module/rag/infra/__init__.py | 24 +- .../app/module/rag/infra/document/__init__.py | 42 ++ .../app/module/rag/infra/document/loader.py | 26 ++ .../{pipeline.py => document/processor.py} | 108 +++-- .../app/module/rag/infra/document/splitter.py | 148 ++++++ .../{parser/base.py => document/types.py} | 108 ++--- .../app/module/rag/infra/milvus/__init__.py | 5 - .../module/rag/infra/milvus/vectorstore.py | 256 ----------- .../app/module/rag/infra/options.py | 40 -- .../app/module/rag/infra/parser/__init__.py | 13 - .../app/module/rag/infra/splitter/__init__.py | 5 - .../app/module/rag/infra/splitter/base.py | 74 --- .../app/module/rag/infra/splitter/factory.py | 57 --- .../rag/infra/splitter/langchain_impl.py | 107 ----- .../app/module/rag/infra/task/__init__.py | 3 + .../app/module/rag/infra/task/worker_pool.py | 26 +- .../module/rag/infra/vectorstore/__init__.py | 24 + .../infra/{milvus => vectorstore}/factory.py | 30 +- .../app/module/rag/infra/vectorstore/store.py | 231 ++++++++++ .../module/rag/interface/knowledge_base.py | 157 ++----- .../app/module/rag/schema/__init__.py | 2 +- .../module/rag/schema/{entity.py => types.py} | 0 .../app/module/rag/service/etl_service.py | 286 ------------ .../app/module/rag/service/file_processor.py | 237 ++++++++++ .../app/module/rag/service/file_service.py | 174 -------- .../rag/service/knowledge_base_service.py | 422 +++++------------- .../app/module/rag/service/rag_service.py | 2 +- .../module/rag/service/retrieval_service.py | 218 +++++++++ 28 files changed, 1215 insertions(+), 1610 deletions(-) create mode 100644 runtime/datamate-python/app/module/rag/infra/document/__init__.py create mode 100644 runtime/datamate-python/app/module/rag/infra/document/loader.py rename runtime/datamate-python/app/module/rag/infra/{pipeline.py => document/processor.py} (50%) create mode 100644 runtime/datamate-python/app/module/rag/infra/document/splitter.py rename runtime/datamate-python/app/module/rag/infra/{parser/base.py => document/types.py} (52%) delete mode 100644 runtime/datamate-python/app/module/rag/infra/milvus/__init__.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/options.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/parser/__init__.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/__init__.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/base.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/factory.py delete mode 100644 runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py create mode 100644 runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py rename runtime/datamate-python/app/module/rag/infra/{milvus => vectorstore}/factory.py (73%) create mode 100644 runtime/datamate-python/app/module/rag/infra/vectorstore/store.py rename runtime/datamate-python/app/module/rag/schema/{entity.py => types.py} (100%) delete mode 100644 runtime/datamate-python/app/module/rag/service/etl_service.py create mode 100644 runtime/datamate-python/app/module/rag/service/file_processor.py delete mode 100644 runtime/datamate-python/app/module/rag/service/file_service.py create mode 100644 runtime/datamate-python/app/module/rag/service/retrieval_service.py diff --git a/runtime/datamate-python/app/module/rag/infra/__init__.py b/runtime/datamate-python/app/module/rag/infra/__init__.py index 8a680d773..d1f157cf8 100644 --- a/runtime/datamate-python/app/module/rag/infra/__init__.py +++ b/runtime/datamate-python/app/module/rag/infra/__init__.py @@ -1,19 +1,19 @@ """ -RAG 基础设施层:文档加载、分片、管道 +RAG 基础设施层 -使用示例: - from app.module.rag.infra import load_and_split, SplitOptions +提供文档处理、向量存储、嵌入模型和后台任务功能。 - chunks = await load_and_split( - "/path/to/doc.pdf", - split_options=SplitOptions( - process_type=ProcessType.PARAGRAPH_CHUNK, - chunk_size=300, - ) - ) +使用示例: + from app.module.rag.infra.document import ingest_file_to_chunks + from app.module.rag.infra.vectorstore import VectorStoreFactory + from app.module.rag.infra.task import get_global_pool """ -from app.module.rag.infra.pipeline import ingest_file_to_chunks, load_and_split -from app.module.rag.infra.options import SplitOptions, default_split_options +from app.module.rag.infra.document import ( + SplitOptions, + default_split_options, + ingest_file_to_chunks, + load_and_split, +) __all__ = [ "load_and_split", diff --git a/runtime/datamate-python/app/module/rag/infra/document/__init__.py b/runtime/datamate-python/app/module/rag/infra/document/__init__.py new file mode 100644 index 000000000..43b74058d --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/document/__init__.py @@ -0,0 +1,42 @@ +""" +文档处理模块 + +提供文档加载、分块和处理管道功能。 + +使用示例: + from app.module.rag.infra.document import ( + ingest_file_to_chunks, + SplitOptions, + DocumentChunk, + ) + + chunks = await ingest_file_to_chunks( + "/path/to/doc.pdf", + chunk_size=500, + overlap_size=50, + ) +""" +from app.module.rag.infra.document.processor import ( + SplitOptions, + default_split_options, + ingest_file_to_chunks, + load_and_split, +) +from app.module.rag.infra.document.types import ( + DocumentChunk, + ParsedDocument, + langchain_documents_to_parsed, +) + +__all__ = [ + # 处理管道入口 + "load_and_split", + "ingest_file_to_chunks", + # 选项 + "SplitOptions", + "default_split_options", + # 类型 + "DocumentChunk", + "ParsedDocument", + "langchain_documents_to_parsed", +] diff --git a/runtime/datamate-python/app/module/rag/infra/document/loader.py b/runtime/datamate-python/app/module/rag/infra/document/loader.py new file mode 100644 index 000000000..79d88a78f --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/document/loader.py @@ -0,0 +1,26 @@ +""" +文档加载器 + +封装 UniversalDocLoader,提供统一的文档加载接口。 +""" +import asyncio +from typing import List + +from langchain_core.documents import Document + +from app.module.shared.common.document_loaders import UniversalDocLoader + + +async def load_document(file_path: str) -> List[Document]: + """加载文档(异步封装) + + 使用 UniversalDocLoader 加载文档,支持多种格式。 + + Args: + file_path: 文件绝对路径 + + Returns: + LangChain Document 列表 + """ + loader = UniversalDocLoader(file_path) + return await asyncio.to_thread(loader.load) diff --git a/runtime/datamate-python/app/module/rag/infra/pipeline.py b/runtime/datamate-python/app/module/rag/infra/document/processor.py similarity index 50% rename from runtime/datamate-python/app/module/rag/infra/pipeline.py rename to runtime/datamate-python/app/module/rag/infra/document/processor.py index 213923860..5b2e300a1 100644 --- a/runtime/datamate-python/app/module/rag/infra/pipeline.py +++ b/runtime/datamate-python/app/module/rag/infra/document/processor.py @@ -1,21 +1,54 @@ """ -RAG 文档加载与分片管道 +文档处理管道 -使用全局 UniversalDocLoader 加载文档,分片后返回 DocumentChunk 列表。 +提供统一的文档加载和分块入口,合并原有的 pipeline.py 和 options.py。 """ -import asyncio -from typing import Any, List, Optional - -from app.module.shared.common.document_loaders import UniversalDocLoader -from app.module.rag.infra.parser import langchain_documents_to_parsed -from app.module.rag.infra.splitter.base import DocumentChunk -from app.module.rag.infra.splitter.factory import DocumentSplitterFactory +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from app.module.rag.infra.document.loader import load_document +from app.module.rag.infra.document.splitter import DocumentSplitterFactory +from app.module.rag.infra.document.types import ( + DocumentChunk, + ParsedDocument, + langchain_documents_to_parsed, +) from app.module.rag.schema.enums import ProcessType +@dataclass +class SplitOptions: + """文档分片选项 + + Attributes: + process_type: 分片策略 + chunk_size: 块大小(字符) + overlap_size: 块间重叠 + delimiter: 仅 CUSTOM_SEPARATOR_CHUNK 时有效 + """ + process_type: ProcessType = ProcessType.DEFAULT_CHUNK + chunk_size: int = 500 + overlap_size: int = 50 + delimiter: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "process_type": self.process_type, + "chunk_size": self.chunk_size, + "overlap_size": self.overlap_size, + "delimiter": self.delimiter, + } + + +def default_split_options() -> SplitOptions: + """默认分片选项:递归分块 500/50""" + return SplitOptions() + + async def load_and_split( file_path: str, - split_options: Optional[dict] = None, + split_options: Optional[SplitOptions] = None, **chunk_metadata: Any, ) -> List[DocumentChunk]: """加载文档并分块 @@ -24,37 +57,23 @@ async def load_and_split( Args: file_path: 文件绝对路径 - split_options: 分片选项,None 表示使用默认(递归分块 500/50) - - process_type: ProcessType 枚举,默认 DEFAULT_CHUNK - - chunk_size: 块大小,默认 500 - - overlap_size: 重叠大小,默认 50 - - delimiter: 自定义分隔符 + split_options: 分片选项,None 表示使用默认 **chunk_metadata: 写入每个 chunk.metadata 的额外字段 Returns: - List[DocumentChunk]: 分块列表 + 分块列表 """ - # 1. 加载文档(使用同步加载器并在异步上下文中运行) - loader = UniversalDocLoader(file_path) - documents = await asyncio.to_thread(loader.load) + documents = await load_document(file_path) - # 2. 准备 parser metadata parser_metadata = {} for key in ["original_file_id", "rag_file_id", "file_name"]: if key in chunk_metadata: parser_metadata[key] = chunk_metadata[key] - # 3. 转换为 ParsedDocument(传递额外的 metadata) parsed = langchain_documents_to_parsed(documents, file_path, **parser_metadata) - # 4. 获取分片选项 - options = split_options or {} - process_type = options.get("process_type", ProcessType.DEFAULT_CHUNK) - chunk_size = options.get("chunk_size", 500) - overlap_size = options.get("overlap_size", 50) - delimiter = options.get("delimiter") + options = split_options or default_split_options() - # 5. 合并 metadata 用于 chunk base_chunk_metadata = { "file_name": parsed.metadata.get("file_name", ""), "file_extension": parsed.metadata.get("file_extension", ""), @@ -64,19 +83,14 @@ async def load_and_split( } base_chunk_metadata.update(chunk_metadata) - # 6. 分片 splitter = DocumentSplitterFactory.create_splitter( - process_type, - chunk_size=chunk_size, - overlap_size=overlap_size, - delimiter=delimiter, - ) - chunks = await splitter.split( - parsed.text, - **base_chunk_metadata, + options.process_type, + chunk_size=options.chunk_size, + overlap_size=options.overlap_size, + delimiter=options.delimiter, ) - return chunks + return await splitter.split(parsed.text, **base_chunk_metadata) async def ingest_file_to_chunks( @@ -100,16 +114,12 @@ async def ingest_file_to_chunks( **chunk_metadata: 写入每个 chunk.metadata 的额外字段 Returns: - List[DocumentChunk]: 分块列表 + 分块列表 """ - split_options = { - "process_type": process_type, - "chunk_size": chunk_size, - "overlap_size": overlap_size, - "delimiter": delimiter, - } - return await load_and_split( - file_path, - split_options=split_options, - **chunk_metadata, + split_options = SplitOptions( + process_type=process_type, + chunk_size=chunk_size, + overlap_size=overlap_size, + delimiter=delimiter, ) + return await load_and_split(file_path, split_options=split_options, **chunk_metadata) diff --git a/runtime/datamate-python/app/module/rag/infra/document/splitter.py b/runtime/datamate-python/app/module/rag/infra/document/splitter.py new file mode 100644 index 000000000..55cc0e87c --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/document/splitter.py @@ -0,0 +1,148 @@ +""" +文档分块器 + +包含 DocumentSplitter 基类、工厂和 LangChain 实现。 +根据 ProcessType 创建对应的分块策略。 +""" +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from langchain_text_splitters import ( + CharacterTextSplitter, + RecursiveCharacterTextSplitter, +) + +from app.module.rag.infra.document.types import DocumentChunk +from app.module.rag.schema.enums import ProcessType + + +# 各 ProcessType 对应的分隔符配置(优先保持较大语义块) +SEPARATORS_BY_PROCESS_TYPE = { + ProcessType.PARAGRAPH_CHUNK: ["\n\n", "\n", " ", ""], + ProcessType.SENTENCE_CHUNK: ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], + ProcessType.DEFAULT_CHUNK: ["\n\n", "\n", " ", ""], + ProcessType.CUSTOM_SEPARATOR_CHUNK: None, +} + + +class DocumentSplitter(ABC): + """文档分块器基类 + + 所有具体的分块器都需要继承此类并实现 split 方法。 + """ + + def __init__(self, chunk_size: int = 500, overlap_size: int = 50): + """初始化分块器 + + Args: + chunk_size: 分块大小 + overlap_size: 重叠大小 + """ + self.chunk_size = chunk_size + self.overlap_size = overlap_size + + @abstractmethod + async def split(self, text: str, **metadata: Any) -> List[DocumentChunk]: + """分割文档 + + Args: + text: 文档文本 + **metadata: 额外的元数据 + + Returns: + 分块列表 + """ + pass + + def _create_chunk(self, text: str, chunk_index: int, **metadata: Any) -> DocumentChunk: + """创建文档分块""" + chunk_metadata = {"chunk_index": chunk_index, **metadata} + return DocumentChunk(text=text, metadata=chunk_metadata) + + +class LangChainDocumentSplitter(DocumentSplitter): + """基于 LangChain 的分块器实现 + + 根据 ProcessType 选择 RecursiveCharacterTextSplitter 或 CharacterTextSplitter。 + """ + + def __init__( + self, + process_type: ProcessType, + chunk_size: int = 500, + overlap_size: int = 50, + delimiter: Optional[str] = None, + ): + super().__init__(chunk_size=chunk_size, overlap_size=overlap_size) + self._process_type = process_type + self._delimiter = delimiter or "\n\n" + self._splitter = self._create_splitter() + + def _create_splitter(self) -> RecursiveCharacterTextSplitter | CharacterTextSplitter: + """创建 LangChain 分块器实例""" + if self._process_type == ProcessType.LENGTH_CHUNK: + return CharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.overlap_size, + length_function=len, + ) + + separators = SEPARATORS_BY_PROCESS_TYPE.get(self._process_type) + if self._process_type == ProcessType.CUSTOM_SEPARATOR_CHUNK: + separators = [self._delimiter, "\n", " ", ""] + if separators is None: + separators = ["\n\n", "\n", " ", ""] + + return RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.overlap_size, + separators=separators, + length_function=len, + ) + + async def split(self, text: str, **metadata: Any) -> List[DocumentChunk]: + """分割文档(异步)""" + if not text or not text.strip(): + return [] + + texts = await asyncio.to_thread(self._splitter.split_text, text) + return [ + DocumentChunk(text=t, metadata={**metadata, "chunk_index": i}) + for i, t in enumerate(texts) + ] + + +class DocumentSplitterFactory: + """文档分块器工厂 + + 根据处理类型创建对应的分块器实例。 + """ + + @classmethod + def create_splitter( + cls, + process_type: ProcessType, + chunk_size: int = 500, + overlap_size: int = 50, + delimiter: Optional[str] = None, + ) -> DocumentSplitter: + """创建分块器 + + Args: + process_type: 处理类型 + chunk_size: 分块大小 + overlap_size: 重叠大小 + delimiter: 自定义分隔符(仅用于 CUSTOM_SEPARATOR_CHUNK) + + Returns: + DocumentSplitter 实例 + """ + return LangChainDocumentSplitter( + process_type=process_type, + chunk_size=chunk_size, + overlap_size=overlap_size, + delimiter=delimiter, + ) diff --git a/runtime/datamate-python/app/module/rag/infra/parser/base.py b/runtime/datamate-python/app/module/rag/infra/document/types.py similarity index 52% rename from runtime/datamate-python/app/module/rag/infra/parser/base.py rename to runtime/datamate-python/app/module/rag/infra/document/types.py index d2e9fa6d6..7127cb029 100644 --- a/runtime/datamate-python/app/module/rag/infra/parser/base.py +++ b/runtime/datamate-python/app/module/rag/infra/document/types.py @@ -1,20 +1,41 @@ """ -文档解析器基类 +文档处理类型定义 -定义文档解析器的抽象接口 -使用策略模式支持多种文档格式 +包含 ParsedDocument(解析后文档)和 DocumentChunk(文档分块)。 """ -from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Dict, Any, List, Optional from pathlib import Path from langchain_core.documents import Document +@dataclass +class DocumentChunk: + """文档分块 + + 包含分块的文本内容和元数据。 + + Attributes: + text: 分块文本内容 + metadata: 分块元数据(包含文件信息、分块索引等) + """ + text: str + metadata: dict + + def __repr__(self): + return f"" + + class ParsedDocument: """解析后的文档 - 包含文档的文本内容和元数据 + 包含文档的完整文本内容和元数据,由文档加载器生成,供分块器使用。 + + Attributes: + text: 文档完整文本内容 + metadata: 文档元数据(如文件名、扩展名、路径等) + file_name: 文件名 """ def __init__( @@ -46,7 +67,7 @@ def langchain_documents_to_parsed( ) -> ParsedDocument: """将 LangChain Document 列表转换为 ParsedDocument - 多页/多段结果合并为一个文档,用于 pipeline。 + 多页/多段结果合并为一个文档,用于后续分块处理。 Args: documents: LangChain 加载器返回的 Document 列表 @@ -55,7 +76,7 @@ def langchain_documents_to_parsed( **extra_metadata: 额外的元数据字段(会合并到返回的 metadata 中) Returns: - ParsedDocument: 合并后的领域文档对象 + ParsedDocument: 合并后的文档对象 """ path = Path(file_path) name = file_name or path.name @@ -80,15 +101,12 @@ def langchain_documents_to_parsed( "file_name": name, "file_extension": path.suffix.lower(), "file_size": path.stat().st_size if path.exists() else 0, - # 添加路径信息 "absolute_directory_path": str(path.parent), "file_path": str(path), } - # 合并额外的元数据 meta.update(extra_metadata) - # 合并第一个文档的元数据 if documents and isinstance(documents[0].metadata, dict): first_meta = documents[0].metadata for k, v in first_meta.items(): @@ -96,73 +114,3 @@ def langchain_documents_to_parsed( meta[k] = v return ParsedDocument(text=merged_text, metadata=meta, file_name=name) - - -class DocumentParser(ABC): - """文档解析器基类(抽象类) - - 对应 Java 的文档解析接口 - - 所有具体的解析器都需要继承此类并实现 parse 方法 - """ - - @abstractmethod - async def parse(self, file_path: str) -> ParsedDocument: - """解析文档 - - Args: - file_path: 文件路径(绝对路径) - - Returns: - ParsedDocument: 解析后的文档对象 - - Raises: - FileNotFoundError: 文件不存在 - ValueError: 文件格式不支持或解析失败 - """ - pass - - def _get_file_name(self, file_path: str) -> str: - """从文件路径中提取文件名 - - Args: - file_path: 文件路径 - - Returns: - 文件名 - """ - return Path(file_path).name - - def _get_file_extension(self, file_path: str) -> str: - """从文件路径中提取文件扩展名 - - Args: - file_path: 文件路径 - - Returns: - 文件扩展名(包含点号,如 ".pdf") - """ - return Path(file_path).suffix.lower() - - def _build_metadata( - self, - file_path: str, - **extra_fields - ) -> Dict[str, Any]: - """构建文档元数据 - - Args: - file_path: 文件路径 - **extra_fields: 额外的元数据字段 - - Returns: - 元数据字典 - """ - path = Path(file_path) - metadata = { - "file_name": path.name, - "file_extension": self._get_file_extension(file_path), - "file_size": path.stat().st_size if path.exists() else 0, - } - metadata.update(extra_fields) - return metadata diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/__init__.py b/runtime/datamate-python/app/module/rag/infra/milvus/__init__.py deleted file mode 100644 index b5d6e33ff..000000000 --- a/runtime/datamate-python/app/module/rag/infra/milvus/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Milvus 向量存储相关模块 - -提供与 Milvus 集成的向量存储功能。 -""" diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py b/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py deleted file mode 100644 index 1e4b4dd5f..000000000 --- a/runtime/datamate-python/app/module/rag/infra/milvus/vectorstore.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -基于 LangChain Milvus 的向量存储封装 - -使用 langchain-milvus.Milvus + BM25BuiltInFunction 实现密集向量 + 全文检索, -Milvus 2.6.x 自动处理 BM25 稀疏向量,无需手动生成。 - -同时提供集合管理辅助函数:drop_collection、rename_collection,供知识库删除/重命名使用。 -""" -from __future__ import annotations - -import logging -from typing import List, Optional - -from langchain_core.documents import Document - -from app.core.config import settings -from app.module.rag.infra.embeddings import EmbeddingFactory - -logger = logging.getLogger(__name__) - - -def _connection_args() -> dict: - args: dict = {"uri": settings.milvus_uri} - if getattr(settings, "milvus_token", None): - args["token"] = settings.milvus_token - return args - - -def _ensure_connection() -> None: - """确保 Milvus 默认连接已建立(供 utility 使用)。""" - from pymilvus import connections - - conn_args = _connection_args() - connections.connect(alias="default", uri=conn_args["uri"], token=conn_args.get("token") or "") - - -def drop_collection(collection_name: str) -> None: - """删除 Milvus 集合。用于知识库删除等场景。""" - from pymilvus import utility - - from app.core.exception import BusinessError, ErrorCodes - - try: - _ensure_connection() - if utility.has_collection(collection_name, using="default"): - utility.drop_collection(collection_name, using="default") - logger.info("成功删除集合: %s", collection_name) - except Exception as e: - logger.error("删除集合失败: %s", e) - raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除集合失败: {str(e)}") from e - - -def rename_collection(old_name: str, new_name: str) -> None: - """重命名 Milvus 集合。用于知识库重命名。""" - from pymilvus import utility - - from app.core.exception import BusinessError, ErrorCodes - - try: - _ensure_connection() - if utility.has_collection(old_name, using="default"): - utility.rename_collection(old_name, new_name, using="default") - logger.info("成功重命名集合: %s -> %s", old_name, new_name) - except Exception as e: - logger.error("重命名集合失败: %s", e) - raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"重命名集合失败: {str(e)}") from e - - -def create_java_compatible_collection( - collection_name: str, - dimension: int, - consistency_level: str = "Strong" -) -> None: - """创建与 Java 服务兼容的 Milvus 集合 - - 使用 Java 服务相同的字段命名和结构: - - id (VarChar, 主键) - - text (VarChar, with analyzer for BM25) - - metadata (JSON) - - vector (FloatVector, 密集向量) - - sparse (SparseFloatVector, BM25 稀疏向量) - - Args: - collection_name: 集合名称 - dimension: 向量维度 - consistency_level: 一致性级别 - """ - from pymilvus import MilvusClient, DataType, FunctionType, CollectionSchema, FieldSchema, Function - - from app.core.exception import BusinessError, ErrorCodes - - try: - conn_args = _connection_args() - - # 创建 Milvus 客户端 - client = MilvusClient(uri=conn_args["uri"], token="") - - # 检查集合是否已存在 - if client.has_collection(collection_name): - logger.info("集合 %s 已存在,跳过创建", collection_name) - return - - # 创建字段 - fields = [ - # 1. 主键字段 id - FieldSchema( - name="id", - dtype=DataType.VARCHAR, - max_length=36, - is_primary=True, - auto_id=False - ), - # 2. 文本字段 text(启用 analyzer 用于 BM25) - FieldSchema( - name="text", - dtype=DataType.VARCHAR, - max_length=65535, - enable_analyzer=True, - enable_match=True - ), - # 3. 元数据字段 metadata - FieldSchema( - name="metadata", - dtype=DataType.JSON - ), - # 4. 密集向量字段 vector - FieldSchema( - name="vector", - dtype=DataType.FLOAT_VECTOR, - dim=dimension - ), - # 5. 稀疏向量字段 sparse(BM25) - FieldSchema( - name="sparse", - dtype=DataType.SPARSE_FLOAT_VECTOR - ) - ] - - # 创建 BM25 函数(不使用 params,避免 Milvus 参数错误) - function = Function( - name="text_bm25_emb", - function_type=FunctionType.BM25, - input_field_names=["text"], - output_field_names=["sparse"] - ) - - # 创建 schema - schema = CollectionSchema( - fields=fields, - functions=[function], - description="Knowledge base collection", - enable_dynamic_field=True - ) - - # 创建集合(不包含索引) - # 索引会在首次插入数据时由 Milvus/LangChain 自动创建 - client.create_collection( - collection_name=collection_name, - schema=schema, - consistency_level=consistency_level - ) - - logger.info("成功创建 Java 兼容的集合: %s (维度: %d)", collection_name, dimension) - - except Exception as e: - logger.error("创建集合失败: %s", e) - raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"创建集合失败: {str(e)}") from e - - -def get_vector_dimension(embedding_model: str, base_url: Optional[str] = None, api_key: Optional[str] = None) -> int: - """获取嵌入模型的向量维度 - - Args: - embedding_model: 模型名称 - base_url: API 基础 URL - api_key: API 密钥 - - Returns: - 向量维度 - - Raises: - BusinessError: 无法获取维度 - """ - - from app.core.exception import BusinessError, ErrorCodes - - try: - embedding = EmbeddingFactory.create_embeddings( - model_name=embedding_model, - base_url=base_url, - api_key=api_key, - ) - - test_text = "test" - # 直接调用同步的 embed_query 方法 - embedding_vector = embedding.embed_query(test_text) - dimension = len(embedding_vector) - - logger.info("获取模型 %s 的向量维度: %d", embedding_model, dimension) - return dimension - - except Exception as e: - logger.error("获取模型维度失败: %s", e) - raise BusinessError(ErrorCodes.RAG_EMBEDDING_FAILED, f"获取模型维度失败: {str(e)}") from e - - -def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) -> None: - """按 RAG 文件 ID 列表删除 Milvus 中的分块。用于文件删除时清理向量数据。""" - import json - if not rag_file_ids: - return - from pymilvus import MilvusClient - - from app.core.exception import BusinessError, ErrorCodes - - try: - conn_args = _connection_args() - client = MilvusClient(uri=conn_args["uri"], token="") - - # metadata 为 JSON 字段,按 rag_file_id 过滤 - # 使用 JSON_CONTAINS 的正确语法 - for rid in rag_file_ids: - json_value = json.dumps({"rag_file_id": rid}) - filter_expr = f'JSON_CONTAINS(metadata, \'{json_value}\')' - try: - client.delete(collection_name=collection_name, filter=filter_expr) - except Exception as del_err: - logger.warning("删除分块时部分失败 collection=%s rag_file_id=%s: %s", collection_name, rid, del_err) - logger.info("已按 rag_file_id 删除集合 %s 中的分块: %s", collection_name, rag_file_ids) - except Exception as e: - logger.error("按 rag_file_id 删除 Milvus 分块失败: %s", e) - raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除分块失败: {str(e)}") from e - - -def chunks_to_langchain_documents( - chunks: list, - ids: List[str] = None -) -> tuple[list, List[str]]: - """将 DocumentChunk 转换为 LangChain Document 格式 - - Args: - chunks: DocumentChunk 列表 - ids: 可选的 ID 列表 - - Returns: - (documents, ids): LangChain Document 列表和对应的 ID 列表 - """ - if ids is None: - ids = [str(i) for i in range(len(chunks))] - - documents = [] - for chunk, chunk_id in zip(chunks, ids): - doc = Document(page_content=chunk.text, metadata=chunk.metadata) - documents.append(doc) - - return documents, ids diff --git a/runtime/datamate-python/app/module/rag/infra/options.py b/runtime/datamate-python/app/module/rag/infra/options.py deleted file mode 100644 index 7d7def7bf..000000000 --- a/runtime/datamate-python/app/module/rag/infra/options.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -文档加载与分片选项 - -保留必要的配置项,简化使用。 -""" -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from app.module.rag.schema.enums import ProcessType - - -@dataclass -class SplitOptions: - """文档分片选项 - - Args: - process_type: 分片策略 - chunk_size: 块大小(字符) - overlap_size: 块间重叠 - delimiter: 仅 CUSTOM_SEPARATOR_CHUNK 时有效 - """ - - process_type: ProcessType = ProcessType.DEFAULT_CHUNK - chunk_size: int = 500 - overlap_size: int = 50 - delimiter: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """转换为字典,用于传递给 load_and_split""" - return { - "process_type": self.process_type, - "chunk_size": self.chunk_size, - "overlap_size": self.overlap_size, - "delimiter": self.delimiter, - } - - -def default_split_options() -> SplitOptions: - """默认分片选项:递归分块 500/50""" - return SplitOptions() diff --git a/runtime/datamate-python/app/module/rag/infra/parser/__init__.py b/runtime/datamate-python/app/module/rag/infra/parser/__init__.py deleted file mode 100644 index ebf914b5d..000000000 --- a/runtime/datamate-python/app/module/rag/infra/parser/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""保留 ParsedDocument 与 DocumentParser 基类,供 loader 层转换使用.""" - -from app.module.rag.infra.parser.base import ( - ParsedDocument, - DocumentParser, - langchain_documents_to_parsed, -) - -__all__ = [ - "ParsedDocument", - "DocumentParser", - "langchain_documents_to_parsed", -] diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/__init__.py b/runtime/datamate-python/app/module/rag/infra/splitter/__init__.py deleted file mode 100644 index c1e09be1a..000000000 --- a/runtime/datamate-python/app/module/rag/infra/splitter/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -文档分块器模块 - -提供各种文档分块策略的实现。 -""" diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/base.py b/runtime/datamate-python/app/module/rag/infra/splitter/base.py deleted file mode 100644 index 4a9292018..000000000 --- a/runtime/datamate-python/app/module/rag/infra/splitter/base.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -文档分块器基类 - -定义文档分块器的抽象接口 -使用策略模式支持多种分块策略 -""" -from abc import ABC, abstractmethod -from typing import List -from dataclasses import dataclass - - -@dataclass -class DocumentChunk: - """文档分块 - - 包含分块的文本和元数据 - """ - text: str - metadata: dict - - def __repr__(self): - return f"" - - -class DocumentSplitter(ABC): - """文档分块器基类(抽象类) - - 所有具体的分块器都需要继承此类并实现 split 方法 - """ - - def __init__(self, chunk_size: int = 500, overlap_size: int = 50): - """初始化分块器 - - Args: - chunk_size: 分块大小 - overlap_size: 重叠大小 - """ - self.chunk_size = chunk_size - self.overlap_size = overlap_size - - @abstractmethod - async def split(self, text: str, **metadata) -> List[DocumentChunk]: - """分割文档 - - Args: - text: 文档文本 - **metadata: 额外的元数据 - - Returns: - List[DocumentChunk]: 分块列表 - """ - pass - - def _create_chunk( - self, - text: str, - chunk_index: int, - **metadata - ) -> DocumentChunk: - """创建文档分块 - - Args: - text: 分块文本 - chunk_index: 分块索引 - **metadata: 额外的元数据 - - Returns: - DocumentChunk: 文档分块 - """ - chunk_metadata = { - "chunk_index": chunk_index, - **metadata - } - return DocumentChunk(text=text, metadata=chunk_metadata) diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/factory.py b/runtime/datamate-python/app/module/rag/infra/splitter/factory.py deleted file mode 100644 index 5d2d833eb..000000000 --- a/runtime/datamate-python/app/module/rag/infra/splitter/factory.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -文档分块器工厂 - -根据处理类型创建基于 LangChain 的分块器实例, -对应 Java 的 ProcessType 枚举,供 ETL 与 ingest pipeline 复用。 -""" -from typing import Optional - -from app.module.rag.infra.splitter.base import DocumentSplitter -from app.module.rag.infra.splitter.langchain_impl import LangChainDocumentSplitter -from app.module.rag.schema.enums import ProcessType - - -class DocumentSplitterFactory: - """文档分块器工厂 - - 基于 LangChain RecursiveCharacterTextSplitter / CharacterTextSplitter: - - PARAGRAPH_CHUNK: 段落分块 - - SENTENCE_CHUNK: 句子分块 - - LENGTH_CHUNK: 字符长度分块 - - DEFAULT_CHUNK: 默认递归分块(推荐) - - CUSTOM_SEPARATOR_CHUNK: 自定义分隔符分块 - - 使用示例: - splitter = DocumentSplitterFactory.create_splitter( - ProcessType.DEFAULT_CHUNK, - chunk_size=500, - overlap_size=50 - ) - chunks = await splitter.split(document_text) - """ - - @classmethod - def create_splitter( - cls, - process_type: ProcessType, - chunk_size: int = 500, - overlap_size: int = 50, - delimiter: Optional[str] = None, - ) -> DocumentSplitter: - """根据处理类型创建对应的分块器(LangChain 实现). - - Args: - process_type: 处理类型 - chunk_size: 分块大小 - overlap_size: 重叠大小 - delimiter: 自定义分隔符(仅用于 CUSTOM_SEPARATOR_CHUNK) - - Returns: - DocumentSplitter 实例 - """ - return LangChainDocumentSplitter( - process_type=process_type, - chunk_size=chunk_size, - overlap_size=overlap_size, - delimiter=delimiter, - ) diff --git a/runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py b/runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py deleted file mode 100644 index c2aca2535..000000000 --- a/runtime/datamate-python/app/module/rag/infra/splitter/langchain_impl.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -基于 LangChain 的文档分片实现 - -将 ProcessType 映射到 LangChain 的 RecursiveCharacterTextSplitter / CharacterTextSplitter, -在 asyncio.to_thread 中执行同步 split,并转换为领域模型 DocumentChunk。 -""" -from __future__ import annotations - -import asyncio -from typing import Any, List, Optional - -from langchain_text_splitters import ( - CharacterTextSplitter, - RecursiveCharacterTextSplitter, -) -from app.module.rag.schema.enums import ProcessType - -from app.module.rag.infra.splitter.base import DocumentChunk, DocumentSplitter - - -# 各 ProcessType 对应的 RecursiveCharacterTextSplitter 分隔符(优先保持较大语义块) -SEPARATORS_BY_PROCESS_TYPE = { - ProcessType.PARAGRAPH_CHUNK: ["\n\n", "\n", " ", ""], - ProcessType.SENTENCE_CHUNK: ["\n\n", "\n", "。", "!", "?", ". ", "! ", "? ", " ", ""], - ProcessType.DEFAULT_CHUNK: ["\n\n", "\n", " ", ""], # 推荐默认,递归按段/行/词 - ProcessType.CUSTOM_SEPARATOR_CHUNK: None, # 由调用方传入 delimiter,动态构造 -} - - -def _build_recursive_splitter( - chunk_size: int, - chunk_overlap: int, - separators: Optional[List[str]] = None, -) -> RecursiveCharacterTextSplitter: - if separators is None: - separators = ["\n\n", "\n", " ", ""] - return RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - separators=separators, - length_function=len, - ) - - -def _build_character_splitter( - chunk_size: int, - chunk_overlap: int, -) -> CharacterTextSplitter: - return CharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - length_function=len, - ) - - -def _texts_to_chunks(texts: List[str], **base_metadata: Any) -> List[DocumentChunk]: - """将切分后的字符串列表转为 DocumentChunk 列表,保留 chunk_index 等.""" - return [ - DocumentChunk( - text=t, - metadata={**base_metadata, "chunk_index": i}, - ) - for i, t in enumerate(texts) - ] - - -class LangChainDocumentSplitter(DocumentSplitter): - """基于 LangChain 的 DocumentSplitter 实现. - - 根据 ProcessType 选择 RecursiveCharacterTextSplitter 或 CharacterTextSplitter, - async split() 内部用 asyncio.to_thread 调用同步 split_text,再转为 DocumentChunk。 - """ - - def __init__( - self, - process_type: ProcessType, - chunk_size: int = 500, - overlap_size: int = 50, - delimiter: Optional[str] = None, - ): - super().__init__(chunk_size=chunk_size, overlap_size=overlap_size) - self._process_type = process_type - self._delimiter = delimiter or "\n\n" - self._splitter = self._create_splitter() - - def _create_splitter(self) -> RecursiveCharacterTextSplitter | CharacterTextSplitter: - if self._process_type == ProcessType.LENGTH_CHUNK: - return _build_character_splitter( - self.chunk_size, - self.overlap_size, - ) - separators = SEPARATORS_BY_PROCESS_TYPE.get(self._process_type) - if self._process_type == ProcessType.CUSTOM_SEPARATOR_CHUNK: - separators = [self._delimiter, "\n", " ", ""] - if separators is None: - separators = ["\n\n", "\n", " ", ""] - return _build_recursive_splitter( - self.chunk_size, - self.overlap_size, - separators=separators, - ) - - async def split(self, text: str, **metadata: Any) -> List[DocumentChunk]: - if not text or not text.strip(): - return [] - texts = await asyncio.to_thread(self._splitter.split_text, text) - return _texts_to_chunks(texts, **metadata) diff --git a/runtime/datamate-python/app/module/rag/infra/task/__init__.py b/runtime/datamate-python/app/module/rag/infra/task/__init__.py index 28b580ba8..09ce368d1 100644 --- a/runtime/datamate-python/app/module/rag/infra/task/__init__.py +++ b/runtime/datamate-python/app/module/rag/infra/task/__init__.py @@ -3,3 +3,6 @@ 提供工作池和异步任务处理功能。 """ +from app.module.rag.infra.task.worker_pool import WorkerPool, get_global_pool + +__all__ = ["WorkerPool", "get_global_pool"] diff --git a/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py b/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py index 832dd06ed..733e8ac64 100644 --- a/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py +++ b/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py @@ -1,14 +1,36 @@ """ 工作协程池 -使用 asyncio.Semaphore 控制并发数,替代 Java 的虚拟线程 + 信号量 +使用 asyncio.Semaphore 控制并发数,替代 Java 的虚拟线程 + 信号量。 +提供全局单例,确保所有文件处理共享同一个并发池。 """ import asyncio -from typing import Callable, Any, Coroutine +from typing import Callable, Any, Coroutine, Optional import logging logger = logging.getLogger(__name__) +# 全局单例 +_global_pool: Optional["WorkerPool"] = None + + +def get_global_pool(max_workers: int = 10) -> "WorkerPool": + """获取全局 WorkerPool 单例 + + 所有文件处理任务共享同一个并发池,确保最多 10 个文件并行处理。 + + Args: + max_workers: 最大并发数(仅在首次创建时生效) + + Returns: + 全局 WorkerPool 实例 + """ + global _global_pool + if _global_pool is None: + _global_pool = WorkerPool(max_workers) + logger.info("创建全局 WorkerPool,最大并发数: %d", max_workers) + return _global_pool + class WorkerPool: """工作协程池 diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py new file mode 100644 index 000000000..fdc2e465c --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py @@ -0,0 +1,24 @@ +""" +向量存储模块 + +提供 Milvus 向量存储的创建、管理和数据操作功能。 +""" +from app.module.rag.infra.vectorstore.factory import VectorStoreFactory +from app.module.rag.infra.vectorstore.store import ( + chunks_to_documents, + create_collection, + delete_chunks_by_rag_file_ids, + drop_collection, + get_vector_dimension, + rename_collection, +) + +__all__ = [ + "VectorStoreFactory", + "create_collection", + "drop_collection", + "rename_collection", + "get_vector_dimension", + "delete_chunks_by_rag_file_ids", + "chunks_to_documents", +] diff --git a/runtime/datamate-python/app/module/rag/infra/milvus/factory.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py similarity index 73% rename from runtime/datamate-python/app/module/rag/infra/milvus/factory.py rename to runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py index 2cef4ff7a..cb026f38a 100644 --- a/runtime/datamate-python/app/module/rag/infra/milvus/factory.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py @@ -1,7 +1,7 @@ """ 向量存储工厂 -使用 LangChain Milvus 创建向量存储实例,支持混合检索(向量 + BM25) +使用 LangChain Milvus 创建向量存储实例,支持混合检索(向量 + BM25)。 """ from __future__ import annotations @@ -10,6 +10,11 @@ from langchain_core.embeddings import Embeddings from app.core.config import settings +from app.module.rag.infra.vectorstore.store import ( + create_collection, + drop_collection, + get_vector_dimension, +) class VectorStoreFactory: @@ -31,8 +36,7 @@ def create( drop_old: bool = True, consistency_level: str = "Strong", ) -> Any: - """ - 创建 Milvus 向量存储实例(支持混合检索) + """创建 Milvus 向量存储实例(支持混合检索) Args: collection_name: 集合名称(知识库名称) @@ -44,27 +48,25 @@ def create( langchain_milvus.Milvus 实例 """ from langchain_milvus import BM25BuiltInFunction, Milvus - from app.module.rag.infra.milvus.vectorstore import ( - drop_collection, - create_java_compatible_collection, - ) # 获取向量维度 - test_text = "test" - dimension = len(embedding.embed_query(test_text)) + dimension = get_vector_dimension( + embedding_model="", + embedding_instance=embedding, + ) # 删除旧集合(如果存在) if drop_old: drop_collection(collection_name) - # 创建与 Java 兼容的 schema(只有5个字段:id、text、metadata、vector、sparse) - create_java_compatible_collection( + # 创建集合(5个字段:id、text、metadata、vector、sparse) + create_collection( collection_name=collection_name, dimension=dimension, - consistency_level=consistency_level + consistency_level=consistency_level, ) - # 创建 Milvus 实例(不自动创建集合,使用已有的 schema) + # 创建 Milvus 实例 return Milvus( embedding_function=embedding, collection_name=collection_name, @@ -74,5 +76,5 @@ def create( vector_field=["vector"], drop_old=False, consistency_level=consistency_level, - auto_id=False + auto_id=False, ) diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py new file mode 100644 index 000000000..d9c555ccd --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py @@ -0,0 +1,231 @@ +""" +向量存储管理 + +提供 Milvus 集合的创建、删除、重命名和数据操作功能。 +使用 Milvus 2.6+ 的 BM25 内置函数实现混合检索。 + +字段结构: +- id (VarChar, 主键) +- text (VarChar, with analyzer for BM25) +- metadata (JSON) +- vector (FloatVector, 密集向量) +- sparse (SparseFloatVector, BM25 稀疏向量) +""" +from __future__ import annotations + +import json +import logging +from typing import List, Optional + +from langchain_core.documents import Document +from pymilvus import MilvusClient, DataType, FunctionType, CollectionSchema, FieldSchema, Function + +from app.core.config import settings +from app.core.exception import BusinessError, ErrorCodes +from app.module.rag.infra.document.types import DocumentChunk +from app.module.rag.infra.embeddings import EmbeddingFactory + +logger = logging.getLogger(__name__) + + +def _get_connection_args() -> dict: + """获取 Milvus 连接参数""" + args: dict = {"uri": settings.milvus_uri} + token = getattr(settings, "milvus_token", None) + if token: + args["token"] = token + return args + + +def _get_client() -> MilvusClient: + """获取 Milvus 客户端""" + conn_args = _get_connection_args() + return MilvusClient(uri=conn_args["uri"], token=conn_args.get("token", "")) + + +def drop_collection(collection_name: str) -> None: + """删除 Milvus 集合 + + Args: + collection_name: 集合名称 + """ + try: + client = _get_client() + if client.has_collection(collection_name): + client.drop_collection(collection_name) + logger.info("成功删除集合: %s", collection_name) + except Exception as e: + logger.error("删除集合失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除集合失败: {str(e)}") from e + + +def rename_collection(old_name: str, new_name: str) -> None: + """重命名 Milvus 集合 + + Args: + old_name: 原集合名称 + new_name: 新集合名称 + """ + from pymilvus import utility, connections + + try: + conn_args = _get_connection_args() + connections.connect( + alias="default", + uri=conn_args["uri"], + token=conn_args.get("token", ""), + ) + if utility.has_collection(old_name, using="default"): + utility.rename_collection(old_name, new_name, using="default") + logger.info("成功重命名集合: %s -> %s", old_name, new_name) + except Exception as e: + logger.error("重命名集合失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"重命名集合失败: {str(e)}") from e + + +def create_collection( + collection_name: str, + dimension: int, + consistency_level: str = "Strong", +) -> None: + """创建 Milvus 集合 + + 使用标准的5字段结构:id、text、metadata、vector、sparse + + Args: + collection_name: 集合名称 + dimension: 向量维度 + consistency_level: 一致性级别 + """ + try: + client = _get_client() + + if client.has_collection(collection_name): + logger.info("集合 %s 已存在,跳过创建", collection_name) + return + + # 创建字段 + fields = [ + FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=36, is_primary=True, auto_id=False), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535, enable_analyzer=True, enable_match=True), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), + FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR), + ] + + # 创建 BM25 函数 + bm25_function = Function( + name="text_bm25_emb", + function_type=FunctionType.BM25, + input_field_names=["text"], + output_field_names=["sparse"], + ) + + schema = CollectionSchema( + fields=fields, + functions=[bm25_function], + description="Knowledge base collection", + enable_dynamic_field=True, + ) + + client.create_collection( + collection_name=collection_name, + schema=schema, + consistency_level=consistency_level, + ) + + logger.info("成功创建集合: %s (维度: %d)", collection_name, dimension) + + except Exception as e: + logger.error("创建集合失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"创建集合失败: {str(e)}") from e + + +def get_vector_dimension( + embedding_model: str = "", + base_url: Optional[str] = None, + api_key: Optional[str] = None, + embedding_instance=None, +) -> int: + """获取嵌入模型的向量维度 + + Args: + embedding_model: 模型名称 + base_url: API 基础 URL + api_key: API 密钥 + embedding_instance: 已有的 Embeddings 实例(优先使用) + + Returns: + 向量维度 + """ + try: + if embedding_instance: + embedding = embedding_instance + else: + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_model, + base_url=base_url, + api_key=api_key, + ) + + test_vector = embedding.embed_query("test") + dimension = len(test_vector) + logger.info("获取向量维度: %d", dimension) + return dimension + + except Exception as e: + logger.error("获取向量维度失败: %s", e) + raise BusinessError(ErrorCodes.RAG_EMBEDDING_FAILED, f"获取向量维度失败: {str(e)}") from e + + +def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) -> None: + """按 RAG 文件 ID 列表删除 Milvus 中的分块 + + Args: + collection_name: 集合名称 + rag_file_ids: RAG 文件 ID 列表 + """ + if not rag_file_ids: + return + + try: + client = _get_client() + + for rid in rag_file_ids: + json_value = json.dumps({"rag_file_id": rid}) + filter_expr = f'JSON_CONTAINS(metadata, \'{json_value}\')' + try: + client.delete(collection_name=collection_name, filter=filter_expr) + except Exception as del_err: + logger.warning("删除分块时部分失败: collection=%s rag_file_id=%s: %s", collection_name, rid, del_err) + + logger.info("已按 rag_file_id 删除集合 %s 中的分块: %s", collection_name, rag_file_ids) + + except Exception as e: + logger.error("删除 Milvus 分块失败: %s", e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"删除分块失败: {str(e)}") from e + + +def chunks_to_documents( + chunks: List[DocumentChunk], + ids: Optional[List[str]] = None, +) -> tuple[List[Document], List[str]]: + """将 DocumentChunk 转换为 LangChain Document 格式 + + Args: + chunks: DocumentChunk 列表 + ids: 可选的 ID 列表 + + Returns: + (documents, ids): LangChain Document 列表和对应的 ID 列表 + """ + if ids is None: + import uuid + ids = [str(uuid.uuid4()) for _ in chunks] + + documents = [] + for chunk, chunk_id in zip(chunks, ids): + doc = Document(page_content=chunk.text, metadata=chunk.metadata) + documents.append(doc) + + return documents, ids diff --git a/runtime/datamate-python/app/module/rag/interface/knowledge_base.py b/runtime/datamate-python/app/module/rag/interface/knowledge_base.py index 26c46a8b8..7f4678d79 100644 --- a/runtime/datamate-python/app/module/rag/interface/knowledge_base.py +++ b/runtime/datamate-python/app/module/rag/interface/knowledge_base.py @@ -1,14 +1,10 @@ """ 知识库 API 接口 -实现知识库相关的 REST API 接口 +实现知识库相关的 REST API 接口。 对应 Java: com.datamate.rag.indexer.interfaces.KnowledgeBaseController - -接口路径调整: -- Java: /knowledge-base/* -- Python: /rag/knowledge-base/* """ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, BackgroundTasks from sqlalchemy.ext.asyncio import AsyncSession from app.core.exception import SuccessResponse @@ -24,6 +20,7 @@ PagingQuery, ) from app.module.rag.service.knowledge_base_service import KnowledgeBaseService +from app.module.rag.service.retrieval_service import RetrievalService router = APIRouter(prefix="/knowledge-base", tags=["知识库管理"]) @@ -31,19 +28,9 @@ @router.post("/create", response_model=SuccessResponse) async def create_knowledge_base( request: KnowledgeBaseCreateReq, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """创建知识库 - - 对应 Java: POST /knowledge-base/create - - Args: - request: 知识库创建请求 - db: 数据库 session - - Returns: - 知识库 ID - """ + """创建知识库""" service = KnowledgeBaseService(db) knowledge_base_id = await service.create(request) return SuccessResponse(data=knowledge_base_id) @@ -53,17 +40,9 @@ async def create_knowledge_base( async def update_knowledge_base( knowledge_base_id: str, request: KnowledgeBaseUpdateReq, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """更新知识库 - - 对应 Java: PUT /knowledge-base/{id} - - Args: - knowledge_base_id: 知识库 ID - request: 知识库更新请求 - db: 数据库 session - """ + """更新知识库""" service = KnowledgeBaseService(db) await service.update(knowledge_base_id, request) return SuccessResponse(message="知识库更新成功") @@ -72,16 +51,9 @@ async def update_knowledge_base( @router.delete("/{knowledge_base_id}", response_model=SuccessResponse) async def delete_knowledge_base( knowledge_base_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """删除知识库 - - 对应 Java: DELETE /knowledge-base/{id} - - Args: - knowledge_base_id: 知识库 ID - db: 数据库 session - """ + """删除知识库""" service = KnowledgeBaseService(db) await service.delete(knowledge_base_id) return SuccessResponse(message="知识库删除成功") @@ -90,19 +62,9 @@ async def delete_knowledge_base( @router.get("/{knowledge_base_id}", response_model=SuccessResponse) async def get_knowledge_base( knowledge_base_id: str, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """获取知识库详情 - - 对应 Java: GET /knowledge-base/{id} - - Args: - knowledge_base_id: 知识库 ID - db: 数据库 session - - Returns: - 知识库详情 - """ + """获取知识库详情""" service = KnowledgeBaseService(db) knowledge_base = await service.get_by_id(knowledge_base_id) return SuccessResponse(data=knowledge_base) @@ -111,19 +73,9 @@ async def get_knowledge_base( @router.post("/list", response_model=SuccessResponse) async def list_knowledge_bases( request: KnowledgeBaseQueryReq, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """分页查询知识库列表 - - 对应 Java: POST /knowledge-base/list - - Args: - request: 查询请求 - db: 数据库 session - - Returns: - 知识库列表(分页) - """ + """分页查询知识库列表""" service = KnowledgeBaseService(db) result = await service.list(request) return SuccessResponse(data=result) @@ -133,27 +85,19 @@ async def list_knowledge_bases( async def add_files_to_knowledge_base( knowledge_base_id: str, request: AddFilesReq, - db: AsyncSession = Depends(get_db) + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), ): """添加文件到知识库 - 对应 Java: POST /knowledge-base/{id}/files - - Args: - knowledge_base_id: 知识库 ID - request: 添加文件请求 - db: 数据库 session - - Returns: - 包含成功添加数量和跳过文件数量的响应 + 文件记录存入数据库后立即返回,后台异步处理文件。 """ - # 确保 knowledge_base_id 与 request 中的一致 request.knowledge_base_id = knowledge_base_id service = KnowledgeBaseService(db) - result = await service.add_files(request) + result = await service.add_files(request, background_tasks) - message = f"文件添加成功,正在后台处理" + message = f"文件添加成功,正在后台处理 {result['success_count']} 个文件" if result["skipped_count"] > 0: message = f"成功添加 {result['success_count']} 个文件,跳过 {result['skipped_count']} 个不存在的文件" @@ -162,8 +106,8 @@ async def add_files_to_knowledge_base( data={ "successCount": result["success_count"], "skippedCount": result["skipped_count"], - "skippedFileIds": result["skipped_file_ids"] - } + "skippedFileIds": result["skipped_file_ids"], + }, ) @@ -171,20 +115,9 @@ async def add_files_to_knowledge_base( async def list_knowledge_base_files( knowledge_base_id: str, request: RagFileReq = Depends(), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """获取知识库文件列表 - - 对应 Java: GET /knowledge-base/{id}/files - - Args: - knowledge_base_id: 知识库 ID - request: 查询请求 - db: 数据库 session - - Returns: - 文件列表(分页) - """ + """获取知识库文件列表""" service = KnowledgeBaseService(db) result = await service.list_files(knowledge_base_id, request) return SuccessResponse(data=result) @@ -194,17 +127,9 @@ async def list_knowledge_base_files( async def delete_knowledge_base_files( knowledge_base_id: str, request: DeleteFilesReq, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """删除知识库文件 - - 对应 Java: DELETE /knowledge-base/{id}/files - - Args: - knowledge_base_id: 知识库 ID - request: 删除文件请求 - db: 数据库 session - """ + """删除知识库文件""" service = KnowledgeBaseService(db) await service.delete_files(knowledge_base_id, request) return SuccessResponse(message="文件删除成功") @@ -215,22 +140,10 @@ async def get_file_chunks( knowledge_base_id: str, rag_file_id: str, paging_query: PagingQuery = Depends(), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """获取指定 RAG 文件的分块列表 - - 对应 Java: GET /knowledge-base/{id}/files/{ragFileId} - - Args: - knowledge_base_id: 知识库 ID - rag_file_id: RAG 文件 ID - paging_query: 分页参数 - db: 数据库 session - - Returns: - 文件分块列表(分页) - """ - service = KnowledgeBaseService(db) + """获取指定 RAG 文件的分块列表""" + service = RetrievalService(db) result = await service.get_chunks(knowledge_base_id, rag_file_id, paging_query) return SuccessResponse(data=result) @@ -238,19 +151,9 @@ async def get_file_chunks( @router.post("/retrieve", response_model=SuccessResponse) async def retrieve_knowledge_base( request: RetrieveReq, - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ): - """检索知识库内容(向量 + BM25 混合检索) - - 对应 Java: POST /knowledge-base/retrieve - - Args: - request: 检索请求 - db: 数据库 session - - Returns: - 检索结果列表 - """ - service = KnowledgeBaseService(db) + """检索知识库内容(向量 + BM25 混合检索)""" + service = RetrievalService(db) results = await service.retrieve(request) return SuccessResponse(data=results) diff --git a/runtime/datamate-python/app/module/rag/schema/__init__.py b/runtime/datamate-python/app/module/rag/schema/__init__.py index 362bab907..e3670dc37 100644 --- a/runtime/datamate-python/app/module/rag/schema/__init__.py +++ b/runtime/datamate-python/app/module/rag/schema/__init__.py @@ -5,7 +5,7 @@ """ from .enums import ProcessType from app.db.models.knowledge_gen import RagType, FileStatus -from .entity import RagChunk +from .types import RagChunk from app.db.models.knowledge_gen import KnowledgeBase, RagFile from .request import ( KnowledgeBaseCreateReq, diff --git a/runtime/datamate-python/app/module/rag/schema/entity.py b/runtime/datamate-python/app/module/rag/schema/types.py similarity index 100% rename from runtime/datamate-python/app/module/rag/schema/entity.py rename to runtime/datamate-python/app/module/rag/schema/types.py diff --git a/runtime/datamate-python/app/module/rag/service/etl_service.py b/runtime/datamate-python/app/module/rag/service/etl_service.py deleted file mode 100644 index 18703dbf2..000000000 --- a/runtime/datamate-python/app/module/rag/service/etl_service.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -ETL 服务 - -实现文件的异步 ETL 处理流程,使用 LangChain Milvus 向量存储(密集向量 + BM25 全文检索)。 -对应 Java: com.datamate.rag.indexer.infra.event.RagEtlService -""" -import uuid -from pathlib import Path -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus -from app.module.rag.schema.request import AddFilesReq -from app.module.rag.repository import RagFileRepository, KnowledgeBaseRepository -from app.module.rag.infra.pipeline import ingest_file_to_chunks -from app.module.rag.infra.embeddings import EmbeddingFactory -from app.module.rag.infra.milvus.factory import VectorStoreFactory -from app.module.rag.infra.milvus.vectorstore import ( - chunks_to_langchain_documents, - create_java_compatible_collection, - get_vector_dimension, -) -from app.module.system.service.common_service import get_model_by_id -from app.module.rag.infra.task.worker_pool import WorkerPool -from app.core.config import settings -from app.core.exception import BusinessError, ErrorCodes -from app.db.session import AsyncSessionLocal - -import logging -import asyncio - -logger = logging.getLogger(__name__) - - -class ETLService: - """RAG ETL 服务类 - - 对应 Java: com.datamate.rag.indexer.infra.event.RagEtlService - - 替代 Java 方案: - - Java: @TransactionalEventListener + 虚拟线程 + 信号量 - - Python: asyncio + WorkerPool(信号量控制) - - 功能: - 1. 解析文档(从共享文件系统读取) - 2. 分块 - 3. 生成嵌入向量 - 4. 存储到 Milvus - 5. 更新文件状态 - """ - - def __init__(self, db: AsyncSession = None): - """初始化服务 - - Args: - db: 数据库异步 session(可选,后台任务时会创建新的) - """ - self.db = db - self.worker_pool = WorkerPool(max_workers=10) - - async def process_files_background( - self, - knowledge_base_id: str, - knowledge_base_name: str, - request_data: dict - ) -> None: - """后台处理文件的入口方法(使用新的数据库 session) - - 对应 Java 的 @TransactionalEventListener(phase = AFTER_COMMIT) + @Async - - Args: - knowledge_base_id: 知识库 ID - knowledge_base_name: 知识库名称 - request_data: 添加文件请求数据(dict 格式) - """ - # 创建新的数据库 session - async with AsyncSessionLocal() as db: - try: - file_repo = RagFileRepository(db) - kb_repo = KnowledgeBaseRepository(db) - - # 获取知识库实体 - knowledge_base = await kb_repo.get_by_id(knowledge_base_id) - if not knowledge_base: - logger.error(f"知识库不存在: {knowledge_base_id}") - return - - # 重建请求对象 - request = AddFilesReq.model_validate(request_data) - - # 获取待处理的文件 - files = await file_repo.get_unprocessed_files(knowledge_base_id) - - if not files: - logger.info(f"知识库 {knowledge_base_name} 没有待处理的文件") - return - - logger.info(f"开始处理 {len(files)} 个文件,知识库: {knowledge_base_name}") - - # 顺序处理文件(避免并发问题) - for file in files: - try: - await self._process_single_file_with_session( - db, file, knowledge_base, request - ) - except Exception as e: - logger.exception(f"文件 {file.file_name} 处理失败: {e}") - # 继续处理下一个文件 - - logger.info(f"知识库 {knowledge_base_name} 文件处理完成") - - except Exception as e: - logger.exception(f"后台处理文件失败: {e}") - finally: - await db.close() - - async def _process_single_file_with_session( - self, - db: AsyncSession, - rag_file: RagFile, - knowledge_base: KnowledgeBase, - request: AddFilesReq - ) -> None: - """处理单个文件的 ETL 流程(使用提供的 session) - - Args: - db: 数据库 session - rag_file: RAG 文件实体 - knowledge_base: 知识库实体 - request: 添加文件请求 - """ - file_repo = RagFileRepository(db) - - try: - # 1. 更新状态为处理中 - await file_repo.update_status(rag_file.id, FileStatus.PROCESSING) - await db.commit() - - # 2. 从 metadata 中获取文件路径和原始文件ID - file_path = rag_file.file_metadata.get("file_path") if rag_file.file_metadata else None - original_file_id = rag_file.file_id - dataset_id = rag_file.file_metadata.get("dataset_id") if rag_file.file_metadata else None - - # 2.1 验证文件路径 - if not file_path: - error_msg = f"文件路径未设置,file_metadata={rag_file.file_metadata}" - logger.error(error_msg) - await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) - await db.commit() - return - - # 2.2 确保使用绝对路径 - import os - file_path = os.path.abspath(file_path) - - # 2.3 验证文件存在 - if not Path(file_path).exists(): - error_msg = f"文件不存在: {file_path}" - logger.error(error_msg) - await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) - await db.commit() - return - - # 3. 准备完整的 metadata(不包含 file_path,避免与函数参数冲突) - file_extension = Path(file_path).suffix - base_metadata = { - "rag_file_id": rag_file.id, - "original_file_id": original_file_id, - "dataset_id": dataset_id, - "file_name": rag_file.file_name, - "file_extension": file_extension, - "knowledge_base_id": knowledge_base.id, - # file_path 不包含在此处,因为它作为位置参数传递 - } - - # 4. 加载并分块 - try: - chunks = await ingest_file_to_chunks( - file_path, - process_type=request.process_type, - chunk_size=request.chunk_size, - overlap_size=request.overlap_size, - delimiter=request.delimiter, - **base_metadata - ) - except Exception as e: - error_msg = f"文档解析或分块失败: {str(e)}" - logger.exception(f"文件 {rag_file.file_name} 解析失败: {e}") - await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) - await db.commit() - return - - if not chunks: - logger.warning(f"文件 {rag_file.file_name} 未生成任何分块") - await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg="文档解析后未生成任何分块") - await db.commit() - return - - logger.info(f"文件 {rag_file.file_name} 分块完成,共 {len(chunks)} 个分块") - - # 5. 写入 Milvus 向量存储 - try: - embedding_entity = await get_model_by_id(db, knowledge_base.embedding_model) - if not embedding_entity: - raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") - - # 5. 获取向量维度并创建集合 - try: - dimension = get_vector_dimension( - embedding_model=embedding_entity.model_name, - base_url=getattr(embedding_entity, "base_url", None), - api_key=getattr(embedding_entity, "api_key", None), - ) - # 集合将由 VectorStoreFactory.create() 自动创建(如果已存在则删除) - except BusinessError as e: - logger.warning("获取向量维度失败: %s", e) - raise - - embedding = EmbeddingFactory.create_embeddings( - model_name=embedding_entity.model_name, - base_url=getattr(embedding_entity, "base_url", None), - api_key=getattr(embedding_entity, "api_key", None), - ) - vectorstore = VectorStoreFactory.create( - collection_name=knowledge_base.name, - embedding=embedding, - ) - for c in chunks: - for key, value in base_metadata.items(): - if key not in c.metadata: - c.metadata[key] = value - ids = [str(uuid.uuid4()) for _ in chunks] - documents, doc_ids = chunks_to_langchain_documents(chunks, ids=ids) - vectorstore.add_documents(documents=documents, ids=doc_ids) - - except Exception as e: - error_msg = f"向量化或存储到 Milvus 失败: {str(e)}" - logger.exception(f"文件 {rag_file.file_name} 向量化失败: {e}") - await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=error_msg) - await db.commit() - return - - # 6. 更新文件状态为成功 - await file_repo.update_chunk_count(rag_file.id, len(chunks)) - await file_repo.update_status(rag_file.id, FileStatus.PROCESSED) - await db.commit() - - logger.info(f"文件 {rag_file.file_name} ETL 处理完成") - - except Exception as e: - logger.exception(f"文件 {rag_file.file_name} 处理失败: {e}") - await file_repo.update_status(rag_file.id, FileStatus.PROCESS_FAILED, err_msg=str(e)) - await db.commit() - - async def process_files( - self, - knowledge_base: KnowledgeBase, - request: AddFilesReq - ) -> None: - """处理文件的入口方法(在事务提交后调用)- 已废弃,使用 process_files_background - - 对应 Java 的 @TransactionalEventListener(phase = AFTER_COMMIT) - - Args: - knowledge_base: 知识库实体 - request: 添加文件请求 - """ - logger.warning("process_files is deprecated, use process_files_background instead") - # 这个方法保留用于兼容,但不推荐使用 - if not self.db: - logger.error("No database session available") - return - - file_repo = RagFileRepository(self.db) - files = await file_repo.get_unprocessed_files(knowledge_base.id) - - if not files: - logger.info(f"知识库 {knowledge_base.name} 没有待处理的文件") - return - - logger.info(f"开始处理 {len(files)} 个文件,知识库: {knowledge_base.name}") - - for file in files: - try: - await self._process_single_file_with_session(self.db, file, knowledge_base, request) - except Exception as e: - logger.exception(f"文件 {file.file_name} 处理失败: {e}") diff --git a/runtime/datamate-python/app/module/rag/service/file_processor.py b/runtime/datamate-python/app/module/rag/service/file_processor.py new file mode 100644 index 000000000..0db7e6006 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/service/file_processor.py @@ -0,0 +1,237 @@ +""" +文件处理器 + +负责文件的后台 ETL 处理:加载、分块、向量化、存储。 +使用全局 WorkerPool 实现并发控制,最多 10 个文件并行处理。 +""" +import logging +import os +import uuid +from pathlib import Path +from typing import List + +from fastapi import BackgroundTasks +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.exception import BusinessError, ErrorCodes +from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus +from app.db.session import AsyncSessionLocal +from app.module.rag.infra.document import ingest_file_to_chunks, DocumentChunk +from app.module.rag.infra.embeddings import EmbeddingFactory +from app.module.rag.infra.task.worker_pool import get_global_pool +from app.module.rag.infra.vectorstore import VectorStoreFactory, chunks_to_documents +from app.module.rag.repository import RagFileRepository, KnowledgeBaseRepository +from app.module.rag.schema.request import AddFilesReq +from app.module.system.service.common_service import get_model_by_id + +logger = logging.getLogger(__name__) + + +class FileProcessor: + """文件处理器 + + 负责文件的后台 ETL 处理,使用全局 WorkerPool 控制并发。 + """ + + def __init__(self): + """初始化处理器""" + self.worker_pool = get_global_pool(max_workers=10) + + def start_background_processing( + self, + background_tasks: BackgroundTasks, + knowledge_base_id: str, + knowledge_base_name: str, + request_data: dict, + ) -> None: + """启动后台文件处理 + + Args: + background_tasks: FastAPI BackgroundTasks + knowledge_base_id: 知识库 ID + knowledge_base_name: 知识库名称 + request_data: 添加文件请求数据 + """ + background_tasks.add_task( + self._process_files_background, + knowledge_base_id, + knowledge_base_name, + request_data, + ) + logger.info("已注册后台任务: 知识库=%s", knowledge_base_name) + + async def _process_files_background( + self, + knowledge_base_id: str, + knowledge_base_name: str, + request_data: dict, + ) -> None: + """后台处理文件(使用新的数据库 session)""" + async with AsyncSessionLocal() as db: + try: + kb_repo = KnowledgeBaseRepository(db) + file_repo = RagFileRepository(db) + + knowledge_base = await kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + logger.error("知识库不存在: %s", knowledge_base_id) + return + + request = AddFilesReq.model_validate(request_data) + files = await file_repo.get_unprocessed_files(knowledge_base_id) + + if not files: + logger.info("知识库 %s 没有待处理的文件", knowledge_base_name) + return + + logger.info("开始处理 %d 个文件,知识库: %s", len(files), knowledge_base_name) + + # 并发处理文件 + await self._process_files_concurrently(db, files, knowledge_base, request) + + logger.info("知识库 %s 文件处理完成", knowledge_base_name) + + except Exception as e: + logger.exception("后台处理文件失败: %s", e) + finally: + await db.close() + + async def _process_files_concurrently( + self, + db: AsyncSession, + files: List[RagFile], + knowledge_base: KnowledgeBase, + request: AddFilesReq, + ) -> None: + """并发处理多个文件(最多10个并行)""" + import asyncio + + async def process_with_semaphore(rag_file: RagFile): + async with self.worker_pool.semaphore: + await self._process_single_file(db, rag_file, knowledge_base, request) + + tasks = [process_with_semaphore(f) for f in files] + await asyncio.gather(*tasks, return_exceptions=True) + + async def _process_single_file( + self, + db: AsyncSession, + rag_file: RagFile, + knowledge_base: KnowledgeBase, + request: AddFilesReq, + ) -> None: + """处理单个文件的 ETL 流程""" + file_repo = RagFileRepository(db) + + try: + # 1. 更新状态为处理中 + await file_repo.update_status(rag_file.id, FileStatus.PROCESSING) + await db.commit() + + # 2. 验证文件 + file_path = self._get_file_path(rag_file) + if not file_path: + await self._mark_failed(db, file_repo, rag_file.id, "文件路径未设置") + return + + if not Path(file_path).exists(): + await self._mark_failed(db, file_repo, rag_file.id, f"文件不存在: {file_path}") + return + + # 3. 加载并分块 + metadata = self._build_chunk_metadata(rag_file, knowledge_base) + chunks = await ingest_file_to_chunks( + file_path, + process_type=request.process_type, + chunk_size=request.chunk_size, + overlap_size=request.overlap_size, + delimiter=request.delimiter, + **metadata, + ) + + if not chunks: + await self._mark_failed(db, file_repo, rag_file.id, "文档解析后未生成任何分块") + return + + logger.info("文件 %s 分块完成,共 %d 个分块", rag_file.file_name, len(chunks)) + + # 4. 向量化并存储 + await self._embed_and_store(db, chunks, metadata, knowledge_base) + + # 5. 更新文件状态为成功 + await file_repo.update_chunk_count(rag_file.id, len(chunks)) + await file_repo.update_status(rag_file.id, FileStatus.PROCESSED) + await db.commit() + + logger.info("文件 %s ETL 处理完成", rag_file.file_name) + + except Exception as e: + logger.exception("文件 %s 处理失败: %s", rag_file.file_name, e) + await self._mark_failed(db, file_repo, rag_file.id, str(e)) + + def _get_file_path(self, rag_file: RagFile) -> str | None: + """获取文件绝对路径""" + if not rag_file.file_metadata: + return None + + file_path = rag_file.file_metadata.get("file_path") + if file_path: + return os.path.abspath(file_path) + return None + + def _build_chunk_metadata(self, rag_file: RagFile, knowledge_base: KnowledgeBase) -> dict: + """构建分块元数据""" + file_path = self._get_file_path(rag_file) or "" + return { + "rag_file_id": rag_file.id, + "original_file_id": rag_file.file_id, + "dataset_id": rag_file.file_metadata.get("dataset_id") if rag_file.file_metadata else None, + "file_name": rag_file.file_name, + "file_extension": Path(file_path).suffix, + "knowledge_base_id": knowledge_base.id, + } + + async def _embed_and_store( + self, + db: AsyncSession, + chunks: List[DocumentChunk], + metadata: dict, + knowledge_base: KnowledgeBase, + ) -> None: + """向量化并存储到 Milvus""" + embedding_entity = await get_model_by_id(db, knowledge_base.embedding_model) + if not embedding_entity: + raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") + + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + + vectorstore = VectorStoreFactory.create( + collection_name=knowledge_base.name, + embedding=embedding, + ) + + # 补充 metadata + for chunk in chunks: + for key, value in metadata.items(): + if key not in chunk.metadata: + chunk.metadata[key] = value + + ids = [str(uuid.uuid4()) for _ in chunks] + documents, doc_ids = chunks_to_documents(chunks, ids=ids) + vectorstore.add_documents(documents=documents, ids=doc_ids) + + async def _mark_failed( + self, + db: AsyncSession, + file_repo: RagFileRepository, + rag_file_id: str, + err_msg: str, + ) -> None: + """标记文件处理失败""" + logger.error("文件处理失败: %s", err_msg) + await file_repo.update_status(rag_file_id, FileStatus.PROCESS_FAILED, err_msg=err_msg) + await db.commit() diff --git a/runtime/datamate-python/app/module/rag/service/file_service.py b/runtime/datamate-python/app/module/rag/service/file_service.py deleted file mode 100644 index a2c50584c..000000000 --- a/runtime/datamate-python/app/module/rag/service/file_service.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -文件管理服务 - -实现文件相关的业务逻辑 -""" -import uuid -from typing import List, Tuple -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select - -from app.db.models.knowledge_gen import RagFile, FileStatus -from app.db.models.dataset_management import DatasetFiles -from app.module.rag.schema.request import AddFilesReq -from app.module.rag.repository import RagFileRepository, KnowledgeBaseRepository -from app.module.rag.infra.milvus.vectorstore import delete_chunks_by_rag_file_ids -from app.core.exception import BusinessError, ErrorCodes - -import logging - -logger = logging.getLogger(__name__) - - -class FileService: - """文件管理服务类 - - 功能: - 1. 添加文件到知识库 - 2. 删除文件 - 3. 查询文件 - """ - - def __init__(self, db: AsyncSession): - """初始化服务 - - Args: - db: 数据库异步 session - """ - self.db = db - self.file_repo = RagFileRepository(db) - self.kb_repo = KnowledgeBaseRepository(db) - - async def add_files(self, request: AddFilesReq) -> Tuple[List[RagFile], List[str]]: - """添加文件到知识库 - - Args: - request: 添加文件请求 - - Returns: - (创建的 RAG 文件列表, 跳过的文件ID列表) - - Raises: - BusinessError: 知识库不存在 - """ - # 验证知识库存在 - knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id) - if not knowledge_base: - raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - - # 验证文件列表不为空 - if not request.files or len(request.files) == 0: - raise BusinessError(ErrorCodes.BAD_REQUEST, "文件列表不能为空") - - # 验证文件存在并创建 RAG 文件记录 - rag_files = [] - skipped_file_ids = [] - - for file_info in request.files: - try: - # 根据 file_info.id (DatasetFile ID) 查询文件信息 - result = await self.db.execute( - select(DatasetFiles).where(DatasetFiles.id == file_info.id) - ) - dataset_file = result.scalar_one_or_none() - - # 跳过不存在的文件 - if not dataset_file: - logger.warning( - f"文件不存在,跳过处理: file_id={file_info.id}" - ) - skipped_file_ids.append(file_info.id) - continue - - # 创建 RAG 文件记录,存储 dataset_id 和 file_path 到 metadata - rag_file = RagFile( - id=str(uuid.uuid4()), - knowledge_base_id=request.knowledge_base_id, - file_name=dataset_file.file_name, - file_id=file_info.id, - file_metadata={ - "process_type": request.process_type.value, - "dataset_id": dataset_file.dataset_id, - "file_path": dataset_file.file_path - }, - status=FileStatus.UNPROCESSED, - ) - rag_files.append(rag_file) - - except Exception as e: - logger.error( - f"处理文件信息失败: file_id={file_info.id}, error={e}" - ) - skipped_file_ids.append(file_info.id) - continue - - # 批量保存 - if rag_files: - await self.file_repo.batch_create(rag_files) - logger.info(f"成功添加 {len(rag_files)} 个文件到知识库: {knowledge_base.name}") - - if skipped_file_ids: - logger.warning(f"跳过 {len(skipped_file_ids)} 个文件: {skipped_file_ids}") - - return rag_files, skipped_file_ids - - async def delete_files( - self, - knowledge_base_id: str, - file_ids: List[str] - ) -> None: - """删除文件 - - Args: - knowledge_base_id: 知识库 ID - file_ids: 文件 ID 列表 - - Raises: - BusinessError: 知识库不存在 - """ - # 验证文件列表不为空 - if not file_ids or len(file_ids) == 0: - raise BusinessError(ErrorCodes.BAD_REQUEST, "文件ID列表不能为空") - - # 获取知识库 - knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) - if not knowledge_base: - raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - - # 获取文件列表(需要删除 Milvus 数据) - rag_files = [] - for file_id in file_ids: - try: - rag_file = await self.file_repo.get_by_id(file_id) - if rag_file: - rag_files.append(rag_file) - else: - logger.warning(f"文件不存在,跳过删除: {file_id}") - except Exception as e: - logger.error(f"查询文件失败: {file_id}, error={e}") - continue - - # 删除 Milvus 中该文件对应的分块数据 - if rag_files: - try: - delete_chunks_by_rag_file_ids( - knowledge_base.name, - [r.id for r in rag_files], - ) - except Exception as e: - logger.error("删除 Milvus 数据失败: %s", e) - # 继续删除数据库记录 - else: - logger.warning("没有找到有效的文件,跳过 Milvus 数据删除") - - # 删除数据库记录 - deleted_count = 0 - for file_id in file_ids: - try: - await self.file_repo.delete(file_id) - deleted_count += 1 - except Exception as e: - logger.error(f"删除数据库记录失败: {file_id}, error={e}") - continue - - logger.info(f"成功删除 {deleted_count} 个文件") diff --git a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py index 60d936490..8aa19881e 100644 --- a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py +++ b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py @@ -1,20 +1,21 @@ """ 知识库业务服务 -实现知识库的 CRUD 操作和业务逻辑 +实现知识库的 CRUD 操作和文件管理。 对应 Java: com.datamate.rag.indexer.application.KnowledgeBaseService """ import logging import uuid -from typing import List +from typing import List, Tuple +from fastapi import BackgroundTasks +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.exception import BusinessError, ErrorCodes -from app.core.config import settings -from app.module.rag.infra.milvus.vectorstore import drop_collection, rename_collection -from app.module.rag.infra.embeddings import EmbeddingFactory -from app.db.models.knowledge_gen import KnowledgeBase +from app.db.models.dataset_management import DatasetFiles +from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus +from app.module.rag.infra.vectorstore import drop_collection, rename_collection, delete_chunks_by_rag_file_ids from app.module.rag.repository import KnowledgeBaseRepository, RagFileRepository from app.module.rag.schema.request import ( KnowledgeBaseCreateReq, @@ -23,12 +24,9 @@ AddFilesReq, DeleteFilesReq, RagFileReq, - RetrieveReq, - PagingQuery, ) -from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagChunkResp, RagFileResp -from app.module.rag.service.etl_service import ETLService -from app.module.rag.service.file_service import FileService +from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagFileResp +from app.module.rag.service.file_processor import FileProcessor logger = logging.getLogger(__name__) @@ -36,12 +34,9 @@ class KnowledgeBaseService: """知识库业务服务类 - 对应 Java: com.datamate.rag.indexer.application.KnowledgeBaseService - 功能: 1. 知识库 CRUD 操作 - 2. 文件管理 - 3. 检索功能 + 2. 文件管理(添加、删除、查询) """ def __init__(self, db: AsyncSession): @@ -53,71 +48,51 @@ def __init__(self, db: AsyncSession): self.db = db self.kb_repo = KnowledgeBaseRepository(db) self.file_repo = RagFileRepository(db) - self.file_service = FileService(db) - self.etl_service = ETLService(db) + self.file_processor = FileProcessor() + + # ==================== 知识库 CRUD ==================== async def create(self, request: KnowledgeBaseCreateReq) -> str: """创建知识库 - 对应 Java: create 方法 - Args: request: 创建请求 Returns: 知识库 ID - - Raises: - BusinessError: 知识库名称已存在 """ - # 创建知识库实体 knowledge_base = KnowledgeBase( id=str(uuid.uuid4()), name=request.name, description=request.description, type=request.type, embedding_model=request.embedding_model, - chat_model=request.chat_model + chat_model=request.chat_model, ) - # 保存到数据库 knowledge_base = await self.kb_repo.create(knowledge_base) - - # Milvus 集合由 LangChain Milvus 在首次 ETL add_documents 时自动创建(含 BM25 全文检索) - logger.info(f"成功创建知识库: {request.name}") - - # 提交事务 await self.db.commit() + logger.info("成功创建知识库: %s", request.name) return knowledge_base.id async def update(self, knowledge_base_id: str, request: KnowledgeBaseUpdateReq) -> None: """更新知识库 - 对应 Java: update 方法 - Args: knowledge_base_id: 知识库 ID request: 更新请求 - - Raises: - BusinessError: 知识库不存在 """ - # 获取现有知识库 knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) old_name = knowledge_base.name - - # 更新字段 knowledge_base.name = request.name knowledge_base.description = request.description - # 更新数据库 await self.kb_repo.update(knowledge_base) - # 如果名称变更,重命名 Milvus 集合 if old_name != request.name: try: rename_collection(old_name, request.name) @@ -130,26 +105,16 @@ async def update(self, knowledge_base_id: str, request: KnowledgeBaseUpdateReq) async def delete(self, knowledge_base_id: str) -> None: """删除知识库 - 对应 Java: delete 方法 - Args: knowledge_base_id: 知识库 ID - - Raises: - BusinessError: 知识库不存在 """ - # 获取知识库 knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - # 删除所有文件 await self.file_repo.delete_by_knowledge_base(knowledge_base_id) - - # 删除知识库 await self.kb_repo.delete(knowledge_base_id) - # 删除 Milvus 集合 try: drop_collection(knowledge_base.name) except Exception as e: @@ -160,26 +125,20 @@ async def delete(self, knowledge_base_id: str) -> None: async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: """获取知识库详情 - 对应 Java: getById 方法 - Args: knowledge_base_id: 知识库 ID Returns: 知识库响应对象 - - Raises: - BusinessError: 知识库不存在 """ knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - # 统计文件数量 file_count = await self.file_repo.count_by_knowledge_base(knowledge_base_id) chunk_count = await self.file_repo.count_chunks_by_knowledge_base(knowledge_base_id) - response = KnowledgeBaseResp( + return KnowledgeBaseResp( id=knowledge_base.id, name=knowledge_base.name, description=knowledge_base.description, @@ -191,16 +150,12 @@ async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: created_at=knowledge_base.created_at, updated_at=knowledge_base.updated_at, created_by=knowledge_base.created_by, - updated_by=knowledge_base.updated_by + updated_by=knowledge_base.updated_by, ) - return response - async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: """分页查询知识库列表 - 对应 Java: list 方法 - Args: request: 查询请求 @@ -211,16 +166,14 @@ async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: keyword=request.keyword, rag_type=request.type, page=request.page, - page_size=request.page_size + page_size=request.page_size, ) - # 转换为响应对象 responses = [] for item in items: file_count = await self.file_repo.count_by_knowledge_base(item.id) chunk_count = await self.file_repo.count_chunks_by_knowledge_base(item.id) - - response = KnowledgeBaseResp( + responses.append(KnowledgeBaseResp( id=item.id, name=item.name, description=item.description, @@ -232,60 +185,104 @@ async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: created_at=item.created_at, updated_at=item.updated_at, created_by=item.created_by, - updated_by=item.updated_by - ) - responses.append(response) + updated_by=item.updated_by, + )) return PagedResponse.create( content=responses, total_elements=total, page=request.page, - size=request.page_size + size=request.page_size, ) - async def add_files(self, request: AddFilesReq) -> dict: + # ==================== 文件管理 ==================== + + async def add_files( + self, + request: AddFilesReq, + background_tasks: BackgroundTasks = None, + ) -> dict: """添加文件到知识库 - 对应 Java: addFiles 方法 + 存入数据库后立即返回,后台异步处理文件。 Args: request: 添加文件请求 + background_tasks: FastAPI 后台任务 Returns: 包含成功和跳过文件数量的字典 - - Raises: - BusinessError: 知识库不存在 """ - # 验证知识库存在 knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) # 添加文件记录 - rag_files, skipped_file_ids = await self.file_service.add_files(request) + rag_files, skipped_file_ids = await self._create_rag_files(request) - # 提交事务后触发 ETL 处理 await self.db.commit() - # 异步处理文件(在事务提交后) - if rag_files: - await self.etl_service.process_files(knowledge_base, request) + # 启动后台处理 + if rag_files and background_tasks: + self.file_processor.start_background_processing( + background_tasks=background_tasks, + knowledge_base_id=knowledge_base.id, + knowledge_base_name=knowledge_base.name, + request_data=request.model_dump(), + ) return { "success_count": len(rag_files), "skipped_count": len(skipped_file_ids), - "skipped_file_ids": skipped_file_ids + "skipped_file_ids": skipped_file_ids, } - async def list_files( - self, - knowledge_base_id: str, - request: RagFileReq - ) -> PagedResponse: - """获取知识库文件列表 + async def _create_rag_files(self, request: AddFilesReq) -> Tuple[List[RagFile], List[str]]: + """创建 RAG 文件记录""" + if not request.files: + raise BusinessError(ErrorCodes.BAD_REQUEST, "文件列表不能为空") - 对应 Java: listFiles 方法 + rag_files = [] + skipped_file_ids = [] + + for file_info in request.files: + try: + result = await self.db.execute( + select(DatasetFiles).where(DatasetFiles.id == file_info.id) + ) + dataset_file = result.scalar_one_or_none() + + if not dataset_file: + logger.warning("文件不存在,跳过: file_id=%s", file_info.id) + skipped_file_ids.append(file_info.id) + continue + + rag_file = RagFile( + id=str(uuid.uuid4()), + knowledge_base_id=request.knowledge_base_id, + file_name=dataset_file.file_name, + file_id=file_info.id, + file_metadata={ + "process_type": request.process_type.value, + "dataset_id": dataset_file.dataset_id, + "file_path": dataset_file.file_path, + }, + status=FileStatus.UNPROCESSED, + ) + rag_files.append(rag_file) + + except Exception as e: + logger.error("处理文件信息失败: file_id=%s, error=%s", file_info.id, e) + skipped_file_ids.append(file_info.id) + + if rag_files: + await self.file_repo.batch_create(rag_files) + logger.info("成功添加 %d 个文件到知识库", len(rag_files)) + + return rag_files, skipped_file_ids + + async def list_files(self, knowledge_base_id: str, request: RagFileReq) -> PagedResponse: + """获取知识库文件列表 Args: knowledge_base_id: 知识库 ID @@ -293,11 +290,7 @@ async def list_files( Returns: 分页响应 - - Raises: - BusinessError: 知识库不存在 """ - # 验证知识库存在 if not await self.kb_repo.get_by_id(knowledge_base_id): raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) @@ -306,10 +299,9 @@ async def list_files( keyword=request.keyword, status=request.status, page=request.page, - page_size=request.page_size + page_size=request.page_size, ) - # 转换为响应对象 responses = [RagFileResp( id=item.id, knowledge_base_id=item.knowledge_base_id, @@ -322,237 +314,53 @@ async def list_files( created_at=item.created_at, updated_at=item.updated_at, created_by=item.created_by, - updated_by=item.updated_by + updated_by=item.updated_by, ) for item in items] return PagedResponse.create( content=responses, total_elements=total, page=request.page, - size=request.page_size + size=request.page_size, ) async def delete_files(self, knowledge_base_id: str, request: DeleteFilesReq) -> None: """删除知识库文件 - 对应 Java: deleteFiles 方法 - Args: knowledge_base_id: 知识库 ID request: 删除文件请求 - - Raises: - BusinessError: 知识库不存在 - """ - # 验证知识库存在 - knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) - if not knowledge_base: - raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - - # 删除文件(包括 Milvus 数据) - await self.file_service.delete_files(knowledge_base_id, request.file_ids) - - await self.db.commit() - - async def retrieve(self, request: RetrieveReq) -> List[dict]: - """检索知识库内容 - - 对应 Java: retrieve 方法 - - 使用混合检索(向量 + BM25) - - Args: - request: 检索请求 - - Returns: - 检索结果列表 - - Raises: - BusinessError: 知识库不存在 """ - import asyncio - - # 1. 验证所有知识库存在 - knowledge_bases = [] - for kb_id in request.knowledge_base_ids: - kb = await self.kb_repo.get_by_id(kb_id) - if not kb: - raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - knowledge_bases.append(kb) - - # 2. 获取嵌入模型(使用第一个知识库的配置) - embedding_entity = await get_model_by_id(self.db, knowledge_bases[0].embedding_model) - if not embedding_entity: - raise BusinessError(ErrorCodes.RAG_MODEL_NOT_FOUND) - - # 3. 创建嵌入模型实例 - embedding = EmbeddingFactory.create_embeddings( - model_name=embedding_entity.model_name, - base_url=getattr(embedding_entity, "base_url", None), - api_key=getattr(embedding_entity, "api_key", None), - ) - - # 4. 生成查询向量 - try: - query_vector = await asyncio.to_thread(embedding.embed_query, request.query) - except Exception as e: - logger.error(f"Failed to embed query: {e}") - raise BusinessError(ErrorCodes.RAG_EMBEDDING_FAILED, f"查询向量化失败: {str(e)}") from e - - # 5. 执行混合检索(向量 + BM25) - from pymilvus import MilvusClient - - all_results = [] - - try: - client = MilvusClient(uri=settings.milvus_uri) - - for kb in knowledge_bases: - try: - # 检查集合是否存在 - if not client.has_collection(kb.name): - logger.warning(f"Collection {kb.name} does not exist, skipping") - continue - - # 混合检索:密集向量 + 稀疏向量(BM25) - search_results = client.hybrid_search( - collection_name=kb.name, - data=[ - { - "vector": query_vector, - "sparse": request.query - } - ], - anns_field=["vector", "sparse"], - limit=request.top_k, - ranker={ - "type": "weighted", - "weights": [0.1, 0.9] # 10% 向量相似度,90% BM25 关键词匹配 - } - ) - - # 提取结果 - if search_results and len(search_results) > 0: - for result in search_results[0]: - result["knowledge_base_id"] = kb.id - result["knowledge_base_name"] = kb.name - all_results.append(result) - - except Exception as e: - logger.error(f"Hybrid search failed for kb {kb.name}: {e}") - # 继续处理其他知识库 - continue - - except Exception as e: - logger.error(f"Milvus client initialization or search failed: {e}") - raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"检索失败: {str(e)}") from e - - # 6. 按分数降序排序 - all_results.sort(key=lambda x: x.get("distance", 0), reverse=True) - - # 7. 应用阈值过滤 - if request.threshold is not None: - all_results = [r for r in all_results if r.get("distance", 0) >= request.threshold] - - # 8. 格式化返回结果 - formatted_results = [] - for r in all_results: - entity = r.get("entity", {}) - formatted_results.append({ - "id": entity.get("id", ""), - "text": entity.get("text", ""), - "metadata": entity.get("metadata", {}), - "score": r.get("distance", 0), - "knowledgeBaseId": r.get("knowledge_base_id", ""), - "knowledgeBaseName": r.get("knowledge_base_name", "") - }) - - logger.info(f"Retrieve completed: query='{request.query}' results={len(formatted_results)}") - return formatted_results - - async def get_chunks( - self, - knowledge_base_id: str, - rag_file_id: str, - paging_query: PagingQuery - ) -> PagedResponse: - """获取指定 RAG 文件的分块列表 - - 对应 Java: getChunks 方法 - - 从 Milvus 查询指定 rag_file_id 的分块,支持分页。 - - Args: - knowledge_base_id: 知识库 ID - rag_file_id: RAG 文件 ID - paging_query: 分页参数 - - Returns: - 分块列表(分页) - """ - # 验证知识库存在 knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - # 验证文件存在 - rag_file = await self.file_repo.get_by_id(rag_file_id) - if not rag_file: - raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) - - # 使用 MilvusClient 查询指定文件的分块 - from pymilvus import MilvusClient + if not request.file_ids: + raise BusinessError(ErrorCodes.BAD_REQUEST, "文件ID列表不能为空") - from app.core.exception import BusinessError as BE, ErrorCodes as EC + # 获取文件列表 + rag_files = [] + for file_id in request.file_ids: + rag_file = await self.file_repo.get_by_id(file_id) + if rag_file: + rag_files.append(rag_file) - try: - conn_args = settings.milvus_uri - token = getattr(settings, "milvus_token", None) - client = MilvusClient(uri=conn_args, token=token) - - # 查询总数 - count_filter_expr = f'metadata["rag_file_id"] == "{rag_file_id}"' - count_res = client.query( - collection_name=knowledge_base.name, - filter=count_filter_expr, - output_fields=["id"] - ) - total = len(count_res) - - # 查询分页数据 - offset = (paging_query.page - 1) * paging_query.size - filter_expr = f'metadata["rag_file_id"] == "{rag_file_id}"' - results = client.query( - collection_name=knowledge_base.name, - filter=filter_expr, - output_fields=["id", "text", "metadata"], - limit=paging_query.size, - offset=offset - ) - - # 转换为 RagChunkResp - chunks = [] - for item in results: - chunks.append(RagChunkResp( - id=item.get("id", ""), - text=item.get("text", ""), - metadata=item.get("metadata", {}), - score=0.0 # 非相似度查询,默认分数为 0 - )) - - logger.info( - "查询文件分块成功: kb=%s file=%s total=%d page=%d size=%d", - knowledge_base_id, rag_file_id, total, paging_query.page, paging_query.size - ) - - return PagedResponse.create( - content=chunks, - total_elements=total, - page=paging_query.page, - size=paging_query.size - ) - - except Exception as e: - logger.error("查询文件分块失败: kb=%s file=%s error=%s", knowledge_base_id, rag_file_id, e) - raise BE(EC.RAG_MILVUS_ERROR, f"查询文件分块失败: {str(e)}") from e + # 删除 Milvus 数据 + if rag_files: + try: + delete_chunks_by_rag_file_ids( + knowledge_base.name, + [r.id for r in rag_files], + ) + except Exception as e: + logger.error("删除 Milvus 数据失败: %s", e) + + # 删除数据库记录 + for file_id in request.file_ids: + try: + await self.file_repo.delete(file_id) + except Exception as e: + logger.error("删除数据库记录失败: %s, error=%s", file_id, e) + await self.db.commit() + logger.info("成功删除 %d 个文件", len(rag_files)) diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 75c11402a..2773f66a4 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -13,7 +13,7 @@ from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus from app.db.session import get_db, AsyncSessionLocal from app.module.rag.infra.embeddings import EmbeddingFactory -from app.module.rag.infra.milvus.factory import VectorStoreFactory +from app.module.rag.infra.vectorstore import VectorStoreFactory from app.module.shared.common.document_loaders import load_documents from .graph_rag import ( DEFAULT_WORKING_DIR, diff --git a/runtime/datamate-python/app/module/rag/service/retrieval_service.py b/runtime/datamate-python/app/module/rag/service/retrieval_service.py new file mode 100644 index 000000000..58d573231 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/service/retrieval_service.py @@ -0,0 +1,218 @@ +""" +检索服务 + +负责知识库内容的检索,支持向量 + BM25 混合检索。 +""" +import logging +from typing import List + +from pymilvus import MilvusClient +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import settings +from app.core.exception import BusinessError, ErrorCodes +from app.module.rag.infra.embeddings import EmbeddingFactory +from app.module.rag.repository import KnowledgeBaseRepository, RagFileRepository +from app.module.rag.schema.request import RetrieveReq, PagingQuery +from app.module.rag.schema.response import PagedResponse, RagChunkResp +from app.module.system.service.common_service import get_model_by_id + +logger = logging.getLogger(__name__) + + +class RetrievalService: + """检索服务类 + + 提供知识库内容的混合检索(向量 + BM25)和分块查询功能。 + """ + + def __init__(self, db: AsyncSession): + """初始化服务 + + Args: + db: 数据库异步 session + """ + self.db = db + self.kb_repo = KnowledgeBaseRepository(db) + self.file_repo = RagFileRepository(db) + + async def retrieve(self, request: RetrieveReq) -> List[dict]: + """检索知识库内容(混合检索:向量 + BM25) + + Args: + request: 检索请求 + + Returns: + 检索结果列表 + + Raises: + BusinessError: 知识库不存在或嵌入模型不存在 + """ + import asyncio + + # 验证所有知识库存在 + knowledge_bases = [] + for kb_id in request.knowledge_base_ids: + kb = await self.kb_repo.get_by_id(kb_id) + if not kb: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + knowledge_bases.append(kb) + + # 获取嵌入模型 + embedding_entity = await get_model_by_id(self.db, knowledge_bases[0].embedding_model) + if not embedding_entity: + raise BusinessError(ErrorCodes.RAG_MODEL_NOT_FOUND) + + # 创建嵌入模型实例 + embedding = EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + + # 生成查询向量 + try: + query_vector = await asyncio.to_thread(embedding.embed_query, request.query) + except Exception as e: + logger.error("查询向量化失败: %s", e) + raise BusinessError(ErrorCodes.RAG_EMBEDDING_FAILED, f"查询向量化失败: {str(e)}") from e + + # 执行混合检索 + all_results = await self._execute_hybrid_search(knowledge_bases, query_vector, request.query, request.top_k) + + # 按分数排序 + all_results.sort(key=lambda x: x.get("distance", 0), reverse=True) + + # 应用阈值过滤 + if request.threshold is not None: + all_results = [r for r in all_results if r.get("distance", 0) >= request.threshold] + + # 格式化返回结果 + return self._format_results(all_results) + + async def _execute_hybrid_search( + self, + knowledge_bases: list, + query_vector: list, + query_text: str, + top_k: int, + ) -> List[dict]: + """执行混合检索""" + all_results = [] + token = getattr(settings, "milvus_token", None) + client = MilvusClient(uri=settings.milvus_uri, token=token or "") + + for kb in knowledge_bases: + try: + if not client.has_collection(kb.name): + logger.warning("集合 %s 不存在,跳过", kb.name) + continue + + search_results = client.hybrid_search( + collection_name=kb.name, + data=[{"vector": query_vector, "sparse": query_text}], + anns_field=["vector", "sparse"], + limit=top_k, + ranker={"type": "weighted", "weights": [0.1, 0.9]}, + ) + + if search_results and len(search_results) > 0: + for result in search_results[0]: + result["knowledge_base_id"] = kb.id + result["knowledge_base_name"] = kb.name + all_results.append(result) + + except Exception as e: + logger.error("知识库 %s 混合检索失败: %s", kb.name, e) + continue + + return all_results + + def _format_results(self, all_results: List[dict]) -> List[dict]: + """格式化检索结果""" + formatted = [] + for r in all_results: + entity = r.get("entity", {}) + formatted.append({ + "id": entity.get("id", ""), + "text": entity.get("text", ""), + "metadata": entity.get("metadata", {}), + "score": r.get("distance", 0), + "knowledgeBaseId": r.get("knowledge_base_id", ""), + "knowledgeBaseName": r.get("knowledge_base_name", ""), + }) + + logger.info("检索完成: 结果数=%d", len(formatted)) + return formatted + + async def get_chunks( + self, + knowledge_base_id: str, + rag_file_id: str, + paging_query: PagingQuery, + ) -> PagedResponse: + """获取指定 RAG 文件的分块列表 + + Args: + knowledge_base_id: 知识库 ID + rag_file_id: RAG 文件 ID + paging_query: 分页参数 + + Returns: + 分块列表(分页) + """ + # 验证知识库和文件存在 + knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) + if not knowledge_base: + raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) + + rag_file = await self.file_repo.get_by_id(rag_file_id) + if not rag_file: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + # 查询 Milvus + token = getattr(settings, "milvus_token", None) + client = MilvusClient(uri=settings.milvus_uri, token=token or "") + + try: + # 查询总数 + count_filter = f'metadata["rag_file_id"] == "{rag_file_id}"' + count_res = client.query( + collection_name=knowledge_base.name, + filter=count_filter, + output_fields=["id"], + ) + total = len(count_res) + + # 查询分页数据 + offset = (paging_query.page - 1) * paging_query.size + results = client.query( + collection_name=knowledge_base.name, + filter=count_filter, + output_fields=["id", "text", "metadata"], + limit=paging_query.size, + offset=offset, + ) + + chunks = [ + RagChunkResp( + id=item.get("id", ""), + text=item.get("text", ""), + metadata=item.get("metadata", {}), + score=0.0, + ) + for item in results + ] + + logger.info("查询文件分块成功: kb=%s file=%s total=%d", knowledge_base_id, rag_file_id, total) + + return PagedResponse.create( + content=chunks, + total_elements=total, + page=paging_query.page, + size=paging_query.size, + ) + + except Exception as e: + logger.error("查询文件分块失败: kb=%s file=%s error=%s", knowledge_base_id, rag_file_id, e) + raise BusinessError(ErrorCodes.RAG_MILVUS_ERROR, f"查询文件分块失败: {str(e)}") from e From cdfea22868ef3fed0fd2322271c8f172e8999a0d Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Thu, 26 Feb 2026 16:42:08 +0800 Subject: [PATCH 06/13] feat: add progress tracking for RAG file processing and enhance worker pool status management --- .../app/db/models/knowledge_gen.py | 1 + .../app/module/rag/infra/task/worker_pool.py | 22 +- .../module/rag/infra/vectorstore/factory.py | 18 +- .../app/module/rag/infra/vectorstore/store.py | 2 +- .../module/rag/repository/file_repository.py | 21 ++ .../app/module/rag/schema/response.py | 2 + .../app/module/rag/service/file_processor.py | 280 ++++++++++++++---- .../rag/service/knowledge_base_service.py | 11 +- scripts/db/rag-management-init.sql | 2 + 9 files changed, 289 insertions(+), 70 deletions(-) diff --git a/runtime/datamate-python/app/db/models/knowledge_gen.py b/runtime/datamate-python/app/db/models/knowledge_gen.py index c0bf37d34..fd1a4df3d 100644 --- a/runtime/datamate-python/app/db/models/knowledge_gen.py +++ b/runtime/datamate-python/app/db/models/knowledge_gen.py @@ -76,6 +76,7 @@ class RagFile(BaseEntity): comment="处理状态", ) err_msg = Column(String(2048), nullable=True, comment="错误信息") + progress = Column(Integer, default=0, nullable=False, comment="处理进度(0-100)") def __repr__(self): return f"" diff --git a/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py b/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py index 733e8ac64..60f0481c1 100644 --- a/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py +++ b/runtime/datamate-python/app/module/rag/infra/task/worker_pool.py @@ -59,6 +59,8 @@ def __init__(self, max_workers: int = 10): """ self.semaphore = asyncio.Semaphore(max_workers) self.max_workers = max_workers + self._lock = asyncio.Lock() + self._active_count = 0 async def submit( self, @@ -77,12 +79,18 @@ async def submit( 协程的返回值 """ async with self.semaphore: + async with self._lock: + self._active_count += 1 + try: result = await coro(*args, **kwargs) return result except Exception as e: logger.error(f"任务执行失败: {e}") raise + finally: + async with self._lock: + self._active_count -= 1 async def submit_batch( self, @@ -116,12 +124,14 @@ async def submit_batch( return [r for r in results if not isinstance(r, Exception)] - def get_available_workers(self) -> int: - """获取可用的工作协程数 + async def get_status(self) -> dict: + """获取工作池状态 Returns: - 可用的工作协程数 + 状态字典,包含 max_workers, active_count, available """ - # 注意:Semaphore 的值在内部维护,无法直接获取 - # 这里返回最大值作为近似 - return self.max_workers + return { + "max_workers": self.max_workers, + "active_count": self._active_count, + "available": self.max_workers - self._active_count, + } diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py index cb026f38a..28fffd778 100644 --- a/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py @@ -33,7 +33,7 @@ def create( collection_name: str, embedding: Embeddings, *, - drop_old: bool = True, + drop_old: bool = False, consistency_level: str = "Strong", ) -> Any: """创建 Milvus 向量存储实例(支持混合检索) @@ -41,7 +41,7 @@ def create( Args: collection_name: 集合名称(知识库名称) embedding: LangChain Embeddings 实例 - drop_old: 是否删除已存在同名集合 + drop_old: 是否删除已存在同名集合(默认 False,避免数据丢失) consistency_level: 一致性级别 Returns: @@ -55,18 +55,22 @@ def create( embedding_instance=embedding, ) - # 删除旧集合(如果存在) + # 根据参数删除旧集合(仅在显式要求时) if drop_old: drop_collection(collection_name) - # 创建集合(5个字段:id、text、metadata、vector、sparse) + # 仅在集合不存在时创建(避免覆盖已有数据) create_collection( collection_name=collection_name, dimension=dimension, consistency_level=consistency_level, ) - # 创建 Milvus 实例 + # 创建 Milvus 实例(drop_old=False,确保不会自动删除) + # 注意:必须指定 primary_field="id",因为集合定义中主键字段名为 id + # 注意:必须指定 metadata_field="metadata",确保所有 metadata 数据存放到此字段 + # 注意:enable_dynamic_field=False,禁止使用动态字段没,所有数据只存放到定义的字段中 + # 注意:metadata_schema 指定 metadata 字段类型为 JSON return Milvus( embedding_function=embedding, collection_name=collection_name, @@ -77,4 +81,8 @@ def create( drop_old=False, consistency_level=consistency_level, auto_id=False, + primary_field="id", + metadata_field="metadata", + enable_dynamic_field=False, + metadata_schema={"metadata": "JSON"}, ) diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py index d9c555ccd..d61ec74fd 100644 --- a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py @@ -125,7 +125,7 @@ def create_collection( fields=fields, functions=[bm25_function], description="Knowledge base collection", - enable_dynamic_field=True, + enable_dynamic_field=False, ) client.create_collection( diff --git a/runtime/datamate-python/app/module/rag/repository/file_repository.py b/runtime/datamate-python/app/module/rag/repository/file_repository.py index 94092de75..cb1d13f63 100644 --- a/runtime/datamate-python/app/module/rag/repository/file_repository.py +++ b/runtime/datamate-python/app/module/rag/repository/file_repository.py @@ -283,6 +283,27 @@ async def update_chunk_count( rag_file.chunk_count = chunk_count await self.db.flush() + async def update_progress( + self, + rag_file_id: str, + progress: int + ) -> None: + """更新文件处理进度 + + Args: + rag_file_id: RAG 文件 ID + progress: 进度值 (0-100) + + Raises: + BusinessError: 文件不存在 + """ + rag_file = await self.get_by_id(rag_file_id) + if not rag_file: + raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) + + rag_file.progress = max(0, min(100, progress)) + await self.db.flush() + async def count_by_knowledge_base( self, knowledge_base_id: str diff --git a/runtime/datamate-python/app/module/rag/schema/response.py b/runtime/datamate-python/app/module/rag/schema/response.py index 68f4545e6..2984d157f 100644 --- a/runtime/datamate-python/app/module/rag/schema/response.py +++ b/runtime/datamate-python/app/module/rag/schema/response.py @@ -83,6 +83,7 @@ class RagFileResp(BaseModel): metadata: Optional[dict] = Field(None, description="元数据") status: FileStatus = Field(..., description="处理状态") err_msg: Optional[str] = Field(None, alias="errMsg", description="错误信息") + progress: int = Field(default=0, ge=0, le=100, description="处理进度(0-100)") created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") @@ -99,6 +100,7 @@ class Config: "chunkCount": 15, "metadata": {"size": 1024, "format": "pdf"}, "status": "PROCESSED", + "progress": 100, "createdAt": "2025-01-01T00:00:00" } } diff --git a/runtime/datamate-python/app/module/rag/service/file_processor.py b/runtime/datamate-python/app/module/rag/service/file_processor.py index 0db7e6006..ff696043d 100644 --- a/runtime/datamate-python/app/module/rag/service/file_processor.py +++ b/runtime/datamate-python/app/module/rag/service/file_processor.py @@ -86,7 +86,7 @@ async def _process_files_background( logger.info("开始处理 %d 个文件,知识库: %s", len(files), knowledge_base_name) - # 并发处理文件 + # 并发处理文件(最多 10 个并行) await self._process_files_concurrently(db, files, knowledge_base, request) logger.info("知识库 %s 文件处理完成", knowledge_base_name) @@ -120,34 +120,28 @@ async def _process_single_file( knowledge_base: KnowledgeBase, request: AddFilesReq, ) -> None: - """处理单个文件的 ETL 流程""" + """处理单个文件的 ETL 流程 + + 1. 验证文件 + 2. 加载并分块(progress=20%) + 3. 向量化并存储(progress=60%) + 4. 完成(progress=100%) + """ file_repo = RagFileRepository(db) try: - # 1. 更新状态为处理中 - await file_repo.update_status(rag_file.id, FileStatus.PROCESSING) + # 更新状态为处理中 + await self._update_status(db, file_repo, rag_file.id, FileStatus.PROCESSING, 5) await db.commit() - # 2. 验证文件 - file_path = self._get_file_path(rag_file) + # 验证文件并获取路径 + file_path = await self._validate_file(db, file_repo, rag_file) if not file_path: - await self._mark_failed(db, file_repo, rag_file.id, "文件路径未设置") return - if not Path(file_path).exists(): - await self._mark_failed(db, file_repo, rag_file.id, f"文件不存在: {file_path}") - return - - # 3. 加载并分块 + # 加载并分块文档 metadata = self._build_chunk_metadata(rag_file, knowledge_base) - chunks = await ingest_file_to_chunks( - file_path, - process_type=request.process_type, - chunk_size=request.chunk_size, - overlap_size=request.overlap_size, - delimiter=request.delimiter, - **metadata, - ) + chunks = await self._load_and_split(file_path, rag_file, metadata, request) if not chunks: await self._mark_failed(db, file_repo, rag_file.id, "文档解析后未生成任何分块") @@ -155,22 +149,157 @@ async def _process_single_file( logger.info("文件 %s 分块完成,共 %d 个分块", rag_file.file_name, len(chunks)) - # 4. 向量化并存储 - await self._embed_and_store(db, chunks, metadata, knowledge_base) - - # 5. 更新文件状态为成功 - await file_repo.update_chunk_count(rag_file.id, len(chunks)) - await file_repo.update_status(rag_file.id, FileStatus.PROCESSED) - await db.commit() + # 向量化并存储到 Milvus + await self._embed_and_store(db, chunks, rag_file, knowledge_base) + # 标记完成 + await self._mark_success(db, file_repo, rag_file.id, len(chunks)) logger.info("文件 %s ETL 处理完成", rag_file.file_name) except Exception as e: logger.exception("文件 %s 处理失败: %s", rag_file.file_name, e) await self._mark_failed(db, file_repo, rag_file.id, str(e)) + async def _validate_file( + self, + db: AsyncSession, + file_repo: RagFileRepository, + rag_file: RagFile, + ) -> str | None: + """验证文件路径并返回绝对路径 + + Args: + db: 数据库 session + file_repo: 文件仓储 + rag_file: RAG 文件实体 + + Returns: + 文件绝对路径,验证失败返回 None + """ + file_path = self._get_file_path(rag_file) + + if not file_path: + await self._mark_failed(db, file_repo, rag_file.id, "文件路径未设置") + return None + + if not Path(file_path).exists(): + await self._mark_failed(db, file_repo, rag_file.id, f"文件不存在: {file_path}") + return None + + return file_path + + async def _load_and_split( + self, + file_path: str, + rag_file: RagFile, + metadata: dict, + request: AddFilesReq, + ) -> List[DocumentChunk]: + """加载文档并分块 + + Args: + file_path: 文件路径 + rag_file: RAG 文件实体 + metadata: 基础元数据 + request: 添加文件请求 + + Returns: + 文档分块列表 + """ + chunks = await ingest_file_to_chunks( + file_path, + process_type=request.process_type, + chunk_size=request.chunk_size, + overlap_size=request.overlap_size, + delimiter=request.delimiter, + **metadata, + ) + + if chunks: + logger.info("文件 %s 加载分块成功,数量: %d", rag_file.file_name, len(chunks)) + + return chunks + + async def _embed_and_store( + self, + db: AsyncSession, + chunks: List[DocumentChunk], + rag_file: RagFile, + knowledge_base: KnowledgeBase, + ) -> None: + """向量化并存储到 Milvus + + Args: + db: 数据库 session + chunks: 文档分块列表 + rag_file: RAG 文件实体 + knowledge_base: 知识库实体 + """ + file_repo = RagFileRepository(db) + + # 获取或创建 Embeddings 实例 + embedding = await self._get_embeddings(db, knowledge_base) + + # 创建向量存储 + vectorstore = VectorStoreFactory.create( + collection_name=knowledge_base.name, + embedding=embedding, + ) + + # 更新进度 + await self._update_progress(db, file_repo, rag_file.id, 60) + await db.commit() + + # 构建完整的 metadata + base_metadata = { + "rag_file_id": rag_file.id, + "original_file_id": rag_file.file_id, + "knowledge_base_id": knowledge_base.id, + } + + for chunk in chunks: + chunk.metadata.update(base_metadata) + + # 生成 ID 并存储 + ids = [str(uuid.uuid4()) for _ in chunks] + documents, doc_ids = chunks_to_documents(chunks, ids=ids) + vectorstore.add_documents(documents=documents, ids=doc_ids) + + logger.info("文件 %s 向量存储完成,数量: %d", rag_file.file_name, len(chunks)) + + async def _get_embeddings( + self, + db: AsyncSession, + knowledge_base: KnowledgeBase, + ): + """获取嵌入模型实例 + + Args: + db: 数据库 session + knowledge_base: 知识库实体 + + Returns: + LangChain Embeddings 实例 + """ + embedding_entity = await get_model_by_id(db, knowledge_base.embedding_model) + if not embedding_entity: + raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") + + return EmbeddingFactory.create_embeddings( + model_name=embedding_entity.model_name, + base_url=getattr(embedding_entity, "base_url", None), + api_key=getattr(embedding_entity, "api_key", None), + ) + def _get_file_path(self, rag_file: RagFile) -> str | None: - """获取文件绝对路径""" + """获取文件绝对路径 + + Args: + rag_file: RAG 文件实体 + + Returns: + 文件绝对路径,不存在返回 None + """ if not rag_file.file_metadata: return None @@ -180,7 +309,15 @@ def _get_file_path(self, rag_file: RagFile) -> str | None: return None def _build_chunk_metadata(self, rag_file: RagFile, knowledge_base: KnowledgeBase) -> dict: - """构建分块元数据""" + """构建分块基础元数据 + + Args: + rag_file: RAG 文件实体 + knowledge_base: 知识库实体 + + Returns: + 元数据字典 + """ file_path = self._get_file_path(rag_file) or "" return { "rag_file_id": rag_file.id, @@ -191,38 +328,64 @@ def _build_chunk_metadata(self, rag_file: RagFile, knowledge_base: KnowledgeBase "knowledge_base_id": knowledge_base.id, } - async def _embed_and_store( + async def _update_status( self, db: AsyncSession, - chunks: List[DocumentChunk], - metadata: dict, - knowledge_base: KnowledgeBase, + file_repo: RagFileRepository, + rag_file_id: str, + status: FileStatus, + progress: int = 0, ) -> None: - """向量化并存储到 Milvus""" - embedding_entity = await get_model_by_id(db, knowledge_base.embedding_model) - if not embedding_entity: - raise ValueError(f"嵌入模型不存在: {knowledge_base.embedding_model}") + """更新文件状态和进度 - embedding = EmbeddingFactory.create_embeddings( - model_name=embedding_entity.model_name, - base_url=getattr(embedding_entity, "base_url", None), - api_key=getattr(embedding_entity, "api_key", None), - ) + Args: + db: 数据库 session + file_repo: 文件仓储 + rag_file_id: RAG 文件 ID + status: 新状态 + progress: 新进度 + """ + await file_repo.update_status(rag_file_id, status) + await file_repo.update_progress(rag_file_id, progress) + await db.flush() - vectorstore = VectorStoreFactory.create( - collection_name=knowledge_base.name, - embedding=embedding, - ) + async def _update_progress( + self, + db: AsyncSession, + file_repo: RagFileRepository, + rag_file_id: str, + progress: int, + ) -> None: + """更新文件处理进度 - # 补充 metadata - for chunk in chunks: - for key, value in metadata.items(): - if key not in chunk.metadata: - chunk.metadata[key] = value + Args: + db: 数据库 session + file_repo: 文件仓储 + rag_file_id: RAG 文件 ID + progress: 进度值 (0-100) + """ + await file_repo.update_progress(rag_file_id, progress) + await db.flush() - ids = [str(uuid.uuid4()) for _ in chunks] - documents, doc_ids = chunks_to_documents(chunks, ids=ids) - vectorstore.add_documents(documents=documents, ids=doc_ids) + async def _mark_success( + self, + db: AsyncSession, + file_repo: RagFileRepository, + rag_file_id: str, + chunk_count: int, + ) -> None: + """标记文件处理成功 + + Args: + db: 数据库 session + file_repo: 文件仓储 + rag_file_id: RAG 文件 ID + chunk_count: 分块数量 + """ + await file_repo.update_chunk_count(rag_file_id, chunk_count) + await file_repo.update_status(rag_file_id, FileStatus.PROCESSED) + await file_repo.update_progress(rag_file_id, 100) + await db.commit() async def _mark_failed( self, @@ -231,7 +394,14 @@ async def _mark_failed( rag_file_id: str, err_msg: str, ) -> None: - """标记文件处理失败""" + """标记文件处理失败 + + Args: + db: 数据库 session + file_repo: 文件仓储 + rag_file_id: RAG 文件 ID + err_msg: 错误信息 + """ logger.error("文件处理失败: %s", err_msg) await file_repo.update_status(rag_file_id, FileStatus.PROCESS_FAILED, err_msg=err_msg) await db.commit() diff --git a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py index 8aa19881e..a59e95728 100644 --- a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py +++ b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py @@ -204,7 +204,8 @@ async def add_files( ) -> dict: """添加文件到知识库 - 存入数据库后立即返回,后台异步处理文件。 + 验证知识库、创建文件记录、启动后台处理。 + 数据库提交后立即返回,不等待文件处理完成。 Args: request: 添加文件请求 @@ -213,16 +214,18 @@ async def add_files( Returns: 包含成功和跳过文件数量的字典 """ + # 1. 验证知识库存在 knowledge_base = await self.kb_repo.get_by_id(request.knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) - # 添加文件记录 + # 2. 创建文件记录 rag_files, skipped_file_ids = await self._create_rag_files(request) + # 3. 立即提交事务,接口返回 await self.db.commit() - # 启动后台处理 + # 4. 注册后台任务(异步处理) if rag_files and background_tasks: self.file_processor.start_background_processing( background_tasks=background_tasks, @@ -231,6 +234,7 @@ async def add_files( request_data=request.model_dump(), ) + # 5. 返回结果 return { "success_count": len(rag_files), "skipped_count": len(skipped_file_ids), @@ -311,6 +315,7 @@ async def list_files(self, knowledge_base_id: str, request: RagFileReq) -> Paged metadata=item.file_metadata, status=item.status, err_msg=item.err_msg, + progress=getattr(item, "progress", 0), created_at=item.created_at, updated_at=item.updated_at, created_by=item.created_by, diff --git a/scripts/db/rag-management-init.sql b/scripts/db/rag-management-init.sql index 1877284ef..c9bc4dbfc 100644 --- a/scripts/db/rag-management-init.sql +++ b/scripts/db/rag-management-init.sql @@ -40,6 +40,7 @@ CREATE TABLE IF NOT EXISTS t_rag_file metadata JSONB, status VARCHAR(50), err_msg TEXT, + progress INTEGER DEFAULT 0, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_by VARCHAR(255), @@ -56,6 +57,7 @@ COMMENT ON COLUMN t_rag_file.chunk_count IS '切片数'; COMMENT ON COLUMN t_rag_file.metadata IS '元数据'; COMMENT ON COLUMN t_rag_file.status IS '文件状态'; COMMENT ON COLUMN t_rag_file.err_msg IS '错误信息'; +COMMENT ON COLUMN t_rag_file.progress IS '处理进度(0-100)'; COMMENT ON COLUMN t_rag_file.created_at IS '创建时间'; COMMENT ON COLUMN t_rag_file.updated_at IS '更新时间'; COMMENT ON COLUMN t_rag_file.created_by IS '创建者'; From b3fa57ba23160146bedabacc6f011c5f4e89c31d Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Sat, 28 Feb 2026 15:18:02 +0800 Subject: [PATCH 07/13] feat: enhance retrieval service with advanced search ranking and filtering capabilities --- .../module/rag/service/retrieval_service.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/runtime/datamate-python/app/module/rag/service/retrieval_service.py b/runtime/datamate-python/app/module/rag/service/retrieval_service.py index 58d573231..d87be2a1c 100644 --- a/runtime/datamate-python/app/module/rag/service/retrieval_service.py +++ b/runtime/datamate-python/app/module/rag/service/retrieval_service.py @@ -6,7 +6,7 @@ import logging from typing import List -from pymilvus import MilvusClient +from pymilvus import AnnSearchRequest, MilvusClient, RRFRanker, Function, FunctionType from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings @@ -81,11 +81,11 @@ async def retrieve(self, request: RetrieveReq) -> List[dict]: all_results = await self._execute_hybrid_search(knowledge_bases, query_vector, request.query, request.top_k) # 按分数排序 - all_results.sort(key=lambda x: x.get("distance", 0), reverse=True) + all_results.sort(key=lambda x: x.get("score") or x.get("distance", 0), reverse=True) # 应用阈值过滤 if request.threshold is not None: - all_results = [r for r in all_results if r.get("distance", 0) >= request.threshold] + all_results = [r for r in all_results if (r.get("score") or r.get("distance", 0)) >= 0] # 格式化返回结果 return self._format_results(all_results) @@ -108,14 +108,39 @@ async def _execute_hybrid_search( logger.warning("集合 %s 不存在,跳过", kb.name) continue + dense_search = AnnSearchRequest( + data=[query_vector], + anns_field="vector", + param={"nprobe": 10}, + limit=top_k, + ) + + sparse_search = AnnSearchRequest( + data=[query_text], + anns_field="sparse", + param={"metric_type": "BM25", "params": {}}, + limit=top_k, + ) + + ranker = Function( + name="rrf", + input_field_names=[], # Must be an empty list + function_type=FunctionType.RERANK, + params={ + "reranker": "rrf", + "k": 100 + } + ) search_results = client.hybrid_search( collection_name=kb.name, - data=[{"vector": query_vector, "sparse": query_text}], - anns_field=["vector", "sparse"], + reqs=[dense_search, sparse_search], + ranker=ranker, + output_fields=["id", "text", "metadata"], limit=top_k, - ranker={"type": "weighted", "weights": [0.1, 0.9]}, ) + logger.info(f"----------, {search_results.__str__()}") + if search_results and len(search_results) > 0: for result in search_results[0]: result["knowledge_base_id"] = kb.id @@ -137,7 +162,7 @@ def _format_results(self, all_results: List[dict]) -> List[dict]: "id": entity.get("id", ""), "text": entity.get("text", ""), "metadata": entity.get("metadata", {}), - "score": r.get("distance", 0), + "score": r.get("score") or r.get("distance", 0), "knowledgeBaseId": r.get("knowledge_base_id", ""), "knowledgeBaseName": r.get("knowledge_base_name", ""), }) From 592defd1a881c2ba96482003148bdac42f1f7de8 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 2 Mar 2026 16:08:29 +0800 Subject: [PATCH 08/13] feat: enhance retrieval service with BM25 indexing and improved ranking parameters --- .../app/module/rag/infra/vectorstore/store.py | 24 ++++++++++++++++++- .../module/rag/service/retrieval_service.py | 11 +++++---- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py index d61ec74fd..abb9bb0c8 100644 --- a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py @@ -107,7 +107,7 @@ def create_collection( # 创建字段 fields = [ FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=36, is_primary=True, auto_id=False), - FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535, enable_analyzer=True, enable_match=True), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535, enable_analyzer=True), FieldSchema(name="metadata", dtype=DataType.JSON), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR), @@ -128,9 +128,31 @@ def create_collection( enable_dynamic_field=False, ) + # BM25 索引参数 + sparse_index_params = { + "inverted_index_algo": "DAAT_MAXSCORE", + "bm25_k1": 1.2, + "bm25_b": 0.75, + } + + index_params = client.prepare_index_params() + index_params.add_index( + field_name="sparse", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + params=sparse_index_params + ) + index_params.add_index( + field_name="vector", + index_type="FLAT", + metric_type="COSINE", + params={} + ) + client.create_collection( collection_name=collection_name, schema=schema, + index_params=index_params, consistency_level=consistency_level, ) diff --git a/runtime/datamate-python/app/module/rag/service/retrieval_service.py b/runtime/datamate-python/app/module/rag/service/retrieval_service.py index d87be2a1c..ac37edd89 100644 --- a/runtime/datamate-python/app/module/rag/service/retrieval_service.py +++ b/runtime/datamate-python/app/module/rag/service/retrieval_service.py @@ -118,17 +118,18 @@ async def _execute_hybrid_search( sparse_search = AnnSearchRequest( data=[query_text], anns_field="sparse", - param={"metric_type": "BM25", "params": {}}, + param={"drop_ratio_search": 0.2}, limit=top_k, ) ranker = Function( - name="rrf", - input_field_names=[], # Must be an empty list + name="weight", + input_field_names=[], function_type=FunctionType.RERANK, params={ - "reranker": "rrf", - "k": 100 + "reranker": "weighted", + "weights": [0.1, 0.9], + "norm_score": True, } ) search_results = client.hybrid_search( From 381761ff0384b1ce1242ccb0bb2508e3d7704a92 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 2 Mar 2026 16:16:47 +0800 Subject: [PATCH 09/13] feat: enhance retrieval service with BM25 indexing and improved ranking parameters --- .../app/module/rag/service/retrieval_service.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/runtime/datamate-python/app/module/rag/service/retrieval_service.py b/runtime/datamate-python/app/module/rag/service/retrieval_service.py index ac37edd89..5d49aae33 100644 --- a/runtime/datamate-python/app/module/rag/service/retrieval_service.py +++ b/runtime/datamate-python/app/module/rag/service/retrieval_service.py @@ -3,6 +3,7 @@ 负责知识库内容的检索,支持向量 + BM25 混合检索。 """ +import json import logging from typing import List @@ -159,11 +160,20 @@ def _format_results(self, all_results: List[dict]) -> List[dict]: formatted = [] for r in all_results: entity = r.get("entity", {}) + metadata = entity.get("metadata", {}) + if isinstance(metadata, dict): + metadata_str = json.dumps(metadata, ensure_ascii=False) + else: + metadata_str = metadata if metadata else "{}" + formatted.append({ - "id": entity.get("id", ""), - "text": entity.get("text", ""), - "metadata": entity.get("metadata", {}), + "entity": { + "metadata": metadata_str, + "text": entity.get("text", ""), + "id": entity.get("id", ""), + }, "score": r.get("score") or r.get("distance", 0), + "id": entity.get("id", ""), "knowledgeBaseId": r.get("knowledge_base_id", ""), "knowledgeBaseName": r.get("knowledge_base_name", ""), }) From 9d6b359737fbbedebf9a8a8570abced5a4af48ee Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 2 Mar 2026 16:34:56 +0800 Subject: [PATCH 10/13] feat: implement Milvus client singleton management and refactor vector store interactions --- .../module/rag/infra/vectorstore/__init__.py | 2 + .../module/rag/infra/vectorstore/factory.py | 66 +++++++++++---- .../rag/infra/vectorstore/milvus_client.py | 83 +++++++++++++++++++ .../app/module/rag/infra/vectorstore/store.py | 33 +++----- .../module/rag/service/retrieval_service.py | 11 +-- 5 files changed, 148 insertions(+), 47 deletions(-) create mode 100644 runtime/datamate-python/app/module/rag/infra/vectorstore/milvus_client.py diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py index fdc2e465c..289fdb7aa 100644 --- a/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/__init__.py @@ -4,6 +4,7 @@ 提供 Milvus 向量存储的创建、管理和数据操作功能。 """ from app.module.rag.infra.vectorstore.factory import VectorStoreFactory +from app.module.rag.infra.vectorstore.milvus_client import get_milvus_client from app.module.rag.infra.vectorstore.store import ( chunks_to_documents, create_collection, @@ -15,6 +16,7 @@ __all__ = [ "VectorStoreFactory", + "get_milvus_client", "create_collection", "drop_collection", "rename_collection", diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py index 28fffd778..805ad4c2d 100644 --- a/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/factory.py @@ -5,7 +5,9 @@ """ from __future__ import annotations -from typing import Any +import logging +import threading +from typing import Any, Dict, Optional from langchain_core.embeddings import Embeddings @@ -16,9 +18,17 @@ get_vector_dimension, ) +logger = logging.getLogger(__name__) + class VectorStoreFactory: - """LangChain Milvus 向量存储工厂""" + """LangChain Milvus 向量存储工厂 + + 使用单例模式缓存 Milvus 实例,确保每个 collection 只有一个实例。 + """ + + _instances: Dict[str, Any] = {} + _lock = threading.Lock() @staticmethod def get_connection_args() -> dict: @@ -28,53 +38,54 @@ def get_connection_args() -> dict: args["token"] = settings.milvus_token return args - @staticmethod + @classmethod def create( + cls, collection_name: str, embedding: Embeddings, *, drop_old: bool = False, consistency_level: str = "Strong", + force_new: bool = False, ) -> Any: - """创建 Milvus 向量存储实例(支持混合检索) + """创建或获取 Milvus 向量存储实例(支持混合检索) Args: collection_name: 集合名称(知识库名称) embedding: LangChain Embeddings 实例 drop_old: 是否删除已存在同名集合(默认 False,避免数据丢失) consistency_level: 一致性级别 + force_new: 是否强制创建新实例(默认 False,优先使用缓存) Returns: langchain_milvus.Milvus 实例 """ + if drop_old: + drop_collection(collection_name) + with cls._lock: + cls._instances.pop(collection_name, None) + + if not force_new and collection_name in cls._instances: + logger.debug("使用缓存的 Milvus 实例: %s", collection_name) + return cls._instances[collection_name] + from langchain_milvus import BM25BuiltInFunction, Milvus - # 获取向量维度 dimension = get_vector_dimension( embedding_model="", embedding_instance=embedding, ) - # 根据参数删除旧集合(仅在显式要求时) - if drop_old: - drop_collection(collection_name) - - # 仅在集合不存在时创建(避免覆盖已有数据) create_collection( collection_name=collection_name, dimension=dimension, consistency_level=consistency_level, ) - # 创建 Milvus 实例(drop_old=False,确保不会自动删除) - # 注意:必须指定 primary_field="id",因为集合定义中主键字段名为 id - # 注意:必须指定 metadata_field="metadata",确保所有 metadata 数据存放到此字段 - # 注意:enable_dynamic_field=False,禁止使用动态字段没,所有数据只存放到定义的字段中 - # 注意:metadata_schema 指定 metadata 字段类型为 JSON - return Milvus( + instance = Milvus( embedding_function=embedding, collection_name=collection_name, - connection_args=VectorStoreFactory.get_connection_args(), + connection_args=cls.get_connection_args(), builtin_function=BM25BuiltInFunction(), text_field="text", vector_field=["vector"], @@ -86,3 +97,24 @@ def create( enable_dynamic_field=False, metadata_schema={"metadata": "JSON"}, ) + + with cls._lock: + cls._instances[collection_name] = instance + logger.info("创建并缓存 Milvus 实例: %s", collection_name) + + return instance + + @classmethod + def clear_cache(cls, collection_name: Optional[str] = None) -> None: + """清除缓存 + + Args: + collection_name: 集合名称,如果为 None 则清除所有缓存 + """ + with cls._lock: + if collection_name: + cls._instances.pop(collection_name, None) + logger.info("清除 Milvus 实例缓存: %s", collection_name) + else: + cls._instances.clear() + logger.info("清除所有 Milvus 实例缓存") diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/milvus_client.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/milvus_client.py new file mode 100644 index 000000000..95dcd0935 --- /dev/null +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/milvus_client.py @@ -0,0 +1,83 @@ +""" +Milvus 客户端单例管理器 + +确保 MilvusClient 在全局范围内只创建一个实例,避免重复连接。 +""" +from __future__ import annotations + +import logging +import threading +from typing import Optional + +from pymilvus import MilvusClient + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +class MilvusClientManager: + """Milvus 客户端单例管理器 + + 使用线程安全的单例模式,确保全局只有一个 MilvusClient 实例。 + """ + + _instance: Optional[MilvusClientManager] = None + _lock = threading.Lock() + _client: Optional[MilvusClient] = None + + def __new__(cls) -> MilvusClientManager: + """单例模式:确保只有一个管理器实例""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @staticmethod + def _get_connection_args() -> dict: + """获取 Milvus 连接参数""" + args: dict = {"uri": settings.milvus_uri} + token = getattr(settings, "milvus_token", None) + if token: + args["token"] = token + return args + + def get_client(self) -> MilvusClient: + """获取 Milvus 客户端实例(单例) + + Returns: + MilvusClient 实例 + """ + if self._client is None: + with self._lock: + if self._client is None: + conn_args = self._get_connection_args() + self._client = MilvusClient( + uri=conn_args["uri"], + token=conn_args.get("token", "") + ) + logger.info("创建 Milvus 客户端单例: uri=%s", conn_args["uri"]) + return self._client + + def close(self) -> None: + """关闭 Milvus 客户端连接""" + if self._client is not None: + with self._lock: + if self._client is not None: + try: + self._client.close() + logger.info("关闭 Milvus 客户端连接") + except Exception as e: + logger.warning("关闭 Milvus 客户端时出错: %s", e) + finally: + self._client = None + + +def get_milvus_client() -> MilvusClient: + """获取 Milvus 客户端实例(全局单例) + + Returns: + MilvusClient 实例 + """ + return MilvusClientManager().get_client() diff --git a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py index abb9bb0c8..dd70933e2 100644 --- a/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py +++ b/runtime/datamate-python/app/module/rag/infra/vectorstore/store.py @@ -18,31 +18,16 @@ from typing import List, Optional from langchain_core.documents import Document -from pymilvus import MilvusClient, DataType, FunctionType, CollectionSchema, FieldSchema, Function +from pymilvus import DataType, FunctionType, CollectionSchema, FieldSchema, Function -from app.core.config import settings from app.core.exception import BusinessError, ErrorCodes from app.module.rag.infra.document.types import DocumentChunk from app.module.rag.infra.embeddings import EmbeddingFactory +from app.module.rag.infra.vectorstore.milvus_client import get_milvus_client logger = logging.getLogger(__name__) -def _get_connection_args() -> dict: - """获取 Milvus 连接参数""" - args: dict = {"uri": settings.milvus_uri} - token = getattr(settings, "milvus_token", None) - if token: - args["token"] = token - return args - - -def _get_client() -> MilvusClient: - """获取 Milvus 客户端""" - conn_args = _get_connection_args() - return MilvusClient(uri=conn_args["uri"], token=conn_args.get("token", "")) - - def drop_collection(collection_name: str) -> None: """删除 Milvus 集合 @@ -50,7 +35,7 @@ def drop_collection(collection_name: str) -> None: collection_name: 集合名称 """ try: - client = _get_client() + client = get_milvus_client() if client.has_collection(collection_name): client.drop_collection(collection_name) logger.info("成功删除集合: %s", collection_name) @@ -67,13 +52,15 @@ def rename_collection(old_name: str, new_name: str) -> None: new_name: 新集合名称 """ from pymilvus import utility, connections + from app.core.config import settings try: - conn_args = _get_connection_args() + uri = settings.milvus_uri + token = getattr(settings, "milvus_token", None) or "" connections.connect( alias="default", - uri=conn_args["uri"], - token=conn_args.get("token", ""), + uri=uri, + token=token, ) if utility.has_collection(old_name, using="default"): utility.rename_collection(old_name, new_name, using="default") @@ -98,7 +85,7 @@ def create_collection( consistency_level: 一致性级别 """ try: - client = _get_client() + client = get_milvus_client() if client.has_collection(collection_name): logger.info("集合 %s 已存在,跳过创建", collection_name) @@ -211,7 +198,7 @@ def delete_chunks_by_rag_file_ids(collection_name: str, rag_file_ids: List[str]) return try: - client = _get_client() + client = get_milvus_client() for rid in rag_file_ids: json_value = json.dumps({"rag_file_id": rid}) diff --git a/runtime/datamate-python/app/module/rag/service/retrieval_service.py b/runtime/datamate-python/app/module/rag/service/retrieval_service.py index 5d49aae33..3c9c04481 100644 --- a/runtime/datamate-python/app/module/rag/service/retrieval_service.py +++ b/runtime/datamate-python/app/module/rag/service/retrieval_service.py @@ -7,12 +7,12 @@ import logging from typing import List -from pymilvus import AnnSearchRequest, MilvusClient, RRFRanker, Function, FunctionType +from pymilvus import AnnSearchRequest, Function, FunctionType from sqlalchemy.ext.asyncio import AsyncSession -from app.core.config import settings from app.core.exception import BusinessError, ErrorCodes from app.module.rag.infra.embeddings import EmbeddingFactory +from app.module.rag.infra.vectorstore.milvus_client import get_milvus_client from app.module.rag.repository import KnowledgeBaseRepository, RagFileRepository from app.module.rag.schema.request import RetrieveReq, PagingQuery from app.module.rag.schema.response import PagedResponse, RagChunkResp @@ -100,8 +100,7 @@ async def _execute_hybrid_search( ) -> List[dict]: """执行混合检索""" all_results = [] - token = getattr(settings, "milvus_token", None) - client = MilvusClient(uri=settings.milvus_uri, token=token or "") + client = get_milvus_client() for kb in knowledge_bases: try: @@ -206,9 +205,7 @@ async def get_chunks( if not rag_file: raise BusinessError(ErrorCodes.RAG_FILE_NOT_FOUND) - # 查询 Milvus - token = getattr(settings, "milvus_token", None) - client = MilvusClient(uri=settings.milvus_uri, token=token or "") + client = get_milvus_client() try: # 查询总数 From 5c37b782026fd9e67ceffa16f6f66a77d1e354c4 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 2 Mar 2026 16:47:16 +0800 Subject: [PATCH 11/13] feat: refactor file processing to use async session for database interactions --- .../app/module/rag/service/file_processor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/runtime/datamate-python/app/module/rag/service/file_processor.py b/runtime/datamate-python/app/module/rag/service/file_processor.py index ff696043d..967180d74 100644 --- a/runtime/datamate-python/app/module/rag/service/file_processor.py +++ b/runtime/datamate-python/app/module/rag/service/file_processor.py @@ -108,7 +108,11 @@ async def _process_files_concurrently( async def process_with_semaphore(rag_file: RagFile): async with self.worker_pool.semaphore: - await self._process_single_file(db, rag_file, knowledge_base, request) + async with AsyncSessionLocal() as file_db: + try: + await self._process_single_file(file_db, rag_file, knowledge_base, request) + finally: + await file_db.close() tasks = [process_with_semaphore(f) for f in files] await asyncio.gather(*tasks, return_exceptions=True) From d61027f0f4f5d8da9be1639ceb30fd9407fdb10f Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 2 Mar 2026 17:25:14 +0800 Subject: [PATCH 12/13] feat: enhance chunk processing with filtering, cleaning, and batch storage --- .claude/skills/fastapi-templates/SKILL.md | 1 + .../app/module/rag/service/file_processor.py | 285 +++++++++++++++++- .../module/shared/common/document_loaders.py | 9 +- 3 files changed, 285 insertions(+), 10 deletions(-) diff --git a/.claude/skills/fastapi-templates/SKILL.md b/.claude/skills/fastapi-templates/SKILL.md index 05c492e37..c1a3d5bc1 100644 --- a/.claude/skills/fastapi-templates/SKILL.md +++ b/.claude/skills/fastapi-templates/SKILL.md @@ -565,3 +565,4 @@ async def test_create_user(client): - **Ignoring Sessions**: Not properly managing database sessions - **No Testing**: Skipping integration tests - **Tight Coupling**: Direct database access in routes +- **Overly Large Functions**: Avoid functions that do too much; break into smaller, focused functions. diff --git a/runtime/datamate-python/app/module/rag/service/file_processor.py b/runtime/datamate-python/app/module/rag/service/file_processor.py index 967180d74..d542862be 100644 --- a/runtime/datamate-python/app/module/rag/service/file_processor.py +++ b/runtime/datamate-python/app/module/rag/service/file_processor.py @@ -241,20 +241,130 @@ async def _embed_and_store( """ file_repo = RagFileRepository(db) - # 获取或创建 Embeddings 实例 - embedding = await self._get_embeddings(db, knowledge_base) + # 过滤和清理分块 + valid_chunks = await self._filter_and_clean_chunks(chunks, rag_file) + if not valid_chunks: + return # 创建向量存储 + vectorstore = await self._create_vectorstore(db, knowledge_base, file_repo, rag_file) + + # 添加元数据 + self._add_metadata_to_chunks(valid_chunks, rag_file, knowledge_base) + + # 分批存储 + await self._store_chunks_in_batches(vectorstore, valid_chunks, rag_file) + + async def _filter_and_clean_chunks( + self, + chunks: List[DocumentChunk], + rag_file: RagFile, + ) -> List[DocumentChunk]: + """过滤和清理无效的分块 + + Args: + chunks: 原始分块列表 + rag_file: RAG 文件实体 + + Returns: + 有效的分块列表 + """ + valid_chunks = [] + for idx, chunk in enumerate(chunks): + cleaned_text = self._clean_text(chunk.text) + self._log_chunk_cleaning(idx, chunk.text, cleaned_text) + + if cleaned_text and len(cleaned_text.strip()) > 0: + chunk.text = cleaned_text + valid_chunks.append(chunk) + else: + logger.warning( + "跳过无效分块: rag_file_id=%s, chunk_index=%s, 原始长度=%d", + rag_file.id, + chunk.metadata.get("chunk_index"), + len(chunk.text) if chunk.text else 0 + ) + + if not valid_chunks: + logger.warning("文件 %s 没有有效的分块内容", rag_file.file_name) + return [] + + logger.info( + "文件 %s 有效分块数量: %d / %d", + rag_file.file_name, + len(valid_chunks), + len(chunks) + ) + return valid_chunks + + def _log_chunk_cleaning(self, idx: int, original_text: str, cleaned_text: str) -> None: + """记录分块清理日志 + + Args: + idx: 分块索引 + original_text: 原始文本 + cleaned_text: 清理后文本 + """ + if idx >= 3: + return + + logger.debug( + "分块 %d 清理前 (长度=%d): %.100s%s", + idx, + len(original_text) if original_text else 0, + repr(original_text[:100]) if original_text else "None", + "..." if original_text and len(original_text) > 100 else "" + ) + logger.debug( + "分块 %d 清理后 (长度=%d): %.100s%s", + idx, + len(cleaned_text) if cleaned_text else 0, + repr(cleaned_text[:100]) if cleaned_text else "None", + "..." if cleaned_text and len(cleaned_text) > 100 else "" + ) + + async def _create_vectorstore( + self, + db: AsyncSession, + knowledge_base: KnowledgeBase, + file_repo: RagFileRepository, + rag_file: RagFile, + ): + """创建向量存储 + + Args: + db: 数据库 session + knowledge_base: 知识库实体 + file_repo: 文件仓储 + rag_file: RAG 文件实体 + + Returns: + 向量存储实例 + """ + embedding = await self._get_embeddings(db, knowledge_base) vectorstore = VectorStoreFactory.create( collection_name=knowledge_base.name, embedding=embedding, ) - # 更新进度 await self._update_progress(db, file_repo, rag_file.id, 60) await db.commit() - # 构建完整的 metadata + return vectorstore + + def _add_metadata_to_chunks( + self, + chunks: List[DocumentChunk], + rag_file: RagFile, + knowledge_base: KnowledgeBase, + ) -> None: + """为分块添加元数据 + + Args: + chunks: 分块列表 + rag_file: RAG 文件实体 + knowledge_base: 知识库实体 + """ base_metadata = { "rag_file_id": rag_file.id, "original_file_id": rag_file.file_id, @@ -264,13 +374,93 @@ async def _embed_and_store( for chunk in chunks: chunk.metadata.update(base_metadata) - # 生成 ID 并存储 - ids = [str(uuid.uuid4()) for _ in chunks] - documents, doc_ids = chunks_to_documents(chunks, ids=ids) - vectorstore.add_documents(documents=documents, ids=doc_ids) + async def _store_chunks_in_batches( + self, + vectorstore, + chunks: List[DocumentChunk], + rag_file: RagFile, + ) -> None: + """分批存储分块到向量数据库 + + Args: + vectorstore: 向量存储实例 + chunks: 分块列表 + rag_file: RAG 文件实体 + """ + batch_size = 20 + total_chunks = len(chunks) + + for batch_start in range(0, total_chunks, batch_size): + batch_end = min(batch_start + batch_size, total_chunks) + batch_chunks = chunks[batch_start:batch_end] + + await self._store_single_batch( + vectorstore, + batch_chunks, + batch_start, + batch_end, + total_chunks + ) logger.info("文件 %s 向量存储完成,数量: %d", rag_file.file_name, len(chunks)) + async def _store_single_batch( + self, + vectorstore, + batch_chunks: List[DocumentChunk], + batch_start: int, + batch_end: int, + total_chunks: int, + ) -> None: + """存储单个批次的分块 + + Args: + vectorstore: 向量存储实例 + batch_chunks: 批次分块列表 + batch_start: 批次起始索引 + batch_end: 批次结束索引 + total_chunks: 总分块数 + """ + logger.info("处理分块批次 %d-%d / %d", batch_start + 1, batch_end, total_chunks) + + self._log_batch_details(batch_chunks, batch_start) + + ids = [str(uuid.uuid4()) for _ in batch_chunks] + documents, doc_ids = chunks_to_documents(batch_chunks, ids=ids) + + try: + vectorstore.add_documents(documents=documents, ids=doc_ids) + logger.info("批次 %d-%d 存储成功", batch_start + 1, batch_end) + except Exception as e: + logger.error( + "批次 %d-%d 存储失败: %s\n第一个文档内容: %.200s", + batch_start + 1, + batch_end, + str(e), + documents[0].page_content if documents else "N/A" + ) + raise + + def _log_batch_details( + self, + batch_chunks: List[DocumentChunk], + batch_start: int, + ) -> None: + """记录批次详细信息 + + Args: + batch_chunks: 批次分块列表 + batch_start: 批次起始索引 + """ + for i, chunk in enumerate(batch_chunks[:2]): + logger.debug( + "批次内分块 %d: 长度=%d, 文本=%.50s%s", + batch_start + i, + len(chunk.text), + chunk.text[:50], + "..." if len(chunk.text) > 50 else "" + ) + async def _get_embeddings( self, db: AsyncSession, @@ -391,6 +581,85 @@ async def _mark_success( await file_repo.update_progress(rag_file_id, 100) await db.commit() + def _clean_text(self, text: str) -> str: + """清理文本内容 + + 移除无效字符、控制字符,并规范化空白字符。 + + Args: + text: 原始文本 + + Returns: + 清理后的文本,如果无效则返回空字符串 + """ + import re + + if not text or not isinstance(text, str): + return "" + + text = self._remove_control_characters(text) + text = self._normalize_whitespace(text) + text = self._remove_empty_lines(text) + + if not self._has_printable_content(text): + return "" + + return text.strip() + + def _remove_control_characters(self, text: str) -> str: + """移除控制字符和零宽字符 + + Args: + text: 原始文本 + + Returns: + 清理后的文本 + """ + import re + text = re.sub(r'[\x00-\x09\x0b\x0c\x0e-\x1f\x7f-\x9f]', '', text) + text = re.sub(r'[\u200b-\u200f\u2028-\u202f\ufeff]', '', text) + return text + + def _normalize_whitespace(self, text: str) -> str: + """规范化空白字符 + + Args: + text: 原始文本 + + Returns: + 规范化后的文本 + """ + import re + text = re.sub(r'[ \t]+', ' ', text) + text = '\n'.join(line.strip() for line in text.split('\n')) + text = re.sub(r'\n{3,}', '\n\n', text) + return text + + def _remove_empty_lines(self, text: str) -> str: + """移除空行 + + Args: + text: 原始文本 + + Returns: + 移除空行后的文本 + """ + return '\n'.join(line for line in text.split('\n') if line.strip()) + + def _has_printable_content(self, text: str) -> bool: + """检查文本是否包含可打印内容 + + Args: + text: 文本内容 + + Returns: + 是否包含可打印字符 + """ + import re + if not text or not text.strip(): + return False + return bool(re.search(r'[\w\u4e00-\u9fff]', text)) + async def _mark_failed( self, db: AsyncSession, diff --git a/runtime/datamate-python/app/module/shared/common/document_loaders.py b/runtime/datamate-python/app/module/shared/common/document_loaders.py index cb0a843f9..ce76382f5 100644 --- a/runtime/datamate-python/app/module/shared/common/document_loaders.py +++ b/runtime/datamate-python/app/module/shared/common/document_loaders.py @@ -92,8 +92,13 @@ def _set_default_kwargs(loader_cls, kwargs: dict) -> dict: kwargs.setdefault("text_content", False) if loader_cls == CSVLoader and "csv_args" not in kwargs: kwargs["csv_args"] = {"delimiter": ","} - if loader_cls == UnstructuredExcelLoader and "mode" not in kwargs: - kwargs.setdefault("mode", "elements") + if loader_cls == UnstructuredExcelLoader: + # Excel 文件使用 "single" 模式,避免生成过多无意义的元素 + # "elements" 模式会为每个单元格生成单独的文档,可能导致大量空内容 + if "mode" not in kwargs: + kwargs.setdefault("mode", "single") + # 确保加载表格结构 + kwargs.setdefault("include_header", True) if loader_cls == UnstructuredFileLoader and "mode" not in kwargs: kwargs.setdefault("mode", "elements") return kwargs From cd98c33eb0faa75ba26d6fb42ddd8d5e2802def4 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 2 Mar 2026 17:53:09 +0800 Subject: [PATCH 13/13] feat: enhance API response models with additional fields and configuration --- .../app/module/rag/schema/response.py | 101 ++++++++------- .../rag/service/knowledge_base_service.py | 118 ++++++++---------- 2 files changed, 103 insertions(+), 116 deletions(-) diff --git a/runtime/datamate-python/app/module/rag/schema/response.py b/runtime/datamate-python/app/module/rag/schema/response.py index 2984d157f..3eeacdc58 100644 --- a/runtime/datamate-python/app/module/rag/schema/response.py +++ b/runtime/datamate-python/app/module/rag/schema/response.py @@ -4,7 +4,7 @@ 定义所有 API 响应的数据结构 与 Java 响应 DTO 保持字段一致 """ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import List, Optional, Any from datetime import datetime from app.db.models.knowledge_gen import RagType, FileStatus @@ -15,18 +15,21 @@ class ModelConfig(BaseModel): 对应 Java: com.datamate.common.setting.domain.entity.ModelConfig """ + model_config = ConfigDict(from_attributes=True, populate_by_name=True) + id: str = Field(..., description="模型ID") - name: str = Field(..., description="模型名称") + created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") + updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") + created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") + updated_by: Optional[str] = Field(None, alias="updatedBy", description="更新人") + model_name: str = Field(..., alias="modelName", description="模型名称") provider: str = Field(..., description="模型提供商") - - class Config: - json_schema_extra = { - "example": { - "id": "model-uuid-123", - "name": "text-embedding-ada-002", - "provider": "openai" - } - } + base_url: str = Field(..., alias="baseUrl", description="API 基础地址") + api_key: str = Field(default="", alias="apiKey", description="API 密钥") + type: str = Field(..., description="模型类型") + is_enabled: bool = Field(default=True, alias="isEnabled", description="是否启用") + is_default: bool = Field(default=False, alias="isDefault", description="是否默认") + is_deleted: bool = Field(default=False, alias="isDeleted", description="是否删除") class KnowledgeBaseResp(BaseModel): @@ -34,24 +37,10 @@ class KnowledgeBaseResp(BaseModel): 对应 Java: com.datamate.rag.indexer.interfaces.dto.KnowledgeBaseResp """ - id: str = Field(..., description="知识库ID") - name: str = Field(..., description="知识库名称") - description: Optional[str] = Field(None, description="知识库描述") - type: RagType = Field(..., description="RAG类型") - embedding_model: str = Field(alias="embeddingModel", description="嵌入模型ID") - chat_model: Optional[str] = Field(None, alias="chatModel", description="聊天模型ID") - file_count: Optional[int] = Field(None, alias="fileCount", description="文件数量") - chunk_count: Optional[int] = Field(None, alias="chunkCount", description="分块数量") - embedding: Optional[ModelConfig] = Field(None, description="嵌入模型配置") - chat: Optional[ModelConfig] = Field(None, description="聊天模型配置") - created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") - updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") - created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") - updated_by: Optional[str] = Field(None, alias="updatedBy", description="更新人") - - class Config: - populate_by_name = True # 允许使用 snake_case 或 camelCase - json_schema_extra = { + model_config = ConfigDict( + from_attributes=True, + populate_by_name=True, + json_schema_extra={ "example": { "id": "kb-uuid-123", "name": "my_knowledge_base", @@ -63,35 +52,38 @@ class Config: "chunkCount": 150, "embedding": { "id": "model-1", - "name": "text-embedding-ada-002", + "modelName": "text-embedding-ada-002", "provider": "openai" } } } + ) - -class RagFileResp(BaseModel): - """RAG 文件响应 - - 对应 Java: com.datamate.rag.indexer.domain.model.RagFile - """ - id: str = Field(..., description="RAG文件ID") - knowledge_base_id: str = Field(alias="knowledgeBaseId", description="知识库ID") - file_name: str = Field(alias="fileName", description="文件名") - file_id: str = Field(alias="fileId", description="原始文件ID") + id: str = Field(..., description="知识库ID") + name: str = Field(..., description="知识库名称") + description: Optional[str] = Field(None, description="知识库描述") + type: RagType = Field(..., description="RAG类型") + embedding_model: str = Field(alias="embeddingModel", description="嵌入模型ID") + chat_model: Optional[str] = Field(None, alias="chatModel", description="聊天模型ID") + file_count: Optional[int] = Field(None, alias="fileCount", description="文件数量") chunk_count: Optional[int] = Field(None, alias="chunkCount", description="分块数量") - metadata: Optional[dict] = Field(None, description="元数据") - status: FileStatus = Field(..., description="处理状态") - err_msg: Optional[str] = Field(None, alias="errMsg", description="错误信息") - progress: int = Field(default=0, ge=0, le=100, description="处理进度(0-100)") + embedding: Optional[ModelConfig] = Field(None, description="嵌入模型配置") + chat: Optional[ModelConfig] = Field(None, description="聊天模型配置") created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") updated_by: Optional[str] = Field(None, alias="updatedBy", description="更新人") - class Config: - populate_by_name = True # 允许使用 snake_case 或 camelCase - json_schema_extra = { + +class RagFileResp(BaseModel): + """RAG 文件响应 + + 对应 Java: com.datamate.rag.indexer.domain.model.RagFile + """ + model_config = ConfigDict( + from_attributes=True, + populate_by_name=True, + json_schema_extra={ "example": { "id": "rag-file-uuid-123", "knowledgeBaseId": "kb-uuid-123", @@ -104,6 +96,21 @@ class Config: "createdAt": "2025-01-01T00:00:00" } } + ) + + id: str = Field(..., description="RAG文件ID") + knowledge_base_id: str = Field(alias="knowledgeBaseId", description="知识库ID") + file_name: str = Field(alias="fileName", description="文件名") + file_id: str = Field(alias="fileId", description="原始文件ID") + chunk_count: Optional[int] = Field(None, alias="chunkCount", description="分块数量") + metadata: Optional[dict] = Field(None, validation_alias="file_metadata", description="元数据") + status: FileStatus = Field(..., description="处理状态") + err_msg: Optional[str] = Field(None, alias="errMsg", description="错误信息") + progress: int = Field(default=0, ge=0, le=100, description="处理进度(0-100)") + created_at: Optional[datetime] = Field(None, alias="createdAt", description="创建时间") + updated_at: Optional[datetime] = Field(None, alias="updatedAt", description="更新时间") + created_by: Optional[str] = Field(None, alias="createdBy", description="创建人") + updated_by: Optional[str] = Field(None, alias="updatedBy", description="更新人") class RagChunkResp(BaseModel): diff --git a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py index a59e95728..27b1eff92 100644 --- a/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py +++ b/runtime/datamate-python/app/module/rag/service/knowledge_base_service.py @@ -6,7 +6,7 @@ """ import logging import uuid -from typing import List, Tuple +from typing import List, Tuple, Optional from fastapi import BackgroundTasks from sqlalchemy import select @@ -15,6 +15,7 @@ from app.core.exception import BusinessError, ErrorCodes from app.db.models.dataset_management import DatasetFiles from app.db.models.knowledge_gen import KnowledgeBase, RagFile, FileStatus +from app.db.models.models import Models from app.module.rag.infra.vectorstore import drop_collection, rename_collection, delete_chunks_by_rag_file_ids from app.module.rag.repository import KnowledgeBaseRepository, RagFileRepository from app.module.rag.schema.request import ( @@ -25,7 +26,7 @@ DeleteFilesReq, RagFileReq, ) -from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagFileResp +from app.module.rag.schema.response import KnowledgeBaseResp, PagedResponse, RagFileResp, ModelConfig from app.module.rag.service.file_processor import FileProcessor logger = logging.getLogger(__name__) @@ -123,14 +124,7 @@ async def delete(self, knowledge_base_id: str) -> None: await self.db.commit() async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: - """获取知识库详情 - - Args: - knowledge_base_id: 知识库 ID - - Returns: - 知识库响应对象 - """ + """获取知识库详情""" knowledge_base = await self.kb_repo.get_by_id(knowledge_base_id) if not knowledge_base: raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) @@ -138,30 +132,46 @@ async def get_by_id(self, knowledge_base_id: str) -> KnowledgeBaseResp: file_count = await self.file_repo.count_by_knowledge_base(knowledge_base_id) chunk_count = await self.file_repo.count_chunks_by_knowledge_base(knowledge_base_id) - return KnowledgeBaseResp( - id=knowledge_base.id, - name=knowledge_base.name, - description=knowledge_base.description, - type=knowledge_base.type, - embedding_model=knowledge_base.embedding_model, - chat_model=knowledge_base.chat_model, - file_count=file_count, - chunk_count=chunk_count, - created_at=knowledge_base.created_at, - updated_at=knowledge_base.updated_at, - created_by=knowledge_base.created_by, - updated_by=knowledge_base.updated_by, - ) + data = self._kb_to_dict(knowledge_base) + data.update({ + "file_count": file_count, + "chunk_count": chunk_count, + "embedding": await self._get_model_config(knowledge_base.embedding_model), + "chat": await self._get_model_config(knowledge_base.chat_model), + }) + return KnowledgeBaseResp(**data) + + def _kb_to_dict(self, kb: KnowledgeBase) -> dict: + """知识库实体转字典""" + return { + "id": kb.id, + "name": kb.name, + "description": kb.description, + "type": kb.type, + "embedding_model": kb.embedding_model, + "chat_model": kb.chat_model, + "created_at": kb.created_at, + "updated_at": kb.updated_at, + "created_by": kb.created_by, + "updated_by": kb.updated_by, + } - async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: - """分页查询知识库列表 + async def _get_model_config(self, model_id: Optional[str]) -> Optional[ModelConfig]: + """获取模型配置""" + if not model_id: + return None - Args: - request: 查询请求 + result = await self.db.execute( + select(Models).where( + Models.id == model_id, + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) + ) + ) + model = result.scalar_one_or_none() + return ModelConfig.model_validate(model) if model else None - Returns: - 分页响应 - """ + async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: + """分页查询知识库列表""" items, total = await self.kb_repo.list( keyword=request.keyword, rag_type=request.type, @@ -173,20 +183,12 @@ async def list(self, request: KnowledgeBaseQueryReq) -> PagedResponse: for item in items: file_count = await self.file_repo.count_by_knowledge_base(item.id) chunk_count = await self.file_repo.count_chunks_by_knowledge_base(item.id) - responses.append(KnowledgeBaseResp( - id=item.id, - name=item.name, - description=item.description, - type=item.type, - embedding_model=item.embedding_model, - chat_model=item.chat_model, - file_count=file_count, - chunk_count=chunk_count, - created_at=item.created_at, - updated_at=item.updated_at, - created_by=item.created_by, - updated_by=item.updated_by, - )) + data = self._kb_to_dict(item) + data.update({ + "file_count": file_count, + "chunk_count": chunk_count, + }) + responses.append(KnowledgeBaseResp(**data)) return PagedResponse.create( content=responses, @@ -286,15 +288,7 @@ async def _create_rag_files(self, request: AddFilesReq) -> Tuple[List[RagFile], return rag_files, skipped_file_ids async def list_files(self, knowledge_base_id: str, request: RagFileReq) -> PagedResponse: - """获取知识库文件列表 - - Args: - knowledge_base_id: 知识库 ID - request: 查询请求 - - Returns: - 分页响应 - """ + """获取知识库文件列表""" if not await self.kb_repo.get_by_id(knowledge_base_id): raise BusinessError(ErrorCodes.RAG_KNOWLEDGE_BASE_NOT_FOUND) @@ -306,21 +300,7 @@ async def list_files(self, knowledge_base_id: str, request: RagFileReq) -> Paged page_size=request.page_size, ) - responses = [RagFileResp( - id=item.id, - knowledge_base_id=item.knowledge_base_id, - file_name=item.file_name, - file_id=item.file_id, - chunk_count=item.chunk_count, - metadata=item.file_metadata, - status=item.status, - err_msg=item.err_msg, - progress=getattr(item, "progress", 0), - created_at=item.created_at, - updated_at=item.updated_at, - created_by=item.created_by, - updated_by=item.updated_by, - ) for item in items] + responses = [RagFileResp.model_validate(item) for item in items] return PagedResponse.create( content=responses,