Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/authenticate/authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/raystack/frontier/core/serviceuser"
"github.com/raystack/frontier/core/user"
pat "github.com/raystack/frontier/core/userpat/models"
"github.com/raystack/frontier/internal/bootstrap/schema"

"github.com/raystack/frontier/pkg/metadata"

Expand Down Expand Up @@ -142,3 +143,12 @@ type Principal struct {
ServiceUser *serviceuser.ServiceUser
PAT *pat.PAT
}

// ResolveSubject returns the subject ID and type for authorization queries.
// For PAT principals, it resolves to the underlying user.
func (p Principal) ResolveSubject() (id string, subjectType string) {
if p.PAT != nil {
return p.PAT.UserID, schema.UserPrincipal
}
return p.ID, p.Type
}
33 changes: 25 additions & 8 deletions core/group/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,20 @@ func (s Service) Update(ctx context.Context, grp Group) (Group, error) {
return Group{}, ErrInvalidID
}

func (s Service) ListByUser(ctx context.Context, principalID, principalType string, flt Filter) ([]Group, error) {
func (s Service) ListByUser(ctx context.Context, principal authenticate.Principal, flt Filter) ([]Group, error) {
subjectID, subjectType := principal.ResolveSubject()
subjectIDs, err := s.relationService.LookupResources(ctx, relation.Relation{
Object: relation.Object{
Namespace: schema.GroupNamespace,
},
Subject: relation.Subject{
Namespace: principalType,
ID: principalID,
},
Object: relation.Object{Namespace: schema.GroupNamespace},
Subject: relation.Subject{Namespace: subjectType, ID: subjectID},
RelationName: schema.MembershipPermission,
})
if err != nil {
return nil, err
}
subjectIDs, err = s.intersectPATScope(ctx, principal, schema.GroupNamespace, subjectIDs)
if err != nil {
return nil, err
}
if len(subjectIDs) == 0 {
// no groups
return nil, nil
Expand All @@ -148,6 +148,23 @@ func (s Service) ListByUser(ctx context.Context, principalID, principalType stri
return s.List(ctx, flt)
}

// intersectPATScope narrows resource IDs to only those the PAT is scoped to.
func (s Service) intersectPATScope(ctx context.Context, principal authenticate.Principal,
namespace string, resourceIDs []string) ([]string, error) {
if principal.PAT == nil || len(resourceIDs) == 0 {
return resourceIDs, nil
}
patIDs, err := s.relationService.LookupResources(ctx, relation.Relation{
Object: relation.Object{Namespace: namespace},
Subject: relation.Subject{ID: principal.PAT.ID, Namespace: schema.PATPrincipal},
RelationName: schema.GetPermission,
})
if err != nil {
return nil, err
}
return utils.Intersection(resourceIDs, patIDs), nil
}

// AddMember adds a subject(user) to group as member
func (s Service) AddMember(ctx context.Context, groupID string, principal authenticate.Principal) error {
// first create a policy for the user as member of the group
Expand Down
105 changes: 105 additions & 0 deletions core/group/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/raystack/frontier/core/policy"
"github.com/raystack/frontier/core/relation"
"github.com/raystack/frontier/core/user"
pat "github.com/raystack/frontier/core/userpat/models"
"github.com/raystack/frontier/internal/bootstrap/schema"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -262,3 +263,107 @@ func TestService_Update(t *testing.T) {
assert.Equal(t, err, group.ErrInvalidID)
})
}

func TestService_ListByUser(t *testing.T) {
ctx := context.Background()

t.Run("should resolve PAT to user and intersect with PAT group scope", func(t *testing.T) {
mockRepo := mocks.NewRepository(t)
mockRelationSvc := mocks.NewRelationService(t)
mockAuthnSvc := mocks.NewAuthnService(t)
mockPolicySvc := mocks.NewPolicyService(t)

svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc)

// LookupResources for user's group memberships
mockRelationSvc.On("LookupResources", ctx, relation.Relation{
Object: relation.Object{Namespace: schema.GroupNamespace},
Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"},
RelationName: schema.MembershipPermission,
}).Return([]string{"group-1", "group-2", "group-3"}, nil).Once()

// LookupResources for PAT's group scope
mockRelationSvc.On("LookupResources", ctx, relation.Relation{
Object: relation.Object{Namespace: schema.GroupNamespace},
Subject: relation.Subject{ID: "pat-456", Namespace: schema.PATPrincipal},
RelationName: schema.GetPermission,
}).Return([]string{"group-1", "group-3"}, nil).Once()

// Repo should be called with intersection
mockRepo.On("List", ctx, group.Filter{
GroupIDs: []string{"group-1", "group-3"},
}).Return([]group.Group{
{ID: "group-1", Name: "group-one"},
{ID: "group-3", Name: "group-three"},
}, nil).Once()

result, err := svc.ListByUser(ctx, authenticate.Principal{
ID: "pat-456",
Type: schema.PATPrincipal,
PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"},
}, group.Filter{})

assert.NoError(t, err)
assert.Len(t, result, 2)
})

t.Run("should return nil when PAT has no group scope overlap", func(t *testing.T) {
mockRepo := mocks.NewRepository(t)
mockRelationSvc := mocks.NewRelationService(t)
mockAuthnSvc := mocks.NewAuthnService(t)
mockPolicySvc := mocks.NewPolicyService(t)

svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc)

mockRelationSvc.On("LookupResources", ctx, relation.Relation{
Object: relation.Object{Namespace: schema.GroupNamespace},
Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"},
RelationName: schema.MembershipPermission,
}).Return([]string{"group-1"}, nil).Once()

mockRelationSvc.On("LookupResources", ctx, relation.Relation{
Object: relation.Object{Namespace: schema.GroupNamespace},
Subject: relation.Subject{ID: "pat-456", Namespace: schema.PATPrincipal},
RelationName: schema.GetPermission,
}).Return([]string{"group-2"}, nil).Once()

result, err := svc.ListByUser(ctx, authenticate.Principal{
ID: "pat-456",
Type: schema.PATPrincipal,
PAT: &pat.PAT{ID: "pat-456", UserID: "user-123", OrgID: "org-1"},
}, group.Filter{})

assert.NoError(t, err)
assert.Nil(t, result)
})

t.Run("should pass through for regular user principal", func(t *testing.T) {
mockRepo := mocks.NewRepository(t)
mockRelationSvc := mocks.NewRelationService(t)
mockAuthnSvc := mocks.NewAuthnService(t)
mockPolicySvc := mocks.NewPolicyService(t)

svc := group.NewService(mockRepo, mockRelationSvc, mockAuthnSvc, mockPolicySvc)

mockRelationSvc.On("LookupResources", ctx, relation.Relation{
Object: relation.Object{Namespace: schema.GroupNamespace},
Subject: relation.Subject{Namespace: schema.UserPrincipal, ID: "user-123"},
RelationName: schema.MembershipPermission,
}).Return([]string{"group-1", "group-2"}, nil).Once()

mockRepo.On("List", ctx, group.Filter{
GroupIDs: []string{"group-1", "group-2"},
}).Return([]group.Group{
{ID: "group-1", Name: "group-one"},
{ID: "group-2", Name: "group-two"},
}, nil).Once()

result, err := svc.ListByUser(ctx, authenticate.Principal{
ID: "user-123",
Type: schema.UserPrincipal,
}, group.Filter{})

assert.NoError(t, err)
assert.Len(t, result, 2)
})
}
31 changes: 15 additions & 16 deletions core/invitation/mocks/group_service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions core/invitation/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type OrganizationService interface {
type GroupService interface {
Get(ctx context.Context, id string) (group.Group, error)
AddMember(ctx context.Context, groupID string, principal authenticate.Principal) error
ListByUser(ctx context.Context, principalID, principalType string, flt group.Filter) ([]group.Group, error)
ListByUser(ctx context.Context, principal authenticate.Principal, flt group.Filter) ([]group.Group, error)
}

type RelationService interface {
Expand Down Expand Up @@ -315,7 +315,9 @@ func (s Service) Accept(ctx context.Context, id uuid.UUID) error {

// check if the invitation has a group membership
if len(invite.GroupIDs) > 0 {
userGroups, err := s.groupSvc.ListByUser(ctx, userOb.ID, schema.UserPrincipal, group.Filter{})
userGroups, err := s.groupSvc.ListByUser(ctx, authenticate.Principal{
ID: userOb.ID, Type: schema.UserPrincipal,
}, group.Filter{})
if err != nil {
return err
}
Expand Down
10 changes: 8 additions & 2 deletions core/organization/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,19 +308,25 @@ func (s Service) ListByUser(ctx context.Context, principal authenticate.Principa
defer promCollect()
}

subjectID, subjectType := principal.ResolveSubject()
subjectIDs, err := s.relationService.LookupResources(ctx, relation.Relation{
Object: relation.Object{
Namespace: schema.OrganizationNamespace,
},
Subject: relation.Subject{
ID: principal.ID,
Namespace: principal.Type,
ID: subjectID,
Namespace: subjectType,
},
RelationName: schema.MembershipPermission,
})
if err != nil {
return nil, err
}

if principal.PAT != nil {
subjectIDs = utils.Intersection(subjectIDs, []string{principal.PAT.OrgID})
}

if len(subjectIDs) == 0 {
// no organizations
return []Organization{}, nil
Expand Down
Loading
Loading