diff --git a/backend/apps/system/crud/assistant.py b/backend/apps/system/crud/assistant.py index f4d6cbc5..130ef9f7 100644 --- a/backend/apps/system/crud/assistant.py +++ b/backend/apps/system/crud/assistant.py @@ -5,12 +5,10 @@ import requests from fastapi import FastAPI -from sqlalchemy import Engine, create_engine from sqlmodel import Session, select from starlette.middleware.cors import CORSMiddleware -# from apps.datasource.embedding.table_embedding import get_table_embedding -from apps.datasource.models.datasource import CoreDatasource, DatasourceConf +from apps.datasource.models.datasource import CoreDatasource from apps.datasource.utils.utils import aes_encrypt from apps.system.models.system_model import AssistantModel from apps.system.schemas.auth import CacheName, CacheNamespace @@ -19,7 +17,7 @@ from common.core.db import engine from common.core.sqlbot_cache import cache from common.utils.aes_crypto import simple_aes_decrypt -from common.utils.utils import SQLBotLogUtil, equals_ignore_case, get_domain_list, string_to_numeric_hash +from common.utils.utils import SQLBotLogUtil, get_domain_list, string_to_numeric_hash from common.core.deps import Trans from common.core.response_middleware import ResponseMiddleware @@ -101,7 +99,9 @@ def init_dynamic_cors(app: FastAPI): if cors_middleware: cors_middleware.kwargs['allow_origins'] = updated_origins if response_middleware: - response_middleware.kwargs['allow_origins'] = updated_origins + for instance in ResponseMiddleware.instances: + instance.update_allow_origins(updated_origins) + except Exception as e: return False, e diff --git a/backend/apps/system/crud/assistant_manage.py b/backend/apps/system/crud/assistant_manage.py index b3f285ca..adec9679 100644 --- a/backend/apps/system/crud/assistant_manage.py +++ b/backend/apps/system/crud/assistant_manage.py @@ -9,6 +9,7 @@ from apps.system.models.system_model import AssistantModel from common.utils.time import get_timestamp from common.utils.utils import get_domain_list +from common.core.response_middleware import ResponseMiddleware def dynamic_upgrade_cors(request: Request, session: Session): @@ -24,14 +25,21 @@ def dynamic_upgrade_cors(request: Request, session: Session): unique_domains.append(domain) app: FastAPI = request.app cors_middleware = None + response_middleware = None for middleware in app.user_middleware: - if middleware.cls == CORSMiddleware: + if not cors_middleware and middleware.cls == CORSMiddleware: cors_middleware = middleware + if not response_middleware and middleware.cls == ResponseMiddleware: + response_middleware = middleware + if cors_middleware and response_middleware: break + + updated_origins = list(set(settings.all_cors_origins + unique_domains)) if cors_middleware: - updated_origins = list(set(settings.all_cors_origins + unique_domains)) cors_middleware.kwargs['allow_origins'] = updated_origins - + if response_middleware: + for instance in ResponseMiddleware.instances: + instance.update_allow_origins(updated_origins) async def save(request: Request, session: Session, creator: AssistantBase, oid: Optional[int] = 1): db_model = AssistantModel.model_validate(creator) diff --git a/backend/common/core/response_middleware.py b/backend/common/core/response_middleware.py index c2eeb43f..922035fa 100644 --- a/backend/common/core/response_middleware.py +++ b/backend/common/core/response_middleware.py @@ -1,6 +1,6 @@ import json +from typing import Optional -from redis import typing from starlette.exceptions import HTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request @@ -11,10 +11,19 @@ class ResponseMiddleware(BaseHTTPMiddleware): - def __init__(self, app): - self.allow_origins = ["'self'"] + instances = [] + + def __init__(self, app, allow_origins: Optional[list[str]] = None): super().__init__(app) + self.allow_origins = allow_origins or ["'self'"] + ResponseMiddleware.instances.append(self) + def update_allow_origins(self, new_allow_origins: Optional[list[str]] = None): + if not new_allow_origins: + return + self.allow_origins = list(set(self.allow_origins + new_allow_origins)) + + async def dispatch(self, request, call_next): response = await call_next(request) diff --git a/frontend/src/views/login/index.vue b/frontend/src/views/login/index.vue index b8d2f626..104f1299 100644 --- a/frontend/src/views/login/index.vue +++ b/frontend/src/views/login/index.vue @@ -41,7 +41,7 @@