Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions backend/apps/system/crud/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

import requests
from fastapi import FastAPI
from sqlalchemy import Engine, create_engine
from sqlmodel import Session, select
from starlette.middleware.cors import CORSMiddleware

# from apps.datasource.embedding.table_embedding import get_table_embedding
from apps.datasource.models.datasource import CoreDatasource, DatasourceConf
from apps.datasource.models.datasource import CoreDatasource
from apps.datasource.utils.utils import aes_encrypt
from apps.system.models.system_model import AssistantModel
from apps.system.schemas.auth import CacheName, CacheNamespace
Expand All @@ -19,7 +17,7 @@
from common.core.db import engine
from common.core.sqlbot_cache import cache
from common.utils.aes_crypto import simple_aes_decrypt
from common.utils.utils import SQLBotLogUtil, equals_ignore_case, get_domain_list, string_to_numeric_hash
from common.utils.utils import SQLBotLogUtil, get_domain_list, string_to_numeric_hash
from common.core.deps import Trans
from common.core.response_middleware import ResponseMiddleware

Expand Down Expand Up @@ -101,7 +99,9 @@ def init_dynamic_cors(app: FastAPI):
if cors_middleware:
cors_middleware.kwargs['allow_origins'] = updated_origins
if response_middleware:
response_middleware.kwargs['allow_origins'] = updated_origins
for instance in ResponseMiddleware.instances:
instance.update_allow_origins(updated_origins)

except Exception as e:
return False, e

Expand Down
14 changes: 11 additions & 3 deletions backend/apps/system/crud/assistant_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from apps.system.models.system_model import AssistantModel
from common.utils.time import get_timestamp
from common.utils.utils import get_domain_list
from common.core.response_middleware import ResponseMiddleware


def dynamic_upgrade_cors(request: Request, session: Session):
Expand All @@ -24,14 +25,21 @@ def dynamic_upgrade_cors(request: Request, session: Session):
unique_domains.append(domain)
app: FastAPI = request.app
cors_middleware = None
response_middleware = None
for middleware in app.user_middleware:
if middleware.cls == CORSMiddleware:
if not cors_middleware and middleware.cls == CORSMiddleware:
cors_middleware = middleware
if not response_middleware and middleware.cls == ResponseMiddleware:
response_middleware = middleware
if cors_middleware and response_middleware:
break

updated_origins = list(set(settings.all_cors_origins + unique_domains))
if cors_middleware:
updated_origins = list(set(settings.all_cors_origins + unique_domains))
cors_middleware.kwargs['allow_origins'] = updated_origins

if response_middleware:
for instance in ResponseMiddleware.instances:
instance.update_allow_origins(updated_origins)

async def save(request: Request, session: Session, creator: AssistantBase, oid: Optional[int] = 1):
db_model = AssistantModel.model_validate(creator)
Expand Down
15 changes: 12 additions & 3 deletions backend/common/core/response_middleware.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import Optional

from redis import typing
from starlette.exceptions import HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
Expand All @@ -11,10 +11,19 @@


class ResponseMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
self.allow_origins = ["'self'"]
instances = []

def __init__(self, app, allow_origins: Optional[list[str]] = None):
super().__init__(app)
self.allow_origins = allow_origins or ["'self'"]
ResponseMiddleware.instances.append(self)

def update_allow_origins(self, new_allow_origins: Optional[list[str]] = None):
if not new_allow_origins:
return
self.allow_origins = list(set(self.allow_origins + new_allow_origins))


async def dispatch(self, request, call_next):
response = await call_next(request)

Expand Down
2 changes: 1 addition & 1 deletion frontend/src/views/login/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
<el-input
v-model="loginForm.username"
clearable
:placeholder="$t('common.your_account_email_address')"
:placeholder="$t('login.input_account')"
size="large"
></el-input>
</el-form-item>
Expand Down
Loading