diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go index b1f15ded..0caa7add 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepositoryFileCheckRuleStore.go @@ -305,6 +305,64 @@ func (_c *MockRepositoryFileCheckRuleStore_ListByRuleType_Call) RunAndReturn(run return _c } +// MatchRegex provides a mock function with given fields: ctx, ruleType, targetString +func (_m *MockRepositoryFileCheckRuleStore) MatchRegex(ctx context.Context, ruleType string, targetString string) (bool, error) { + ret := _m.Called(ctx, ruleType, targetString) + + if len(ret) == 0 { + panic("no return value specified for MatchRegex") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (bool, error)); ok { + return rf(ctx, ruleType, targetString) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, ruleType, targetString) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, ruleType, targetString) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRepositoryFileCheckRuleStore_MatchRegex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MatchRegex' +type MockRepositoryFileCheckRuleStore_MatchRegex_Call struct { + *mock.Call +} + +// MatchRegex is a helper method to define mock.On call +// - ctx context.Context +// - ruleType string +// - targetString string +func (_e *MockRepositoryFileCheckRuleStore_Expecter) MatchRegex(ctx interface{}, ruleType interface{}, targetString interface{}) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + return &MockRepositoryFileCheckRuleStore_MatchRegex_Call{Call: _e.mock.On("MatchRegex", ctx, ruleType, targetString)} +} + +func (_c *MockRepositoryFileCheckRuleStore_MatchRegex_Call) Run(run func(ctx context.Context, ruleType string, targetString string)) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockRepositoryFileCheckRuleStore_MatchRegex_Call) Return(_a0 bool, _a1 error) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRepositoryFileCheckRuleStore_MatchRegex_Call) RunAndReturn(run func(context.Context, string, string) (bool, error)) *MockRepositoryFileCheckRuleStore_MatchRegex_Call { + _c.Call.Return(run) + return _c +} + // NewMockRepositoryFileCheckRuleStore creates a new instance of MockRepositoryFileCheckRuleStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockRepositoryFileCheckRuleStore(t interface { diff --git a/aigateway/component/openai.go b/aigateway/component/openai.go index a55386e1..2793937f 100644 --- a/aigateway/component/openai.go +++ b/aigateway/component/openai.go @@ -248,7 +248,7 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( Created: deploy.CreatedAt.Unix(), SupportFunctionCall: supportFunctionCall, Task: string(deploy.Task), - DisplayName: repoName, + OfficialName: repoName, Metadata: map[string]any{ types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type), }, @@ -262,6 +262,9 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( ImageID: deploy.ImageID, RuntimeFramework: deploy.RuntimeFramework, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, } if deploy.Type == commontypes.ServerlessType { m.BaseModel.OwnedBy = "OpenCSG" @@ -325,12 +328,12 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model extModel.Metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM m := types.Model{ BaseModel: types.BaseModel{ - Object: "model", - ID: extModel.ModelName, - OwnedBy: extModel.Provider, - DisplayName: extModel.DisplayName, - Metadata: extModel.Metadata, - Task: task, + Object: "model", + ID: extModel.ModelName, + OwnedBy: extModel.Provider, + OfficialName: extModel.OfficialName, + Metadata: extModel.Metadata, + Task: task, }, Endpoint: extModel.ApiEndpoint, ExternalModelInfo: types.ExternalModelInfo{ diff --git a/aigateway/component/openai_ce_test.go b/aigateway/component/openai_ce_test.go index 0bfcf64c..4732ca9a 100644 --- a/aigateway/component/openai_ce_test.go +++ b/aigateway/component/openai_ce_test.go @@ -87,17 +87,15 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil) - // Must match JSON produced by saveModelsToCache (getCSGHubModels + ForInternalUse): - // DisplayName is Repository.Name; NeedSensitiveCheck is unset (false) on CSGHub models. expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "publicuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: deploys[0].Repository.Name, + ID: "model1:svc1", + OwnedBy: "publicuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model1", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeInference, }, @@ -111,16 +109,18 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[0].Type, ImageID: deploys[0].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, { BaseModel: types.BaseModel{ - ID: "hf-model2:svc2", - OwnedBy: "OpenCSG", - Object: "model", - Created: deploys[1].CreatedAt.Unix(), - Task: "text-to-image", - DisplayName: deploys[1].Repository.Name, + ID: "hf-model2:svc2", + OwnedBy: "OpenCSG", + Object: "model", + Created: deploys[1].CreatedAt.Unix(), + Task: "text-to-image", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeServerless, }, @@ -133,6 +133,9 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[1].Type, ImageID: deploys[1].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, } @@ -211,12 +214,12 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: deploys[0].Repository.Name, + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model1", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeInference, }, @@ -230,16 +233,18 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[0].Type, ImageID: deploys[0].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, { BaseModel: types.BaseModel{ - ID: "hf-model2:svc2", - OwnedBy: "OpenCSG", - Object: "model", - Created: deploys[1].CreatedAt.Unix(), - Task: "text-to-image", - DisplayName: deploys[1].Repository.Name, + ID: "hf-model2:svc2", + OwnedBy: "OpenCSG", + Object: "model", + Created: deploys[1].CreatedAt.Unix(), + Task: "text-to-image", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeServerless, }, @@ -253,6 +258,9 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[1].Type, ImageID: deploys[1].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, } @@ -325,12 +333,12 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model3:svc3", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: deploys[0].Repository.Name, + ID: "model3:svc3", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model3", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeInference, }, @@ -344,6 +352,9 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { SvcType: deploys[0].Type, ImageID: deploys[0].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, } @@ -416,12 +427,11 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { expectModels := []types.Model{ { BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: string(deploys[0].Task), - DisplayName: deploys[0].Repository.Name, + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + OfficialName: "model1", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeInference, }, @@ -435,6 +445,9 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { SvcType: deploys[0].Type, ImageID: deploys[0].ImageID, }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, InternalUse: true, }, } @@ -506,11 +519,12 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { deploys[0].CreatedAt = now expectModel := types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + OfficialName: "model1", Metadata: map[string]any{ types.MetaKeyLLMType: types.ProviderTypeInference, }, diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index ded71288..afa89c72 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -24,6 +24,7 @@ import ( "opencsg.com/csghub-server/builder/proxy" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/cache" + "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/errorx" commonType "opencsg.com/csghub-server/common/types" @@ -72,7 +73,8 @@ func NewOpenAIHandlerFromConfig(config *config.Config) (OpenAIHandler, error) { return nil, fmt.Errorf("failed to create cluster component, error: %w", err) } storage, _ := component.NewStorage(config) - return newOpenAIHandler(modelService, repoComp, modComponent, clusterComp, token.NewCounterFactory(), text2image.NewRegistry(), config, storage), nil + whitelistRule := database.NewRepositoryFileCheckRuleStore() + return newOpenAIHandler(modelService, repoComp, modComponent, clusterComp, token.NewCounterFactory(), text2image.NewRegistry(), config, storage, whitelistRule), nil } func newOpenAIHandler( @@ -84,6 +86,7 @@ func newOpenAIHandler( t2iRegistry *text2image.Registry, config *config.Config, storage types.Storage, + whitelistRule database.RepositoryFileCheckRuleStore, ) *OpenAIHandlerImpl { return &OpenAIHandlerImpl{ openaiComponent: modelService, @@ -94,6 +97,7 @@ func newOpenAIHandler( t2iRegistry: t2iRegistry, config: config, storage: storage, + whitelistRule: whitelistRule, } } @@ -126,6 +130,60 @@ func (h *OpenAIHandlerImpl) handleInsufficientBalance(c *gin.Context, isStream b } } +func (h *OpenAIHandlerImpl) checkSensitive(ctx context.Context, model *types.Model, chatReq *ChatCompletionRequest, userUUID string, stream bool) (bool, *rpc.CheckResult, error) { + if !model.NeedSensitiveCheck { + return false, nil, nil + } + + if exists, err := h.checkNamespaceWhitelist(ctx, model.OfficialName); err != nil { + return false, nil, err + } else if exists { + slog.DebugContext(ctx, "Skip Sensitive check with OfficialName in white list", slog.String("pattern", model.OfficialName)) + return false, nil, nil + } + + if exists, err := h.checkNamespaceWhitelist(ctx, model.ID); err != nil { + return false, nil, err + } else if exists { + slog.DebugContext(ctx, "Skip Sensitive check with modelID in white list", slog.String("pattern", model.ID)) + return false, nil, nil + } + + // Check model name regex match in white list + matched, err := h.whitelistRule.MatchRegex(ctx, database.RuleTypeModelName, model.ID) + if err != nil { + return false, nil, fmt.Errorf("failed to match model name regex: %w", err) + } + if matched { + slog.DebugContext(ctx, "Skip Sensitive check with MatchRegex in white list", slog.String("RuleTypeModelName", model.ID)) + return false, nil, nil + } + + key := fmt.Sprintf("%s:%s", userUUID, model.ID) + result, err := h.modComponent.CheckChatPrompts(ctx, chatReq.Messages, key, stream) + if err != nil { + return false, nil, fmt.Errorf("failed to call moderation error:%w", err) + } + + return true, result, nil +} + +func (h *OpenAIHandlerImpl) checkNamespaceWhitelist(ctx context.Context, modelPath string) (bool, error) { + namespace, _, err := common.GetNamespaceAndNameFromPath(modelPath) + if err != nil { + return false, nil + } + exists, err := h.whitelistRule.Exists(ctx, database.RuleTypeNamespace, namespace) + if err != nil { + return false, fmt.Errorf("failed to check namespace in white list: %w", err) + } + if exists { + slog.DebugContext(ctx, "Skip Sensitive check with namespace in white list", slog.String("namespace", namespace)) + return true, nil + } + return false, nil +} + // OpenAIHandlerImpl implements the OpenAIHandler interface type OpenAIHandlerImpl struct { openaiComponent component.OpenAIComponent @@ -136,6 +194,7 @@ type OpenAIHandlerImpl struct { t2iRegistry *text2image.Registry config *config.Config storage types.Storage + whitelistRule database.RepositoryFileCheckRuleStore } // ListModels godoc @@ -356,7 +415,14 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { } var modComponent component.Moderation = nil - if model.NeedSensitiveCheck { + isCheck, _, err := h.checkSensitive(c.Request.Context(), model, chatReq, userUUID, chatReq.Stream) + if err != nil { + slog.ErrorContext(c.Request.Context(), "failed to check sensitive", + slog.String("model_id", modelID), + slog.String("username", username), + slog.Any("error", err)) + } + if isCheck { modComponent = h.modComponent // Create a combined key using userUUID and modelID for caching and tracking key := fmt.Sprintf("%s:%s", userUUID, modelID) diff --git a/aigateway/handler/openai_check_sensitive_test.go b/aigateway/handler/openai_check_sensitive_test.go new file mode 100644 index 00000000..9612fb44 --- /dev/null +++ b/aigateway/handler/openai_check_sensitive_test.go @@ -0,0 +1,217 @@ +package handler + +import ( + "context" + "errors" + "testing" + + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "opencsg.com/csghub-server/aigateway/types" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/builder/store/database" +) + +func TestOpenAIHandler_checkSensitive(t *testing.T) { + ctx := context.Background() + chatReq := &ChatCompletionRequest{ + Messages: []openai.ChatCompletionMessageParamUnion{}, + } + userUUID := "test-uuid" + + t.Run("NeedSensitiveCheck is false", func(t *testing.T) { + tester, _, _ := setupTest(t) + model := &types.Model{ + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: false, + }, + } + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("OfficialName namespace in whitelist", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "another/model", + OfficialName: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(true, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("OfficialName namespace check fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "another/model", + OfficialName: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, errors.New("db error")).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to check namespace in white list: db error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Namespace in whitelist", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(true, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Namespace check fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, errors.New("db error")).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to check namespace in white list: db error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Check moderation API succeeds", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(false, nil).Once() + + expectedResult := &rpc.CheckResult{IsSensitive: true} + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(ctx, chatReq.Messages, "test-uuid:Qwen/Qwen3Guard", false).Return(expectedResult, nil).Once() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.True(t, needCheck) + assert.Equal(t, expectedResult, result) + }) + + t.Run("Check model name regex match succeeds", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(true, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.NoError(t, err) + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Check model name regex match fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(false, errors.New("db error")).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to match model name regex: db error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) + + t.Run("Check moderation API fails", func(t *testing.T) { + tester, _, _ := setupTest(t) + tester.mocks.whitelistRule.ExpectedCalls = nil + tester.mocks.moderationComp.ExpectedCalls = nil + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "Qwen/Qwen3Guard", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + Provider: "", + }, + } + tester.mocks.whitelistRule.EXPECT().Exists(ctx, database.RuleTypeNamespace, "Qwen").Return(false, nil).Once() + tester.mocks.whitelistRule.EXPECT().MatchRegex(ctx, database.RuleTypeModelName, "Qwen/Qwen3Guard").Return(false, nil).Once() + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(ctx, chatReq.Messages, "test-uuid:Qwen/Qwen3Guard", false).Return(nil, errors.New("mod api error")).Once() + + needCheck, result, err := tester.handler.checkSensitive(ctx, model, chatReq, userUUID, false) + assert.ErrorContains(t, err, "failed to call moderation error:mod api error") + assert.False(t, needCheck) + assert.Nil(t, result) + }) +} diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index b86b7417..6e69bb47 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" mockcomp "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/component" mocktoken "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/token" + mockdatabase "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/database" apicomp "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" "opencsg.com/csghub-server/aigateway/component/adapter/text2image" "opencsg.com/csghub-server/aigateway/token" @@ -39,6 +40,7 @@ type testerOpenAIHandler struct { repoComp *apicomp.MockRepoComponent mockClsComp *apicomp.MockClusterComponent tokenCounterFactory *mocktoken.MockCounterFactory + whitelistRule *mockdatabase.MockRepositoryFileCheckRuleStore } handler *OpenAIHandlerImpl @@ -51,7 +53,8 @@ func setupTest(t *testing.T) (*testerOpenAIHandler, *gin.Context, *httptest.Resp mockClsComp := apicomp.NewMockClusterComponent(t) mockTokenCounterFactory := mocktoken.NewMockCounterFactory(t) cfg := &config.Config{} - handler := newOpenAIHandler(mockOpenAI, mockRepo, mockModeration, mockClsComp, mockTokenCounterFactory, text2image.NewRegistry(), cfg, nil) + mockWhitelistRule := mockdatabase.NewMockRepositoryFileCheckRuleStore(t) + handler := newOpenAIHandler(mockOpenAI, mockRepo, mockModeration, mockClsComp, mockTokenCounterFactory, text2image.NewRegistry(), cfg, nil, mockWhitelistRule) // Set test user tester := &testerOpenAIHandler{ @@ -67,6 +70,11 @@ func setupTest(t *testing.T) (*testerOpenAIHandler, *gin.Context, *httptest.Resp tester.mocks.repoComp = mockRepo tester.mocks.mockClsComp = mockClsComp tester.mocks.tokenCounterFactory = mockTokenCounterFactory + tester.mocks.whitelistRule = mockWhitelistRule + + tester.mocks.whitelistRule.EXPECT().Exists(mock.Anything, database.RuleTypeNamespace, mock.Anything).Return(false, nil).Maybe() + tester.mocks.whitelistRule.EXPECT().MatchRegex(mock.Anything, database.RuleTypeModelName, mock.Anything).Return(false, nil).Maybe() + return tester, c, w } diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index 935924dc..8db3251c 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -40,7 +40,7 @@ type BaseModel struct { Created int64 `json:"created"` // organization-owner (e.g. openai) OwnedBy string `json:"owned_by"` Task string `json:"task"` // like text-generation, text-to-image etc - DisplayName string `json:"display_name"` + OfficialName string `json:"official_name"` SupportFunctionCall bool `json:"support_function_call,omitempty"` // whether the model supports function calling IsPinned *bool `json:"is_pinned,omitempty"` // whether the model is pinned Metadata map[string]any `json:"metadata"` @@ -105,7 +105,7 @@ func (m Model) MarshalJSON() ([]byte, error) { Created: m.Created, OwnedBy: m.OwnedBy, Task: m.Task, - DisplayName: m.DisplayName, + DisplayName: m.OfficialName, Endpoint: m.Endpoint, Metadata: m.Metadata, NeedSensitiveCheck: m.NeedSensitiveCheck, @@ -176,7 +176,7 @@ func (m *Model) UnmarshalJSON(data []byte) error { m.Created = aux.Created m.OwnedBy = aux.OwnedBy m.Task = aux.Task - m.DisplayName = aux.DisplayName + m.OfficialName = aux.DisplayName m.SupportFunctionCall = aux.SupportFunctionCall m.Endpoint = aux.Endpoint m.Metadata = aux.Metadata diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index 099dc8a4..8bf618e8 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -16,15 +16,15 @@ type lLMConfigStoreImpl struct { } type LLMConfig struct { - ID int64 `bun:",pk,autoincrement" json:"id"` - ModelName string `bun:",notnull" json:"model_name"` - DisplayName string `bun:"display_name,nullzero" json:"display_name"` - ApiEndpoint string `bun:",notnull" json:"api_endpoint"` - AuthHeader string `bun:",notnull" json:"auth_header"` - Type int `bun:",notnull" json:"type"` // 1: optimization, 2: comparison, 4: summary readme, 8: mcp scan, 16: for aigateway call external llm - Enabled bool `bun:",notnull" json:"enabled"` - Provider string `bun:"," json:"provider"` - Metadata map[string]any `bun:",type:jsonb,nullzero" json:"metadata"` + ID int64 `bun:",pk,autoincrement" json:"id"` + ModelName string `bun:",notnull" json:"model_name"` + OfficialName string `bun:"official_name,nullzero" json:"official_name"` + ApiEndpoint string `bun:",notnull" json:"api_endpoint"` + AuthHeader string `bun:",notnull" json:"auth_header"` + Type int `bun:",notnull" json:"type"` // 1: optimization, 2: comparison, 4: summary readme, 8: mcp scan, 16: for aigateway call external llm + Enabled bool `bun:",notnull" json:"enabled"` + Provider string `bun:"," json:"provider"` + Metadata map[string]any `bun:",type:jsonb,nullzero" json:"metadata"` // NeedSensitiveCheck controls whether requests for this model should go // through sensitive content detection in aigateway. Set to false to skip // the check (e.g. for guard models or trusted internal models). diff --git a/builder/store/database/llm_config_test.go b/builder/store/database/llm_config_test.go index 989b4149..d5ccdcb1 100644 --- a/builder/store/database/llm_config_test.go +++ b/builder/store/database/llm_config_test.go @@ -22,7 +22,7 @@ func TestLLMConfigStore_GetOptimization(t *testing.T) { Type: 1, Enabled: true, ModelName: "c1", - DisplayName: "c1", + OfficialName: "c1", Metadata: map[string]any{"source": "test"}, }).Exec(ctx) require.Nil(t, err) @@ -30,21 +30,21 @@ func TestLLMConfigStore_GetOptimization(t *testing.T) { Type: 2, Enabled: true, ModelName: "c2", - DisplayName: "c2", + OfficialName: "c2", }).Exec(ctx) require.Nil(t, err) _, err = db.Core.NewInsert().Model(&database.LLMConfig{ Type: 1, Enabled: false, ModelName: "c3", - DisplayName: "c3", + OfficialName: "c3", }).Exec(ctx) require.Nil(t, err) cfg, err := store.GetOptimization(ctx) require.Nil(t, err) require.Equal(t, "c1", cfg.ModelName) - require.Equal(t, "c1", cfg.DisplayName) + require.Equal(t, "c1", cfg.OfficialName) require.Equal(t, map[string]any{"source": "test"}, cfg.Metadata) } @@ -59,7 +59,7 @@ func TestLLMConfigStore_GetModelForSummaryReadme(t *testing.T) { Type: 5, Enabled: true, ModelName: "summary1", - DisplayName: "summary1", + OfficialName: "summary1", Metadata: map[string]any{"k": "v"}, }).Exec(ctx) require.Nil(t, err) @@ -82,7 +82,7 @@ func TestLLMConfigStore_GetModelForSummaryReadme(t *testing.T) { require.Nil(t, err) require.NotNil(t, cfg) require.Equal(t, "summary1", cfg.ModelName) - require.Equal(t, "summary1", cfg.DisplayName) + require.Equal(t, "summary1", cfg.OfficialName) require.Equal(t, map[string]any{"k": "v"}, cfg.Metadata) } @@ -98,7 +98,7 @@ func TestLLMConfigStore_GetByID(t *testing.T) { Type: 5, Enabled: true, ModelName: "summary1", - DisplayName: "summary1", + OfficialName: "summary1", Metadata: map[string]any{"k": "v"}, } _, err = db.Core.NewInsert().Model(&dbInput).Exec(ctx) @@ -108,7 +108,7 @@ func TestLLMConfigStore_GetByID(t *testing.T) { require.Nil(t, err) require.NotNil(t, cfg) require.Equal(t, "summary1", cfg.ModelName) - require.Equal(t, "summary1", cfg.DisplayName) + require.Equal(t, "summary1", cfg.OfficialName) require.Equal(t, map[string]any{"k": "v"}, cfg.Metadata) } @@ -123,14 +123,14 @@ func TestLLMConfigStore_CRUD(t *testing.T) { Type: 5, Enabled: true, ModelName: "summary1", - DisplayName: "summary1", + OfficialName: "summary1", Metadata: map[string]any{"k": "v", "tasks": []interface{}{"text-generation", "text-to-image"}}, } res, err := store.Create(ctx, dbInput) require.Nil(t, err) require.NotNil(t, res) require.Equal(t, "summary1", res.ModelName) - require.Equal(t, "summary1", res.DisplayName) + require.Equal(t, "summary1", res.OfficialName) require.Equal(t, map[string]any{"k": "v", "tasks": []interface{}{"text-generation", "text-to-image"}}, res.Metadata) searchType := 5 @@ -157,10 +157,10 @@ func TestLLMConfigStore_Search(t *testing.T) { // Create test data with hyphens and letter-number combinations testModels := []database.LLMConfig{ - {Type: 1, Enabled: true, ModelName: "deepseek-v3", DisplayName: "deepseek-v3"}, - {Type: 1, Enabled: true, ModelName: "openai/gpt-4", DisplayName: "gpt-4"}, - {Type: 1, Enabled: true, ModelName: "claude3-opus", DisplayName: "claude3-opus"}, - {Type: 1, Enabled: true, ModelName: "llama2-7b", DisplayName: "llama2-7b"}, + {Type: 1, Enabled: true, ModelName: "deepseek-v3", OfficialName: "deepseek-v3"}, + {Type: 1, Enabled: true, ModelName: "openai/gpt-4", OfficialName: "gpt-4"}, + {Type: 1, Enabled: true, ModelName: "claude3-opus", OfficialName: "claude3-opus"}, + {Type: 1, Enabled: true, ModelName: "llama2-7b", OfficialName: "llama2-7b"}, } for _, model := range testModels { @@ -250,7 +250,7 @@ func TestLLMConfigStore_Index_EnabledFilter(t *testing.T) { } _, err = store.Create(ctx, database.LLMConfig{ ModelName: "idx-en-on", - DisplayName: "idx-en-on", + OfficialName: "idx-en-on", Enabled: true, Type: base.Type, ApiEndpoint: base.ApiEndpoint, @@ -260,7 +260,7 @@ func TestLLMConfigStore_Index_EnabledFilter(t *testing.T) { require.Nil(t, err) _, err = store.Create(ctx, database.LLMConfig{ ModelName: "idx-en-off", - DisplayName: "idx-en-off", + OfficialName: "idx-en-off", Enabled: false, Type: base.Type, ApiEndpoint: base.ApiEndpoint, diff --git a/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql new file mode 100644 index 00000000..7f40d69f --- /dev/null +++ b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_rule_type_pattern; \ No newline at end of file diff --git a/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql new file mode 100644 index 00000000..430c4b2d --- /dev/null +++ b/builder/store/database/migrations/20260403000000_add_unique_idx_to_repo_file_check_rule.up.sql @@ -0,0 +1,2 @@ +UPDATE repository_file_check_rules SET pattern = LOWER(pattern); +CREATE UNIQUE INDEX IF NOT EXISTS idx_rule_type_pattern ON repository_file_check_rules (rule_type, pattern); \ No newline at end of file diff --git a/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql new file mode 100644 index 00000000..b04bd867 --- /dev/null +++ b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.down.sql @@ -0,0 +1,6 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE llm_configs +RENAME COLUMN official_name TO display_name; diff --git a/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql new file mode 100644 index 00000000..d0d12cb2 --- /dev/null +++ b/builder/store/database/migrations/20260408093000_rename_display_name_to_official_name_in_llm_configs.up.sql @@ -0,0 +1,6 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE llm_configs +RENAME COLUMN display_name TO official_name; diff --git a/builder/store/database/repository_file_check_rule.go b/builder/store/database/repository_file_check_rule.go index 174edd1b..f24e5a49 100644 --- a/builder/store/database/repository_file_check_rule.go +++ b/builder/store/database/repository_file_check_rule.go @@ -2,9 +2,15 @@ package database import ( "context" + "strings" "time" ) +const ( + RuleTypeNamespace = "namespace" + RuleTypeModelName = "model_name" +) + type RepositoryFileCheckRule struct { ID int64 `bun:",pk,autoincrement"` RuleType string `bun:"rule_type,notnull,unique:idx_rule_type_pattern"` @@ -19,6 +25,7 @@ type RepositoryFileCheckRuleStore interface { ListByRuleType(ctx context.Context, ruleType string) ([]RepositoryFileCheckRule, error) Delete(ctx context.Context, ruleType, pattern string) error Exists(ctx context.Context, ruleType, pattern string) (bool, error) + MatchRegex(ctx context.Context, ruleType, targetString string) (bool, error) } type repositoryFileCheckRuleStore struct { @@ -34,6 +41,7 @@ func NewRepositoryFileCheckRuleStoreWithDB(db *DB) RepositoryFileCheckRuleStore } func (s *repositoryFileCheckRuleStore) Create(ctx context.Context, ruleType, pattern string) (*RepositoryFileCheckRule, error) { + pattern = strings.ToLower(pattern) rule := &RepositoryFileCheckRule{RuleType: ruleType, Pattern: pattern} _, err := s.db.Operator.Core.NewInsert().Model(rule).Exec(ctx) return rule, err @@ -52,6 +60,7 @@ func (s *repositoryFileCheckRuleStore) ListByRuleType(ctx context.Context, ruleT } func (s *repositoryFileCheckRuleStore) Delete(ctx context.Context, ruleType, pattern string) error { + pattern = strings.ToLower(pattern) _, err := s.db.Operator.Core.NewDelete().Model((*RepositoryFileCheckRule)(nil)). Where("rule_type = ?", ruleType). Where("pattern = ?", pattern). @@ -60,9 +69,18 @@ func (s *repositoryFileCheckRuleStore) Delete(ctx context.Context, ruleType, pat } func (s *repositoryFileCheckRuleStore) Exists(ctx context.Context, ruleType, pattern string) (bool, error) { + pattern = strings.ToLower(pattern) exists, err := s.db.Operator.Core.NewSelect().Model((*RepositoryFileCheckRule)(nil)). Where("rule_type = ?", ruleType). Where("pattern = ?", pattern). Exists(ctx) return exists, err } + +func (s *repositoryFileCheckRuleStore) MatchRegex(ctx context.Context, ruleType, targetString string) (bool, error) { + exists, err := s.db.Operator.Core.NewSelect().Model((*RepositoryFileCheckRule)(nil)). + Where("rule_type = ?", ruleType). + Where("? ~* pattern", targetString). + Exists(ctx) + return exists, err +} diff --git a/common/types/llm_service.go b/common/types/llm_service.go index e42d3a35..452bc82c 100644 --- a/common/types/llm_service.go +++ b/common/types/llm_service.go @@ -5,7 +5,7 @@ import "time" type LLMConfig struct { ID int64 `json:"id"` ModelName string `json:"model_name"` - DisplayName string `json:"display_name"` + OfficialName string `json:"official_name"` ApiEndpoint string `json:"api_endpoint"` AuthHeader string `json:"auth_header"` Type int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme @@ -37,7 +37,7 @@ type SearchPromptPrefix struct { type UpdateLLMConfigReq struct { ID int64 `json:"id"` ModelName *string `json:"model_name"` - DisplayName *string `json:"display_name"` + OfficialName *string `json:"official_name"` ApiEndpoint *string `json:"api_endpoint"` AuthHeader *string `json:"auth_header"` Type *int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme @@ -55,7 +55,7 @@ type UpdatePromptPrefixReq struct { type CreateLLMConfigReq struct { ModelName string `json:"model_name" binding:"required"` - DisplayName string `json:"display_name"` + OfficialName string `json:"official_name"` ApiEndpoint string `json:"api_endpoint" binding:"required"` AuthHeader string `json:"auth_header"` Type int `json:"type" binding:"required,oneof=1 2 4 8 16"` // 1: optimization, 2: comparison, 4: summary readme, 8: mcp scan, 16: for aigateway call external llm diff --git a/component/llm_service.go b/component/llm_service.go index 00258873..892d1c44 100644 --- a/component/llm_service.go +++ b/component/llm_service.go @@ -67,7 +67,7 @@ func (s *llmServiceComponentImpl) ShowLLMConfig(ctx context.Context, id int64) ( llmConfig := &types.LLMConfig{ ID: dbLlmConfig.ID, ModelName: dbLlmConfig.ModelName, - DisplayName: dbLlmConfig.DisplayName, + OfficialName: dbLlmConfig.OfficialName, ApiEndpoint: dbLlmConfig.ApiEndpoint, AuthHeader: dbLlmConfig.AuthHeader, Type: dbLlmConfig.Type, @@ -102,8 +102,8 @@ func (s *llmServiceComponentImpl) UpdateLLMConfig(ctx context.Context, req *type if req.ModelName != nil { llmConfig.ModelName = *req.ModelName } - if req.DisplayName != nil { - llmConfig.DisplayName = *req.DisplayName + if req.OfficialName != nil { + llmConfig.OfficialName = *req.OfficialName } if req.ApiEndpoint != nil { llmConfig.ApiEndpoint = *req.ApiEndpoint @@ -130,7 +130,7 @@ func (s *llmServiceComponentImpl) UpdateLLMConfig(ctx context.Context, req *type resLLMConfig := &types.LLMConfig{ ID: updatedConfig.ID, ModelName: updatedConfig.ModelName, - DisplayName: updatedConfig.DisplayName, + OfficialName: updatedConfig.OfficialName, ApiEndpoint: updatedConfig.ApiEndpoint, AuthHeader: updatedConfig.AuthHeader, Type: updatedConfig.Type, @@ -173,7 +173,7 @@ func (s *llmServiceComponentImpl) UpdatePromptPrefix(ctx context.Context, req *t func (s *llmServiceComponentImpl) CreateLLMConfig(ctx context.Context, req *types.CreateLLMConfigReq) (*types.LLMConfig, error) { dbLLMConfig := database.LLMConfig{ ModelName: req.ModelName, - DisplayName: req.DisplayName, + OfficialName: req.OfficialName, ApiEndpoint: req.ApiEndpoint, AuthHeader: req.AuthHeader, Type: req.Type, @@ -188,7 +188,7 @@ func (s *llmServiceComponentImpl) CreateLLMConfig(ctx context.Context, req *type resLLMConfig := &types.LLMConfig{ ID: dbRes.ID, ModelName: dbRes.ModelName, - DisplayName: dbRes.DisplayName, + OfficialName: dbRes.OfficialName, ApiEndpoint: dbRes.ApiEndpoint, AuthHeader: dbRes.AuthHeader, Type: dbRes.Type, diff --git a/moderation/component/repo.go b/moderation/component/repo.go index fc4d8517..9d53cc1f 100644 --- a/moderation/component/repo.go +++ b/moderation/component/repo.go @@ -75,15 +75,12 @@ func (c *repoComponentImpl) UpdateRepoSensitiveCheckStatus(ctx context.Context, } func (c *repoComponentImpl) SkipSensitiveCheckForWhiteList(ctx context.Context, req RepoFullCheckRequest) (bool, error) { - whiteList, err := c.GetNamespaceWhiteList(ctx) + exists, err := c.whitelistRule.Exists(ctx, database.RuleTypeNamespace, req.Namespace) if err != nil { - return false, fmt.Errorf("failed to get namespace white list: %w", err) + return false, fmt.Errorf("failed to check namespace in white list: %w", err) } - for _, rule := range whiteList { - if req.Namespace != rule { - continue - } + if exists { repo, err := c.GetRepo(ctx, req.RepoType, req.Namespace, req.Name) if err != nil { return false, fmt.Errorf("failed to get repo for skip sensitive check, namespace: %s, name: %s, error: %w", req.Namespace, req.Name, err) @@ -246,7 +243,7 @@ func (cc *repoComponentImpl) CheckRequestV2(ctx context.Context, req types.Sensi } func (c *repoComponentImpl) GetNamespaceWhiteList(ctx context.Context) ([]string, error) { - namespaceWhiteList, err := c.whitelistRule.ListByRuleType(ctx, "namespace") + namespaceWhiteList, err := c.whitelistRule.ListByRuleType(ctx, database.RuleTypeNamespace) if err != nil { return nil, err } diff --git a/moderation/component/repo_test.go b/moderation/component/repo_test.go index 69b9ffc7..ba42c01b 100644 --- a/moderation/component/repo_test.go +++ b/moderation/component/repo_test.go @@ -168,7 +168,7 @@ func TestRepoComponent_SkipSensitiveCheckForWhiteList(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(true, nil).Once() mockRepoStore.EXPECT().FindByPath(ctx, req.RepoType, req.Namespace, req.Name).Return(&database.Repository{ID: 10}, nil).Once() mockRepoStore.EXPECT().UpdateRepoSensitiveCheckStatus(ctx, int64(10), types.SensitiveCheckSkip).Return(nil).Once() @@ -191,7 +191,7 @@ func TestRepoComponent_SkipSensitiveCheckForWhiteList(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(false, nil).Once() skipped, err := repoComp.SkipSensitiveCheckForWhiteList(ctx, req) require.NoError(t, err) @@ -216,7 +216,7 @@ func TestRepoComponent_RepoFullCheck(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(true, nil).Once() mockRepoStore.EXPECT().FindByPath(ctx, req.RepoType, req.Namespace, req.Name).Return(&database.Repository{ID: 10}, nil).Once() mockRepoStore.EXPECT().UpdateRepoSensitiveCheckStatus(ctx, int64(10), types.SensitiveCheckSkip).Return(nil).Once() @@ -241,7 +241,7 @@ func TestRepoComponent_RepoFullCheck(t *testing.T) { RepoType: types.ModelRepo, } - mockRuleStore.EXPECT().ListByRuleType(ctx, "namespace").Return([]database.RepositoryFileCheckRule{{Pattern: "admin"}}, nil).Once() + mockRuleStore.EXPECT().Exists(ctx, database.RuleTypeNamespace, req.Namespace).Return(false, nil).Once() mockWorkflowClient := mocktemporal.NewMockClient(t) temporal.Assign(mockWorkflowClient) workflowOptions := client.StartWorkflowOptions{