Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 39 additions & 19 deletions core/userpat/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,26 @@ func (s *Service) GetByID(ctx context.Context, id string) (patmodels.PAT, error)
return s.repo.GetByID(ctx, id)
}

// Get retrieves a PAT by ID, verifying it belongs to the given user.
// Returns ErrDisabled if PATs are not enabled, ErrNotFound if the PAT
// does not exist or belongs to a different user.
func (s *Service) Get(ctx context.Context, userID, id string) (patmodels.PAT, error) {
if !s.config.Enabled {
return patmodels.PAT{}, paterrors.ErrDisabled
}
pat, err := s.repo.GetByID(ctx, id)
if err != nil {
return patmodels.PAT{}, err
}
if pat.UserID != userID {
return patmodels.PAT{}, paterrors.ErrNotFound
}
if err := s.enrichWithScope(ctx, &pat); err != nil {
return patmodels.PAT{}, fmt.Errorf("enriching PAT scope: %w", err)
}
return pat, nil
}

// Create generates a new PAT and returns it with the plaintext value.
// The plaintext value is only available at creation time.
func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) {
Expand Down Expand Up @@ -307,24 +327,7 @@ func (s *Service) createProjectScopedPolicies(ctx context.Context, patID, orgID
return nil
}

// List retrieves all PATs for a user in an org and enriches each with scope fields.
func (s *Service) List(ctx context.Context, userID, orgID string, query *rql.Query) (patmodels.PATList, error) {
if !s.config.Enabled {
return patmodels.PATList{}, paterrors.ErrDisabled
}
result, err := s.repo.List(ctx, userID, orgID, query)
if err != nil {
return patmodels.PATList{}, err
}
for i := range result.PATs {
if err := s.enrichWithScope(ctx, &result.PATs[i]); err != nil {
return patmodels.PATList{}, fmt.Errorf("enriching PAT scope: %w", err)
}
}
return result, nil
}

// enrichWithScope derives role_ids and project_ids
// enrichWithScope derives role_ids and project_ids from the PAT's SpiceDB policies.
func (s *Service) enrichWithScope(ctx context.Context, pat *patmodels.PAT) error {
policies, err := s.policyService.List(ctx, policy.Filter{
PrincipalID: pat.ID,
Expand All @@ -349,12 +352,29 @@ func (s *Service) enrichWithScope(ctx context.Context, pat *patmodels.PAT) error

pat.RoleIDs = pkgUtils.Deduplicate(roleIDs)
if !allProjects {
pat.ProjectIDs = projectIDs
pat.ProjectIDs = pkgUtils.Deduplicate(projectIDs)
}
// allProjects → pat.ProjectIDs stays nil (empty = all projects, matching create semantics)
return nil
}

// List retrieves all PATs for a user in an org and enriches each with scope fields.
func (s *Service) List(ctx context.Context, userID, orgID string, query *rql.Query) (patmodels.PATList, error) {
if !s.config.Enabled {
return patmodels.PATList{}, paterrors.ErrDisabled
}
result, err := s.repo.List(ctx, userID, orgID, query)
if err != nil {
return patmodels.PATList{}, err
}
for i := range result.PATs {
if err := s.enrichWithScope(ctx, &result.PATs[i]); err != nil {
return patmodels.PATList{}, fmt.Errorf("enriching PAT scope: %w", err)
}
}
return result, nil
}

// generatePAT creates a random PAT string with the configured prefix and returns
// the plaintext value along with its SHA3-256 hash for storage.
// The hash is computed over the raw secret bytes (not the formatted PAT string)
Expand Down
120 changes: 120 additions & 0 deletions core/userpat/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1243,3 +1243,123 @@ func TestConfig_MaxExpiry(t *testing.T) {
})
}
}

func TestService_Get(t *testing.T) {
testPAT := models.PAT{
ID: "pat-1",
UserID: "user-1",
OrgID: "org-1",
Title: "my-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}

tests := []struct {
name string
setup func() *userpat.Service
userID string
patID string
wantErr bool
wantErrIs error
}{
{
name: "should return ErrDisabled when PAT feature is disabled",
userID: "user-1",
patID: "pat-1",
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
orgSvc, _, policySvc, auditRepo := newSuccessMocks(t)
return userpat.NewService(log.NewNoop(), repo, userpat.Config{
Enabled: false,
}, orgSvc, nil, policySvc, auditRepo)
},
wantErr: true,
wantErrIs: paterrors.ErrDisabled,
},
{
name: "should return error when repo GetByID fails",
userID: "user-1",
patID: "pat-1",
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(models.PAT{}, paterrors.ErrNotFound)
orgSvc, _, policySvc, auditRepo := newSuccessMocks(t)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: true,
wantErrIs: paterrors.ErrNotFound,
},
{
name: "should return ErrNotFound when PAT belongs to different user",
userID: "user-2",
patID: "pat-1",
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(testPAT, nil)
orgSvc, _, policySvc, auditRepo := newSuccessMocks(t)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: true,
wantErrIs: paterrors.ErrNotFound,
},
{
name: "should return PAT when user owns it",
userID: "user-1",
patID: "pat-1",
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(testPAT, nil)
orgSvc, _, policySvc, auditRepo := newSuccessMocks(t)
policySvc.On("List", mock.Anything, mock.Anything).
Return([]policy.Policy{
{RoleID: "role-1", ResourceType: "app/organization", ResourceID: "org-1"},
}, nil).Maybe()
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: false,
},
{
name: "should return error when enrichWithScope fails",
userID: "user-1",
patID: "pat-1",
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(testPAT, nil)
orgSvc := mocks.NewOrganizationService(t)
policySvc := mocks.NewPolicyService(t)
policySvc.On("List", mock.Anything, mock.Anything).
Return(nil, errors.New("spicedb down"))
auditRepo := mocks.NewAuditRecordRepository(t)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := tt.setup()
got, err := svc.Get(context.Background(), tt.userID, tt.patID)
if tt.wantErr {
if err == nil {
t.Fatal("Get() expected error, got nil")
}
if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) {
t.Errorf("Get() error = %v, want %v", err, tt.wantErrIs)
}
return
}
if err != nil {
t.Fatalf("Get() unexpected error: %v", err)
}
if got.ID != testPAT.ID {
t.Errorf("Get() PAT ID = %v, want %v", got.ID, testPAT.ID)
}
})
}
}
1 change: 1 addition & 0 deletions internal/api/v1beta1connect/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,4 +403,5 @@ type UserPATService interface {
ValidateExpiry(expiresAt time.Time) error
Create(ctx context.Context, req userpat.CreateRequest) (models.PAT, string, error)
List(ctx context.Context, userID, orgID string, query *rql.Query) (models.PATList, error)
Get(ctx context.Context, userID, id string) (models.PAT, error)
}
58 changes: 58 additions & 0 deletions internal/api/v1beta1connect/mocks/user_pat_service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 45 additions & 7 deletions internal/api/v1beta1connect/user_pat.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,53 @@ func (h *ConnectHandler) CreateCurrentUserPAT(ctx context.Context, request *conn
}), nil
}

func (h *ConnectHandler) GetCurrentUserPAT(ctx context.Context, request *connect.Request[frontierv1beta1.GetCurrentUserPATRequest]) (*connect.Response[frontierv1beta1.GetCurrentUserPATResponse], error) {
errorLogger := NewErrorLogger()

principal, err := h.GetLoggedInPrincipal(ctx)
if err != nil {
return nil, err
}
if principal.User == nil {
return nil, connect.NewError(connect.CodePermissionDenied, ErrUnauthenticated)
}

if err := request.Msg.Validate(); err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}

pat, err := h.userPATService.Get(ctx, principal.User.ID, request.Msg.GetId())
if err != nil {
errorLogger.LogServiceError(ctx, request, "GetCurrentUserPAT", err,
zap.String("user_id", principal.User.ID),
zap.String("pat_id", request.Msg.GetId()))

switch {
case errors.Is(err, paterrors.ErrDisabled):
return nil, connect.NewError(connect.CodeFailedPrecondition, err)
case errors.Is(err, paterrors.ErrNotFound):
return nil, connect.NewError(connect.CodeNotFound, err)
default:
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}
}

return connect.NewResponse(&frontierv1beta1.GetCurrentUserPATResponse{
Pat: transformPATToPB(pat, ""),
}), nil
}

func transformPATToPB(pat models.PAT, patValue string) *frontierv1beta1.PAT {
pbPAT := &frontierv1beta1.PAT{
Id: pat.ID,
Title: pat.Title,
UserId: pat.UserID,
OrgId: pat.OrgID,
ExpiresAt: timestamppb.New(pat.ExpiresAt),
CreatedAt: timestamppb.New(pat.CreatedAt),
UpdatedAt: timestamppb.New(pat.UpdatedAt),
Id: pat.ID,
Title: pat.Title,
UserId: pat.UserID,
OrgId: pat.OrgID,
RoleIds: pat.RoleIDs,
ProjectIds: pat.ProjectIDs,
ExpiresAt: timestamppb.New(pat.ExpiresAt),
CreatedAt: timestamppb.New(pat.CreatedAt),
UpdatedAt: timestamppb.New(pat.UpdatedAt),
}
if patValue != "" {
pbPAT.Token = patValue
Expand Down
Loading
Loading