diff --git a/core/authenticate/authenticate.go b/core/authenticate/authenticate.go index f84fb5f79..6815d22ea 100644 --- a/core/authenticate/authenticate.go +++ b/core/authenticate/authenticate.go @@ -8,6 +8,7 @@ import ( "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" pat "github.com/raystack/frontier/core/userpat/models" + "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/raystack/frontier/pkg/metadata" @@ -142,3 +143,12 @@ type Principal struct { ServiceUser *serviceuser.ServiceUser PAT *pat.PAT } + +// ResolveSubject returns the subject ID and type for authorization queries. +// For PAT principals, it resolves to the underlying user. +func (p Principal) ResolveSubject() (id string, subjectType string) { + if p.PAT != nil { + return p.PAT.UserID, schema.UserPrincipal + } + return p.ID, p.Type +} diff --git a/core/group/service.go b/core/group/service.go index 3a64348f4..cad606489 100644 --- a/core/group/service.go +++ b/core/group/service.go @@ -126,20 +126,20 @@ func (s Service) Update(ctx context.Context, grp Group) (Group, error) { return Group{}, ErrInvalidID } -func (s Service) ListByUser(ctx context.Context, principalID, principalType string, flt Filter) ([]Group, error) { +func (s Service) ListByUser(ctx context.Context, principal authenticate.Principal, flt Filter) ([]Group, error) { + subjectID, subjectType := principal.ResolveSubject() subjectIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.GroupNamespace, - }, - Subject: relation.Subject{ - Namespace: principalType, - ID: principalID, - }, + Object: relation.Object{Namespace: schema.GroupNamespace}, + Subject: relation.Subject{Namespace: subjectType, ID: subjectID}, RelationName: schema.MembershipPermission, }) if err != nil { return nil, err } + subjectIDs, err = s.intersectPATScope(ctx, principal, schema.GroupNamespace, subjectIDs) + if err != nil { + return nil, err + } if len(subjectIDs) == 0 { // no groups return nil, nil @@ -148,6 +148,23 @@ func (s Service) ListByUser(ctx context.Context, principalID, principalType stri return s.List(ctx, flt) } +// intersectPATScope narrows resource IDs to only those the PAT is scoped to. +func (s Service) intersectPATScope(ctx context.Context, principal authenticate.Principal, + namespace string, resourceIDs []string) ([]string, error) { + if principal.PAT == nil || len(resourceIDs) == 0 { + return resourceIDs, nil + } + patIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ + Object: relation.Object{Namespace: namespace}, + Subject: relation.Subject{ID: principal.PAT.ID, Namespace: schema.PATPrincipal}, + RelationName: schema.GetPermission, + }) + if err != nil { + return nil, err + } + return utils.Intersection(resourceIDs, patIDs), nil +} + // AddMember adds a subject(user) to group as member func (s Service) AddMember(ctx context.Context, groupID string, principal authenticate.Principal) error { // first create a policy for the user as member of the group diff --git a/core/group/service_test.go b/core/group/service_test.go index dfe1a64d7..1df317482 100644 --- a/core/group/service_test.go +++ b/core/group/service_test.go @@ -14,6 +14,7 @@ import ( "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/user" + pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -262,3 +263,107 @@ func TestService_Update(t *testing.T) { assert.Equal(t, err, group.ErrInvalidID) }) } + +func TestService_ListByUser(t *testing.T) { + ctx := context.Background() + + t.Run("should resolve PAT to user and intersect with PAT group scope", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + + // LookupResources for user's group memberships + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.GroupNamespace}, + Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"}, + RelationName: schema.MembershipPermission, + }).Return([]string{"group-1", "group-2", "group-3"}, nil).Once() + + // LookupResources for PAT's group scope + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: "pat-456", Namespace: schema.PATPrincipal}, + RelationName: schema.GetPermission, + }).Return([]string{"group-1", "group-3"}, nil).Once() + + // Repo should be called with intersection + mockRepo.On("List", ctx, group.Filter{ + GroupIDs: []string{"group-1", "group-3"}, + }).Return([]group.Group{ + {ID: "group-1", Name: "group-one"}, + {ID: "group-3", Name: "group-three"}, + }, nil).Once() + + result, err := svc.ListByUser(ctx, authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, + }, group.Filter{}) + + assert.NoError(t, err) + assert.Len(t, result, 2) + }) + + t.Run("should return nil when PAT has no group scope overlap", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.GroupNamespace}, + Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"}, + RelationName: schema.MembershipPermission, + }).Return([]string{"group-1"}, nil).Once() + + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.GroupNamespace}, + Subject: relation.Subject{ID: "pat-456", Namespace: schema.PATPrincipal}, + RelationName: schema.GetPermission, + }).Return([]string{"group-2"}, nil).Once() + + result, err := svc.ListByUser(ctx, authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, + }, group.Filter{}) + + assert.NoError(t, err) + assert.Nil(t, result) + }) + + t.Run("should pass through for regular user principal", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + + svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc) + + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.GroupNamespace}, + Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"}, + RelationName: schema.MembershipPermission, + }).Return([]string{"group-1", "group-2"}, nil).Once() + + mockRepo.On("List", ctx, group.Filter{ + GroupIDs: []string{"group-1", "group-2"}, + }).Return([]group.Group{ + {ID: "group-1", Name: "group-one"}, + {ID: "group-2", Name: "group-two"}, + }, nil).Once() + + result, err := svc.ListByUser(ctx, authenticate.Principal{ + ID: "user-123", + Type: schema.UserPrincipal, + }, group.Filter{}) + + assert.NoError(t, err) + assert.Len(t, result, 2) + }) +} diff --git a/core/invitation/mocks/group_service.go b/core/invitation/mocks/group_service.go index 61478c28a..7e305f578 100644 --- a/core/invitation/mocks/group_service.go +++ b/core/invitation/mocks/group_service.go @@ -130,9 +130,9 @@ func (_c *GroupService_Get_Call) RunAndReturn(run func(context.Context, string) return _c } -// ListByUser provides a mock function with given fields: ctx, principalID, principalType, flt -func (_m *GroupService) ListByUser(ctx context.Context, principalID string, principalType string, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, principalID, principalType, flt) +// ListByUser provides a mock function with given fields: ctx, principal, flt +func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) { + ret := _m.Called(ctx, principal, flt) if len(ret) == 0 { panic("no return value specified for ListByUser") @@ -140,19 +140,19 @@ func (_m *GroupService) ListByUser(ctx context.Context, principalID string, prin var r0 []group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)); ok { + return rf(ctx, principal, flt) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, group.Filter) []group.Group); ok { - r0 = rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) []group.Group); ok { + r0 = rf(ctx, principal, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]group.Group) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string, group.Filter) error); ok { - r1 = rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, group.Filter) error); ok { + r1 = rf(ctx, principal, flt) } else { r1 = ret.Error(1) } @@ -167,16 +167,15 @@ type GroupService_ListByUser_Call struct { // ListByUser is a helper method to define mock.On call // - ctx context.Context -// - principalID string -// - principalType string +// - principal authenticate.Principal // - flt group.Filter -func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principalID interface{}, principalType interface{}, flt interface{}) *GroupService_ListByUser_Call { - return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principalID, principalType, flt)} +func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *GroupService_ListByUser_Call { + return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} } -func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principalID string, principalType string, flt group.Filter)) *GroupService_ListByUser_Call { +func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt group.Filter)) *GroupService_ListByUser_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(group.Filter)) + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(group.Filter)) }) return _c } @@ -186,7 +185,7 @@ func (_c *GroupService_ListByUser_Call) Return(_a0 []group.Group, _a1 error) *Gr return _c } -func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, string, string, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { +func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { _c.Call.Return(run) return _c } diff --git a/core/invitation/service.go b/core/invitation/service.go index f429aa88b..bf59ff17d 100644 --- a/core/invitation/service.go +++ b/core/invitation/service.go @@ -53,7 +53,7 @@ type OrganizationService interface { type GroupService interface { Get(ctx context.Context, id string) (group.Group, error) AddMember(ctx context.Context, groupID string, principal authenticate.Principal) error - ListByUser(ctx context.Context, principalID, principalType string, flt group.Filter) ([]group.Group, error) + ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) } type RelationService interface { @@ -315,7 +315,9 @@ func (s Service) Accept(ctx context.Context, id uuid.UUID) error { // check if the invitation has a group membership if len(invite.GroupIDs) > 0 { - userGroups, err := s.groupSvc.ListByUser(ctx, userOb.ID, schema.UserPrincipal, group.Filter{}) + userGroups, err := s.groupSvc.ListByUser(ctx, authenticate.Principal{ + ID: userOb.ID, Type: schema.UserPrincipal, + }, group.Filter{}) if err != nil { return err } diff --git a/core/organization/service.go b/core/organization/service.go index 2ebe31c0d..d5c1d93fa 100644 --- a/core/organization/service.go +++ b/core/organization/service.go @@ -308,19 +308,25 @@ func (s Service) ListByUser(ctx context.Context, principal authenticate.Principa defer promCollect() } + subjectID, subjectType := principal.ResolveSubject() subjectIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ Object: relation.Object{ Namespace: schema.OrganizationNamespace, }, Subject: relation.Subject{ - ID: principal.ID, - Namespace: principal.Type, + ID: subjectID, + Namespace: subjectType, }, RelationName: schema.MembershipPermission, }) if err != nil { return nil, err } + + if principal.PAT != nil { + subjectIDs = utils.Intersection(subjectIDs, []string{principal.PAT.OrgID}) + } + if len(subjectIDs) == 0 { // no organizations return []Organization{}, nil diff --git a/core/organization/service_test.go b/core/organization/service_test.go index 13047107a..30be35355 100644 --- a/core/organization/service_test.go +++ b/core/organization/service_test.go @@ -13,6 +13,7 @@ import ( "github.com/raystack/frontier/core/policy" "github.com/raystack/frontier/core/preference" "github.com/raystack/frontier/core/relation" + pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -258,3 +259,103 @@ func TestService_AttachToPlatform(t *testing.T) { assert.Equal(t, expectedErr, err) }) } + +func TestService_ListByUser(t *testing.T) { + ctx := context.Background() + + t.Run("should resolve PAT to user and intersect with PAT org", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockPrefSvc := mocks.NewPreferencesService(t) + mockAuditRecordRepo := mocks.NewAuditRecordRepository(t) + + svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, mockPolicySvc, mockPrefSvc, mockAuditRecordRepo) + + // LookupResources should be called with user ID/type, not PAT + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ID: "user-123", Namespace: schema.UserPrincipal}, + RelationName: schema.MembershipPermission, + }).Return([]string{"org-1", "org-2"}, nil).Once() + + // Repo should only be called with the PAT's org (intersection result) + mockRepo.On("List", ctx, organization.Filter{ + IDs: []string{"org-1"}, + }).Return([]organization.Organization{ + {ID: "org-1", Name: "org-one"}, + }, nil).Once() + + result, err := svc.ListByUser(ctx, authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"}, + }, organization.Filter{}) + + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "org-1", result[0].ID) + }) + + t.Run("should return empty when PAT org not in user memberships", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockPrefSvc := mocks.NewPreferencesService(t) + mockAuditRecordRepo := mocks.NewAuditRecordRepository(t) + + svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, mockPolicySvc, mockPrefSvc, mockAuditRecordRepo) + + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ID: "user-123", Namespace: schema.UserPrincipal}, + RelationName: schema.MembershipPermission, + }).Return([]string{"org-1", "org-2"}, nil).Once() + + result, err := svc.ListByUser(ctx, authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-999"}, + }, organization.Filter{}) + + assert.NoError(t, err) + assert.Empty(t, result) + }) + + t.Run("should pass through for regular user principal", func(t *testing.T) { + mockRepo := mocks.NewRepository(t) + mockRelationSvc := mocks.NewRelationService(t) + mockUserSvc := mocks.NewUserService(t) + mockAuthnSvc := mocks.NewAuthnService(t) + mockPolicySvc := mocks.NewPolicyService(t) + mockPrefSvc := mocks.NewPreferencesService(t) + mockAuditRecordRepo := mocks.NewAuditRecordRepository(t) + + svc := organization.NewService(mockRepo, mockRelationSvc, mockUserSvc, mockAuthnSvc, mockPolicySvc, mockPrefSvc, mockAuditRecordRepo) + + mockRelationSvc.On("LookupResources", ctx, relation.Relation{ + Object: relation.Object{Namespace: schema.OrganizationNamespace}, + Subject: relation.Subject{ID: "user-123", Namespace: schema.UserPrincipal}, + RelationName: schema.MembershipPermission, + }).Return([]string{"org-1", "org-2"}, nil).Once() + + mockRepo.On("List", ctx, organization.Filter{ + IDs: []string{"org-1", "org-2"}, + }).Return([]organization.Organization{ + {ID: "org-1", Name: "org-one"}, + {ID: "org-2", Name: "org-two"}, + }, nil).Once() + + result, err := svc.ListByUser(ctx, authenticate.Principal{ + ID: "user-123", + Type: schema.UserPrincipal, + }, organization.Filter{}) + + assert.NoError(t, err) + assert.Len(t, result, 2) + }) +} diff --git a/core/project/mocks/group_service.go b/core/project/mocks/group_service.go index c7e54d041..cdd3a05b9 100644 --- a/core/project/mocks/group_service.go +++ b/core/project/mocks/group_service.go @@ -5,7 +5,10 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + group "github.com/raystack/frontier/core/group" + mock "github.com/stretchr/testify/mock" ) @@ -81,9 +84,9 @@ func (_c *GroupService_GetByIDs_Call) RunAndReturn(run func(context.Context, []s return _c } -// ListByUser provides a mock function with given fields: ctx, principalID, principalType, flt -func (_m *GroupService) ListByUser(ctx context.Context, principalID string, principalType string, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, principalID, principalType, flt) +// ListByUser provides a mock function with given fields: ctx, principal, flt +func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) { + ret := _m.Called(ctx, principal, flt) if len(ret) == 0 { panic("no return value specified for ListByUser") @@ -91,19 +94,19 @@ func (_m *GroupService) ListByUser(ctx context.Context, principalID string, prin var r0 []group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)); ok { + return rf(ctx, principal, flt) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, group.Filter) []group.Group); ok { - r0 = rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) []group.Group); ok { + r0 = rf(ctx, principal, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]group.Group) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string, group.Filter) error); ok { - r1 = rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, group.Filter) error); ok { + r1 = rf(ctx, principal, flt) } else { r1 = ret.Error(1) } @@ -118,16 +121,15 @@ type GroupService_ListByUser_Call struct { // ListByUser is a helper method to define mock.On call // - ctx context.Context -// - principalID string -// - principalType string +// - principal authenticate.Principal // - flt group.Filter -func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principalID interface{}, principalType interface{}, flt interface{}) *GroupService_ListByUser_Call { - return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principalID, principalType, flt)} +func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *GroupService_ListByUser_Call { + return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} } -func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principalID string, principalType string, flt group.Filter)) *GroupService_ListByUser_Call { +func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt group.Filter)) *GroupService_ListByUser_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(group.Filter)) + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(group.Filter)) }) return _c } @@ -137,7 +139,7 @@ func (_c *GroupService_ListByUser_Call) Return(_a0 []group.Group, _a1 error) *Gr return _c } -func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, string, string, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { +func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { _c.Call.Return(run) return _c } diff --git a/core/project/service.go b/core/project/service.go index 93d9f0961..ebd41b15c 100644 --- a/core/project/service.go +++ b/core/project/service.go @@ -48,7 +48,7 @@ type AuthnService interface { type GroupService interface { GetByIDs(ctx context.Context, ids []string) ([]group.Group, error) - ListByUser(ctx context.Context, principalID, principalType string, flt group.Filter) ([]group.Group, error) + ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) } type Service struct { @@ -137,63 +137,31 @@ func (s Service) List(ctx context.Context, f Filter) ([]Project, error) { return projects, nil } -func (s Service) ListByUser(ctx context.Context, principalID, principalType string, +func (s Service) ListByUser(ctx context.Context, principal authenticate.Principal, flt Filter) ([]Project, error) { + subjectID, subjectType := principal.ResolveSubject() + var projIDs []string var err error - if flt.NonInherited == true { + if flt.NonInherited { // direct added users - policies, err := s.policyService.List(ctx, policy.Filter{ - PrincipalType: principalType, - PrincipalID: principalID, - ResourceType: schema.ProjectNamespace, - }) - if err != nil { - return nil, err - } - for _, pol := range policies { - projIDs = append(projIDs, pol.ResourceID) - } - - // added via groups - groups, err := s.groupService.ListByUser(ctx, principalID, principalType, group.Filter{}) - if err != nil { - return nil, err - } - groupIDs := utils.Map(groups, func(g group.Group) string { - return g.ID - }) - if len(groupIDs) > 0 { - policies, err = s.policyService.List(ctx, policy.Filter{ - PrincipalType: schema.GroupPrincipal, - PrincipalIDs: groupIDs, - ResourceType: schema.ProjectNamespace, - }) - if err != nil { - return nil, err - } - for _, pol := range policies { - projIDs = append(projIDs, pol.ResourceID) - } - } + projIDs, err = s.listNonInheritedProjectIDs(ctx, subjectID, subjectType) } else { projIDs, err = s.relationService.LookupResources(ctx, relation.Relation{ - Object: relation.Object{ - Namespace: schema.ProjectNamespace, - }, - Subject: relation.Subject{ - Namespace: principalType, - ID: principalID, - }, + Object: relation.Object{Namespace: schema.ProjectNamespace}, + Subject: relation.Subject{Namespace: subjectType, ID: subjectID}, RelationName: MemberPermission, }) - if err != nil { - return nil, err - } + } + if err != nil { + return nil, err } - // de-duplicate project IDs projIDs = utils.Deduplicate(projIDs) + projIDs, err = s.intersectPATScope(ctx, principal, schema.ProjectNamespace, projIDs) + if err != nil { + return nil, err + } if len(projIDs) == 0 { return []Project{}, nil } @@ -202,6 +170,62 @@ func (s Service) ListByUser(ctx context.Context, principalID, principalType stri return s.List(ctx, flt) } +// listNonInheritedProjectIDs returns project IDs where the principal has direct +// role assignments (not inherited through org), including via group memberships. +func (s Service) listNonInheritedProjectIDs(ctx context.Context, principalID, principalType string) ([]string, error) { + policies, err := s.policyService.List(ctx, policy.Filter{ + PrincipalType: principalType, + PrincipalID: principalID, + ResourceType: schema.ProjectNamespace, + }) + if err != nil { + return nil, err + } + var projIDs []string + for _, pol := range policies { + projIDs = append(projIDs, pol.ResourceID) + } + + // projects added via group memberships + groups, err := s.groupService.ListByUser(ctx, + authenticate.Principal{ID: principalID, Type: principalType}, group.Filter{}) + if err != nil { + return nil, err + } + groupIDs := utils.Map(groups, func(g group.Group) string { return g.ID }) + if len(groupIDs) > 0 { + policies, err = s.policyService.List(ctx, policy.Filter{ + PrincipalType: schema.GroupPrincipal, + PrincipalIDs: groupIDs, + ResourceType: schema.ProjectNamespace, + }) + if err != nil { + return nil, err + } + for _, pol := range policies { + projIDs = append(projIDs, pol.ResourceID) + } + } + return projIDs, nil +} + +// intersectPATScope narrows resource IDs to only those the PAT is scoped to. +func (s Service) intersectPATScope(ctx context.Context, principal authenticate.Principal, + namespace string, resourceIDs []string) ([]string, error) { + if principal.PAT == nil || len(resourceIDs) == 0 { + return resourceIDs, nil + } + patIDs, err := s.relationService.LookupResources(ctx, relation.Relation{ + Object: relation.Object{Namespace: namespace}, + Subject: relation.Subject{ID: principal.PAT.ID, Namespace: schema.PATPrincipal}, + RelationName: schema.GetPermission, + }) + if err != nil { + return nil, err + } + return utils.Intersection(resourceIDs, patIDs), nil +} + func (s Service) Update(ctx context.Context, prj Project) (Project, error) { if utils.IsValidUUID(prj.ID) { return s.repository.UpdateByID(ctx, prj) diff --git a/core/project/service_test.go b/core/project/service_test.go index e9df9511a..988df4b55 100644 --- a/core/project/service_test.go +++ b/core/project/service_test.go @@ -16,6 +16,7 @@ import ( "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/serviceuser" "github.com/raystack/frontier/core/user" + pat "github.com/raystack/frontier/core/userpat/models" "github.com/raystack/frontier/internal/bootstrap/schema" ) @@ -280,9 +281,8 @@ func TestService_List(t *testing.T) { func TestService_ListByUser(t *testing.T) { ctx := context.Background() type args struct { - principalID string - principalType string - flt project.Filter + principal authenticate.Principal + flt project.Filter } tests := []struct { name string @@ -294,9 +294,8 @@ func TestService_ListByUser(t *testing.T) { { name: "list all projects by user successfully", args: args{ - principalID: "user-id", - principalType: schema.UserPrincipal, - flt: project.Filter{}, + principal: authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, + flt: project.Filter{}, }, want: []project.Project{ { @@ -366,8 +365,7 @@ func TestService_ListByUser(t *testing.T) { { name: "list all projects by user with non-inherited policies (with no groups)", args: args{ - principalID: "user-id", - principalType: schema.UserPrincipal, + principal: authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, flt: project.Filter{ NonInherited: true, }, @@ -397,7 +395,7 @@ func TestService_ListByUser(t *testing.T) { }, }, nil) - groupService.EXPECT().ListByUser(ctx, "user-id", schema.UserPrincipal, group.Filter{}).Return([]group.Group{}, nil) + groupService.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, group.Filter{}).Return([]group.Group{}, nil) repo.EXPECT().List(ctx, project.Filter{ ProjectIDs: []string{"project-id"}, @@ -417,8 +415,7 @@ func TestService_ListByUser(t *testing.T) { { name: "list all projects by user with non-inherited policies (with groups)", args: args{ - principalID: "user-id", - principalType: schema.UserPrincipal, + principal: authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, flt: project.Filter{ NonInherited: true, }, @@ -455,7 +452,7 @@ func TestService_ListByUser(t *testing.T) { }, }, nil) - groupService.EXPECT().ListByUser(ctx, "user-id", schema.UserPrincipal, group.Filter{}).Return([]group.Group{ + groupService.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, group.Filter{}).Return([]group.Group{ { ID: "group-id", }, @@ -496,11 +493,188 @@ func TestService_ListByUser(t *testing.T) { return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService) }, }, + { + name: "PAT principal should resolve to user and intersect with PAT project scope", + args: args{ + principal: authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, + }, + flt: project.Filter{}, + }, + want: []project.Project{ + { + ID: "project-id", + Name: "test", + Organization: organization.Organization{ + ID: "org-id", + }, + }, + }, + wantErr: false, + setup: func() *project.Service { + repo, userService, suserService, relationService, policyService, authnService, groupService := mockService(t) + // LookupResources for user's project memberships (resolved from PAT) + relationService.EXPECT().LookupResources(ctx, relation.Relation{ + Object: relation.Object{ + Namespace: schema.ProjectNamespace, + }, + Subject: relation.Subject{ + Namespace: schema.UserPrincipal, + ID: "user-id", + }, + RelationName: project.MemberPermission, + }).Return([]string{"project-id", "project-id-2", "project-id-3"}, nil) + + // LookupResources for PAT's project scope + relationService.EXPECT().LookupResources(ctx, relation.Relation{ + Object: relation.Object{ + Namespace: schema.ProjectNamespace, + }, + Subject: relation.Subject{ + ID: "pat-456", + Namespace: schema.PATPrincipal, + }, + RelationName: schema.GetPermission, + }).Return([]string{"project-id"}, nil) + + // Repo called with intersection + repo.EXPECT().List(ctx, project.Filter{ + ProjectIDs: []string{"project-id"}, + }).Return([]project.Project{ + { + ID: "project-id", + Name: "test", + Organization: organization.Organization{ + ID: "org-id", + }, + }, + }, nil) + return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService) + }, + }, + { + name: "PAT principal with non-inherited should resolve to user and intersect", + args: args{ + principal: authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, + }, + flt: project.Filter{ + NonInherited: true, + }, + }, + want: []project.Project{ + { + ID: "project-id", + Name: "test", + Organization: organization.Organization{ + ID: "org-id", + }, + }, + }, + wantErr: false, + setup: func() *project.Service { + repo, userService, suserService, relationService, policyService, authnService, groupService := mockService(t) + // Direct policies for user (resolved from PAT) + policyService.EXPECT().List(ctx, policy.Filter{ + PrincipalType: schema.UserPrincipal, + PrincipalID: "user-id", + ResourceType: schema.ProjectNamespace, + }).Return([]policy.Policy{ + { + ResourceID: "project-id", + ResourceType: schema.ProjectNamespace, + PrincipalID: "user-id", + PrincipalType: schema.UserPrincipal, + }, + { + ResourceID: "project-id-2", + ResourceType: schema.ProjectNamespace, + PrincipalID: "user-id", + PrincipalType: schema.UserPrincipal, + }, + }, nil) + + // Group lookup uses user-only principal (no double PAT filtering) + groupService.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "user-id", Type: schema.UserPrincipal}, group.Filter{}).Return([]group.Group{}, nil) + + // PAT scope intersection + relationService.EXPECT().LookupResources(ctx, relation.Relation{ + Object: relation.Object{ + Namespace: schema.ProjectNamespace, + }, + Subject: relation.Subject{ + ID: "pat-456", + Namespace: schema.PATPrincipal, + }, + RelationName: schema.GetPermission, + }).Return([]string{"project-id"}, nil) + + // Repo called with intersection result + repo.EXPECT().List(ctx, project.Filter{ + ProjectIDs: []string{"project-id"}, + NonInherited: true, + }).Return([]project.Project{ + { + ID: "project-id", + Name: "test", + Organization: organization.Organization{ + ID: "org-id", + }, + }, + }, nil) + return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService) + }, + }, + { + name: "PAT principal with no project overlap returns empty", + args: args{ + principal: authenticate.Principal{ + ID: "pat-456", + Type: schema.PATPrincipal, + PAT: &pat.PAT{ID: "pat-456", UserID: "user-id", OrgID: "org-1"}, + }, + flt: project.Filter{}, + }, + want: []project.Project{}, + wantErr: false, + setup: func() *project.Service { + repo, userService, suserService, relationService, policyService, authnService, groupService := mockService(t) + // User has projects + relationService.EXPECT().LookupResources(ctx, relation.Relation{ + Object: relation.Object{ + Namespace: schema.ProjectNamespace, + }, + Subject: relation.Subject{ + Namespace: schema.UserPrincipal, + ID: "user-id", + }, + RelationName: project.MemberPermission, + }).Return([]string{"project-id-1"}, nil) + + // PAT scoped to different projects + relationService.EXPECT().LookupResources(ctx, relation.Relation{ + Object: relation.Object{ + Namespace: schema.ProjectNamespace, + }, + Subject: relation.Subject{ + ID: "pat-456", + Namespace: schema.PATPrincipal, + }, + RelationName: schema.GetPermission, + }).Return([]string{"project-id-2"}, nil) + + return project.NewService(repo, relationService, userService, policyService, authnService, suserService, groupService) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := tt.setup() - got, err := s.ListByUser(ctx, tt.args.principalID, tt.args.principalType, tt.args.flt) + got, err := s.ListByUser(ctx, tt.args.principal, tt.args.flt) if (err != nil) != tt.wantErr { t.Errorf("ListByUser() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 3738f39ae..e9f71d97c 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -303,7 +303,7 @@ type GroupService interface { Get(ctx context.Context, id string) (group.Group, error) List(ctx context.Context, flt group.Filter) ([]group.Group, error) Update(ctx context.Context, grp group.Group) (group.Group, error) - ListByUser(ctx context.Context, principalId, principalType string, flt group.Filter) ([]group.Group, error) + ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) AddUsers(ctx context.Context, groupID string, userID []string) error RemoveUsers(ctx context.Context, groupID string, userID []string) error Enable(ctx context.Context, id string) error @@ -343,7 +343,7 @@ type ProjectService interface { Get(ctx context.Context, idOrName string) (project.Project, error) Create(ctx context.Context, prj project.Project) (project.Project, error) List(ctx context.Context, f project.Filter) ([]project.Project, error) - ListByUser(ctx context.Context, principalID, principalType string, flt project.Filter) ([]project.Project, error) + ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) Update(ctx context.Context, toUpdate project.Project) (project.Project, error) ListUsers(ctx context.Context, id string, permissionFilter string) ([]user.User, error) ListServiceUsers(ctx context.Context, id string, permissionFilter string) ([]serviceuser.ServiceUser, error) diff --git a/internal/api/v1beta1connect/mocks/group_service.go b/internal/api/v1beta1connect/mocks/group_service.go index 262f3c2b1..610d634e3 100644 --- a/internal/api/v1beta1connect/mocks/group_service.go +++ b/internal/api/v1beta1connect/mocks/group_service.go @@ -5,7 +5,10 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + group "github.com/raystack/frontier/core/group" + mock "github.com/stretchr/testify/mock" ) @@ -384,9 +387,9 @@ func (_c *GroupService_List_Call) RunAndReturn(run func(context.Context, group.F return _c } -// ListByUser provides a mock function with given fields: ctx, principalId, principalType, flt -func (_m *GroupService) ListByUser(ctx context.Context, principalId string, principalType string, flt group.Filter) ([]group.Group, error) { - ret := _m.Called(ctx, principalId, principalType, flt) +// ListByUser provides a mock function with given fields: ctx, principal, flt +func (_m *GroupService) ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error) { + ret := _m.Called(ctx, principal, flt) if len(ret) == 0 { panic("no return value specified for ListByUser") @@ -394,19 +397,19 @@ func (_m *GroupService) ListByUser(ctx context.Context, principalId string, prin var r0 []group.Group var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, group.Filter) ([]group.Group, error)); ok { - return rf(ctx, principalId, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)); ok { + return rf(ctx, principal, flt) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, group.Filter) []group.Group); ok { - r0 = rf(ctx, principalId, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, group.Filter) []group.Group); ok { + r0 = rf(ctx, principal, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]group.Group) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string, group.Filter) error); ok { - r1 = rf(ctx, principalId, principalType, flt) + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, group.Filter) error); ok { + r1 = rf(ctx, principal, flt) } else { r1 = ret.Error(1) } @@ -421,16 +424,15 @@ type GroupService_ListByUser_Call struct { // ListByUser is a helper method to define mock.On call // - ctx context.Context -// - principalId string -// - principalType string +// - principal authenticate.Principal // - flt group.Filter -func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principalId interface{}, principalType interface{}, flt interface{}) *GroupService_ListByUser_Call { - return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principalId, principalType, flt)} +func (_e *GroupService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *GroupService_ListByUser_Call { + return &GroupService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} } -func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principalId string, principalType string, flt group.Filter)) *GroupService_ListByUser_Call { +func (_c *GroupService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt group.Filter)) *GroupService_ListByUser_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(group.Filter)) + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(group.Filter)) }) return _c } @@ -440,7 +442,7 @@ func (_c *GroupService_ListByUser_Call) Return(_a0 []group.Group, _a1 error) *Gr return _c } -func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, string, string, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { +func (_c *GroupService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, group.Filter) ([]group.Group, error)) *GroupService_ListByUser_Call { _c.Call.Return(run) return _c } diff --git a/internal/api/v1beta1connect/mocks/project_service.go b/internal/api/v1beta1connect/mocks/project_service.go index f8c92f364..3635dccb6 100644 --- a/internal/api/v1beta1connect/mocks/project_service.go +++ b/internal/api/v1beta1connect/mocks/project_service.go @@ -5,7 +5,10 @@ package mocks import ( context "context" + authenticate "github.com/raystack/frontier/core/authenticate" + group "github.com/raystack/frontier/core/group" + mock "github.com/stretchr/testify/mock" project "github.com/raystack/frontier/core/project" @@ -295,9 +298,9 @@ func (_c *ProjectService_List_Call) RunAndReturn(run func(context.Context, proje return _c } -// ListByUser provides a mock function with given fields: ctx, principalID, principalType, flt -func (_m *ProjectService) ListByUser(ctx context.Context, principalID string, principalType string, flt project.Filter) ([]project.Project, error) { - ret := _m.Called(ctx, principalID, principalType, flt) +// ListByUser provides a mock function with given fields: ctx, principal, flt +func (_m *ProjectService) ListByUser(ctx context.Context, principal authenticate.Principal, flt project.Filter) ([]project.Project, error) { + ret := _m.Called(ctx, principal, flt) if len(ret) == 0 { panic("no return value specified for ListByUser") @@ -305,19 +308,19 @@ func (_m *ProjectService) ListByUser(ctx context.Context, principalID string, pr var r0 []project.Project var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, project.Filter) ([]project.Project, error)); ok { - return rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)); ok { + return rf(ctx, principal, flt) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, project.Filter) []project.Project); ok { - r0 = rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(0).(func(context.Context, authenticate.Principal, project.Filter) []project.Project); ok { + r0 = rf(ctx, principal, flt) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]project.Project) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string, project.Filter) error); ok { - r1 = rf(ctx, principalID, principalType, flt) + if rf, ok := ret.Get(1).(func(context.Context, authenticate.Principal, project.Filter) error); ok { + r1 = rf(ctx, principal, flt) } else { r1 = ret.Error(1) } @@ -332,16 +335,15 @@ type ProjectService_ListByUser_Call struct { // ListByUser is a helper method to define mock.On call // - ctx context.Context -// - principalID string -// - principalType string +// - principal authenticate.Principal // - flt project.Filter -func (_e *ProjectService_Expecter) ListByUser(ctx interface{}, principalID interface{}, principalType interface{}, flt interface{}) *ProjectService_ListByUser_Call { - return &ProjectService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principalID, principalType, flt)} +func (_e *ProjectService_Expecter) ListByUser(ctx interface{}, principal interface{}, flt interface{}) *ProjectService_ListByUser_Call { + return &ProjectService_ListByUser_Call{Call: _e.mock.On("ListByUser", ctx, principal, flt)} } -func (_c *ProjectService_ListByUser_Call) Run(run func(ctx context.Context, principalID string, principalType string, flt project.Filter)) *ProjectService_ListByUser_Call { +func (_c *ProjectService_ListByUser_Call) Run(run func(ctx context.Context, principal authenticate.Principal, flt project.Filter)) *ProjectService_ListByUser_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(project.Filter)) + run(args[0].(context.Context), args[1].(authenticate.Principal), args[2].(project.Filter)) }) return _c } @@ -351,7 +353,7 @@ func (_c *ProjectService_ListByUser_Call) Return(_a0 []project.Project, _a1 erro return _c } -func (_c *ProjectService_ListByUser_Call) RunAndReturn(run func(context.Context, string, string, project.Filter) ([]project.Project, error)) *ProjectService_ListByUser_Call { +func (_c *ProjectService_ListByUser_Call) RunAndReturn(run func(context.Context, authenticate.Principal, project.Filter) ([]project.Project, error)) *ProjectService_ListByUser_Call { _c.Call.Return(run) return _c } diff --git a/internal/api/v1beta1connect/serviceuser.go b/internal/api/v1beta1connect/serviceuser.go index 46464dfbb..3210b6f77 100644 --- a/internal/api/v1beta1connect/serviceuser.go +++ b/internal/api/v1beta1connect/serviceuser.go @@ -7,6 +7,7 @@ import ( "connectrpc.com/connect" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/raystack/frontier/core/audit" + "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/project" "github.com/raystack/frontier/core/relation" "github.com/raystack/frontier/core/serviceuser" @@ -455,7 +456,9 @@ func (h *ConnectHandler) ListServiceUserProjects(ctx context.Context, request *c serviceUserID := request.Msg.GetId() orgID := request.Msg.GetOrgId() - projList, err := h.projectService.ListByUser(ctx, serviceUserID, schema.ServiceUserPrincipal, project.Filter{ + projList, err := h.projectService.ListByUser(ctx, authenticate.Principal{ + ID: serviceUserID, Type: schema.ServiceUserPrincipal, + }, project.Filter{ OrgID: orgID, }) if err != nil { diff --git a/internal/api/v1beta1connect/serviceuser_test.go b/internal/api/v1beta1connect/serviceuser_test.go index 456362d11..eb0282f04 100644 --- a/internal/api/v1beta1connect/serviceuser_test.go +++ b/internal/api/v1beta1connect/serviceuser_test.go @@ -9,6 +9,7 @@ import ( "connectrpc.com/connect" "github.com/google/uuid" "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/raystack/frontier/core/authenticate" "github.com/raystack/frontier/core/organization" "github.com/raystack/frontier/core/permission" "github.com/raystack/frontier/core/project" @@ -1337,7 +1338,7 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { Id: "1", }), setup: func(projSvc *mocks.ProjectService, permSvc *mocks.PermissionService, resourceSvc *mocks.ResourceService) { - projSvc.EXPECT().ListByUser(mock.Anything, "1", schema.ServiceUserPrincipal, project.Filter{}).Return(nil, errors.New("test error")) + projSvc.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, project.Filter{}).Return(nil, errors.New("test error")) }, want: nil, wantErr: ErrInternalServerError, @@ -1353,7 +1354,7 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { for _, projectID := range testProjectIDList { projects = append(projects, testProjectMap[projectID]) } - projSvc.EXPECT().ListByUser(mock.Anything, "1", schema.ServiceUserPrincipal, project.Filter{}).Return(projects, nil) + projSvc.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, project.Filter{}).Return(projects, nil) }, want: connect.NewResponse(&frontierv1beta1.ListServiceUserProjectsResponse{ Projects: []*frontierv1beta1.Project{{ @@ -1398,7 +1399,7 @@ func TestHandler_ListServiceUserProjects(t *testing.T) { } ctx := mock.Anything - projSvc.EXPECT().ListByUser(ctx, "1", schema.ServiceUserPrincipal, project.Filter{}).Return(projects, nil) + projSvc.EXPECT().ListByUser(ctx, authenticate.Principal{ID: "1", Type: schema.ServiceUserPrincipal}, project.Filter{}).Return(projects, nil) permSvc.EXPECT().Get(ctx, "app/project:get").Return( permission.Permission{ diff --git a/internal/api/v1beta1connect/user.go b/internal/api/v1beta1connect/user.go index 4548c087f..eb99b37a9 100644 --- a/internal/api/v1beta1connect/user.go +++ b/internal/api/v1beta1connect/user.go @@ -453,8 +453,9 @@ func (h *ConnectHandler) ListUserGroups(ctx context.Context, request *connect.Re errorLogger := NewErrorLogger() var groups []*frontierv1beta1.Group - groupsList, err := h.groupService.ListByUser(ctx, request.Msg.GetId(), schema.UserPrincipal, - group.Filter{OrganizationID: request.Msg.GetOrgId()}) + groupsList, err := h.groupService.ListByUser(ctx, authenticate.Principal{ + ID: request.Msg.GetId(), Type: schema.UserPrincipal, + }, group.Filter{OrganizationID: request.Msg.GetOrgId()}) if err != nil { errorLogger.LogServiceError(ctx, request, "ListUserGroups.ListByUser", err, zap.String("user_id", request.Msg.GetId()), @@ -496,7 +497,7 @@ func (h *ConnectHandler) ListCurrentUserGroups(ctx context.Context, request *con var groupsPb []*frontierv1beta1.Group var accessPairsPb []*frontierv1beta1.ListCurrentUserGroupsResponse_AccessPair - groupsList, err := h.groupService.ListByUser(ctx, principal.ID, principal.Type, + groupsList, err := h.groupService.ListByUser(ctx, principal, group.Filter{ OrganizationID: request.Msg.GetOrgId(), WithMemberCount: request.Msg.GetWithMemberCount(), @@ -836,7 +837,9 @@ func (h *ConnectHandler) ListProjectsByUser(ctx context.Context, request *connec errorLogger := NewErrorLogger() userID := request.Msg.GetId() - projList, err := h.projectService.ListByUser(ctx, userID, schema.UserPrincipal, project.Filter{}) + projList, err := h.projectService.ListByUser(ctx, authenticate.Principal{ + ID: userID, Type: schema.UserPrincipal, + }, project.Filter{}) if err != nil { errorLogger.LogServiceError(ctx, request, "ListProjectsByUser.ListByUser", err, zap.String("user_id", userID)) @@ -877,7 +880,7 @@ func (h *ConnectHandler) ListProjectsByCurrentUser(ctx context.Context, request } paginate := pagination.NewPagination(request.Msg.GetPageNum(), request.Msg.GetPageSize()) - projList, err := h.projectService.ListByUser(ctx, principal.ID, principal.Type, project.Filter{ + projList, err := h.projectService.ListByUser(ctx, principal, project.Filter{ OrgID: request.Msg.GetOrgId(), NonInherited: request.Msg.GetNonInherited(), WithMemberCount: request.Msg.GetWithMemberCount(), diff --git a/internal/api/v1beta1connect/user_test.go b/internal/api/v1beta1connect/user_test.go index d77a22dbd..f69671229 100644 --- a/internal/api/v1beta1connect/user_test.go +++ b/internal/api/v1beta1connect/user_test.go @@ -914,7 +914,7 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should list user groups successfully", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, userID, "app/user", group.Filter{OrganizationID: orgID}).Return([]group.Group{ + gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: userID, Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return([]group.Group{ { ID: "group-1", Name: "test-group-1", @@ -966,7 +966,7 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should return empty list when user has no groups", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, userID, "app/user", group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) + gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: userID, Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) }, req: &frontierv1beta1.ListUserGroupsRequest{ Id: userID, @@ -980,7 +980,7 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should return not found error for invalid user ID", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, "invalid-id", "app/user", group.Filter{OrganizationID: orgID}).Return(nil, group.ErrInvalidID) + gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "invalid-id", Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return(nil, group.ErrInvalidID) }, req: &frontierv1beta1.ListUserGroupsRequest{ Id: "invalid-id", @@ -992,7 +992,7 @@ func TestConnectHandler_ListUserGroups(t *testing.T) { { title: "should return internal error for service failure", setup: func(gs *mocks.GroupService) { - gs.EXPECT().ListByUser(mock.Anything, userID, "app/user", group.Filter{OrganizationID: orgID}).Return(nil, errors.New("database error")) + gs.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: userID, Type: "app/user"}, group.Filter{OrganizationID: orgID}).Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListUserGroupsRequest{ Id: userID, @@ -1061,7 +1061,7 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - gs.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", group.Filter{OrganizationID: orgID}).Return([]group.Group{ + gs.EXPECT().ListByUser(mock.Anything, mockPrincipal, group.Filter{OrganizationID: orgID}).Return([]group.Group{ { ID: "group-1", Name: "test-group-1", @@ -1102,7 +1102,7 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { User: &user.User{ID: "user-1", Email: "test@example.com"}, } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - gs.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) + gs.EXPECT().ListByUser(mock.Anything, mockPrincipal, group.Filter{OrganizationID: orgID}).Return([]group.Group{}, nil) }, req: &frontierv1beta1.ListCurrentUserGroupsRequest{ OrgId: orgID, @@ -1133,7 +1133,7 @@ func TestConnectHandler_ListCurrentUserGroups(t *testing.T) { User: &user.User{ID: "user-1", Email: "test@example.com"}, } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - gs.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", group.Filter{OrganizationID: orgID}).Return(nil, errors.New("database error")) + gs.EXPECT().ListByUser(mock.Anything, mockPrincipal, group.Filter{OrganizationID: orgID}).Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListCurrentUserGroupsRequest{ OrgId: orgID, @@ -1559,7 +1559,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should list user projects successfully", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, "user-1", schema.UserPrincipal, project.Filter{}).Return([]project.Project{ + ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}, project.Filter{}).Return([]project.Project{ { ID: "project-1", Name: "test-project-1", @@ -1606,7 +1606,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return empty list when user has no projects", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, "user-1", schema.UserPrincipal, project.Filter{}).Return([]project.Project{}, nil) + ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}, project.Filter{}).Return([]project.Project{}, nil) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "user-1"}, want: &frontierv1beta1.ListProjectsByUserResponse{ @@ -1617,7 +1617,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return not found error when user does not exist", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, "non-existent-user", schema.UserPrincipal, project.Filter{}).Return(nil, user.ErrNotExist) + ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "non-existent-user", Type: schema.UserPrincipal}, project.Filter{}).Return(nil, user.ErrNotExist) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "non-existent-user"}, want: nil, @@ -1626,7 +1626,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return bad request error for invalid user ID", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, "invalid-id", schema.UserPrincipal, project.Filter{}).Return(nil, user.ErrInvalidUUID) + ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "invalid-id", Type: schema.UserPrincipal}, project.Filter{}).Return(nil, user.ErrInvalidUUID) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "invalid-id"}, want: nil, @@ -1635,7 +1635,7 @@ func TestConnectHandler_ListProjectsByUser(t *testing.T) { { title: "should return internal error for project service failure", setup: func(ps *mocks.ProjectService, as *mocks.AuthnService) { - ps.EXPECT().ListByUser(mock.Anything, "user-1", schema.UserPrincipal, project.Filter{}).Return(nil, errors.New("database error")) + ps.EXPECT().ListByUser(mock.Anything, authenticate.Principal{ID: "user-1", Type: schema.UserPrincipal}, project.Filter{}).Return(nil, errors.New("database error")) }, req: &frontierv1beta1.ListProjectsByUserRequest{Id: "user-1"}, want: nil, @@ -1702,7 +1702,7 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { return filter.OrgID == "" })).Return([]project.Project{ { @@ -1759,7 +1759,7 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { return filter.OrgID == "org-1" })).Return([]project.Project{ { @@ -1799,7 +1799,7 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { return filter.OrgID == "" })).Return([]project.Project{}, nil) }, @@ -1829,7 +1829,7 @@ func TestConnectHandler_ListProjectsByCurrentUser(t *testing.T) { } as.EXPECT().GetPrincipal(mock.Anything).Return(mockPrincipal, nil) - ps.EXPECT().ListByUser(mock.Anything, "user-1", "app/user", mock.MatchedBy(func(filter project.Filter) bool { + ps.EXPECT().ListByUser(mock.Anything, mockPrincipal, mock.MatchedBy(func(filter project.Filter) bool { return filter.OrgID == "" })).Return(nil, errors.New("database error")) },