From 3f6a6e1dae66d5b00cf809bbe302d4ea7f477845 Mon Sep 17 00:00:00 2001 From: Dev Agent Date: Thu, 9 Apr 2026 09:57:36 +0800 Subject: [PATCH 1/3] feat: Optimize aigateway sensitive check whitelist mechanism --- .../mock_RepositoryFileCheckRuleStore.go | 58 +++ aigateway/component/openai.go | 339 ++++++++++++------ aigateway/component/openai_ce_test.go | 191 +++++----- aigateway/handler/openai.go | 126 +++++-- .../handler/openai_check_sensitive_test.go | 217 +++++++++++ aigateway/handler/openai_test.go | 10 +- aigateway/types/openai.go | 110 +++++- builder/store/database/llm_config.go | 2 +- builder/store/database/llm_config_test.go | 108 +++++- ...nique_idx_to_repo_file_check_rule.down.sql | 1 + ..._unique_idx_to_repo_file_check_rule.up.sql | 2 + ...e_to_official_name_in_llm_configs.down.sql | 6 + ...ame_to_official_name_in_llm_configs.up.sql | 6 + .../database/repository_file_check_rule.go | 18 + common/types/llm_service.go | 22 +- component/llm_service.go | 12 +- moderation/component/repo.go | 11 +- moderation/component/repo_test.go | 8 +- 18 files changed, 951 insertions(+), 296 deletions(-) create mode 100644 aigateway/handler/openai_check_sensitive_test.go create mode 100644 builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql create mode 100644 builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql create mode 100644 builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql create mode 100644 builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go index b1f15ded8..0caa7add7 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go @@ -305,6 +305,64 @@ func (_c *MockRepositoryFileCheckRuleStore_ListByRuleType_Call) RunAndReturn(run return _c } +// MatchRegex provides a mock function with given fields: ctx, ruleType, targetString +func (_m *MockRepositoryFileCheckRuleStore) MatchRegex(ctx context.Context, ruleType string, targetString string) (bool, error) { + ret := _m.Called(ctx, ruleType, targetString) + + if len(ret) == 0 { + panic("no return value specified for MatchRegex") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (bool, error)); ok { + return rf(ctx, ruleType, targetString) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, ruleType, targetString) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, ruleType, targetString) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRepositoryFileCheckRuleStore_MatchRegex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MatchRegex' +type MockRepositoryFileCheckRuleStore_MatchRegex_Call struct { + *mock.Call +} + +// MatchRegex is a helper method to define mock.On call +// - ctx context.Context +// - ruleType string +// - targetString string +func (_e *MockRepositoryFileCheckRuleStore_Expecter) MatchRegex(ctx interface{}, ruleType interface{}, targetString interface{}) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + return &MockRepositoryFileCheckRuleStore_MatchRegex_Call{Call: _e.mock.On("MatchRegex", ctx, ruleType, targetString)} +} + +func (_c *MockRepositoryFileCheckRuleStore_MatchRegex_Call) Run(run func(ctx context.Context, ruleType string, targetString string)) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockRepositoryFileCheckRuleStore_MatchRegex_Call) Return(_a0 bool, _a1 error) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRepositoryFileCheckRuleStore_MatchRegex_Call) RunAndReturn(run func(context.Context, string, string) (bool, error)) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + _c.Call.Return(run) + return _c +} + // NewMockRepositoryFileCheckRuleStore creates a new instance of MockRepositoryFileCheckRuleStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockRepositoryFileCheckRuleStore(t interface { diff --git a/aigateway/component/openai.go b/aigateway/component/openai.go index 3c8745120..2793937f0 100644 --- a/aigateway/component/openai.go +++ b/aigateway/component/openai.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "strconv" "strings" "time" @@ -29,8 +30,8 @@ type OpenAIComponent interface { GetAvailableModels(c context.Context, user string) ([]types.Model, error) ListModels(c context.Context, user string, req types.ListModelsReq) (types.ModelList, error) GetModelByID(c context.Context, username, modelID string) (*types.Model, error) - RecordUsage(c context.Context, userUUID string, model *types.Model, tokenCounter token.Counter, sceneValue string) error - CheckBalance(ctx context.Context, username string) error + RecordUsage(c context.Context, userUUID string, model *types.Model, tokenCounter token.Counter) error + CheckBalance(ctx context.Context, username, userUUID string) error } type openaiComponentImpl struct { @@ -67,6 +68,8 @@ func (m *openaiComponentImpl) GetAvailableModels(c context.Context, userName str externalModels := m.getExternalModels(c) models = append(models, externalModels...) + models = m.enrichModelsWithPrice(c, models) + // Save models to cache asynchronously go func(modelList []types.Model) { if len(modelList) == 0 { @@ -103,56 +106,72 @@ func (m *openaiComponentImpl) ListModels(c context.Context, userName string, req return filterAndPaginateModels(models, req), nil } -func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) types.ModelList { - // Apply fuzzy search filter if model_id is provided - searchQuery := req.ModelID - if searchQuery != "" { - filtered := make([]types.Model, 0, len(models)) - sq := strings.ToLower(searchQuery) - for _, model := range models { - if strings.Contains(strings.ToLower(model.ID), sq) { - filtered = append(filtered, model) - } - } - models = filtered +type modelFilter func(m *types.Model) bool + +func filterByModelID(query string) modelFilter { + return func(m *types.Model) bool { + return strings.Contains(strings.ToLower(m.ID), query) } +} - // Apply public filter if provided and parseable - if req.Public != "" { - if isPublic, err := strconv.ParseBool(req.Public); err == nil { - filtered := make([]types.Model, 0, len(models)) - for _, model := range models { - if model.Public == isPublic { - filtered = append(filtered, model) - } - } - models = filtered +func filterBySource(source string) modelFilter { + return func(m *types.Model) bool { + switch source { + case string(types.ModelSourceCSGHub): + return m.CSGHubModelID != "" + case string(types.ModelSourceExternal): + return m.Provider != "" + default: + return true } } +} - // Apply source filter if provided - if req.Source != "" { - source := strings.ToLower(req.Source) - filtered := make([]types.Model, 0, len(models)) - for _, model := range models { - switch source { - case string(types.ModelSourceCSGHub): - if model.CSGHubModelID != "" { - filtered = append(filtered, model) - } - case string(types.ModelSourceExternal): - if model.Provider != "" { - filtered = append(filtered, model) - } - default: - // Unknown source value, include all - filtered = append(filtered, model) +func filterByTask(task string) modelFilter { + return func(m *types.Model) bool { + modelTasks := strings.FieldsFunc(strings.ToLower(m.Task), func(r rune) bool { + return r == ',' + }) + return slices.Contains(modelTasks, task) + } +} + +func applyFilters(models []types.Model, filters []modelFilter) []types.Model { + if len(filters) == 0 { + return models + } + filtered := make([]types.Model, 0, len(models)) + for i := range models { + m := &models[i] + keep := true + for _, f := range filters { + if !f(m) { + keep = false + break } } - models = filtered + if keep { + filtered = append(filtered, *m) + } + } + return filtered +} + +func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) types.ModelList { + var filters []modelFilter + + if searchQuery := strings.ToLower(req.ModelID); searchQuery != "" { + filters = append(filters, filterByModelID(searchQuery)) + } + if source := strings.ToLower(req.Source); source != "" { + filters = append(filters, filterBySource(source)) } + if task := strings.ToLower(req.Task); task != "" { + filters = append(filters, filterByTask(task)) + } + + models = applyFilters(models, filters) - // Parse pagination parameters (defaults match previous handler behavior) per := 20 page := 1 if req.Per != "" { @@ -170,8 +189,7 @@ func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) type } totalCount := len(models) - offset := (page - 1) * per - startIndex := offset + startIndex := (page - 1) * per if startIndex > totalCount { startIndex = totalCount } @@ -187,18 +205,29 @@ func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) type firstID = &paginated[0].ID lastID = &paginated[len(paginated)-1].ID } - hasMore := endIndex < totalCount return types.ModelList{ Object: "list", Data: paginated, FirstID: firstID, LastID: lastID, - HasMore: hasMore, + HasMore: endIndex < totalCount, TotalCount: totalCount, } } +// providerTypeFromDeployType maps a deploy type integer to the LLM type string (MetaKeyLLMType). +func providerTypeFromDeployType(t int) string { + switch t { + case commontypes.ServerlessType: + return types.ProviderTypeServerless + case commontypes.InferenceType: + return types.ProviderTypeInference + default: + return types.ProviderTypeInference + } +} + func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ([]types.Model, error) { runningDeploys, err := m.deployStore.RunningVisibleToUser(c, userID) if err != nil { @@ -212,11 +241,6 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( } // Check if engine_args contains tool-call-parser parameter supportFunctionCall := strings.Contains(deploy.EngineArgs, "tool-call-parser") - // Determine public/private based on deployment type, ownership and secure level. - isPublic := true - if deploy.Type == commontypes.InferenceType && deploy.SecureLevel == commontypes.EndpointPrivate && deploy.UserID == userID { - isPublic = false // private - user's own deployment with private secure level - } repoName := deploy.Repository.Name m := types.Model{ BaseModel: types.BaseModel{ @@ -224,8 +248,10 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( Created: deploy.CreatedAt.Unix(), SupportFunctionCall: supportFunctionCall, Task: string(deploy.Task), - DisplayName: repoName, - Public: isPublic, + OfficialName: repoName, + Metadata: map[string]any{ + types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type), + }, }, InternalModelInfo: types.InternalModelInfo{ CSGHubModelID: deploy.Repository.Path, @@ -236,6 +262,9 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( ImageID: deploy.ImageID, RuntimeFramework: deploy.RuntimeFramework, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, } if deploy.Type == commontypes.ServerlessType { m.BaseModel.OwnedBy = "OpenCSG" @@ -266,6 +295,8 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model search := &commontypes.SearchLLMConfig{} searchType := 16 search.Type = &searchType + enabled := true + search.Enabled = &enabled per := 50 page := 1 @@ -278,20 +309,37 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model } for _, extModel := range extModels { + // Extract tasks from metadata if present + task := "" + if extModel.Metadata != nil { + if tasks, ok := extModel.Metadata[types.MetaKeyTasks].([]any); ok && len(tasks) > 0 { + tasksStrings := make([]string, 0, len(tasks)) + for _, t := range tasks { + if s, ok := t.(string); ok { + tasksStrings = append(tasksStrings, s) + } + } + task = strings.Join(tasksStrings, ",") + } + } + if extModel.Metadata == nil { + extModel.Metadata = map[string]any{} + } + extModel.Metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM m := types.Model{ BaseModel: types.BaseModel{ - Object: "model", - ID: extModel.ModelName, - OwnedBy: extModel.Provider, - DisplayName: extModel.DisplayName, - // Metadata is allowed to be nil; JSON will contain `null` for nil maps. - Metadata: extModel.Metadata, - Public: true, // external models are always public + Object: "model", + ID: extModel.ModelName, + OwnedBy: extModel.Provider, + OfficialName: extModel.OfficialName, + Metadata: extModel.Metadata, + Task: task, }, Endpoint: extModel.ApiEndpoint, ExternalModelInfo: types.ExternalModelInfo{ - Provider: extModel.Provider, - AuthHead: extModel.AuthHeader, + Provider: extModel.Provider, + AuthHead: extModel.AuthHeader, + NeedSensitiveCheck: extModel.NeedSensitiveCheck, }, } models = append(models, m) @@ -397,88 +445,157 @@ func getSceneFromSvcType(svcType int) int { } } -func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, model *types.Model, counter token.Counter, sceneValue string) error { - usage, err := counter.Usage(c) - if err != nil { - return fmt.Errorf("failed to get token usage from counter,error:%w", err) +// csghubMeteringLLMTypeFromModel returns metadata llm_type (e.g. serverless, inference) used as the path component in csghub://… metering URIs. +func csghubMeteringLLMTypeFromModel(m *types.Model) (string, error) { + if m == nil { + return "", fmt.Errorf("model is nil") } + if m.Metadata == nil { + return "", fmt.Errorf("model metadata is nil: cannot resolve %s for resource path", types.MetaKeyLLMType) + } + llmType, ok := m.Metadata[types.MetaKeyLLMType].(string) + if !ok { + return "", fmt.Errorf("model metadata %s missing or not a string", types.MetaKeyLLMType) + } + return llmType, nil +} - scene := parseScene(sceneValue) - slog.DebugContext(c, "token usage", slog.Any("usage", usage), slog.Any("scene", scene)) - var tokenUsageExtra = struct { - PromptTokenNum string `json:"prompt_token_num"` - CompletionTokenNum string `json:"completion_token_num"` - // 0: external, 1: owner is user, 2: other user is inference, 3: serverless - OwnerType commontypes.TokenUsageType `json:"owner_type"` - }{ - PromptTokenNum: fmt.Sprintf("%d", usage.PromptTokens), - CompletionTokenNum: fmt.Sprintf("%d", usage.CompletionTokens), +// meteringResourceFromModel builds a MeteringResource from an OpenAI gateway model (see types.MeteringResource). +func meteringResourceFromModel(model *types.Model) (types.MeteringResource, error) { + if model == nil { + return types.MeteringResource{}, fmt.Errorf("model is nil") + } + if model.CSGHubModelID != "" { + llmType, err := csghubMeteringLLMTypeFromModel(model) + if err != nil { + return types.MeteringResource{}, err + } + id := fmt.Sprintf(types.CSGHubResourceFmt, llmType, model.CSGHubModelID) + return types.MeteringResource{ + ResourceID: id, + ResourceName: id, + CustomerID: model.SvcName, + }, nil + } + if model.Provider != "" { + id := fmt.Sprintf(types.ExternalLLMResourceFmt, model.Provider, model.ID) + return types.MeteringResource{ + ResourceID: id, + ResourceName: id, + CustomerID: id, + }, nil + } + return types.MeteringResource{}, nil +} + +// tokenUsageMeteringExtra is serialized into MeteringEvent.Extra for token billing breakdown. +type tokenUsageMeteringExtra struct { + PromptTokenNum string `json:"prompt_token_num"` + CompletionTokenNum string `json:"completion_token_num"` + OwnerType commontypes.TokenUsageType `json:"owner_type"` +} + +func validateModelForUsageRecord(c context.Context, model *types.Model) error { + if model == nil { + return fmt.Errorf("record usage: model is nil") } if model.CSGHubModelID != "" && model.Provider != "" { slog.WarnContext(c, "bad model info, both csghub model id and external model provider is set", - slog.Any("model info", model)) + slog.Any("model", model)) + return fmt.Errorf("record usage: conflicting csghub model id and external provider") } if model.CSGHubModelID == "" && model.Provider == "" { slog.WarnContext(c, "bad model info, both csghub model id and external model provider is not set", - slog.Any("model info", model)) + slog.Any("model", model)) + return fmt.Errorf("record usage: model missing resource identifiers") + } + return nil +} + +func (m *openaiComponentImpl) tokenUsageMeteringExtraAndScene(c context.Context, userUUID string, model *types.Model, usage *token.Usage) (tokenUsageMeteringExtra, commontypes.SceneType, error) { + scene := commontypes.SceneModelServerless + extra := tokenUsageMeteringExtra{ + PromptTokenNum: fmt.Sprintf("%d", usage.PromptTokens), + CompletionTokenNum: fmt.Sprintf("%d", usage.CompletionTokens), } if model.CSGHubModelID != "" { switch model.SvcType { case commontypes.ServerlessType: - tokenUsageExtra.OwnerType = commontypes.CSGHubServerlessInference + extra.OwnerType = commontypes.CSGHubServerlessInference case commontypes.InferenceType: if model.OwnerUUID == userUUID { - tokenUsageExtra.OwnerType = commontypes.CSGHubUserDeployedInference + extra.OwnerType = commontypes.CSGHubUserDeployedInference } else { belong, err := m.checkOrganization(c, userUUID, model.OwnerUUID) if err != nil { - return fmt.Errorf("failed to check organization,error:%w", err) + return tokenUsageMeteringExtra{}, 0, fmt.Errorf("failed to check organization: %w", err) } if belong { - tokenUsageExtra.OwnerType = commontypes.CSGHubOrganFellowDeployedInference + extra.OwnerType = commontypes.CSGHubOrganFellowDeployedInference } else { - tokenUsageExtra.OwnerType = commontypes.CSGHubOtherDeployedInference + extra.OwnerType = commontypes.CSGHubOtherDeployedInference } } + scene = commontypes.SceneModelInference default: - slog.WarnContext(c, "bad model info, csghub model missing service type", - slog.Any("model info", model)) + slog.ErrorContext(c, "bad model info, csghub model missing service type", slog.Any("model", model)) + return tokenUsageMeteringExtra{}, 0, fmt.Errorf("record usage: csghub model has invalid or missing service type") } + } else if model.Provider != "" { + extra.OwnerType = commontypes.ExternalInference } - if model.Provider != "" { - tokenUsageExtra.OwnerType = commontypes.ExternalInference - } + return extra, scene, nil +} - extraData, _ := json.Marshal(tokenUsageExtra) - event := commontypes.MeteringEvent{ - Uuid: uuid.New(), - UserUUID: userUUID, - Value: usage.TotalTokens, - ValueType: commontypes.TokenNumberType, // count by token - Scene: int(scene), - OpUID: "aigateway", - CreatedAt: time.Now(), - Extra: string(extraData), +func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, model *types.Model, counter token.Counter) error { + usage, err := counter.Usage(c) + if err != nil { + return fmt.Errorf("failed to get token usage from counter: %w", err) } - if model.CSGHubModelID != "" { - event.ResourceID = model.CSGHubModelID - event.ResourceName = model.CSGHubModelID - event.CustomerID = model.SvcName + if err := validateModelForUsageRecord(c, model); err != nil { + return err } - if model.Provider != "" { - extendModelKey := fmt.Sprintf("%s:%s", model.Provider, model.ID) - event.ResourceID = extendModelKey - event.ResourceName = extendModelKey - event.CustomerID = extendModelKey + res, ridErr := meteringResourceFromModel(model) + if ridErr != nil { + slog.ErrorContext(c, "cannot record usage: invalid model for resource id", slog.Any("error", ridErr), slog.Any("model", model)) + return fmt.Errorf("cannot record usage: %w", ridErr) + } + if res.ResourceID == "" { + slog.ErrorContext(c, "cannot record usage: empty resource id for model", slog.Any("model", model)) + return fmt.Errorf("cannot record usage: empty resource id") + } + extra, scene, err := m.tokenUsageMeteringExtraAndScene(c, userUUID, model, usage) + if err != nil { + return err + } + extraData, err := json.Marshal(extra) + if err != nil { + return fmt.Errorf("failed to marshal token usage extra: %w", err) + } + event := commontypes.MeteringEvent{ + Uuid: uuid.New(), + UserUUID: userUUID, + Value: usage.TotalTokens, + ValueType: commontypes.TokenNumberType, + Scene: int(scene), + OpUID: "aigateway", + CreatedAt: time.Now(), + Extra: string(extraData), + ResourceID: res.ResourceID, + ResourceName: res.ResourceName, + CustomerID: res.CustomerID, + } + eventData, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("failed to marshal metering event: %w", err) } - eventData, _ := json.Marshal(event) err = m.eventPub.PublishMeteringEvent(eventData) if err != nil { slog.ErrorContext(c, "failed to publish token usage event", slog.Any("event", event), slog.Any("error", err)) - return fmt.Errorf("failed to publish token usage event,error:%w", err) + return fmt.Errorf("failed to publish token usage event: %w", err) } - slog.InfoContext(c, "public token usage event success", slog.Any("event", event)) + slog.InfoContext(c, "published token usage event success", slog.Any("event", event)) return nil } diff --git a/aigateway/component/openai_ce_test.go b/aigateway/component/openai_ce_test.go index 229f2e3eb..4732ca9a6 100644 --- a/aigateway/component/openai_ce_test.go +++ b/aigateway/component/openai_ce_test.go @@ -90,13 +90,15 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "publicuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: "model1", - Public: true, + ID: "model1:svc1", + OwnedBy: "publicuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model1", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", InternalModelInfo: types.InternalModelInfo{ @@ -107,6 +109,9 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[0].Type, ImageID: deploys[0].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, { @@ -116,7 +121,9 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { Object: "model", Created: deploys[1].CreatedAt.Unix(), Task: "text-to-image", - Public: true, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeServerless, + }, }, Endpoint: "endpoint2", InternalModelInfo: types.InternalModelInfo{ @@ -126,6 +133,9 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[1].Type, ImageID: deploys[1].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, } @@ -147,10 +157,8 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { require.Len(t, models, 2) assert.Equal(t, "model1:svc1", models[0].ID) assert.Equal(t, "publicuser", models[0].OwnedBy) - assert.True(t, models[0].Public) assert.Equal(t, "hf-model2:svc2", models[1].ID) assert.Equal(t, "OpenCSG", models[1].OwnedBy) - assert.True(t, models[1].Public) wg.Wait() }) @@ -206,19 +214,27 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: "model1", - Public: true, + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model1", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[0].ClusterID, - SvcName: deploys[0].SvcName, - ImageID: deploys[0].ImageID, + CSGHubModelID: deploys[0].Repository.Path, + OwnerUUID: deploys[0].User.UUID, + ClusterID: deploys[0].ClusterID, + SvcName: deploys[0].SvcName, + SvcType: deploys[0].Type, + ImageID: deploys[0].ImageID, + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, }, InternalUse: true, }, @@ -229,13 +245,21 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { Object: "model", Created: deploys[1].CreatedAt.Unix(), Task: "text-to-image", - Public: true, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeServerless, + }, }, Endpoint: "endpoint2", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[1].ClusterID, - SvcName: deploys[1].SvcName, - ImageID: deploys[1].ImageID, + CSGHubModelID: deploys[1].Repository.Path, + OwnerUUID: deploys[1].User.UUID, + ClusterID: deploys[1].ClusterID, + SvcName: deploys[1].SvcName, + SvcType: deploys[1].Type, + ImageID: deploys[1].ImageID, + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, }, InternalUse: true, }, @@ -309,19 +333,27 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model3:svc3", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: "model3", - Public: false, + ID: "model3:svc3", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model3", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint3", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[0].ClusterID, - SvcName: deploys[0].SvcName, - ImageID: deploys[0].ImageID, + CSGHubModelID: deploys[0].Repository.Path, + OwnerUUID: deploys[0].User.UUID, + ClusterID: deploys[0].ClusterID, + SvcName: deploys[0].SvcName, + SvcType: deploys[0].Type, + ImageID: deploys[0].ImageID, + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, }, InternalUse: true, }, @@ -344,7 +376,6 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { assert.NoError(t, err) assert.Len(t, models, 1) assert.Equal(t, "model3:svc3", models[0].ID) - assert.False(t, models[0].Public) wg.Wait() }) @@ -396,18 +427,26 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - DisplayName: "model1", - Public: true, + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + OfficialName: "model1", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[0].ClusterID, - SvcName: deploys[0].SvcName, - ImageID: deploys[0].ImageID, + CSGHubModelID: deploys[0].Repository.Path, + OwnerUUID: deploys[0].User.UUID, + ClusterID: deploys[0].ClusterID, + SvcName: deploys[0].SvcName, + SvcType: deploys[0].Type, + ImageID: deploys[0].ImageID, + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, }, InternalUse: true, }, @@ -480,13 +519,15 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { deploys[0].CreatedAt = now expectModel := types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: "model1", - Public: true, + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model1", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", } @@ -514,8 +555,10 @@ func TestOpenAIComponent_ExtGetAvailableModels_Error(t *testing.T) { modelListCache: mockCache, } searchType := 16 + enabled := true search := &commontypes.SearchLLMConfig{ - Type: &searchType, + Type: &searchType, + Enabled: &enabled, } mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search). Return(nil, 0, errors.New("test error")).Once() @@ -563,7 +606,9 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { ID: "test-model-1", OwnedBy: "OpenAI", Object: "model", - Public: true, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeExternalLLM, + }, }, Endpoint: "http://test-endpoint-1.com", ExternalModelInfo: types.ExternalModelInfo{ @@ -584,8 +629,10 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, user.ID). Return([]database.Deploy{}, nil) searchType := 16 + enabled := true search := &commontypes.SearchLLMConfig{ - Type: &searchType, + Type: &searchType, + Enabled: &enabled, } mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search).Return(mockModels, 1, nil) mockCache.EXPECT().HSet(mock.Anything, modelCacheKey, "test-model-1", string(expectJson)). @@ -604,39 +651,3 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { require.Equal(t, "test-model-1", models[0].ID) wg.Wait() } - -func TestParseScene(t *testing.T) { - tests := []struct { - name string - sceneValue string - expected commontypes.SceneType - }{ - { - name: "any scene value returns SceneModelServerless", - sceneValue: commontypes.SceneHeaderCSGHub, - expected: commontypes.SceneModelServerless, - }, - { - name: "empty scene returns SceneModelServerless", - sceneValue: "", - expected: commontypes.SceneModelServerless, - }, - { - name: "agentichub scene returns SceneModelServerless", - sceneValue: commontypes.SceneHeaderAgenticHub, - expected: commontypes.SceneModelServerless, - }, - { - name: "unknown scene returns SceneModelServerless", - sceneValue: "unknown", - expected: commontypes.SceneModelServerless, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := parseScene(tt.sceneValue) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index f31983e6f..9c6106869 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -24,6 +24,7 @@ import ( "opencsg.com/csghub-server/builder/proxy" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/cache" + "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/errorx" commonType "opencsg.com/csghub-server/common/types" @@ -72,7 +73,8 @@ func NewOpenAIHandlerFromConfig(config *config.Config) (OpenAIHandler, error) { return nil, fmt.Errorf("failed to create cluster component, error: %w", err) } storage, _ := component.NewStorage(config) - return newOpenAIHandler(modelService, repoComp, modComponent, clusterComp, token.NewCounterFactory(), text2image.NewRegistry(), config, storage), nil + whitelistRule := database.NewRepositoryFileCheckRuleStore() + return newOpenAIHandler(modelService, repoComp, modComponent, clusterComp, token.NewCounterFactory(), text2image.NewRegistry(), config, storage, whitelistRule), nil } func newOpenAIHandler( @@ -84,6 +86,7 @@ func newOpenAIHandler( t2iRegistry *text2image.Registry, config *config.Config, storage types.Storage, + whitelistRule database.RepositoryFileCheckRuleStore, ) *OpenAIHandlerImpl { return &OpenAIHandlerImpl{ openaiComponent: modelService, @@ -94,6 +97,7 @@ func newOpenAIHandler( t2iRegistry: t2iRegistry, config: config, storage: storage, + whitelistRule: whitelistRule, } } @@ -126,6 +130,60 @@ func (h *OpenAIHandlerImpl) handleInsufficientBalance(c *gin.Context, isStream b } } +func (h *OpenAIHandlerImpl) checkSensitive(ctx context.Context, model *types.Model, chatReq *ChatCompletionRequest, userUUID string, stream bool) (bool, *rpc.CheckResult, error) { + if !model.NeedSensitiveCheck { + return false, nil, nil + } + + if exists, err := h.checkNamespaceWhitelist(ctx, model.OfficialName); err != nil { + return false, nil, err + } else if exists { + slog.DebugContext(ctx, "Skip Sensitive check with OfficialName in white list", slog.String("pattern", model.OfficialName)) + return false, nil, nil + } + + if exists, err := h.checkNamespaceWhitelist(ctx, model.ID); err != nil { + return false, nil, err + } else if exists { + slog.DebugContext(ctx, "Skip Sensitive check with modelID in white list", slog.String("pattern", model.ID)) + return false, nil, nil + } + + // Check model name regex match in white list + matched, err := h.whitelistRule.MatchRegex(ctx, database.RuleTypeModelName, model.ID) + if err != nil { + return false, nil, fmt.Errorf("failed to match model name regex: %w", err) + } + if matched { + slog.DebugContext(ctx, "Skip Sensitive check with MatchRegex in white list", slog.String("RuleTypeModelName", model.ID)) + return false, nil, nil + } + + key := fmt.Sprintf("%s:%s", userUUID, model.ID) + result, err := h.modComponent.CheckChatPrompts(ctx, chatReq.Messages, key, stream) + if err != nil { + return false, nil, fmt.Errorf("failed to call moderation error:%w", err) + } + + return true, result, nil +} + +func (h *OpenAIHandlerImpl) checkNamespaceWhitelist(ctx context.Context, modelPath string) (bool, error) { + namespace, _, err := common.GetNamespaceAndNameFromPath(modelPath) + if err != nil { + return false, nil + } + exists, err := h.whitelistRule.Exists(ctx, database.RuleTypeNamespace, namespace) + if err != nil { + return false, fmt.Errorf("failed to check namespace in white list: %w", err) + } + if exists { + slog.DebugContext(ctx, "Skip Sensitive check with namespace in white list", slog.String("namespace", namespace)) + return true, nil + } + return false, nil +} + // OpenAIHandlerImpl implements the OpenAIHandler interface type OpenAIHandlerImpl struct { openaiComponent component.OpenAIComponent @@ -136,17 +194,18 @@ type OpenAIHandlerImpl struct { t2iRegistry *text2image.Registry config *config.Config storage types.Storage + whitelistRule database.RepositoryFileCheckRuleStore } // ListModels godoc // @Summary List available models -// @Description Returns a list of available models, supports fuzzy search by model_id query parameter and filtering by public status +// @Description Returns a list of available models, supports fuzzy search by model_id query parameter and filtering by source and task // @Tags AIGateway // @Accept json // @Produce json // @Param model_id query string false "Model ID for fuzzy search" -// @Param public query bool false "Filter by public status (true for public models, false for private models)" // @Param source query string false "Filter by source (csghub for CSGHub models, external for external models)" Enums(csghub, external) +// @Param task query string false "Filter by task (e.g., text-generation, text-to-image)" // @Param per query int false "Models per page (default 20, max 100)" // @Param page query int false "Page number (1-based, default 1)" // @Success 200 {object} types.ModelList "OK" @@ -173,8 +232,8 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { resp, err := h.openaiComponent.ListModels(c.Request.Context(), currentUser, types.ListModelsReq{ ModelID: c.Query("model_id"), - Public: c.Query("public"), Source: source, + Task: c.Query("task"), Per: c.Query("per"), Page: c.Query("page"), }) @@ -206,6 +265,7 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { func (h *OpenAIHandlerImpl) GetModel(c *gin.Context) { username := httpbase.GetCurrentUser(c) modelID := c.Param("model") + modelID = strings.TrimPrefix(modelID, "/") if modelID == "" { c.JSON(http.StatusBadRequest, gin.H{ "error": types.Error{ @@ -335,11 +395,12 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { } } - sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey) // Check balance before processing request - if err := h.openaiComponent.CheckBalance(c.Request.Context(), username); err != nil { - h.handleInsufficientBalance(c, chatReq.Stream, username, modelID, err) - return + if !model.SkipBalance() { + if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, userUUID); err != nil { + h.handleInsufficientBalance(c, chatReq.Stream, username, modelID, err) + return + } } // marshal updated request map back to JSON bytes @@ -352,26 +413,26 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { c.String(http.StatusInternalServerError, fmt.Errorf("failed to create reverse proxy:%w", err).Error()) return } - slog.InfoContext(c.Request.Context(), "proxy chat request to model target", slog.Any("target", target), slog.Any("host", host), - slog.Any("user", username), slog.Any("model_name", modelName)) - // Create a combined key using userUUID and modelID for caching and tracking - key := fmt.Sprintf("%s:%s", userUUID, modelID) - result, err := h.modComponent.CheckChatPrompts(c.Request.Context(), chatReq.Messages, key) + + var modComponent component.Moderation = nil + isCheck, result, err := h.checkSensitive(c.Request.Context(), model, chatReq, userUUID, chatReq.Stream) if err != nil { - c.String(http.StatusInternalServerError, fmt.Errorf("failed to call moderation error:%w", err).Error()) - return - } - if result.IsSensitive { - slog.DebugContext(c.Request.Context(), "sensitive content", slog.String("reason", result.Reason)) - errorChunk := generateSensitiveRespForPrompt() - errorChunkJson, _ := json.Marshal(errorChunk) - _, err := c.Writer.Write([]byte("data: " + string(errorChunkJson) + "\n\n" + "[DONE]")) - if err != nil { - slog.ErrorContext(c.Request.Context(), "write into resp error:", slog.String("err", err.Error())) + slog.ErrorContext(c.Request.Context(), "failed to check sensitive", + slog.String("model_id", modelID), + slog.String("username", username), + slog.Any("error", err)) + } + if isCheck { + modComponent = h.modComponent + if result != nil && result.IsSensitive { + handleSensitiveResponse(c, chatReq.Stream, result) + return } - c.Writer.Flush() - return } + + slog.InfoContext(c.Request.Context(), "proxy chat request to model target", slog.Any("target", target), slog.Any("host", host), + slog.Any("user", username), slog.Any("model_name", modelName)) + tokenCounter := h.tokenCounterFactory.NewChat(token.CreateParam{ Endpoint: target, Host: host, @@ -379,7 +440,8 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { ImageID: model.ImageID, Provider: model.Provider, }) - w := NewResponseWriterWrapper(c.Writer, chatReq.Stream, h.modComponent, tokenCounter) + + w := NewResponseWriterWrapper(c.Writer, chatReq.Stream, modComponent, tokenCounter) defer w.ClearBuffer() tokenCounter.AppendPrompts(chatReq.Messages) @@ -411,7 +473,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { usageCtx, cancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), 3*time.Second) defer cancel() - err := h.openaiComponent.RecordUsage(usageCtx, userUUID, model, tokenCounter, sceneValue) + err := h.openaiComponent.RecordUsage(usageCtx, userUUID, model, tokenCounter) if err != nil { slog.ErrorContext(usageCtx, "failed to record token usage", slog.Any("error", err)) } @@ -509,8 +571,7 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) { return } - sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey) - if err := h.openaiComponent.CheckBalance(ctx, username); err != nil { + if err := h.openaiComponent.CheckBalance(ctx, username, userUUID); err != nil { h.handleInsufficientBalance(c, false, username, modelID, err) return } @@ -594,7 +655,7 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) { go func() { usageCtx, cancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), 3*time.Second) defer cancel() - if err := h.openaiComponent.RecordUsage(usageCtx, userUUID, model, imageCounter, sceneValue); err != nil { + if err := h.openaiComponent.RecordUsage(usageCtx, userUUID, model, imageCounter); err != nil { slog.ErrorContext(usageCtx, "failed to record image usage", slog.Any("error", err)) } }() @@ -688,9 +749,8 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { return } - sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey) // Check balance before processing request - if err := h.openaiComponent.CheckBalance(c.Request.Context(), username); err != nil { + if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, userUUID); err != nil { h.handleInsufficientBalance(c, false, username, modelID, err) return } @@ -719,7 +779,7 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) { usageCtx, cancel := context.WithTimeout(context.WithoutCancel(c.Request.Context()), 3*time.Second) defer cancel() - err := h.openaiComponent.RecordUsage(usageCtx, userUUID, model, tokenCounter, sceneValue) + err := h.openaiComponent.RecordUsage(usageCtx, userUUID, model, tokenCounter) if err != nil { slog.ErrorContext(c, "failed to record embedding token usage", "error", err) } diff --git a/aigateway/handler/openai_check_sensitive_test.go b/aigateway/handler/openai_check_sensitive_test.go new file mode 100644 index 000000000..9612fb447 --- /dev/null +++ b/aigateway/handler/openai_check_sensitive_test.go @@ -0,0 +1,217 @@ +package handler + +import ( + "context" + "errors" + "testing" + + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "opencsg.com/csghub-server/aigateway/types" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/builder/store/database" +) + +func TestOpenAIHandler_checkSensitive(t *testing.T) { + ctx := context.Background() + chatReq := &ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessageParamUnion{}, + } + userUUID := "test-uuid" + + t.Run("NeedSensitiveCheck is false", func(t *testing.T) { + tester, _, _ := setupTest(t) + model := &types.Model{ + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: false, + }, + } + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("OfficialName namespace in whitelist", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "another/model", + OfficialName: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(true, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("OfficialName namespace check fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "another/model", + OfficialName: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, errors.New("db error")).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to check namespace in white list: db error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Namespace in whitelist", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(true, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Namespace check fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, errors.New("db error")).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to check namespace in white list: db error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Check moderation API succeeds", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(false, nil).Once() + + expectedResult := &rpc.CheckResult{IsSensitive: true} + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(ctx, chatReq.Messages, "test-uuid:Qwen/Qwen3Guard", false).Return(expectedResult, nil).Once() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.True(t, needCheck) + assert.Equal(t, expectedResult, result) + }) + + t.Run("Check model name regex match succeeds", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(true, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Check model name regex match fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(false, errors.New("db error")).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to match model name regex: db error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Check moderation API fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(false, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(ctx, chatReq.Messages, "test-uuid:Qwen/Qwen3Guard", false).Return(nil, errors.New("mod api error")).Once() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to call moderation error:mod api error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) +} diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index 60615f6e0..100f33f39 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" mockcomp "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/component" mocktoken "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/token" + mockdatabase "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" apicomp "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" "opencsg.com/csghub-server/aigateway/component/adapter/text2image" "opencsg.com/csghub-server/aigateway/token" @@ -39,6 +40,7 @@ type testerOpenAIHandler struct { repoComp *apicomp.MockRepoComponent mockClsComp *apicomp.MockClusterComponent tokenCounterFactory *mocktoken.MockCounterFactory + whitelistRule *mockdatabase.MockRepositoryFileCheckRuleStore } handler *OpenAIHandlerImpl @@ -51,7 +53,8 @@ func setupTest(t *testing.T) (*testerOpenAIHandler, *gin.Context, *httptest.Resp mockClsComp := apicomp.NewMockClusterComponent(t) mockTokenCounterFactory := mocktoken.NewMockCounterFactory(t) cfg := &config.Config{} - handler := newOpenAIHandler(mockOpenAI, mockRepo, mockModeration, mockClsComp, mockTokenCounterFactory, text2image.NewRegistry(), cfg, nil) + mockWhitelistRule := mockdatabase.NewMockRepositoryFileCheckRuleStore(t) + handler := newOpenAIHandler(mockOpenAI, mockRepo, mockModeration, mockClsComp, mockTokenCounterFactory, text2image.NewRegistry(), cfg, nil, mockWhitelistRule) // Set test user tester := &testerOpenAIHandler{ @@ -67,6 +70,11 @@ func setupTest(t *testing.T) (*testerOpenAIHandler, *gin.Context, *httptest.Resp tester.mocks.repoComp = mockRepo tester.mocks.mockClsComp = mockClsComp tester.mocks.tokenCounterFactory = mockTokenCounterFactory + tester.mocks.whitelistRule = mockWhitelistRule + + tester.mocks.whitelistRule.EXPECT().Exists(mock.Anything, database.RuleTypeNamespace, mock.Anything).Return(false, nil).Maybe() + tester.mocks.whitelistRule.EXPECT().MatchRegex(mock.Anything, database.RuleTypeModelName, mock.Anything).Return(false, nil).Maybe() + return tester, c, w } diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index 93426aa6e..8db3251c5 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -6,6 +6,33 @@ import ( "opencsg.com/csghub-server/common/types" ) +// Provider type values for Metadata[MetaKeyLLMType]. +const ( + ProviderTypeServerless = "serverless" + ProviderTypeInference = "inference" + ProviderTypeExternalLLM = "external_llm" +) + +// Metadata key constants used when enriching model metadata. +const ( + MetaKeyLLMType = "llm_type" + MetaKeyPricing = "pricing" + MetaKeyTasks = "tasks" +) + +// Resource ID format strings for external LLM (provider, model ID) and CSGHub internal (path segment, repo path). +const ( + ExternalLLMResourceFmt = "%s://%s" + CSGHubResourceFmt = "csghub://%s/%s" +) + +// MeteringResource holds ResourceID, ResourceName, and CustomerID for metering events. +type MeteringResource struct { + ResourceID string + ResourceName string + CustomerID string +} + // BaseModel represents the base model fields type BaseModel struct { ID string `json:"id"` @@ -13,10 +40,9 @@ type BaseModel struct { Created int64 `json:"created"` // organization-owner (e.g. openai) OwnedBy string `json:"owned_by"` Task string `json:"task"` // like text-generation, text-to-image etc - DisplayName string `json:"display_name"` + OfficialName string `json:"official_name"` SupportFunctionCall bool `json:"support_function_call,omitempty"` // whether the model supports function calling IsPinned *bool `json:"is_pinned,omitempty"` // whether the model is pinned - Public bool `json:"public"` // whether the model is public (false = private, true = public) Metadata map[string]any `json:"metadata"` } @@ -36,6 +62,10 @@ type InternalModelInfo struct { type ExternalModelInfo struct { Provider string `json:"-"` // external provider name, like openai, anthropic etc AuthHead string `json:"-"` // the auth header to access the external model + // NeedSensitiveCheck controls whether requests for this model should go + // through sensitive content detection in aigateway. Set to false to skip + // the check (e.g. for guard models or trusted internal models). + NeedSensitiveCheck bool `json:"-"` } type Model struct { @@ -57,31 +87,40 @@ func (m Model) MarshalJSON() ([]byte, error) { Task string `json:"task"` DisplayName string `json:"display_name"` SupportFunctionCall *bool `json:"support_function_call,omitempty"` - Public bool `json:"public"` Endpoint string `json:"endpoint"` Metadata map[string]any `json:"metadata"` + CSGHubModelID *string `json:"csghub_model_id,omitempty"` + OwnerUUID *string `json:"owner_uuid,omitempty"` ClusterID *string `json:"cluster_id,omitempty"` SvcName *string `json:"svc_name,omitempty"` + SvcType *int `json:"svc_type,omitempty"` ImageID *string `json:"image_id,omitempty"` AuthHead *string `json:"auth_head,omitempty"` Provider *string `json:"provider,omitempty"` + NeedSensitiveCheck bool `json:"need_sensitive_check"` } resp := internalModelResponse{ - ID: m.ID, - Object: m.Object, - Created: m.Created, - OwnedBy: m.OwnedBy, - Task: m.Task, - DisplayName: m.DisplayName, - Public: m.Public, - Endpoint: m.Endpoint, - Metadata: m.Metadata, + ID: m.ID, + Object: m.Object, + Created: m.Created, + OwnedBy: m.OwnedBy, + Task: m.Task, + DisplayName: m.OfficialName, + Endpoint: m.Endpoint, + Metadata: m.Metadata, + NeedSensitiveCheck: m.NeedSensitiveCheck, } if m.SupportFunctionCall { supportFC := m.SupportFunctionCall resp.SupportFunctionCall = &supportFC } + if m.CSGHubModelID != "" { + resp.CSGHubModelID = &m.CSGHubModelID + } + if m.OwnerUUID != "" { + resp.OwnerUUID = &m.OwnerUUID + } if m.Provider != "" { resp.Provider = &m.Provider } @@ -94,6 +133,9 @@ func (m Model) MarshalJSON() ([]byte, error) { if m.SvcName != "" { resp.SvcName = &m.SvcName } + if m.SvcType != 0 { + resp.SvcType = &m.SvcType + } if m.ImageID != "" { resp.ImageID = &m.ImageID } @@ -113,14 +155,17 @@ func (m *Model) UnmarshalJSON(data []byte) error { Task string `json:"task"` DisplayName string `json:"display_name"` SupportFunctionCall bool `json:"support_function_call,omitempty"` - Public bool `json:"public"` Endpoint string `json:"endpoint"` Metadata map[string]any `json:"metadata"` + CSGHubModelID string `json:"csghub_model_id,omitempty"` + OwnerUUID string `json:"owner_uuid,omitempty"` ClusterID string `json:"cluster_id,omitempty"` SvcName string `json:"svc_name,omitempty"` + SvcType int `json:"svc_type,omitempty"` ImageID string `json:"image_id,omitempty"` AuthHead string `json:"auth_head,omitempty"` Provider string `json:"provider,omitempty"` + NeedSensitiveCheck bool `json:"need_sensitive_check"` } var aux internalModelResponse if err := json.Unmarshal(data, &aux); err != nil { @@ -131,16 +176,19 @@ func (m *Model) UnmarshalJSON(data []byte) error { m.Created = aux.Created m.OwnedBy = aux.OwnedBy m.Task = aux.Task - m.DisplayName = aux.DisplayName + m.OfficialName = aux.DisplayName m.SupportFunctionCall = aux.SupportFunctionCall - m.Public = aux.Public m.Endpoint = aux.Endpoint m.Metadata = aux.Metadata + m.CSGHubModelID = aux.CSGHubModelID + m.OwnerUUID = aux.OwnerUUID m.ClusterID = aux.ClusterID m.SvcName = aux.SvcName + m.SvcType = aux.SvcType m.ImageID = aux.ImageID m.AuthHead = aux.AuthHead m.Provider = aux.Provider + m.NeedSensitiveCheck = aux.NeedSensitiveCheck return nil } @@ -156,6 +204,19 @@ func (m *Model) ForExternalResponse() *Model { return m } +// SkipBalance set the model for skip balance mode +func (m *Model) SkipBalance() bool { + // MetaTaskKey values is array of strings, check if MetaTaskValGuard is in it + if tasks, ok := m.Metadata[MetaTaskKey].([]interface{}); ok { + for _, t := range tasks { + if task, ok := t.(string); ok && task == MetaTaskValGuard { + return true + } + } + } + return false +} + // ModelList represents the model list response type ModelList struct { Object string `json:"object"` @@ -172,10 +233,10 @@ type ModelList struct { // filtering, and pagination behavior consistently. type ListModelsReq struct { ModelID string `json:"model_id"` - Public string `json:"public"` Per string `json:"per"` Page string `json:"page"` Source string `json:"source"` // filter by source (csghub for CSGHub models, external for external models) + Task string `json:"task"` // filter by task } // UserPreferenceRequest defines the request parameters for UserPreference method @@ -188,7 +249,9 @@ type UserPreferenceRequest struct { const OpenCSGAppNameHeader string = "OpenCSG-App-Name" const ( - AgenticHubApp = "Agentichub" + AgenticHubApp = "Agentichub" + MetaTaskKey = "task" + MetaTaskValGuard = "guard" ) // ModelSource represents the source of a model @@ -200,3 +263,16 @@ const ( // ModelSourceExternal represents models from external providers ModelSourceExternal ModelSource = "external" ) + +// ModelTokenPrice is currency plus per-million-token rate (major units, from accounting cents + sku_unit). +type ModelTokenPrice struct { + Currency string `json:"currency,omitempty"` + PricePerMillion float64 `json:"price_per_million,omitempty"` +} + +// ModelScenePrice is Metadata["pricing"]: serverless and external_llm use input/output token prices (SaaS serverless scene). +type ModelScenePrice struct { + InputTokenPrice *ModelTokenPrice `json:"input_token_price,omitempty"` + OutputTokenPrice *ModelTokenPrice `json:"output_token_price,omitempty"` + TokenPrice *ModelTokenPrice `json:"token_price,omitempty"` +} diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index 83c26afd5..ba9b8aca3 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -18,7 +18,7 @@ type lLMConfigStoreImpl struct { type LLMConfig struct { ID int64 `bun:",pk,autoincrement" json:"id"` ModelName string `bun:",notnull" json:"model_name"` - DisplayName string `bun:"display_name,nullzero" json:"display_name"` + OfficialName string `bun:"official_name,nullzero" json:"official_name"` ApiEndpoint string `bun:",notnull" json:"api_endpoint"` AuthHeader string `bun:",notnull" json:"auth_header"` Type int `bun:",notnull" json:"type"` // 1: optimization, 2: comparison, 4: summary readme, 8: mcp scan, 16: for aigateway call external llm diff --git a/builder/store/database/llm_config_test.go b/builder/store/database/llm_config_test.go index 7aeaf3a43..d5ccdcb1d 100644 --- a/builder/store/database/llm_config_test.go +++ b/builder/store/database/llm_config_test.go @@ -22,7 +22,7 @@ func TestLLMConfigStore_GetOptimization(t *testing.T) { Type: 1, Enabled: true, ModelName: "c1", - DisplayName: "c1", + OfficialName: "c1", Metadata: map[string]any{"source": "test"}, }).Exec(ctx) require.Nil(t, err) @@ -30,21 +30,21 @@ func TestLLMConfigStore_GetOptimization(t *testing.T) { Type: 2, Enabled: true, ModelName: "c2", - DisplayName: "c2", + OfficialName: "c2", }).Exec(ctx) require.Nil(t, err) _, err = db.Core.NewInsert().Model(&database.LLMConfig{ Type: 1, Enabled: false, ModelName: "c3", - DisplayName: "c3", + OfficialName: "c3", }).Exec(ctx) require.Nil(t, err) cfg, err := store.GetOptimization(ctx) require.Nil(t, err) require.Equal(t, "c1", cfg.ModelName) - require.Equal(t, "c1", cfg.DisplayName) + require.Equal(t, "c1", cfg.OfficialName) require.Equal(t, map[string]any{"source": "test"}, cfg.Metadata) } @@ -59,7 +59,7 @@ func TestLLMConfigStore_GetModelForSummaryReadme(t *testing.T) { Type: 5, Enabled: true, ModelName: "summary1", - DisplayName: "summary1", + OfficialName: "summary1", Metadata: map[string]any{"k": "v"}, }).Exec(ctx) require.Nil(t, err) @@ -82,7 +82,7 @@ func TestLLMConfigStore_GetModelForSummaryReadme(t *testing.T) { require.Nil(t, err) require.NotNil(t, cfg) require.Equal(t, "summary1", cfg.ModelName) - require.Equal(t, "summary1", cfg.DisplayName) + require.Equal(t, "summary1", cfg.OfficialName) require.Equal(t, map[string]any{"k": "v"}, cfg.Metadata) } @@ -98,7 +98,7 @@ func TestLLMConfigStore_GetByID(t *testing.T) { Type: 5, Enabled: true, ModelName: "summary1", - DisplayName: "summary1", + OfficialName: "summary1", Metadata: map[string]any{"k": "v"}, } _, err = db.Core.NewInsert().Model(&dbInput).Exec(ctx) @@ -108,7 +108,7 @@ func TestLLMConfigStore_GetByID(t *testing.T) { require.Nil(t, err) require.NotNil(t, cfg) require.Equal(t, "summary1", cfg.ModelName) - require.Equal(t, "summary1", cfg.DisplayName) + require.Equal(t, "summary1", cfg.OfficialName) require.Equal(t, map[string]any{"k": "v"}, cfg.Metadata) } @@ -123,15 +123,15 @@ func TestLLMConfigStore_CRUD(t *testing.T) { Type: 5, Enabled: true, ModelName: "summary1", - DisplayName: "summary1", - Metadata: map[string]any{"k": "v"}, + OfficialName: "summary1", + Metadata: map[string]any{"k": "v", "tasks": []interface{}{"text-generation", "text-to-image"}}, } res, err := store.Create(ctx, dbInput) require.Nil(t, err) require.NotNil(t, res) require.Equal(t, "summary1", res.ModelName) - require.Equal(t, "summary1", res.DisplayName) - require.Equal(t, map[string]any{"k": "v"}, res.Metadata) + require.Equal(t, "summary1", res.OfficialName) + require.Equal(t, map[string]any{"k": "v", "tasks": []interface{}{"text-generation", "text-to-image"}}, res.Metadata) searchType := 5 search := &types.SearchLLMConfig{ @@ -157,10 +157,10 @@ func TestLLMConfigStore_Search(t *testing.T) { // Create test data with hyphens and letter-number combinations testModels := []database.LLMConfig{ - {Type: 1, Enabled: true, ModelName: "deepseek-v3", DisplayName: "deepseek-v3"}, - {Type: 1, Enabled: true, ModelName: "openai/gpt-4", DisplayName: "gpt-4"}, - {Type: 1, Enabled: true, ModelName: "claude3-opus", DisplayName: "claude3-opus"}, - {Type: 1, Enabled: true, ModelName: "llama2-7b", DisplayName: "llama2-7b"}, + {Type: 1, Enabled: true, ModelName: "deepseek-v3", OfficialName: "deepseek-v3"}, + {Type: 1, Enabled: true, ModelName: "openai/gpt-4", OfficialName: "gpt-4"}, + {Type: 1, Enabled: true, ModelName: "claude3-opus", OfficialName: "claude3-opus"}, + {Type: 1, Enabled: true, ModelName: "llama2-7b", OfficialName: "llama2-7b"}, } for _, model := range testModels { @@ -232,3 +232,79 @@ func TestLLMConfigStore_Search(t *testing.T) { } require.True(t, found, "Should find gpt-4 when searching for gpt-4") } + +func TestLLMConfigStore_Index_EnabledFilter(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + config, err := config.LoadConfig() + require.Nil(t, err) + store := database.NewLLMConfigStoreWithDB(db, config) + + searchType := 16 + base := database.LLMConfig{ + Type: searchType, + ApiEndpoint: "https://example.test/v1", + AuthHeader: "{}", + Provider: "test", + } + _, err = store.Create(ctx, database.LLMConfig{ + ModelName: "idx-en-on", + OfficialName: "idx-en-on", + Enabled: true, + Type: base.Type, + ApiEndpoint: base.ApiEndpoint, + AuthHeader: base.AuthHeader, + Provider: base.Provider, + }) + require.Nil(t, err) + _, err = store.Create(ctx, database.LLMConfig{ + ModelName: "idx-en-off", + OfficialName: "idx-en-off", + Enabled: false, + Type: base.Type, + ApiEndpoint: base.ApiEndpoint, + AuthHeader: base.AuthHeader, + Provider: base.Provider, + }) + require.Nil(t, err) + + enabledTrue := true + enabledFalse := false + + cfgsOn, totalOn, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Type: &searchType, + Enabled: &enabledTrue, + }) + require.Nil(t, err) + require.Equal(t, 1, totalOn) + require.Len(t, cfgsOn, 1) + require.Equal(t, "idx-en-on", cfgsOn[0].ModelName) + require.True(t, cfgsOn[0].Enabled) + + cfgsOff, totalOff, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Type: &searchType, + Enabled: &enabledFalse, + }) + require.Nil(t, err) + require.Equal(t, 1, totalOff) + require.Len(t, cfgsOff, 1) + require.Equal(t, "idx-en-off", cfgsOff[0].ModelName) + require.False(t, cfgsOff[0].Enabled) + + cfgsBoth, totalBoth, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Type: &searchType, + }) + require.Nil(t, err) + require.Equal(t, 2, totalBoth) + require.Len(t, cfgsBoth, 2) + + cfgsKeyword, totalKeyword, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Keyword: "idx-en-", + Enabled: &enabledTrue, + }) + require.Nil(t, err) + require.Equal(t, 1, totalKeyword) + require.Len(t, cfgsKeyword, 1) + require.Equal(t, "idx-en-on", cfgsKeyword[0].ModelName) +} diff --git a/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql new file mode 100644 index 000000000..7f40d69fe --- /dev/null +++ b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_rule_type_pattern; \ No newline at end of file diff --git a/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql new file mode 100644 index 000000000..430c4b2d6 --- /dev/null +++ b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql @@ -0,0 +1,2 @@ +UPDATE repository_file_check_rules SET pattern = LOWER(pattern); +CREATE UNIQUE INDEX IF NOT EXISTS idx_rule_type_pattern ON repository_file_check_rules (rule_type, pattern); \ No newline at end of file diff --git a/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql new file mode 100644 index 000000000..b04bd8677 --- /dev/null +++ b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql @@ -0,0 +1,6 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE llm_configs +RENAME COLUMN official_name TO display_name; diff --git a/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql new file mode 100644 index 000000000..d0d12cb2e --- /dev/null +++ b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql @@ -0,0 +1,6 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE llm_configs +RENAME COLUMN display_name TO official_name; diff --git a/builder/store/database/repository_file_check_rule.go b/builder/store/database/repository_file_check_rule.go index 174edd1b6..f24e5a49a 100644 --- a/builder/store/database/repository_file_check_rule.go +++ b/builder/store/database/repository_file_check_rule.go @@ -2,9 +2,15 @@ package database import ( "context" + "strings" "time" ) +const ( + RuleTypeNamespace = "namespace" + RuleTypeModelName = "model_name" +) + type RepositoryFileCheckRule struct { ID int64 `bun:",pk,autoincrement"` RuleType string `bun:"rule_type,notnull,unique:idx_rule_type_pattern"` @@ -19,6 +25,7 @@ type RepositoryFileCheckRuleStore interface { ListByRuleType(ctx context.Context, ruleType string) ([]RepositoryFileCheckRule, error) Delete(ctx context.Context, ruleType, pattern string) error Exists(ctx context.Context, ruleType, pattern string) (bool, error) + MatchRegex(ctx context.Context, ruleType, targetString string) (bool, error) } type repositoryFileCheckRuleStore struct { @@ -34,6 +41,7 @@ func NewRepositoryFileCheckRuleStoreWithDB(db *DB) RepositoryFileCheckRuleStore } func (s *repositoryFileCheckRuleStore) Create(ctx context.Context, ruleType, pattern string) (*RepositoryFileCheckRule, error) { + pattern = strings.ToLower(pattern) rule := &RepositoryFileCheckRule{RuleType: ruleType, Pattern: pattern} _, err := s.db.Operator.Core.NewInsert().Model(rule).Exec(ctx) return rule, err @@ -52,6 +60,7 @@ func (s *repositoryFileCheckRuleStore) ListByRuleType(ctx context.Context, ruleT } func (s *repositoryFileCheckRuleStore) Delete(ctx context.Context, ruleType, pattern string) error { + pattern = strings.ToLower(pattern) _, err := s.db.Operator.Core.NewDelete().Model((*RepositoryFileCheckRule)(nil)). Where("rule_type = ?", ruleType). Where("pattern = ?", pattern). @@ -60,9 +69,18 @@ func (s *repositoryFileCheckRuleStore) Delete(ctx context.Context, ruleType, pat } func (s *repositoryFileCheckRuleStore) Exists(ctx context.Context, ruleType, pattern string) (bool, error) { + pattern = strings.ToLower(pattern) exists, err := s.db.Operator.Core.NewSelect().Model((*RepositoryFileCheckRule)(nil)). Where("rule_type = ?", ruleType). Where("pattern = ?", pattern). Exists(ctx) return exists, err } + +func (s *repositoryFileCheckRuleStore) MatchRegex(ctx context.Context, ruleType, targetString string) (bool, error) { + exists, err := s.db.Operator.Core.NewSelect().Model((*RepositoryFileCheckRule)(nil)). + Where("rule_type = ?", ruleType). + Where("? ~* pattern", targetString). + Exists(ctx) + return exists, err +} diff --git a/common/types/llm_service.go b/common/types/llm_service.go index 09614f557..452bc82ce 100644 --- a/common/types/llm_service.go +++ b/common/types/llm_service.go @@ -5,13 +5,13 @@ import "time" type LLMConfig struct { ID int64 `json:"id"` ModelName string `json:"model_name"` - DisplayName string `json:"display_name"` + OfficialName string `json:"official_name"` ApiEndpoint string `json:"api_endpoint"` AuthHeader string `json:"auth_header"` Type int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme Enabled bool `json:"enabled"` Provider string `json:"provider"` - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` // tasks stored as: {"tasks": ["text-generation", "text-to-image"]} CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -26,6 +26,7 @@ type PromptPrefix struct { type SearchLLMConfig struct { Keyword string `json:"keyword"` // Search keyword Type *int `json:"type"` // Type of search + Enabled *bool `json:"enabled"` // Enabled filter } type SearchPromptPrefix struct { @@ -36,13 +37,13 @@ type SearchPromptPrefix struct { type UpdateLLMConfigReq struct { ID int64 `json:"id"` ModelName *string `json:"model_name"` - DisplayName *string `json:"display_name"` + OfficialName *string `json:"official_name"` ApiEndpoint *string `json:"api_endpoint"` AuthHeader *string `json:"auth_header"` Type *int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme Enabled *bool `json:"enabled"` Provider *string `json:"provider"` - Metadata *map[string]any `json:"metadata"` + Metadata *map[string]any `json:"metadata"` // tasks stored as: {"tasks": ["text-generation", "text-to-image"]} } type UpdatePromptPrefixReq struct { @@ -53,15 +54,16 @@ type UpdatePromptPrefixReq struct { } type CreateLLMConfigReq struct { - ModelName string `json:"model_name"` - DisplayName string `json:"display_name"` - ApiEndpoint string `json:"api_endpoint"` + ModelName string `json:"model_name" binding:"required"` + OfficialName string `json:"official_name"` + ApiEndpoint string `json:"api_endpoint" binding:"required"` AuthHeader string `json:"auth_header"` - Type int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme - Provider string `json:"provider"` + Type int `json:"type" binding:"required,oneof=1 2 4 8 16"` // 1: optimization, 2: comparison, 4: summary readme, 8: mcp scan, 16: for aigateway call external llm + Provider string `json:"provider" binding:"required"` Enabled bool `json:"enabled"` - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` // tasks stored as: {"tasks": ["text-generation", "text-to-image"]} } + type CreatePromptPrefixReq struct { ZH string `json:"zh"` EN string `json:"en"` diff --git a/component/llm_service.go b/component/llm_service.go index 00258873c..892d1c44e 100644 --- a/component/llm_service.go +++ b/component/llm_service.go @@ -67,7 +67,7 @@ func (s *llmServiceComponentImpl) ShowLLMConfig(ctx context.Context, id int64) ( llmConfig := &types.LLMConfig{ ID: dbLlmConfig.ID, ModelName: dbLlmConfig.ModelName, - DisplayName: dbLlmConfig.DisplayName, + OfficialName: dbLlmConfig.OfficialName, ApiEndpoint: dbLlmConfig.ApiEndpoint, AuthHeader: dbLlmConfig.AuthHeader, Type: dbLlmConfig.Type, @@ -102,8 +102,8 @@ func (s *llmServiceComponentImpl) UpdateLLMConfig(ctx context.Context, req *type if req.ModelName != nil { llmConfig.ModelName = *req.ModelName } - if req.DisplayName != nil { - llmConfig.DisplayName = *req.DisplayName + if req.OfficialName != nil { + llmConfig.OfficialName = *req.OfficialName } if req.ApiEndpoint != nil { llmConfig.ApiEndpoint = *req.ApiEndpoint @@ -130,7 +130,7 @@ func (s *llmServiceComponentImpl) UpdateLLMConfig(ctx context.Context, req *type resLLMConfig := &types.LLMConfig{ ID: updatedConfig.ID, ModelName: updatedConfig.ModelName, - DisplayName: updatedConfig.DisplayName, + OfficialName: updatedConfig.OfficialName, ApiEndpoint: updatedConfig.ApiEndpoint, AuthHeader: updatedConfig.AuthHeader, Type: updatedConfig.Type, @@ -173,7 +173,7 @@ func (s *llmServiceComponentImpl) UpdatePromptPrefix(ctx context.Context, req *t func (s *llmServiceComponentImpl) CreateLLMConfig(ctx context.Context, req *types.CreateLLMConfigReq) (*types.LLMConfig, error) { dbLLMConfig := database.LLMConfig{ ModelName: req.ModelName, - DisplayName: req.DisplayName, + OfficialName: req.OfficialName, ApiEndpoint: req.ApiEndpoint, AuthHeader: req.AuthHeader, Type: req.Type, @@ -188,7 +188,7 @@ func (s *llmServiceComponentImpl) CreateLLMConfig(ctx context.Context, req *type resLLMConfig := &types.LLMConfig{ ID: dbRes.ID, ModelName: dbRes.ModelName, - DisplayName: dbRes.DisplayName, + OfficialName: dbRes.OfficialName, ApiEndpoint: dbRes.ApiEndpoint, AuthHeader: dbRes.AuthHeader, Type: dbRes.Type, diff --git a/moderation/component/repo.go b/moderation/component/repo.go index fc4d85176..9d53cc1f4 100644 --- a/moderation/component/repo.go +++ b/moderation/component/repo.go @@ -75,15 +75,12 @@ func (c *repoComponentImpl) UpdateRepoSensitiveCheckStatus(ctx context.Context, } func (c *repoComponentImpl) SkipSensitiveCheckForWhiteList(ctx context.Context, req RepoFullCheckRequest) (bool, error) { - whiteList, err := c.GetNamespaceWhiteList(ctx) + exists, err := c.whitelistRule.Exists(ctx, database.RuleTypeNamespace, req.Namespace) if err != nil { - return false, fmt.Errorf("failed to get namespace white list: %w", err) + return false, fmt.Errorf("failed to check namespace in white list: %w", err) } - for _, rule := range whiteList { - if req.Namespace != rule { - continue - } + if exists { repo, err := c.GetRepo(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return false, fmt.Errorf("failed to get repo for skip sensitive check, namespace: %s, name: %s, error: %w", req.Namespace, req.Name, err) @@ -246,7 +243,7 @@ func (cc *repoComponentImpl) CheckRequestV2(ctx context.Context, req types.Sensi } func (c *repoComponentImpl) GetNamespaceWhiteList(ctx context.Context) ([]string, error) { - namespaceWhiteList, err := c.whitelistRule.ListByRuleType(ctx, "namespace") + namespaceWhiteList, err := c.whitelistRule.ListByRuleType(ctx, database.RuleTypeNamespace) if err != nil { return nil, err } diff --git a/moderation/component/repo_test.go b/moderation/component/repo_test.go index 69b9ffc79..ba42c01be 100644 --- a/moderation/component/repo_test.go +++ b/moderation/component/repo_test.go @@ -168,7 +168,7 @@ func TestRepoComponent_SkipSensitiveCheckForWhiteList(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(true, nil).Once() mockRepoStore.EXPECT().FindByPath(ctx, req.RepoType, req.Namespace, req.Name).Return(&database.Repository{ID: 10}, nil).Once() mockRepoStore.EXPECT().UpdateRepoSensitiveCheckStatus(ctx, int64(10), types.SensitiveCheckSkip).Return(nil).Once() @@ -191,7 +191,7 @@ func TestRepoComponent_SkipSensitiveCheckForWhiteList(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(false, nil).Once() skipped, err := repoComp.SkipSensitiveCheckForWhiteList(ctx, req) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestRepoComponent_RepoFullCheck(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(true, nil).Once() mockRepoStore.EXPECT().FindByPath(ctx, req.RepoType, req.Namespace, req.Name).Return(&database.Repository{ID: 10}, nil).Once() mockRepoStore.EXPECT().UpdateRepoSensitiveCheckStatus(ctx, int64(10), types.SensitiveCheckSkip).Return(nil).Once() @@ -241,7 +241,7 @@ func TestRepoComponent_RepoFullCheck(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(false, nil).Once() mockWorkflowClient := mocktemporal.NewMockClient(t) temporal.Assign(mockWorkflowClient) workflowOptions := client.StartWorkflowOptions{ From ed3c646d45b5a898d08133276d783be2778247a6 Mon Sep 17 00:00:00 2001 From: cemeng Date: Fri, 10 Apr 2026 12:15:16 +0800 Subject: [PATCH 2/3] test: fix openai chat sensitive-check fallback case --- aigateway/handler/openai_test.go | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index eaaf782c5..6e69bb475 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -529,9 +529,28 @@ func TestOpenAIHandler_Chat(t *testing.T) { _ = json.Unmarshal(body, &expectReq) tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false). Return(nil, errors.New("some error")) + llmTokenCounter := mocktoken.NewMockChatTokenCounter(t) + tester.mocks.tokenCounterFactory.EXPECT().NewChat( + token.CreateParam{ + Endpoint: model.Endpoint, + Host: "", + Model: "model1", + ImageID: model.ImageID, + Provider: model.Provider, + }). + Return(llmTokenCounter) + llmTokenCounter.EXPECT().AppendPrompts(expectReq.Messages).Return() + var wg sync.WaitGroup + wg.Add(1) + tester.mocks.openAIComp.EXPECT().RecordUsage(mock.Anything, "testuuid", model, llmTokenCounter). + RunAndReturn(func(ctx context.Context, uuid string, model *types.Model, counter token.Counter) error { + wg.Done() + return nil + }) tester.handler.Chat(c) + wg.Wait() - assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, http.StatusOK, w.Code) }) t.Run("success", func(t *testing.T) { tester, c, w := setupTest(t) From c087ea98c10c1d46ae0b17342b6bff97e7bd7775 Mon Sep 17 00:00:00 2001 From: Lei Da Date: Fri, 10 Apr 2026 15:13:14 +0800 Subject: [PATCH 3/3] fix build error --- aigateway/handler/openai.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index b30137851..afa89c72b 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -415,7 +415,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { } var modComponent component.Moderation = nil - isCheck, result, err := h.checkSensitive(c.Request.Context(), model, chatReq, userUUID, chatReq.Stream) + isCheck, _, err := h.checkSensitive(c.Request.Context(), model, chatReq, userUUID, chatReq.Stream) if err != nil { slog.ErrorContext(c.Request.Context(), "failed to check sensitive", slog.String("model_id", modelID),