Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions internal/ghmcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type MCPServerConfig struct {

const stdioServerLogPrefix = "stdioserver"

func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
func NewMCPServer(cfg MCPServerConfig, logger *slog.Logger) (*server.MCPServer, error) {
apiHost, err := parseAPIHost(cfg.Host)
if err != nil {
return nil, fmt.Errorf("failed to parse API host: %w", err)
Expand All @@ -88,6 +88,9 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
if cfg.RepoAccessTTL != nil {
repoAccessOpts = append(repoAccessOpts, lockdown.WithTTL(*cfg.RepoAccessTTL))
}

repoAccessLogger := logger.With("component", "lockdown")
repoAccessOpts = append(repoAccessOpts, lockdown.WithLogger(repoAccessLogger))
var repoAccessCache *lockdown.RepoAccessCache
if cfg.LockdownMode {
repoAccessCache = lockdown.GetInstance(gqlClient, repoAccessOpts...)
Expand Down Expand Up @@ -273,7 +276,7 @@ func RunStdioServer(cfg StdioServerConfig) error {
ContentWindowSize: cfg.ContentWindowSize,
LockdownMode: cfg.LockdownMode,
RepoAccessTTL: cfg.RepoAccessCacheTTL,
})
}, logger)
if err != nil {
return fmt.Errorf("failed to create MCP server: %w", err)
}
Expand Down
65 changes: 49 additions & 16 deletions pkg/lockdown/lockdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ import (
// RepoAccessCache caches repository metadata related to lockdown checks so that
// multiple tools can reuse the same access information safely across goroutines.
type RepoAccessCache struct {
client *githubv4.Client
mu sync.Mutex
cache *cache2go.CacheTable
ttl time.Duration
logger *slog.Logger
client *githubv4.Client
mu sync.Mutex
cache *cache2go.CacheTable
ttl time.Duration
logger *slog.Logger
trustedBotLogins map[string]struct{}
}

type repoAccessCacheEntry struct {
Expand Down Expand Up @@ -85,6 +86,9 @@ func GetInstance(client *githubv4.Client, opts ...RepoAccessOption) *RepoAccessC
client: client,
cache: cache2go.Cache(defaultRepoAccessCacheKey),
ttl: defaultRepoAccessTTL,
trustedBotLogins: map[string]struct{}{
"copilot": {},
},
}
for _, opt := range opts {
if opt != nil {
Expand All @@ -109,13 +113,22 @@ type CacheStats struct {
Evictions int64
}

// IsSafeContent determines if the specified user can safely access the requested repository content.
// Safe access applies when any of the following is true:
// - the content was created by a trusted bot;
// - the author currently has push access to the repository;
// - the repository is private;
// - the content was created by the viewer.
func (c *RepoAccessCache) IsSafeContent(ctx context.Context, username, owner, repo string) (bool, error) {
repoInfo, err := c.getRepoAccessInfo(ctx, username, owner, repo)
if err != nil {
c.logDebug("error checking repo access info for content filtering", "owner", owner, "repo", repo, "user", username, "error", err)
return false, err
}
if repoInfo.IsPrivate || repoInfo.ViewerLogin == username {

c.logDebug(ctx, fmt.Sprintf("evaluated repo access for user %s to %s/%s for content filtering, result: hasPushAccess=%t, isPrivate=%t",
username, owner, repo, repoInfo.HasPushAccess, repoInfo.IsPrivate))

if c.isTrustedBot(username) || repoInfo.IsPrivate || repoInfo.ViewerLogin == strings.ToLower(username) {
return true, nil
}
return repoInfo.HasPushAccess, nil
Expand All @@ -136,30 +149,34 @@ func (c *RepoAccessCache) getRepoAccessInfo(ctx context.Context, username, owner
if err == nil {
entry := cacheItem.Data().(*repoAccessCacheEntry)
if cachedHasPush, known := entry.knownUsers[userKey]; known {
c.logDebug("repo access cache hit", "owner", owner, "repo", repo, "user", username)
c.logDebug(ctx, fmt.Sprintf("repo access cache hit for user %s to %s/%s", username, owner, repo))
return RepoAccessInfo{
IsPrivate: entry.isPrivate,
HasPushAccess: cachedHasPush,
ViewerLogin: entry.viewerLogin,
}, nil
}
c.logDebug("known users cache miss", "owner", owner, "repo", repo, "user", username)

c.logDebug(ctx, "known users cache miss, fetching from graphql API")

info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo)
if queryErr != nil {
return RepoAccessInfo{}, queryErr
}

entry.knownUsers[userKey] = info.HasPushAccess
entry.viewerLogin = info.ViewerLogin
entry.isPrivate = info.IsPrivate
c.cache.Add(key, c.ttl, entry)

return RepoAccessInfo{
IsPrivate: entry.isPrivate,
HasPushAccess: entry.knownUsers[userKey],
ViewerLogin: entry.viewerLogin,
}, nil
}

c.logDebug("repo access cache miss", "owner", owner, "repo", repo, "user", username)
c.logDebug(ctx, fmt.Sprintf("repo access cache miss for user %s to %s/%s", username, owner, repo))

info, queryErr := c.queryRepoAccessInfo(ctx, username, owner, repo)
if queryErr != nil {
Expand Down Expand Up @@ -223,19 +240,35 @@ func (c *RepoAccessCache) queryRepoAccessInfo(ctx context.Context, username, own
}
}

c.logDebug(ctx, fmt.Sprintf("queried repo access info for user %s to %s/%s: isPrivate=%t, hasPushAccess=%t, viewerLogin=%s",
username, owner, repo, bool(query.Repository.IsPrivate), hasPush, query.Viewer.Login))

return RepoAccessInfo{
IsPrivate: bool(query.Repository.IsPrivate),
HasPushAccess: hasPush,
ViewerLogin: string(query.Viewer.Login),
}, nil
}

func cacheKey(owner, repo string) string {
return fmt.Sprintf("%s/%s", strings.ToLower(owner), strings.ToLower(repo))
func (c *RepoAccessCache) log(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) {
if c == nil || c.logger == nil {
return
}
if !c.logger.Enabled(ctx, level) {
return
}
c.logger.LogAttrs(ctx, level, msg, attrs...)
}

func (c *RepoAccessCache) logDebug(msg string, args ...any) {
if c != nil && c.logger != nil {
c.logger.Debug(msg, args...)
}
func (c *RepoAccessCache) logDebug(ctx context.Context, msg string, attrs ...slog.Attr) {
c.log(ctx, slog.LevelDebug, msg, attrs...)
}

func (c *RepoAccessCache) isTrustedBot(username string) bool {
_, ok := c.trustedBotLogins[strings.ToLower(username)]
return ok
}

func cacheKey(owner, repo string) string {
return fmt.Sprintf("%s/%s", strings.ToLower(owner), strings.ToLower(repo))
}
Loading