diff --git a/core/userpat/service.go b/core/userpat/service.go index b0751d0b6..de0762bab 100644 --- a/core/userpat/service.go +++ b/core/userpat/service.go @@ -94,6 +94,26 @@ func (s *Service) GetByID(ctx context.Context, id string) (patmodels.PAT, error) return s.repo.GetByID(ctx, id) } +// Get retrieves a PAT by ID, verifying it belongs to the given user. +// Returns ErrDisabled if PATs are not enabled, ErrNotFound if the PAT +// does not exist or belongs to a different user. +func (s *Service) Get(ctx context.Context, userID, id string) (patmodels.PAT, error) { + if !s.config.Enabled { + return patmodels.PAT{}, paterrors.ErrDisabled + } + pat, err := s.repo.GetByID(ctx, id) + if err != nil { + return patmodels.PAT{}, err + } + if pat.UserID != userID { + return patmodels.PAT{}, paterrors.ErrNotFound + } + if err := s.enrichWithScope(ctx, &pat); err != nil { + return patmodels.PAT{}, fmt.Errorf("enriching PAT scope: %w", err) + } + return pat, nil +} + // Create generates a new PAT and returns it with the plaintext value. // The plaintext value is only available at creation time. func (s *Service) Create(ctx context.Context, req CreateRequest) (patmodels.PAT, string, error) { @@ -307,24 +327,7 @@ func (s *Service) createProjectScopedPolicies(ctx context.Context, patID, orgID return nil } -// List retrieves all PATs for a user in an org and enriches each with scope fields. -func (s *Service) List(ctx context.Context, userID, orgID string, query *rql.Query) (patmodels.PATList, error) { - if !s.config.Enabled { - return patmodels.PATList{}, paterrors.ErrDisabled - } - result, err := s.repo.List(ctx, userID, orgID, query) - if err != nil { - return patmodels.PATList{}, err - } - for i := range result.PATs { - if err := s.enrichWithScope(ctx, &result.PATs[i]); err != nil { - return patmodels.PATList{}, fmt.Errorf("enriching PAT scope: %w", err) - } - } - return result, nil -} - -// enrichWithScope derives role_ids and project_ids +// enrichWithScope derives role_ids and project_ids from the PAT's SpiceDB policies. func (s *Service) enrichWithScope(ctx context.Context, pat *patmodels.PAT) error { policies, err := s.policyService.List(ctx, policy.Filter{ PrincipalID: pat.ID, @@ -349,12 +352,29 @@ func (s *Service) enrichWithScope(ctx context.Context, pat *patmodels.PAT) error pat.RoleIDs = pkgUtils.Deduplicate(roleIDs) if !allProjects { - pat.ProjectIDs = projectIDs + pat.ProjectIDs = pkgUtils.Deduplicate(projectIDs) } // allProjects → pat.ProjectIDs stays nil (empty = all projects, matching create semantics) return nil } +// List retrieves all PATs for a user in an org and enriches each with scope fields. +func (s *Service) List(ctx context.Context, userID, orgID string, query *rql.Query) (patmodels.PATList, error) { + if !s.config.Enabled { + return patmodels.PATList{}, paterrors.ErrDisabled + } + result, err := s.repo.List(ctx, userID, orgID, query) + if err != nil { + return patmodels.PATList{}, err + } + for i := range result.PATs { + if err := s.enrichWithScope(ctx, &result.PATs[i]); err != nil { + return patmodels.PATList{}, fmt.Errorf("enriching PAT scope: %w", err) + } + } + return result, nil +} + // generatePAT creates a random PAT string with the configured prefix and returns // the plaintext value along with its SHA3-256 hash for storage. // The hash is computed over the raw secret bytes (not the formatted PAT string) diff --git a/core/userpat/service_test.go b/core/userpat/service_test.go index dd1629472..92b7fbb6a 100644 --- a/core/userpat/service_test.go +++ b/core/userpat/service_test.go @@ -1243,3 +1243,123 @@ func TestConfig_MaxExpiry(t *testing.T) { }) } } + +func TestService_Get(t *testing.T) { + testPAT := models.PAT{ + ID: "pat-1", + UserID: "user-1", + OrgID: "org-1", + Title: "my-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + tests := []struct { + name string + setup func() *userpat.Service + userID string + patID string + wantErr bool + wantErrIs error + }{ + { + name: "should return ErrDisabled when PAT feature is disabled", + userID: "user-1", + patID: "pat-1", + setup: func() *userpat.Service { + repo := mocks.NewRepository(t) + orgSvc, _, policySvc, auditRepo := newSuccessMocks(t) + return userpat.NewService(log.NewNoop(), repo, userpat.Config{ + Enabled: false, + }, orgSvc, nil, policySvc, auditRepo) + }, + wantErr: true, + wantErrIs: paterrors.ErrDisabled, + }, + { + name: "should return error when repo GetByID fails", + userID: "user-1", + patID: "pat-1", + setup: func() *userpat.Service { + repo := mocks.NewRepository(t) + repo.EXPECT().GetByID(mock.Anything, "pat-1"). + Return(models.PAT{}, paterrors.ErrNotFound) + orgSvc, _, policySvc, auditRepo := newSuccessMocks(t) + return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo) + }, + wantErr: true, + wantErrIs: paterrors.ErrNotFound, + }, + { + name: "should return ErrNotFound when PAT belongs to different user", + userID: "user-2", + patID: "pat-1", + setup: func() *userpat.Service { + repo := mocks.NewRepository(t) + repo.EXPECT().GetByID(mock.Anything, "pat-1"). + Return(testPAT, nil) + orgSvc, _, policySvc, auditRepo := newSuccessMocks(t) + return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo) + }, + wantErr: true, + wantErrIs: paterrors.ErrNotFound, + }, + { + name: "should return PAT when user owns it", + userID: "user-1", + patID: "pat-1", + setup: func() *userpat.Service { + repo := mocks.NewRepository(t) + repo.EXPECT().GetByID(mock.Anything, "pat-1"). + Return(testPAT, nil) + orgSvc, _, policySvc, auditRepo := newSuccessMocks(t) + policySvc.On("List", mock.Anything, mock.Anything). + Return([]policy.Policy{ + {RoleID: "role-1", ResourceType: "app/organization", ResourceID: "org-1"}, + }, nil).Maybe() + return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo) + }, + wantErr: false, + }, + { + name: "should return error when enrichWithScope fails", + userID: "user-1", + patID: "pat-1", + setup: func() *userpat.Service { + repo := mocks.NewRepository(t) + repo.EXPECT().GetByID(mock.Anything, "pat-1"). + Return(testPAT, nil) + orgSvc := mocks.NewOrganizationService(t) + policySvc := mocks.NewPolicyService(t) + policySvc.On("List", mock.Anything, mock.Anything). + Return(nil, errors.New("spicedb down")) + auditRepo := mocks.NewAuditRecordRepository(t) + return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := tt.setup() + got, err := svc.Get(context.Background(), tt.userID, tt.patID) + if tt.wantErr { + if err == nil { + t.Fatal("Get() expected error, got nil") + } + if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) { + t.Errorf("Get() error = %v, want %v", err, tt.wantErrIs) + } + return + } + if err != nil { + t.Fatalf("Get() unexpected error: %v", err) + } + if got.ID != testPAT.ID { + t.Errorf("Get() PAT ID = %v, want %v", got.ID, testPAT.ID) + } + }) + } +} diff --git a/internal/api/v1beta1connect/interfaces.go b/internal/api/v1beta1connect/interfaces.go index 523366e87..42075027b 100644 --- a/internal/api/v1beta1connect/interfaces.go +++ b/internal/api/v1beta1connect/interfaces.go @@ -403,4 +403,5 @@ type UserPATService interface { ValidateExpiry(expiresAt time.Time) error Create(ctx context.Context, req userpat.CreateRequest) (models.PAT, string, error) List(ctx context.Context, userID, orgID string, query *rql.Query) (models.PATList, error) + Get(ctx context.Context, userID, id string) (models.PAT, error) } diff --git a/internal/api/v1beta1connect/mocks/user_pat_service.go b/internal/api/v1beta1connect/mocks/user_pat_service.go index c05f09f59..ea07cdcdd 100644 --- a/internal/api/v1beta1connect/mocks/user_pat_service.go +++ b/internal/api/v1beta1connect/mocks/user_pat_service.go @@ -151,6 +151,64 @@ func (_c *UserPATService_List_Call) RunAndReturn(run func(context.Context, strin return _c } +// Get provides a mock function with given fields: ctx, userID, id +func (_m *UserPATService) Get(ctx context.Context, userID string, id string) (models.PAT, error) { + ret := _m.Called(ctx, userID, id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 models.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (models.PAT, error)); ok { + return rf(ctx, userID, id) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) models.PAT); ok { + r0 = rf(ctx, userID, id) + } else { + r0 = ret.Get(0).(models.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, userID, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UserPATService_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type UserPATService_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - id string +func (_e *UserPATService_Expecter) Get(ctx interface{}, userID interface{}, id interface{}) *UserPATService_Get_Call { + return &UserPATService_Get_Call{Call: _e.mock.On("Get", ctx, userID, id)} +} + +func (_c *UserPATService_Get_Call) Run(run func(ctx context.Context, userID string, id string)) *UserPATService_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *UserPATService_Get_Call) Return(_a0 models.PAT, _a1 error) *UserPATService_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *UserPATService_Get_Call) RunAndReturn(run func(context.Context, string, string) (models.PAT, error)) *UserPATService_Get_Call { + _c.Call.Return(run) + return _c +} + // ValidateExpiry provides a mock function with given fields: expiresAt func (_m *UserPATService) ValidateExpiry(expiresAt time.Time) error { ret := _m.Called(expiresAt) diff --git a/internal/api/v1beta1connect/user_pat.go b/internal/api/v1beta1connect/user_pat.go index 77fa1ade7..69af691bc 100644 --- a/internal/api/v1beta1connect/user_pat.go +++ b/internal/api/v1beta1connect/user_pat.go @@ -74,15 +74,53 @@ func (h *ConnectHandler) CreateCurrentUserPAT(ctx context.Context, request *conn }), nil } +func (h *ConnectHandler) GetCurrentUserPAT(ctx context.Context, request *connect.Request[frontierv1beta1.GetCurrentUserPATRequest]) (*connect.Response[frontierv1beta1.GetCurrentUserPATResponse], error) { + errorLogger := NewErrorLogger() + + principal, err := h.GetLoggedInPrincipal(ctx) + if err != nil { + return nil, err + } + if principal.User == nil { + return nil, connect.NewError(connect.CodePermissionDenied, ErrUnauthenticated) + } + + if err := request.Msg.Validate(); err != nil { + return nil, connect.NewError(connect.CodeInvalidArgument, err) + } + + pat, err := h.userPATService.Get(ctx, principal.User.ID, request.Msg.GetId()) + if err != nil { + errorLogger.LogServiceError(ctx, request, "GetCurrentUserPAT", err, + zap.String("user_id", principal.User.ID), + zap.String("pat_id", request.Msg.GetId())) + + switch { + case errors.Is(err, paterrors.ErrDisabled): + return nil, connect.NewError(connect.CodeFailedPrecondition, err) + case errors.Is(err, paterrors.ErrNotFound): + return nil, connect.NewError(connect.CodeNotFound, err) + default: + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) + } + } + + return connect.NewResponse(&frontierv1beta1.GetCurrentUserPATResponse{ + Pat: transformPATToPB(pat, ""), + }), nil +} + func transformPATToPB(pat models.PAT, patValue string) *frontierv1beta1.PAT { pbPAT := &frontierv1beta1.PAT{ - Id: pat.ID, - Title: pat.Title, - UserId: pat.UserID, - OrgId: pat.OrgID, - ExpiresAt: timestamppb.New(pat.ExpiresAt), - CreatedAt: timestamppb.New(pat.CreatedAt), - UpdatedAt: timestamppb.New(pat.UpdatedAt), + Id: pat.ID, + Title: pat.Title, + UserId: pat.UserID, + OrgId: pat.OrgID, + RoleIds: pat.RoleIDs, + ProjectIds: pat.ProjectIDs, + ExpiresAt: timestamppb.New(pat.ExpiresAt), + CreatedAt: timestamppb.New(pat.CreatedAt), + UpdatedAt: timestamppb.New(pat.UpdatedAt), } if patValue != "" { pbPAT.Token = patValue diff --git a/internal/api/v1beta1connect/user_pat_test.go b/internal/api/v1beta1connect/user_pat_test.go index f42e028e8..bc1c1101b 100644 --- a/internal/api/v1beta1connect/user_pat_test.go +++ b/internal/api/v1beta1connect/user_pat_test.go @@ -385,6 +385,169 @@ func TestHandler_CreateCurrentUserPAT(t *testing.T) { } } +func TestHandler_GetCurrentUserPAT(t *testing.T) { + testCreatedAt := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC) + testExpiry := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC) + testUserID := "8e256f86-31a3-11ec-8d3d-0242ac130003" + testPATID := "6c256f86-31a3-11ec-8d3d-0242ac130003" + + tests := []struct { + name string + setup func(ps *mocks.UserPATService, as *mocks.AuthnService) + request *connect.Request[frontierv1beta1.GetCurrentUserPATRequest] + want *frontierv1beta1.GetCurrentUserPATResponse + wantErr error + }{ + { + name: "should return unauthenticated error when GetLoggedInPrincipal fails", + setup: func(ps *mocks.UserPATService, as *mocks.AuthnService) { + as.EXPECT().GetPrincipal(mock.Anything).Return(authenticate.Principal{}, errors.ErrUnauthenticated) + }, + request: connect.NewRequest(&frontierv1beta1.GetCurrentUserPATRequest{ + Id: testPATID, + }), + want: nil, + wantErr: connect.NewError(connect.CodeUnauthenticated, ErrUnauthenticated), + }, + { + name: "should return permission denied when principal is not a user", + setup: func(ps *mocks.UserPATService, as *mocks.AuthnService) { + as.EXPECT().GetPrincipal(mock.Anything).Return(authenticate.Principal{ + ID: "sv-1", + Type: schema.ServiceUserPrincipal, + }, nil) + }, + request: connect.NewRequest(&frontierv1beta1.GetCurrentUserPATRequest{ + Id: testPATID, + }), + want: nil, + wantErr: connect.NewError(connect.CodePermissionDenied, ErrUnauthenticated), + }, + { + name: "should return failed precondition when PAT is disabled", + setup: func(ps *mocks.UserPATService, as *mocks.AuthnService) { + as.EXPECT().GetPrincipal(mock.Anything).Return(authenticate.Principal{ + ID: testUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: testUserID}, + }, nil) + ps.EXPECT().Get(mock.Anything, testUserID, testPATID). + Return(models.PAT{}, paterrors.ErrDisabled) + }, + request: connect.NewRequest(&frontierv1beta1.GetCurrentUserPATRequest{ + Id: testPATID, + }), + want: nil, + wantErr: connect.NewError(connect.CodeFailedPrecondition, paterrors.ErrDisabled), + }, + { + name: "should return not found when PAT does not exist", + setup: func(ps *mocks.UserPATService, as *mocks.AuthnService) { + as.EXPECT().GetPrincipal(mock.Anything).Return(authenticate.Principal{ + ID: testUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: testUserID}, + }, nil) + ps.EXPECT().Get(mock.Anything, testUserID, testPATID). + Return(models.PAT{}, paterrors.ErrNotFound) + }, + request: connect.NewRequest(&frontierv1beta1.GetCurrentUserPATRequest{ + Id: testPATID, + }), + want: nil, + wantErr: connect.NewError(connect.CodeNotFound, paterrors.ErrNotFound), + }, + { + name: "should return internal error for unknown failure", + setup: func(ps *mocks.UserPATService, as *mocks.AuthnService) { + as.EXPECT().GetPrincipal(mock.Anything).Return(authenticate.Principal{ + ID: testUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: testUserID}, + }, nil) + ps.EXPECT().Get(mock.Anything, testUserID, testPATID). + Return(models.PAT{}, errors.New("unexpected error")) + }, + request: connect.NewRequest(&frontierv1beta1.GetCurrentUserPATRequest{ + Id: testPATID, + }), + want: nil, + wantErr: connect.NewError(connect.CodeInternal, ErrInternalServerError), + }, + { + name: "should return PAT successfully", + setup: func(ps *mocks.UserPATService, as *mocks.AuthnService) { + as.EXPECT().GetPrincipal(mock.Anything).Return(authenticate.Principal{ + ID: testUserID, + Type: schema.UserPrincipal, + User: &user.User{ID: testUserID}, + }, nil) + ps.EXPECT().Get(mock.Anything, testUserID, testPATID). + Return(models.PAT{ + ID: testPATID, + UserID: testUserID, + OrgID: "org-1", + Title: "my-token", + RoleIDs: []string{"role-1"}, + ExpiresAt: testExpiry, + CreatedAt: testCreatedAt, + UpdatedAt: testCreatedAt, + }, nil) + }, + request: connect.NewRequest(&frontierv1beta1.GetCurrentUserPATRequest{ + Id: testPATID, + }), + want: &frontierv1beta1.GetCurrentUserPATResponse{ + Pat: &frontierv1beta1.PAT{ + Id: testPATID, + UserId: testUserID, + OrgId: "org-1", + Title: "my-token", + RoleIds: []string{"role-1"}, + ExpiresAt: timestamppb.New(testExpiry), + CreatedAt: timestamppb.New(testCreatedAt), + UpdatedAt: timestamppb.New(testCreatedAt), + }, + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockPATSrv := new(mocks.UserPATService) + mockAuthnSrv := new(mocks.AuthnService) + + if tt.setup != nil { + tt.setup(mockPATSrv, mockAuthnSrv) + } + + handler := &ConnectHandler{ + userPATService: mockPATSrv, + authnService: mockAuthnSrv, + } + + resp, err := handler.GetCurrentUserPAT(context.Background(), tt.request) + + if tt.wantErr != nil { + assert.Error(t, err) + assert.Equal(t, tt.wantErr.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + if tt.want != nil { + assert.Equal(t, tt.want, resp.Msg) + } else { + assert.Nil(t, resp) + } + + mockPATSrv.AssertExpectations(t) + mockAuthnSrv.AssertExpectations(t) + }) + } +} + func TestTransformPATToPB(t *testing.T) { testTime := time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC) testCreatedAt := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC) diff --git a/pkg/server/connect_interceptors/authorization.go b/pkg/server/connect_interceptors/authorization.go index 40f43f0e8..073016d81 100644 --- a/pkg/server/connect_interceptors/authorization.go +++ b/pkg/server/connect_interceptors/authorization.go @@ -134,6 +134,8 @@ var authorizationSkipEndpoints = map[string]bool{ "/raystack.frontier.v1beta1.FrontierService/ListSessions": true, "/raystack.frontier.v1beta1.FrontierService/PingUserSession": true, "/raystack.frontier.v1beta1.FrontierService/RevokeSession": true, + + "/raystack.frontier.v1beta1.FrontierService/GetCurrentUserPAT": true, } // authorizationValidationMap stores path to validation function