From 9bf8176f88a94b9ff33ea097ce843878ae393062 Mon Sep 17 00:00:00 2001 From: fit2cloud-chenyw Date: Fri, 12 Dec 2025 11:03:45 +0800 Subject: [PATCH] feat: API permission control --- backend/apps/chat/api/chat.py | 4 + backend/apps/datasource/api/datasource.py | 4 + backend/apps/system/api/aimodel.py | 6 ++ backend/apps/system/api/parameter.py | 3 + backend/apps/system/api/user.py | 21 +++- backend/apps/system/api/workspace.py | 23 +++- backend/apps/system/schemas/permission.py | 126 ++++++++++++++++++++++ backend/main.py | 2 + 8 files changed, 181 insertions(+), 8 deletions(-) create mode 100644 backend/apps/system/schemas/permission.py diff --git a/backend/apps/chat/api/chat.py b/backend/apps/chat/api/chat.py index 8e672a46..485397cd 100644 --- a/backend/apps/chat/api/chat.py +++ b/backend/apps/chat/api/chat.py @@ -14,6 +14,7 @@ format_json_data, format_json_list_data, get_chart_config, list_recent_questions from apps.chat.models.chat_model import CreateChat, ChatRecord, RenameChat, ChatQuestion, AxisObj from apps.chat.task.llm import LLMService +from apps.system.schemas.permission import SqlbotPermission, require_permissions from common.core.deps import CurrentAssistant, SessionDep, CurrentUser, Trans from common.utils.data_format import DataFormat @@ -86,6 +87,7 @@ async def delete(session: SessionDep, chart_id: int): @router.post("/start") +@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="create_chat_obj.datasource")) async def start_chat(session: SessionDep, current_user: CurrentUser, create_chat_obj: CreateChat): try: return create_chat(session, current_user, create_chat_obj) @@ -137,11 +139,13 @@ def _err(_e: Exception): @router.get("/recent_questions/{datasource_id}") +@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="datasource_id")) async def recommend_questions(session: SessionDep, current_user: CurrentUser, datasource_id: int): return list_recent_questions(session=session, current_user=current_user, datasource_id=datasource_id) @router.post("/question") +@require_permissions(permission=SqlbotPermission(type='chat', keyExpression="request_question.chat_id")) async def stream_sql(session: SessionDep, current_user: CurrentUser, request_question: ChatQuestion, current_assistant: CurrentAssistant): """Stream SQL analysis results diff --git a/backend/apps/datasource/api/datasource.py b/backend/apps/datasource/api/datasource.py index 7cb5ef32..b03f2705 100644 --- a/backend/apps/datasource/api/datasource.py +++ b/backend/apps/datasource/api/datasource.py @@ -13,6 +13,7 @@ from apps.db.db import get_schema from apps.db.engine import get_engine_conn from apps.swagger.i18n import PLACEHOLDER_PREFIX +from apps.system.schemas.permission import SqlbotPermission, require_permissions from common.core.config import settings from common.core.deps import SessionDep, CurrentUser, Trans from common.utils.utils import SQLBotLogUtil @@ -78,6 +79,7 @@ def inner(): @router.post("/update", response_model=CoreDatasource) +@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="ds.id")) async def update(session: SessionDep, trans: Trans, user: CurrentUser, ds: CoreDatasource): def inner(): return update_ds(session, trans, user, ds) @@ -86,11 +88,13 @@ def inner(): @router.post("/delete/{id}", response_model=CoreDatasource) +@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="id")) async def delete(session: SessionDep, id: int): return delete_ds(session, id) @router.post("/getTables/{id}") +@require_permissions(permission=SqlbotPermission(type='ds', keyExpression="id")) async def get_tables(session: SessionDep, id: int): return getTables(session, id) diff --git a/backend/apps/system/api/aimodel.py b/backend/apps/system/api/aimodel.py index a96354da..6abba7a2 100644 --- a/backend/apps/system/api/aimodel.py +++ b/backend/apps/system/api/aimodel.py @@ -8,6 +8,7 @@ from sqlmodel import func, select, update from apps.system.models.system_model import AiModelDetail +from apps.system.schemas.permission import SqlbotPermission, require_permissions from common.core.deps import SessionDep, Trans from common.utils.crypto import sqlbot_decrypt from common.utils.time import get_timestamp @@ -51,6 +52,7 @@ async def check_default(session: SessionDep, trans: Trans): raise Exception(trans('i18n_llm.miss_default')) @router.put("/default/{id}") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def set_default(session: SessionDep, id: int): db_model = session.get(AiModelDetail, id) if not db_model: @@ -70,6 +72,7 @@ async def set_default(session: SessionDep, id: int): raise e @router.get("", response_model=list[AiModelGridItem]) +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def query( session: SessionDep, keyword: Union[str, None] = Query(default=None, max_length=255) @@ -113,6 +116,7 @@ async def get_model_by_id( return AiModelEditor(**data) @router.post("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def add_model( session: SessionDep, creator: AiModelCreator @@ -129,6 +133,7 @@ async def add_model( session.commit() @router.put("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def update_model( session: SessionDep, editor: AiModelEditor @@ -144,6 +149,7 @@ async def update_model( session.commit() @router.delete("/{id}") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def delete_model( session: SessionDep, trans: Trans, diff --git a/backend/apps/system/api/parameter.py b/backend/apps/system/api/parameter.py index af254d71..89a5cf4d 100644 --- a/backend/apps/system/api/parameter.py +++ b/backend/apps/system/api/parameter.py @@ -4,6 +4,7 @@ from apps.system.crud.parameter_manage import get_groups, get_parameter_args, save_parameter_args +from apps.system.schemas.permission import SqlbotPermission, require_permissions from common.core.deps import SessionDep router = APIRouter(tags=["system/parameter"], prefix="/system/parameter") @@ -13,9 +14,11 @@ async def get_login_args(session: SessionDep) -> list[SysArgModel]: return await get_groups(session, "login") @router.get("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def get_args(session: SessionDep) -> list[SysArgModel]: return await get_parameter_args(session) @router.post("", ) +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def save_args(session: SessionDep, request: Request): return await save_parameter_args(session = session, request = request) diff --git a/backend/apps/system/api/user.py b/backend/apps/system/api/user.py index 43ec5577..adc5096c 100644 --- a/backend/apps/system/api/user.py +++ b/backend/apps/system/api/user.py @@ -6,6 +6,7 @@ from apps.system.models.system_model import UserWsModel, WorkspaceModel from apps.system.models.user import UserModel from apps.system.schemas.auth import CacheName, CacheNamespace +from apps.system.schemas.permission import SqlbotPermission, require_permissions from apps.system.schemas.system_schema import PwdEditor, UserCreator, UserEditor, UserGrid, UserLanguage, UserStatus, UserWs from common.core.deps import CurrentUser, SessionDep, Trans from common.core.pagination import Paginator @@ -20,11 +21,14 @@ async def user_info(current_user: CurrentUser): return current_user + @router.get("/defaultPwd") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def default_pwd() -> str: return settings.DEFAULT_PWD @router.get("/pager/{pageNum}/{pageSize}", response_model=PaginatedResponse[UserGrid]) +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def pager( session: SessionDep, pageNum: int, @@ -123,6 +127,7 @@ async def ws_change(session: SessionDep, current_user: CurrentUser, trans:Trans, session.commit() @router.get("/{id}", response_model=UserEditor) +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def query(session: SessionDep, trans: Trans, id: int) -> UserEditor: db_user: UserModel = get_db_user(session = session, user_id = id) u_ws_options = await user_ws_options(session, id, trans) @@ -131,7 +136,9 @@ async def query(session: SessionDep, trans: Trans, id: int) -> UserEditor: result.oid_list = [item.id for item in u_ws_options] return result + @router.post("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def create(session: SessionDep, creator: UserCreator, trans: Trans): if check_account_exists(session=session, account=creator.account): raise Exception(trans('i18n_exist', msg = f"{trans('i18n_user.account')} [{creator.account}]")) @@ -158,8 +165,10 @@ async def create(session: SessionDep, creator: UserCreator, trans: Trans): user_model.oid = creator.oid_list[0] session.add(user_model) session.commit() + @router.put("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="editor.id") async def update(session: SessionDep, editor: UserEditor, trans: Trans): user_model: UserModel = get_db_user(session = session, user_id = editor.id) @@ -193,12 +202,14 @@ async def update(session: SessionDep, editor: UserEditor, trans: Trans): user_model.oid = origin_oid if origin_oid in editor.oid_list else editor.oid_list[0] session.add(user_model) session.commit() - + @router.delete("/{id}") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def delete(session: SessionDep, id: int): await single_delete(session, id) -@router.delete("") +@router.delete("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def batch_del(session: SessionDep, id_list: list[int]): for id in id_list: await single_delete(session, id) @@ -213,8 +224,10 @@ async def langChange(session: SessionDep, current_user: CurrentUser, trans: Tran db_user.language = lang session.add(db_user) session.commit() - + + @router.patch("/pwd/{id}") +@require_permissions(permission=SqlbotPermission(role=['admin'])) @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="id") async def pwdReset(session: SessionDep, current_user: CurrentUser, trans: Trans, id: int): if not current_user.isAdmin: @@ -236,8 +249,10 @@ async def pwdUpdate(session: SessionDep, current_user: CurrentUser, trans: Trans db_user.password = md5pwd(new_pwd) session.add(db_user) session.commit() + @router.patch("/status") +@require_permissions(permission=SqlbotPermission(role=['admin'])) @clear_cache(namespace=CacheNamespace.AUTH_INFO, cacheName=CacheName.USER_INFO, keyExpression="statusDto.id") async def langChange(session: SessionDep, current_user: CurrentUser, trans: Trans, statusDto: UserStatus): if not current_user.isAdmin: diff --git a/backend/apps/system/api/workspace.py b/backend/apps/system/api/workspace.py index 554e3de8..1dd9a091 100644 --- a/backend/apps/system/api/workspace.py +++ b/backend/apps/system/api/workspace.py @@ -5,6 +5,7 @@ from apps.system.crud.workspace import reset_single_user_oid, reset_user_oid from apps.system.models.system_model import UserWsModel, WorkspaceBase, WorkspaceEditor, WorkspaceModel from apps.system.models.user import UserModel +from apps.system.schemas.permission import SqlbotPermission, require_permissions from apps.system.schemas.system_schema import UserWsBase, UserWsDTO, UserWsEditor, UserWsOption, WorkspaceUser from common.core.deps import CurrentUser, SessionDep, Trans from common.core.pagination import Paginator @@ -14,6 +15,7 @@ router = APIRouter(tags=["system/workspace"], prefix="/system/workspace") @router.get("/uws/option/pager/{pageNum}/{pageSize}", response_model=PaginatedResponse[UserWsOption]) +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def option_pager( session: SessionDep, current_user: CurrentUser, @@ -48,6 +50,7 @@ async def option_pager( ) @router.get("/uws/option", response_model=UserWsOption | None) +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def option_user( session: SessionDep, current_user: CurrentUser, @@ -74,7 +77,9 @@ async def option_user( ) return session.exec(stmt).first() + @router.get("/uws/pager/{pageNum}/{pageSize}", response_model=PaginatedResponse[WorkspaceUser]) +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def pager( session: SessionDep, current_user: CurrentUser, @@ -114,7 +119,8 @@ async def pager( ) -@router.post("/uws") +@router.post("/uws") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def create(session: SessionDep, current_user: CurrentUser, trans: Trans, creator: UserWsDTO): if not current_user.isAdmin and current_user.weight == 0: raise Exception(trans('i18n_permission.no_permission', url = '', msg = '')) @@ -136,7 +142,8 @@ async def create(session: SessionDep, current_user: CurrentUser, trans: Trans, c session.add_all(db_model_list) session.commit() -@router.put("/uws") +@router.put("/uws") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def edit(session: SessionDep, trans: Trans, editor: UserWsEditor): if not editor.oid or not editor.uid: raise Exception(trans('i18n_miss_args', key = '[oid, uid]')) @@ -152,7 +159,8 @@ async def edit(session: SessionDep, trans: Trans, editor: UserWsEditor): await clean_user_cache(editor.uid) session.commit() -@router.delete("/uws") +@router.delete("/uws") +@require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def delete(session: SessionDep, current_user: CurrentUser, trans: Trans, dto: UserWsBase): if not current_user.isAdmin and current_user.weight == 0: raise Exception(trans('i18n_permission.no_permission', url = '', msg = '')) @@ -170,6 +178,7 @@ async def delete(session: SessionDep, current_user: CurrentUser, trans: Trans, d session.commit() @router.get("", response_model=list[WorkspaceModel]) +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def query(session: SessionDep, trans: Trans): list_result = session.exec(select(WorkspaceModel)).all() for ws in list_result: @@ -179,6 +188,7 @@ async def query(session: SessionDep, trans: Trans): return list_result @router.post("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def add(session: SessionDep, creator: WorkspaceBase): db_model = WorkspaceModel.model_validate(creator) db_model.create_time = get_timestamp() @@ -186,6 +196,7 @@ async def add(session: SessionDep, creator: WorkspaceBase): session.commit() @router.put("") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def update(session: SessionDep, editor: WorkspaceEditor): id = editor.id db_model = session.get(WorkspaceModel, id) @@ -195,7 +206,8 @@ async def update(session: SessionDep, editor: WorkspaceEditor): session.add(db_model) session.commit() -@router.get("/{id}", response_model=WorkspaceModel) +@router.get("/{id}", response_model=WorkspaceModel) +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def get_one(session: SessionDep, trans: Trans, id: int): db_model = session.get(WorkspaceModel, id) if not db_model: @@ -204,7 +216,8 @@ async def get_one(session: SessionDep, trans: Trans, id: int): db_model.name = trans(db_model.name) return db_model -@router.delete("/{id}") +@router.delete("/{id}") +@require_permissions(permission=SqlbotPermission(role=['admin'])) async def single_delete(session: SessionDep, current_user: CurrentUser, id: int): if not current_user.isAdmin: raise HTTPException("only admin can delete workspace") diff --git a/backend/apps/system/schemas/permission.py b/backend/apps/system/schemas/permission.py new file mode 100644 index 00000000..d0dccd24 --- /dev/null +++ b/backend/apps/system/schemas/permission.py @@ -0,0 +1,126 @@ +from contextvars import ContextVar +from functools import wraps +from inspect import signature +from typing import Optional +from fastapi import HTTPException, Request +from pydantic import BaseModel +import re +from starlette.middleware.base import BaseHTTPMiddleware +from sqlmodel import Session, select +from apps.chat.models.chat_model import Chat +from apps.datasource.models.datasource import CoreDatasource +from common.core.db import engine +from apps.system.schemas.system_schema import UserInfoDTO + + +class SqlbotPermission(BaseModel): + role: Optional[list[str]] = None + type: Optional[str] = None + keyExpression: Optional[str] = None + +async def get_ws_resource(oid, type) -> list: + with Session(engine) as session: + stmt = None + if type == 'ds' or type == 'datasource': + stmt = select(CoreDatasource.id).where(CoreDatasource.oid == oid) + if type == 'chat': + stmt = select(Chat.id).where(Chat.oid == oid) + if stmt is not None: + db_list = session.exec(stmt).all() + return db_list + return [] + + +async def check_ws_permission(oid, type, resource) -> bool: + resource_id_list = await get_ws_resource(oid, type) + if not resource_id_list: + return False + if isinstance(resource, list): + return set(resource).issubset(set(resource_id_list)) + return resource in resource_id_list + + +def require_permissions(permission: SqlbotPermission): + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + request = RequestContext.get_request() + current_user: UserInfoDTO = getattr(request.state, 'current_user', None) + if not current_user: + raise HTTPException( + status_code=401, + detail="用户未认证" + ) + current_oid = current_user.oid + + if current_user.isAdmin: + return await func(*args, **kwargs) + role_list = permission.role + keyExpression = permission.keyExpression + resource_type = permission.type + + if role_list: + if 'admin' in role_list and not current_user.isAdmin: + raise Exception('no permission to execute, only for admin') + if 'ws_admin' in role_list and current_user.weight == 0: + raise Exception('no permission to execute, only for workspace admin') + if not resource_type: + return await func(*args, **kwargs) + if keyExpression: + sig = signature(func) + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + + if keyExpression.startswith("args["): + if match := re.match(r"args\[(\d+)\]", keyExpression): + index = int(match.group(1)) + value = bound_args.args[index] + if await check_ws_permission(current_oid, resource_type, value): + return await func(*args, **kwargs) + raise Exception('no permission to execute or resource do not exist!') + + parts = keyExpression.split('.') + if not bound_args.arguments.get(parts[0]): + return await func(*args, **kwargs) + value = bound_args.arguments[parts[0]] + for part in parts[1:]: + value = getattr(value, part) + if await check_ws_permission(current_oid, resource_type, value): + return await func(*args, **kwargs) + raise Exception('no permission to execute or resource do not exist!') + + return await func(*args, **kwargs) + + return wrapper + return decorator + +class RequestContext: + + _current_request: ContextVar[Request] = ContextVar("_current_request") + @classmethod + def set_request(cls, request: Request): + return cls._current_request.set(request) + + @classmethod + def get_request(cls) -> Request: + try: + return cls._current_request.get() + except LookupError: + raise RuntimeError( + "No request context found. " + "Make sure RequestContextMiddleware is installed." + ) + + @classmethod + def reset(cls, token): + cls._current_request.reset(token) + +class RequestContextMiddleware(BaseHTTPMiddleware): + + async def dispatch(self, request: Request, call_next): + token = RequestContext.set_request(request) + try: + response = await call_next(request) + return response + finally: + RequestContext.reset(token) \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index cfb80d72..6508db30 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,6 +20,7 @@ from apps.system.crud.aimodel_manage import async_model_info from apps.system.crud.assistant import init_dynamic_cors from apps.system.middleware.auth import TokenMiddleware +from apps.system.schemas.permission import RequestContextMiddleware from common.core.config import settings from common.core.response_middleware import ResponseMiddleware, exception_handler from common.core.sqlbot_cache import init_sqlbot_cache @@ -197,6 +198,7 @@ async def custom_swagger_ui(request: Request): app.add_middleware(TokenMiddleware) app.add_middleware(ResponseMiddleware) +app.add_middleware(RequestContextMiddleware) app.include_router(api_router, prefix=settings.API_V1_STR) # Register exception handlers