diff --git a/deployment/clouddeploy/gke-workers/base/gitter.yaml b/deployment/clouddeploy/gke-workers/base/gitter.yaml index 8a46fd0537d..9b659954c76 100644 --- a/deployment/clouddeploy/gke-workers/base/gitter.yaml +++ b/deployment/clouddeploy/gke-workers/base/gitter.yaml @@ -29,6 +29,8 @@ spec: - "--port=8888" - "--work-dir=/work/gitter" - "--fetch-timeout=1h" + - "--repo-cache-ttl=1h" + - "--repo-cache-max-cost=100GiB" env: - name: GOMEMLIMIT value: "100GiB" diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml index 25c098d7110..576a83cfe6f 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/gitter.yaml @@ -11,6 +11,8 @@ spec: - "--port=8888" - "--work-dir=/work/gitter" - "--fetch-timeout=1h" + - "--repo-cache-ttl=1h" + - "--repo-cache-max-cost=100GiB" env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb-test diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml index fc1558ba8b4..24554f18405 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb/gitter.yaml @@ -11,6 +11,8 @@ spec: - "--port=8888" - "--work-dir=/work/gitter" - "--fetch-timeout=1h" + - "--repo-cache-ttl=1h" + - "--repo-cache-max-cost=100GiB" env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb diff --git a/go/cmd/gitter/gitter.go b/go/cmd/gitter/gitter.go index 7a97df829bb..7ea59794875 100644 --- a/go/cmd/gitter/gitter.go +++ b/go/cmd/gitter/gitter.go @@ -6,11 +6,12 @@ import ( "context" "crypto/sha256" "encoding/hex" - "encoding/json" + "errors" "flag" "fmt" "io" "log/slog" + "math" "net" "net/http" "os" @@ -26,9 +27,15 @@ import ( _ "net/http/pprof" //nolint:gosec // This is a internal only service not public to the internet + "github.com/dgraph-io/ristretto/v2" + "github.com/dustin/go-humanize" "github.com/google/osv.dev/go/logger" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "golang.org/x/sync/singleflight" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + + pb "github.com/google/osv.dev/go/cmd/gitter/pb/repository" ) type contextKey string @@ -43,8 +50,9 @@ const gitStoreFileName = "git-store" // API Endpoints var endpointHandlers = map[string]http.HandlerFunc{ - "GET /git": gitHandler, - "POST /cache": cacheHandler, + "GET /git": gitHandler, + "POST /cache": cacheHandler, + "POST /affected-commits": affectedCommitsHandler, } var ( @@ -55,10 +63,47 @@ var ( gitStorePath = filepath.Join(defaultGitterWorkDir, gitStoreFileName) fetchTimeout time.Duration semaphore chan struct{} // Request concurrency control + // LRU cache for recently loaded repositories (key: repo URL) + repoCache *ristretto.Cache[string, *Repository] + repoTTL time.Duration + repoCacheMaxCostBytes int64 ) +var validURLRegex = regexp.MustCompile(`^(https?|git)://`) + const shutdownTimeout = 10 * time.Second +type SeparatedEvents struct { + Introduced []string + Fixed []string + LastAffected []string + Limit []string +} + +func separateEvents(events []*pb.Event) (*SeparatedEvents, error) { + se := &SeparatedEvents{} + for _, event := range events { + switch event.GetEventType() { + case pb.EventType_INTRODUCED: + se.Introduced = append(se.Introduced, event.GetHash()) + case pb.EventType_FIXED: + se.Fixed = append(se.Fixed, event.GetHash()) + case pb.EventType_LAST_AFFECTED: + se.LastAffected = append(se.LastAffected, event.GetHash()) + case pb.EventType_LIMIT: + se.Limit = append(se.Limit, event.GetHash()) + default: + return nil, fmt.Errorf("invalid event type: %s", event.GetEventType()) + } + } + + if len(se.Limit) > 0 && (len(se.Fixed) > 0 || len(se.LastAffected) > 0) { + return nil, errors.New("limit and fixed/last_affected shouldn't exist in the same request") + } + + return se, nil +} + // repoLocks is a map of per-repository RWMutexes, with url as the key. // It coordinates access between write operations (FetchRepo) that modify the git directory on disk // and read operations (ArchiveRepo, LoadRepository, etc). @@ -72,10 +117,57 @@ func GetRepoLock(url string) *sync.RWMutex { return lock.(*sync.RWMutex) } -// runCmd executes a command with context cancellation handled by sending SIGINT. -// It logs cancellation errors separately as requested. -func runCmd(ctx context.Context, dir string, env []string, name string, args ...string) error { - logger.DebugContext(ctx, "Running command", slog.String("cmd", name), slog.Any("args", args)) +// repoCostBytes is the cost function for a repository in the LRU cache. +// The memory cost of a repository is approximated from the num of commits and a base overhead. +func repoCostBytes(repo *Repository) int64 { + // Mutex (8 bytes), string for repo path (say 128 bytes), root commit (assume 1 root only, 32 bytes) + repoOverhead := 168 + // Assuming per commit adds: + // - Commit struct (Hash, PatchID, Parent []int of size 1, Refs []string) + // = 20 + 20 + 24 + 8 + 24 + ~= 128 bytes + // - 1 pointer into []*Commit + // = 8 bytes + // - 1 entry in commitGraph ([][]int, assuming linear history) + // = 24 + 8 = 32 bytes + // - 1 entry to hashToIndex (map[SHA1]int) + // = 20 + 8 ~= 32 bytes + // - 1 entry to patchIDToCommits (map[SHA1][]int, assuming all commits are unique) + // = 20 + 24 + 8 ~= 64 bytes + // TOTAL: 264 bytes -> We round up to 300 for some buffer + costPerCommit := 300 + + return int64(repoOverhead + len(repo.commits)*costPerCommit) +} + +// General guidance is to make NumCounters 10x the cache capacity (in terms of items) +// We're assuming the cache will hold 5000 repositories +const numCounters = int64(10 * 5000) + +// InitRepoCache initializes the LRU cache for repositories. +func InitRepoCache() { + var err error + repoCache, err = ristretto.NewCache(&ristretto.Config[string, *Repository]{ + NumCounters: numCounters, + MaxCost: repoCacheMaxCostBytes, + BufferItems: 64, + Cost: repoCostBytes, + // Check for TTL expiry every 60 seconds + TtlTickerDurationInSec: 60, + }) + if err != nil { + logger.FatalContext(context.Background(), "Failed to initialize repository cache", slog.Any("err", err)) + } +} + +// CloseRepoCache closes the LRU cache. +func CloseRepoCache() { + if repoCache != nil { + repoCache.Close() + } +} + +// prepareCmd prepares the command with context cancellation handled by sending SIGINT. +func prepareCmd(ctx context.Context, dir string, env []string, name string, args ...string) *exec.Cmd { cmd := exec.CommandContext(ctx, name, args...) if dir != "" { cmd.Dir = dir @@ -91,6 +183,14 @@ func runCmd(ctx context.Context, dir string, env []string, name string, args ... // Ensure it eventually dies if it ignores SIGINT cmd.WaitDelay = shutdownTimeout / 2 + return cmd +} + +// runCmd executes a command with context cancellation handled by sending SIGINT. +// It logs cancellation errors separately as requested. +func runCmd(ctx context.Context, dir string, env []string, name string, args ...string) error { + logger.DebugContext(ctx, "Running command", slog.String("cmd", name), slog.Any("args", args)) + cmd := prepareCmd(ctx, dir, env, name, args...) out, err := cmd.CombinedOutput() if err != nil { if ctx.Err() != nil { @@ -106,26 +206,6 @@ func runCmd(ctx context.Context, dir string, env []string, name string, args ... return nil } -// prepareCmd prepares the command with context cancellation handled by sending SIGINT. -func prepareCmd(ctx context.Context, dir string, env []string, name string, args ...string) *exec.Cmd { - cmd := exec.CommandContext(ctx, name, args...) - if dir != "" { - cmd.Dir = dir - } - if len(env) > 0 { - cmd.Env = append(os.Environ(), env...) - } - // Use SIGINT instead of SIGKILL for graceful shutdown of subprocesses - cmd.Cancel = func() error { - logger.DebugContext(ctx, "SIGINT sent to command", slog.String("cmd", name), slog.Any("args", args)) - return cmd.Process.Signal(syscall.SIGINT) - } - // Ensure it eventually dies if it ignores SIGINT - cmd.WaitDelay = shutdownTimeout / 2 - - return cmd -} - func isLocalRequest(r *http.Request) bool { host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { @@ -143,6 +223,21 @@ func isLocalRequest(r *http.Request) bool { return ip.IsLoopback() } +func validateURL(r *http.Request, url string) error { + if url == "" { + return errors.New("missing url parameter") + } + // If request came from a local ip, don't do the check + if !isLocalRequest(r) { + // Check if url starts with protocols: http(s)://, git:// + if !validURLRegex.MatchString(url) { + return errors.New("invalid url parameter") + } + } + + return nil +} + func getRepoDirName(url string) string { base := path.Base(url) base = filepath.Base(base) @@ -168,6 +263,88 @@ func isIndexLockError(err error) bool { return strings.Contains(errString, "index.lock") && strings.Contains(errString, "File exists") } +// Helper function to unmarshal request body based on Content-Type (protobuf or JSON) +func unmarshalRequest(r *http.Request, body proto.Message) error { + data, err := io.ReadAll(r.Body) + if err != nil { + return err + } + defer r.Body.Close() + + contentType := r.Header.Get("Content-Type") + if contentType == "application/json" { + return protojson.Unmarshal(data, body) + } + // Default to protobuf + return proto.Unmarshal(data, body) +} + +// Helper function to marshal response body based on Content-Type (protobuf or JSON) +func marshalResponse(r *http.Request, m proto.Message) ([]byte, error) { + contentType := r.Header.Get("Content-Type") + if contentType == "application/json" { + return protojson.Marshal(m) + } + // Default to protobuf + return proto.Marshal(m) +} + +func doFetch(ctx context.Context, w http.ResponseWriter, url string, forceUpdate bool) error { + _, err, _ := gFetch.Do(url, func() (any, error) { + return nil, FetchRepo(ctx, url, forceUpdate) + }) + if err != nil { + logger.ErrorContext(ctx, "Error fetching blob", slog.Any("error", err)) + if isAuthError(err) { + http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusForbidden) + } else { + http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusInternalServerError) + } + + return err + } + + return nil +} + +// getFreshRepo handles fetching and loading of a repository +// If forceUpdate is true, it will always refetch and rebuild the repository (commit graph, patch ID, etc) +// Otherwise, it will use a cache if available +func getFreshRepo(ctx context.Context, w http.ResponseWriter, url string, forceUpdate bool) (*Repository, error) { + repoDirName := getRepoDirName(url) + repoPath := filepath.Join(gitStorePath, repoDirName) + + if !forceUpdate { + if repo, ok := repoCache.Get(url); ok { + // repoCache.Get() will not return expired items, so we can safely return the repo + logger.DebugContext(ctx, "Repository already in cache, skipping fetch and load") + return repo, nil + } + } + + if err := doFetch(ctx, w, url, forceUpdate); err != nil { + return nil, err + } + + repoAny, err, _ := gLoad.Do(repoPath, func() (any, error) { + repoLock := GetRepoLock(url) + repoLock.RLock() + defer repoLock.RUnlock() + + return LoadRepository(ctx, repoPath) + }) + if err != nil { + logger.ErrorContext(ctx, "Failed to load repository", slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Failed to load repository: %v", err), http.StatusInternalServerError) + + return nil, err + } + repo := repoAny.(*Repository) + repoCache.SetWithTTL(url, repo, 0, repoTTL) + + return repo, nil +} + func FetchRepo(ctx context.Context, url string, forceUpdate bool) error { logger.InfoContext(ctx, "Starting fetch repo") start := time.Now() @@ -188,6 +365,7 @@ func FetchRepo(ctx context.Context, url string, forceUpdate bool) error { logger.InfoContext(ctx, "Fetching git blob", slog.Duration("sinceAccessTime", time.Since(accessTime))) if _, err := os.Stat(filepath.Join(repoPath, ".git")); os.IsNotExist(err) { // Clone + logger.InfoContext(ctx, "Cloning git repository", slog.Duration("sinceAccessTime", time.Since(accessTime))) err := runCmd(ctx, "", []string{"GIT_TERMINAL_PROMPT=0"}, "git", "clone", "--", url, repoPath) if err != nil { return fmt.Errorf("git clone failed: %w", err) @@ -196,6 +374,7 @@ func FetchRepo(ctx context.Context, url string, forceUpdate bool) error { // Fetch/Pull - implementing simple git pull for now, might need reset --hard if we want exact mirrors // For a generic "get latest", pull is usually sufficient if we treat it as read-only. // Ideally safely: git fetch origin && git reset --hard origin/HEAD + logger.InfoContext(ctx, "Fetching git repository", slog.Duration("sinceAccessTime", time.Since(accessTime))) err := runCmd(ctx, repoPath, nil, "git", "fetch", "origin") if err != nil { return fmt.Errorf("git fetch failed: %w", err) @@ -285,17 +464,29 @@ func main() { workDir := flag.String("work-dir", defaultGitterWorkDir, "Work directory") flag.DurationVar(&fetchTimeout, "fetch-timeout", time.Hour, "Fetch timeout duration") concurrentLimit := flag.Int("concurrent-limit", 100, "Concurrent limit for unique requests") + flag.DurationVar(&repoTTL, "repo-cache-ttl", time.Hour, "Repository LRU cache time-to-live duration") + repoMaxCostStr := flag.String("repo-cache-max-cost", "1GiB", "Repository LRU cache max cost (in bytes)") flag.Parse() semaphore = make(chan struct{}, *concurrentLimit) persistencePath = filepath.Join(*workDir, persistenceFileName) gitStorePath = filepath.Join(*workDir, gitStoreFileName) - if err := os.MkdirAll(gitStorePath, 0755); err != nil { logger.Fatal("Failed to create git store path", slog.String("path", gitStorePath), slog.Any("error", err)) } + repoMaxCostUint, err := humanize.ParseBytes(*repoMaxCostStr) + if err != nil { + logger.Fatal("Failed to parse repo cache max cost", slog.String("maxCost", *repoMaxCostStr), slog.Any("error", err)) + } + if repoMaxCostUint > math.MaxInt64 { + logger.Fatal("Repo cache max cost too large", slog.Uint64("maxCost", repoMaxCostUint)) + } + repoCacheMaxCostBytes = int64(repoMaxCostUint) + loadLastFetchMap() + InitRepoCache() + defer CloseRepoCache() // Create a context that listens for the interrupt signal from the OS. ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) @@ -354,23 +545,14 @@ func main() { func gitHandler(w http.ResponseWriter, r *http.Request) { url := r.URL.Query().Get("url") - if url == "" { - http.Error(w, "Missing url parameter", http.StatusBadRequest) + if err := validateURL(r, url); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } forceUpdate := r.URL.Query().Get("force-update") == "true" ctx := context.WithValue(r.Context(), urlKey, url) - logger.InfoContext(ctx, "Received request: /git", slog.Bool("forceUpdate", forceUpdate), slog.String("remoteAddr", r.RemoteAddr)) - // If request came from a local ip, don't do the check - if !isLocalRequest(r) { - // Check if url starts with protocols: http(s)://, git://, ssh://, (s)ftp:// - if match, _ := regexp.MatchString("^(https?|git|ssh)://", url); !match { - http.Error(w, "Invalid url parameter", http.StatusBadRequest) - return - } - } select { case semaphore <- struct{}{}: @@ -384,30 +566,11 @@ func gitHandler(w http.ResponseWriter, r *http.Request) { logger.DebugContext(ctx, "Concurrent requests", slog.Int("count", len(semaphore))) // Fetch repo first - // Keep the key as the url regardless of forceUpdate. - // Occasionally this could be problematic if an existing unforce updated - // query is already inplace, no force update will happen. - // That is highly unlikely in our use case, as importer only queries - // the repo once, and always with force update. - // This is a tradeoff for simplicity to avoid having to setup locks per repo. - // I can't change singleflight's interface - _, err, _ := gFetch.Do(url, func() (any, error) { - return nil, FetchRepo(ctx, url, forceUpdate) - }) - if err != nil { - logger.ErrorContext(ctx, "Error fetching blob", slog.Any("error", err)) - if isAuthError(err) { - http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusForbidden) - - return - } - http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusInternalServerError) - + if err := doFetch(ctx, w, url, forceUpdate); err != nil { return } // Archive repo - // I can't change singleflight's interface fileDataAny, err, _ := gArchive.Do(url, func() (any, error) { return ArchiveRepo(ctx, url) }) @@ -434,20 +597,18 @@ func gitHandler(w http.ResponseWriter, r *http.Request) { func cacheHandler(w http.ResponseWriter, r *http.Request) { start := time.Now() - // POST requets body processing - var body struct { - URL string `json:"url"` - ForceUpdate bool `json:"force_update"` + body := &pb.CacheRequest{} + if err := unmarshalRequest(r, body); err != nil { + http.Error(w, fmt.Sprintf("Error unmarshaling request: %v", err), http.StatusBadRequest) + return } - err := json.NewDecoder(r.Body).Decode(&body) - if err != nil { - http.Error(w, fmt.Sprintf("Error decoding JSON: %v", err), http.StatusBadRequest) + url := body.GetUrl() + if err := validateURL(r, url); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } - defer r.Body.Close() - url := body.URL ctx := context.WithValue(r.Context(), urlKey, url) logger.InfoContext(ctx, "Received request: /cache") @@ -462,40 +623,92 @@ func cacheHandler(w http.ResponseWriter, r *http.Request) { } logger.DebugContext(ctx, "Concurrent requests", slog.Int("count", len(semaphore))) - // Fetch repo if it's not fresh - // I can't change singleflight's interface - if _, err, _ := gFetch.Do(url, func() (any, error) { - return nil, FetchRepo(ctx, url, body.ForceUpdate) - }); err != nil { - logger.ErrorContext(ctx, "Error fetching blob", slog.Any("error", err)) - if isAuthError(err) { - http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusForbidden) + if _, err := getFreshRepo(ctx, w, url, body.GetForceUpdate()); err != nil { + return + } - return - } - http.Error(w, fmt.Sprintf("Error fetching blob: %v", err), http.StatusInternalServerError) + w.WriteHeader(http.StatusOK) + logger.InfoContext(ctx, "Request completed successfully: /cache", slog.Duration("duration", time.Since(start))) +} +func affectedCommitsHandler(w http.ResponseWriter, r *http.Request) { + start := time.Now() + body := &pb.AffectedCommitsRequest{} + if err := unmarshalRequest(r, body); err != nil { + http.Error(w, fmt.Sprintf("Error unmarshaling request: %v", err), http.StatusBadRequest) return } - repoDirName := getRepoDirName(url) - repoPath := filepath.Join(gitStorePath, repoDirName) + url := body.GetUrl() + if err := validateURL(r, url); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } - // I can't change singleflight's interface - _, err, _ = gLoad.Do(repoPath, func() (any, error) { - repoLock := GetRepoLock(url) - repoLock.RLock() - defer repoLock.RUnlock() + se, err := separateEvents(body.GetEvents()) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } - return LoadRepository(ctx, repoPath) - }) + cherrypickIntro := body.GetDetectCherrypicksIntroduced() + cherrypickFixed := body.GetDetectCherrypicksFixed() + cherrypickLimit := body.GetDetectCherrypicksLimit() + + ctx := context.WithValue(r.Context(), urlKey, url) + logger.InfoContext(ctx, "Received request: /affected-commits", slog.Any("introduced", se.Introduced), slog.Any("fixed", se.Fixed), slog.Any("last_affected", se.LastAffected), slog.Any("limit", se.Limit), slog.Bool("cherrypickIntro", cherrypickIntro), slog.Bool("cherrypickFixed", cherrypickFixed), slog.Bool("cherrypickLimit", cherrypickLimit)) + + select { + case semaphore <- struct{}{}: + defer func() { <-semaphore }() + case <-ctx.Done(): + logger.WarnContext(ctx, "Request cancelled while waiting for semaphore") + http.Error(w, "Server context cancelled", http.StatusServiceUnavailable) + + return + } + logger.DebugContext(ctx, "Concurrent requests", slog.Int("count", len(semaphore))) + + repo, err := getFreshRepo(ctx, w, url, body.GetForceUpdate()) if err != nil { - logger.ErrorContext(ctx, "Failed to load repository", slog.Any("error", err)) - http.Error(w, fmt.Sprintf("Failed to load repository: %v", err), http.StatusInternalServerError) + return + } + + var affectedCommits []*Commit + if len(se.Limit) > 0 { + affectedCommits = repo.Limit(ctx, se, cherrypickIntro, cherrypickLimit) + } else { + affectedCommits = repo.Affected(ctx, se, cherrypickIntro, cherrypickFixed) + } + + resp := &pb.AffectedCommitsResponse{ + Commits: make([]*pb.AffectedCommit, 0, len(affectedCommits)), + Refs: make([]*pb.AffectedRefs, 0), + } + for _, c := range affectedCommits { + resp.Commits = append(resp.Commits, &pb.AffectedCommit{ + Hash: c.Hash[:], + }) + for _, ref := range c.Refs { + resp.Refs = append(resp.Refs, &pb.AffectedRefs{ + Ref: ref, + Hash: c.Hash[:], + }) + } + } + + out, err := marshalResponse(r, resp) + if err != nil { + logger.ErrorContext(ctx, "Error marshaling affected commits", slog.Any("error", err)) + http.Error(w, fmt.Sprintf("Error marshaling affected commits: %v", err), http.StatusInternalServerError) return } + w.Header().Set("Content-Type", r.Header.Get("Content-Type")) w.WriteHeader(http.StatusOK) - logger.InfoContext(ctx, "Request completed successfully: /cache", slog.Duration("duration", time.Since(start))) + if _, err := w.Write(out); err != nil { + logger.ErrorContext(ctx, "Error writing response", slog.Any("error", err)) + } + logger.InfoContext(ctx, "Request completed successfully: /affected-commits", slog.Duration("duration", time.Since(start))) } diff --git a/go/cmd/gitter/gitter_test.go b/go/cmd/gitter/gitter_test.go index 770ba200865..3eb915f602e 100644 --- a/go/cmd/gitter/gitter_test.go +++ b/go/cmd/gitter/gitter_test.go @@ -2,12 +2,17 @@ package main import ( "bytes" - "encoding/json" + "encoding/hex" "errors" "net/http" "net/http/httptest" "testing" "time" + + "github.com/google/go-cmp/cmp" + pb "github.com/google/osv.dev/go/cmd/gitter/pb/repository" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" ) func TestGetRepoDirName(t *testing.T) { @@ -79,13 +84,24 @@ func TestGitHandler_InvalidURL(t *testing.T) { } } +func resetSaveTimer() { + lastFetchMu.Lock() + defer lastFetchMu.Unlock() + if saveTimer != nil { + saveTimer.Stop() + saveTimer = nil + } +} + // Override global variables for test // Note: In a real app we might want to dependency inject these, // but for this simple script we modify package globals. func setupTest(t *testing.T) { t.Helper() - tmpDir := t.TempDir() + resetSaveTimer() + + tmpDir := t.TempDir() gitStorePath = tmpDir persistencePath = tmpDir + "/last-fetch.json" // Use simple path join for test fetchTimeout = time.Minute @@ -98,11 +114,7 @@ func setupTest(t *testing.T) { // Initialize semaphore for tests semaphore = make(chan struct{}, 100) - // Stop any existing timer - if saveTimer != nil { - saveTimer.Stop() - saveTimer = nil - } + t.Cleanup(resetSaveTimer) } func TestGitHandler_Integration(t *testing.T) { @@ -172,8 +184,10 @@ func TestCacheHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - body, _ := json.Marshal(map[string]string{"url": tt.url}) + reqProto := &pb.CacheRequest{Url: tt.url} + body, _ := protojson.Marshal(reqProto) req, err := http.NewRequest(http.MethodPost, "/cache", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") if err != nil { t.Fatal(err) } @@ -187,3 +201,119 @@ func TestCacheHandler(t *testing.T) { }) } } + +func TestAffectedCommitsHandler(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + setupTest(t) + + tests := []struct { + name string + url string + introduced []string + fixed []string + lastAffected []string + limit []string + invalidType []string + expectedCode int + expectedBody []string + }{ + { + name: "Valid range in public repo", + url: "https://github.com/google/oss-fuzz-vulns.git", + introduced: []string{"3350c55f9525cb83fc3e0b61bde076433c2da8dc"}, + fixed: []string{"8920ed8e47c660a0c20c28cb1004a600780c5b59"}, + expectedCode: http.StatusOK, + expectedBody: []string{"3350c55f9525cb83fc3e0b61bde076433c2da8dc"}, + }, + { + name: "Invalid mixed limit and fixed", + url: "https://github.com/google/oss-fuzz-vulns.git", + introduced: []string{"3350c55f9525cb83fc3e0b61bde076433c2da8dc"}, + fixed: []string{"8920ed8e47c660a0c20c28cb1004a600780c5b59"}, + limit: []string{"996962b987c856bf751948e55b9366751e806c64"}, + expectedCode: http.StatusBadRequest, + }, + { + name: "Non-existent repo", + url: "https://github.com/google/this-repo-does-not-exist-12345.git", + introduced: []string{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + expectedCode: http.StatusForbidden, + }, + { + name: "Invalid event type", + url: "https://github.com/google/oss-fuzz-vulns.git", + invalidType: []string{"3350c55f9525cb83fc3e0b61bde076433c2da8dc"}, + expectedCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var events []*pb.Event + for _, h := range tt.introduced { + events = append(events, &pb.Event{EventType: pb.EventType_INTRODUCED, Hash: h}) + } + for _, h := range tt.fixed { + events = append(events, &pb.Event{EventType: pb.EventType_FIXED, Hash: h}) + } + for _, h := range tt.lastAffected { + events = append(events, &pb.Event{EventType: pb.EventType_LAST_AFFECTED, Hash: h}) + } + for _, h := range tt.limit { + events = append(events, &pb.Event{EventType: pb.EventType_LIMIT, Hash: h}) + } + for _, h := range tt.invalidType { + events = append(events, &pb.Event{EventType: 999, Hash: h}) + } + + reqProto := &pb.AffectedCommitsRequest{ + Url: tt.url, + Events: events, + } + + body, _ := protojson.Marshal(reqProto) + req, err := http.NewRequest(http.MethodPost, "/affected-commits", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + affectedCommitsHandler(rr, req) + + if status := rr.Code; status != tt.expectedCode { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tt.expectedCode) + } + + if tt.expectedBody == nil { + return + } + + respBody := &pb.AffectedCommitsResponse{} + if rr.Header().Get("Content-Type") == "application/json" { + if err := protojson.Unmarshal(rr.Body.Bytes(), respBody); err != nil { + t.Fatalf("Failed to unmarshal JSON response: %v", err) + } + } else { + if err := proto.Unmarshal(rr.Body.Bytes(), respBody); err != nil { + t.Fatalf("Failed to unmarshal proto response: %v", err) + } + } + + var gotHashes []string + for _, c := range respBody.GetCommits() { + gotHashes = append(gotHashes, hex.EncodeToString(c.GetHash())) + } + if gotHashes == nil { + gotHashes = []string{} + } + + if diff := cmp.Diff(tt.expectedBody, gotHashes); diff != "" { + t.Errorf("handler returned wrong body (-want +got):\n%s", diff) + } + }) + } +} diff --git a/go/cmd/gitter/pb/repository/repository.pb.go b/go/cmd/gitter/pb/repository/repository.pb.go index 8d30547714c..d8e9ad5de98 100644 --- a/go/cmd/gitter/pb/repository/repository.pb.go +++ b/go/cmd/gitter/pb/repository/repository.pb.go @@ -7,12 +7,11 @@ package repository import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" unsafe "unsafe" - - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -22,6 +21,58 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type EventType int32 + +const ( + EventType_INTRODUCED EventType = 0 + EventType_FIXED EventType = 1 + EventType_LAST_AFFECTED EventType = 2 + EventType_LIMIT EventType = 3 +) + +// Enum value maps for EventType. +var ( + EventType_name = map[int32]string{ + 0: "INTRODUCED", + 1: "FIXED", + 2: "LAST_AFFECTED", + 3: "LIMIT", + } + EventType_value = map[string]int32{ + "INTRODUCED": 0, + "FIXED": 1, + "LAST_AFFECTED": 2, + "LIMIT": 3, + } +) + +func (x EventType) Enum() *EventType { + p := new(EventType) + *p = x + return p +} + +func (x EventType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (EventType) Descriptor() protoreflect.EnumDescriptor { + return file_repository_proto_enumTypes[0].Descriptor() +} + +func (EventType) Type() protoreflect.EnumType { + return &file_repository_proto_enumTypes[0] +} + +func (x EventType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use EventType.Descriptor instead. +func (EventType) EnumDescriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{0} +} + type CommitDetail struct { state protoimpl.MessageState `protogen:"open.v1"` Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` @@ -120,6 +171,342 @@ func (x *RepositoryCache) GetCommits() []*CommitDetail { return nil } +type AffectedCommit struct { + state protoimpl.MessageState `protogen:"open.v1"` + Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AffectedCommit) Reset() { + *x = AffectedCommit{} + mi := &file_repository_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AffectedCommit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AffectedCommit) ProtoMessage() {} + +func (x *AffectedCommit) ProtoReflect() protoreflect.Message { + mi := &file_repository_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AffectedCommit.ProtoReflect.Descriptor instead. +func (*AffectedCommit) Descriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{2} +} + +func (x *AffectedCommit) GetHash() []byte { + if x != nil { + return x.Hash + } + return nil +} + +type AffectedRefs struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ref string `protobuf:"bytes,1,opt,name=ref,proto3" json:"ref,omitempty"` + Hash []byte `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AffectedRefs) Reset() { + *x = AffectedRefs{} + mi := &file_repository_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AffectedRefs) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AffectedRefs) ProtoMessage() {} + +func (x *AffectedRefs) ProtoReflect() protoreflect.Message { + mi := &file_repository_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AffectedRefs.ProtoReflect.Descriptor instead. +func (*AffectedRefs) Descriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{3} +} + +func (x *AffectedRefs) GetRef() string { + if x != nil { + return x.Ref + } + return "" +} + +func (x *AffectedRefs) GetHash() []byte { + if x != nil { + return x.Hash + } + return nil +} + +type AffectedCommitsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Commits []*AffectedCommit `protobuf:"bytes,1,rep,name=commits,proto3" json:"commits,omitempty"` + Refs []*AffectedRefs `protobuf:"bytes,2,rep,name=refs,proto3" json:"refs,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AffectedCommitsResponse) Reset() { + *x = AffectedCommitsResponse{} + mi := &file_repository_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AffectedCommitsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AffectedCommitsResponse) ProtoMessage() {} + +func (x *AffectedCommitsResponse) ProtoReflect() protoreflect.Message { + mi := &file_repository_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AffectedCommitsResponse.ProtoReflect.Descriptor instead. +func (*AffectedCommitsResponse) Descriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{4} +} + +func (x *AffectedCommitsResponse) GetCommits() []*AffectedCommit { + if x != nil { + return x.Commits + } + return nil +} + +func (x *AffectedCommitsResponse) GetRefs() []*AffectedRefs { + if x != nil { + return x.Refs + } + return nil +} + +type Event struct { + state protoimpl.MessageState `protogen:"open.v1"` + EventType EventType `protobuf:"varint,1,opt,name=event_type,json=eventType,proto3,enum=gitter.EventType" json:"event_type,omitempty"` + Hash string `protobuf:"bytes,2,opt,name=hash,proto3" json:"hash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Event) Reset() { + *x = Event{} + mi := &file_repository_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Event) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Event) ProtoMessage() {} + +func (x *Event) ProtoReflect() protoreflect.Message { + mi := &file_repository_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Event.ProtoReflect.Descriptor instead. +func (*Event) Descriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{5} +} + +func (x *Event) GetEventType() EventType { + if x != nil { + return x.EventType + } + return EventType_INTRODUCED +} + +func (x *Event) GetHash() string { + if x != nil { + return x.Hash + } + return "" +} + +type CacheRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` + ForceUpdate bool `protobuf:"varint,2,opt,name=force_update,json=forceUpdate,proto3" json:"force_update,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CacheRequest) Reset() { + *x = CacheRequest{} + mi := &file_repository_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CacheRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CacheRequest) ProtoMessage() {} + +func (x *CacheRequest) ProtoReflect() protoreflect.Message { + mi := &file_repository_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CacheRequest.ProtoReflect.Descriptor instead. +func (*CacheRequest) Descriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{6} +} + +func (x *CacheRequest) GetUrl() string { + if x != nil { + return x.Url + } + return "" +} + +func (x *CacheRequest) GetForceUpdate() bool { + if x != nil { + return x.ForceUpdate + } + return false +} + +type AffectedCommitsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Url string `protobuf:"bytes,1,opt,name=url,proto3" json:"url,omitempty"` + Events []*Event `protobuf:"bytes,2,rep,name=events,proto3" json:"events,omitempty"` + DetectCherrypicksIntroduced bool `protobuf:"varint,3,opt,name=detect_cherrypicks_introduced,json=detectCherrypicksIntroduced,proto3" json:"detect_cherrypicks_introduced,omitempty"` + DetectCherrypicksFixed bool `protobuf:"varint,4,opt,name=detect_cherrypicks_fixed,json=detectCherrypicksFixed,proto3" json:"detect_cherrypicks_fixed,omitempty"` + DetectCherrypicksLimit bool `protobuf:"varint,5,opt,name=detect_cherrypicks_limit,json=detectCherrypicksLimit,proto3" json:"detect_cherrypicks_limit,omitempty"` + ForceUpdate bool `protobuf:"varint,6,opt,name=force_update,json=forceUpdate,proto3" json:"force_update,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AffectedCommitsRequest) Reset() { + *x = AffectedCommitsRequest{} + mi := &file_repository_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AffectedCommitsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AffectedCommitsRequest) ProtoMessage() {} + +func (x *AffectedCommitsRequest) ProtoReflect() protoreflect.Message { + mi := &file_repository_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AffectedCommitsRequest.ProtoReflect.Descriptor instead. +func (*AffectedCommitsRequest) Descriptor() ([]byte, []int) { + return file_repository_proto_rawDescGZIP(), []int{7} +} + +func (x *AffectedCommitsRequest) GetUrl() string { + if x != nil { + return x.Url + } + return "" +} + +func (x *AffectedCommitsRequest) GetEvents() []*Event { + if x != nil { + return x.Events + } + return nil +} + +func (x *AffectedCommitsRequest) GetDetectCherrypicksIntroduced() bool { + if x != nil { + return x.DetectCherrypicksIntroduced + } + return false +} + +func (x *AffectedCommitsRequest) GetDetectCherrypicksFixed() bool { + if x != nil { + return x.DetectCherrypicksFixed + } + return false +} + +func (x *AffectedCommitsRequest) GetDetectCherrypicksLimit() bool { + if x != nil { + return x.DetectCherrypicksLimit + } + return false +} + +func (x *AffectedCommitsRequest) GetForceUpdate() bool { + if x != nil { + return x.ForceUpdate + } + return false +} + var File_repository_proto protoreflect.FileDescriptor const file_repository_proto_rawDesc = "" + @@ -129,7 +516,35 @@ const file_repository_proto_rawDesc = "" + "\x04hash\x18\x01 \x01(\fR\x04hash\x12\x19\n" + "\bpatch_id\x18\x02 \x01(\fR\apatchId\"A\n" + "\x0fRepositoryCache\x12.\n" + - "\acommits\x18\x01 \x03(\v2\x14.gitter.CommitDetailR\acommitsB\x0eZ\f./repositoryb\x06proto3" + "\acommits\x18\x01 \x03(\v2\x14.gitter.CommitDetailR\acommits\"$\n" + + "\x0eAffectedCommit\x12\x12\n" + + "\x04hash\x18\x01 \x01(\fR\x04hash\"4\n" + + "\fAffectedRefs\x12\x10\n" + + "\x03ref\x18\x01 \x01(\tR\x03ref\x12\x12\n" + + "\x04hash\x18\x02 \x01(\fR\x04hash\"u\n" + + "\x17AffectedCommitsResponse\x120\n" + + "\acommits\x18\x01 \x03(\v2\x16.gitter.AffectedCommitR\acommits\x12(\n" + + "\x04refs\x18\x02 \x03(\v2\x14.gitter.AffectedRefsR\x04refs\"M\n" + + "\x05Event\x120\n" + + "\n" + + "event_type\x18\x01 \x01(\x0e2\x11.gitter.EventTypeR\teventType\x12\x12\n" + + "\x04hash\x18\x02 \x01(\tR\x04hash\"C\n" + + "\fCacheRequest\x12\x10\n" + + "\x03url\x18\x01 \x01(\tR\x03url\x12!\n" + + "\fforce_update\x18\x02 \x01(\bR\vforceUpdate\"\xac\x02\n" + + "\x16AffectedCommitsRequest\x12\x10\n" + + "\x03url\x18\x01 \x01(\tR\x03url\x12%\n" + + "\x06events\x18\x02 \x03(\v2\r.gitter.EventR\x06events\x12B\n" + + "\x1ddetect_cherrypicks_introduced\x18\x03 \x01(\bR\x1bdetectCherrypicksIntroduced\x128\n" + + "\x18detect_cherrypicks_fixed\x18\x04 \x01(\bR\x16detectCherrypicksFixed\x128\n" + + "\x18detect_cherrypicks_limit\x18\x05 \x01(\bR\x16detectCherrypicksLimit\x12!\n" + + "\fforce_update\x18\x06 \x01(\bR\vforceUpdate*D\n" + + "\tEventType\x12\x0e\n" + + "\n" + + "INTRODUCED\x10\x00\x12\t\n" + + "\x05FIXED\x10\x01\x12\x11\n" + + "\rLAST_AFFECTED\x10\x02\x12\t\n" + + "\x05LIMIT\x10\x03B\x0eZ\f./repositoryb\x06proto3" var ( file_repository_proto_rawDescOnce sync.Once @@ -143,18 +558,30 @@ func file_repository_proto_rawDescGZIP() []byte { return file_repository_proto_rawDescData } -var file_repository_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_repository_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_repository_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_repository_proto_goTypes = []any{ - (*CommitDetail)(nil), // 0: gitter.CommitDetail - (*RepositoryCache)(nil), // 1: gitter.RepositoryCache + (EventType)(0), // 0: gitter.EventType + (*CommitDetail)(nil), // 1: gitter.CommitDetail + (*RepositoryCache)(nil), // 2: gitter.RepositoryCache + (*AffectedCommit)(nil), // 3: gitter.AffectedCommit + (*AffectedRefs)(nil), // 4: gitter.AffectedRefs + (*AffectedCommitsResponse)(nil), // 5: gitter.AffectedCommitsResponse + (*Event)(nil), // 6: gitter.Event + (*CacheRequest)(nil), // 7: gitter.CacheRequest + (*AffectedCommitsRequest)(nil), // 8: gitter.AffectedCommitsRequest } var file_repository_proto_depIdxs = []int32{ - 0, // 0: gitter.RepositoryCache.commits:type_name -> gitter.CommitDetail - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 1, // 0: gitter.RepositoryCache.commits:type_name -> gitter.CommitDetail + 3, // 1: gitter.AffectedCommitsResponse.commits:type_name -> gitter.AffectedCommit + 4, // 2: gitter.AffectedCommitsResponse.refs:type_name -> gitter.AffectedRefs + 0, // 3: gitter.Event.event_type:type_name -> gitter.EventType + 6, // 4: gitter.AffectedCommitsRequest.events:type_name -> gitter.Event + 5, // [5:5] is the sub-list for method output_type + 5, // [5:5] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name } func init() { file_repository_proto_init() } @@ -167,13 +594,14 @@ func file_repository_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_repository_proto_rawDesc), len(file_repository_proto_rawDesc)), - NumEnums: 0, - NumMessages: 2, + NumEnums: 1, + NumMessages: 8, NumExtensions: 0, NumServices: 0, }, GoTypes: file_repository_proto_goTypes, DependencyIndexes: file_repository_proto_depIdxs, + EnumInfos: file_repository_proto_enumTypes, MessageInfos: file_repository_proto_msgTypes, }.Build() File_repository_proto = out.File diff --git a/go/cmd/gitter/pb/repository/repository.proto b/go/cmd/gitter/pb/repository/repository.proto index 2a128251da6..26ea0d29657 100644 --- a/go/cmd/gitter/pb/repository/repository.proto +++ b/go/cmd/gitter/pb/repository/repository.proto @@ -14,3 +14,43 @@ message CommitDetail { message RepositoryCache { repeated CommitDetail commits = 1; } + +message AffectedCommit { + bytes hash = 1; +} + +message AffectedRefs { + string ref = 1; + bytes hash = 2; +} + +message AffectedCommitsResponse { + repeated AffectedCommit commits = 1; + repeated AffectedRefs refs = 2; +} + +enum EventType { + INTRODUCED = 0; + FIXED = 1; + LAST_AFFECTED = 2; + LIMIT = 3; +} + +message Event { + EventType event_type = 1; + string hash = 2; +} + +message CacheRequest { + string url = 1; + bool force_update = 2; +} + +message AffectedCommitsRequest { + string url = 1; + repeated Event events = 2; + bool detect_cherrypicks_introduced = 3; + bool detect_cherrypicks_fixed = 4; + bool detect_cherrypicks_limit = 5; + bool force_update = 6; +} diff --git a/go/cmd/gitter/persistence.go b/go/cmd/gitter/persistence.go index 67f06859e67..67e83d31a69 100644 --- a/go/cmd/gitter/persistence.go +++ b/go/cmd/gitter/persistence.go @@ -90,7 +90,7 @@ func saveRepositoryCache(cachePath string, repo *Repository) error { logger.Info("Saving repository cache", slog.String("path", cachePath)) cache := &pb.RepositoryCache{} - for _, commit := range repo.commitDetails { + for _, commit := range repo.commits { cache.Commits = append(cache.Commits, &pb.CommitDetail{ Hash: commit.Hash[:], PatchId: commit.PatchID[:], diff --git a/go/cmd/gitter/repository.go b/go/cmd/gitter/repository.go index 1c23f768694..0b26dde611f 100644 --- a/go/cmd/gitter/repository.go +++ b/go/cmd/gitter/repository.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "log/slog" + "maps" "os" + "slices" "strings" "sync" "time" @@ -20,10 +22,10 @@ import ( type SHA1 [20]byte type Commit struct { - Hash SHA1 `json:"hash"` - PatchID SHA1 `json:"patch_id"` - Parents []SHA1 `json:"parents"` - Tags []string `json:"tags"` + Hash SHA1 + PatchID SHA1 + Parents []int + Refs []string } // Repository holds the commit graph and other details for a git repository. @@ -32,14 +34,19 @@ type Repository struct { patchIDMu sync.Mutex // Path to the .git directory within gitter's working dir repoPath string - // Adjacency list: Parent -> []Children - commitGraph map[SHA1][]SHA1 - // Actual commit details - commitDetails map[SHA1]*Commit - // Store tags to commit because it's useful for CVE conversion - tagToCommit map[string]SHA1 - // For cherry-pick detection: PatchID -> []commit hash - patchIDToCommits map[SHA1][]SHA1 + // All commits in the repository (the array index is used as the commit index below) + commits []*Commit + // Adjacency list: Parent index -> []Children indexes + commitGraph [][]int + // Map of commit hash to its index in the commits slice + hashToIndex map[SHA1]int + // Store refs to commit because it's useful for CVE conversion + refToCommit map[string]int + // For cherry-pick detection: PatchID -> []commit indexes + patchIDToCommits map[SHA1][]int + // Root commits (commits with no parents) + // In a typical repository this is the initial commit + rootCommits []int } // %H commit hash; %P parent hashes; %D:refs (tab delimited) @@ -53,10 +60,9 @@ var workers = 16 func NewRepository(repoPath string) *Repository { return &Repository{ repoPath: repoPath, - commitGraph: make(map[SHA1][]SHA1), - commitDetails: make(map[SHA1]*Commit), - tagToCommit: make(map[string]SHA1), - patchIDToCommits: make(map[SHA1][]SHA1), + hashToIndex: make(map[SHA1]int), + refToCommit: make(map[string]int), + patchIDToCommits: make(map[SHA1][]int), } } @@ -100,10 +106,27 @@ func LoadRepository(ctx context.Context, repoPath string) (*Repository, error) { return repo, nil } +// getOrCreateIndex returns the index for a given commit hash. +// If the hash is new, it creates a new barebone commit and expands the graph structure to accommodate it. +func (r *Repository) getOrCreateIndex(hash SHA1) int { + // Check if we've already assigned an index to this hash + if idx, ok := r.hashToIndex[hash]; ok { + return idx + } + + idx := len(r.commits) + r.commits = append(r.commits, &Commit{Hash: hash}) + r.hashToIndex[hash] = idx + // Expand the commitGraph (adjacency list) to match the commits slice. + r.commitGraph = append(r.commitGraph, nil) + + return idx +} + // buildCommitGraph builds the commit graph and associate commit details from scratch -// Returns a list of new commit hashes that don't have cached Patch IDs. +// Returns a list of new commit indexes that don't have cached Patch IDs. // The new commit list is in reverse chronological order based on commit date (the default for git log). -func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryCache) ([]SHA1, error) { +func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryCache) ([]int, error) { logger.InfoContext(ctx, "Starting graph construction") start := time.Now() @@ -119,7 +142,7 @@ func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryC } } } - var newCommits []SHA1 + var newCommits []int // Temp outFile for git log output tmpFile, err := os.CreateTemp(r.repoPath, "git-log.out") @@ -158,17 +181,20 @@ func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryC var childHash SHA1 parentHashes := []SHA1{} - tags := []string{} + refs := []string{} switch len(commitInfo) { case 3: // refs are separated by commas - refs := strings.Split(commitInfo[2], ", ") - for _, ref := range refs { - // Remove prefixes from tags, other refs such as HEAD will be left as is - if strings.Contains(ref, "tag: ") { - tags = append(tags, strings.TrimPrefix(ref, "tag: ")) + rawRefs := strings.Split(commitInfo[2], ", ") + for _, ref := range rawRefs { + if ref == "" { + continue } + // Remove prefixes from tags, other refs such as branches will be left as is + ref = strings.TrimPrefix(ref, "tag: ") + ref = strings.TrimPrefix(ref, "HEAD -> ") // clean up HEAD -> branch-name to just keep the branch name + refs = append(refs, ref) } fallthrough @@ -198,32 +224,36 @@ func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryC continue } + childIdx := r.getOrCreateIndex(childHash) + commit := r.commits[childIdx] + commit.Refs = refs + + // We want to keep the root commit (no parent) easily accessible for introduced=0 + if len(parentHashes) == 0 { + r.rootCommits = append(r.rootCommits, childIdx) + } + // Add commit to graph (parent -> []child) for _, parentHash := range parentHashes { - r.commitGraph[parentHash] = append(r.commitGraph[parentHash], childHash) - } + parentIdx := r.getOrCreateIndex(parentHash) + commit.Parents = append(commit.Parents, parentIdx) - commit := Commit{ - Hash: childHash, - Tags: tags, - Parents: parentHashes, + r.commitGraph[parentIdx] = append(r.commitGraph[parentIdx], childIdx) } if patchID, ok := cachedPatchIDs[childHash]; ok { // Assign saved patch ID to commit details and map if found commit.PatchID = patchID // Also populate patchIDToCommits map - r.patchIDToCommits[patchID] = append(r.patchIDToCommits[patchID], childHash) + r.patchIDToCommits[patchID] = append(r.patchIDToCommits[patchID], childIdx) } else { // Add to slice for patch ID to be generated later - newCommits = append(newCommits, childHash) + newCommits = append(newCommits, childIdx) } - r.commitDetails[childHash] = &commit - - // Also populate the tag-to-commit map - for _, tag := range tags { - r.tagToCommit[tag] = childHash + // Also populate the ref-to-commit map + for _, ref := range refs { + r.refToCommit[ref] = childIdx } } @@ -234,7 +264,7 @@ func (r *Repository) buildCommitGraph(ctx context.Context, cache *pb.RepositoryC // calculatePatchIDs calculates patch IDs only for the specific commits provided. // Commits should be passed in order if possible. Processing linear commits sequentially improves performance slightly (in the 'git show' commands). -func (r *Repository) calculatePatchIDs(ctx context.Context, commits []SHA1) error { +func (r *Repository) calculatePatchIDs(ctx context.Context, commits []int) error { logger.InfoContext(ctx, "Starting patch ID calculation") start := time.Now() @@ -272,7 +302,7 @@ func (r *Repository) calculatePatchIDs(ctx context.Context, commits []SHA1) erro // calculatePatchIDsWorker calculates patch IDs and update CommitDetail and patchIDToCommits map. // Essentially running `git show | git patch-id --stable` -func (r *Repository) calculatePatchIDsWorker(ctx context.Context, chunk []SHA1) error { +func (r *Repository) calculatePatchIDsWorker(ctx context.Context, chunk []int) error { // Prepare git commands // `git show --stdin --patch --first-parent --no-color`: // --patch to show diffs in a format that can be directly piped into `git patch-id` @@ -319,11 +349,12 @@ func (r *Repository) calculatePatchIDsWorker(ctx context.Context, chunk []SHA1) // Write hashes to git show stdin go func() { defer in.Close() - for _, hash := range chunk { + for _, idx := range chunk { // Handle context cancel if ctx.Err() != nil { return } + hash := r.commits[idx].Hash fmt.Fprintf(in, "%s\n", hex.EncodeToString(hash[:])) } }() @@ -385,13 +416,267 @@ func (r *Repository) calculatePatchIDsWorker(ctx context.Context, chunk []SHA1) return nil } +// updatePatchID updates the PatchID for a given commit and adds it to the patchIDToCommits map. func (r *Repository) updatePatchID(commitHash, patchID SHA1) { r.patchIDMu.Lock() defer r.patchIDMu.Unlock() - commit := r.commitDetails[commitHash] + idx, ok := r.hashToIndex[commitHash] + if !ok { + // This should never happen because we only call git patch-id on commits we see when building commit graph. + return + } + commit := r.commits[idx] commit.PatchID = patchID - r.commitDetails[commitHash] = commit - r.patchIDToCommits[patchID] = append(r.patchIDToCommits[patchID], commitHash) + r.patchIDToCommits[patchID] = append(r.patchIDToCommits[patchID], idx) +} + +// parseHashes converts a slice of string hashes into a slice of commit indexes. +func (r *Repository) parseHashes(ctx context.Context, hashesStr []string) []int { + indices := make([]int, 0, len(hashesStr)) + addedRoot := false // Only add root commits once if multiple intro=0 are provided + + for _, hash := range hashesStr { + if hash == "0" { + if !addedRoot { + indices = append(indices, r.rootCommits...) + addedRoot = true + } + + continue + } + + hashBytes, err := hex.DecodeString(hash) + // Log error but continue with the rest of the hashes if a commit hash is invalid + if err != nil { + logger.ErrorContext(ctx, "failed to decode commit hash", slog.String("hash", hash), slog.Any("err", err)) + continue + } + if len(hashBytes) != 20 { + logger.ErrorContext(ctx, "invalid hash length", slog.String("hash", hash), slog.Int("len", len(hashBytes))) + continue + } + + h := SHA1(hashBytes) + if idx, ok := r.hashToIndex[h]; ok { + indices = append(indices, idx) + } else { + logger.ErrorContext(ctx, "commit hash not found in repository", slog.String("hash", hash)) + } + } + + return indices +} + +// expandByCherrypick expands a slice of commits by adding commits that have the same Patch ID (cherrypicked commits) returns a new list containing the original commits + any other commits that share the same Patch ID +func (r *Repository) expandByCherrypick(commits []int) []int { + unique := make(map[int]struct{}, len(commits)) // avoid duplication + var zeroPatchID SHA1 + + for _, idx := range commits { + // Find patch ID from commit details + commit := r.commits[idx] + if commit.PatchID == zeroPatchID { + unique[idx] = struct{}{} + continue + } + + // Add equivalent commits with the same Patch ID (including the current commit) + equivalents := r.patchIDToCommits[commit.PatchID] + for _, eq := range equivalents { + unique[eq] = struct{}{} + } + } + + keys := slices.Collect(maps.Keys(unique)) + + return keys +} + +// Affected returns a list of commits that are affected by the given introduced, fixed and last_affected events +// A commit is affected when: from at least one introduced that is an ancestor of the commit, there is no path between them that passes through a fix. +// A fix can either be a fixed commit, or the children of a lastAffected commit. +func (r *Repository) Affected(ctx context.Context, se *SeparatedEvents, cherrypickIntro, cherrypickFixed bool) []*Commit { + logger.InfoContext(ctx, "Starting affected commit walking") + start := time.Now() + + introduced := r.parseHashes(ctx, se.Introduced) + fixed := r.parseHashes(ctx, se.Fixed) + lastAffected := r.parseHashes(ctx, se.LastAffected) + + // Expands the introduced and fixed commits to include cherrypick equivalents + // lastAffected should not be expanded because it does not imply a "fix" commit that can be cherrypicked to other branches + if cherrypickIntro { + introduced = r.expandByCherrypick(introduced) + } + if cherrypickFixed { + fixed = r.expandByCherrypick(fixed) + } + + // Fixed commits and children of last affected are both in this set + // For graph traversal sake they are both considered the fix + fixedMap := make([]bool, len(r.commits)) + + for _, idx := range fixed { + fixedMap[idx] = true + } + + for _, idx := range lastAffected { + if idx < len(r.commitGraph) { + for _, childIdx := range r.commitGraph[idx] { + fixedMap[childIdx] = true + } + } + } + + // The graph traversal + // affectedMap deduplicates the affected commits from the graph walk from each introduced commit + affectedMap := make([]bool, len(r.commits)) + + // Preallocating the big slices, they will be cleared inside the per-intro graph walking + queue := make([]int, 0, len(r.commits)) + affectedFromIntro := make([]bool, len(r.commits)) + updatedIdx := make([]int, 0, len(r.commits)) + unaffectable := make([]bool, len(r.commits)) + visited := make([]bool, len(r.commits)) + + // Walk each introduced commit and find its affected commit + for _, introIdx := range introduced { + // BFS from intro + queue = append(queue, introIdx) + clear(affectedFromIntro) + clear(updatedIdx) + clear(unaffectable) + clear(visited) + + for len(queue) > 0 { + curr := queue[0] + queue = queue[1:] + + if visited[curr] { + continue + } + visited[curr] = true + + // Descendant of a fixed commit + if unaffectable[curr] { + continue + } + + // If we hit a fixed commit, its entire tree is treated as unaffectable + // as any downstream commit can go through this fixed commit to become unaffected + if fixedMap[curr] { + unaffectable[curr] = true + // Inline DFS from current (fixed) node to make all descendants as unaffected / unaffectable + // 1. If a previous path added the descendant to affected list, remove it + // 2. Add to the unaffectable set to block future paths + stack := []int{curr} + for len(stack) > 0 { + unaffected := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + // Remove from affected list if it was reached via a previous non-fixed path. + affectedFromIntro[unaffected] = false + + if unaffected < len(r.commitGraph) { + for _, childIdx := range r.commitGraph[unaffected] { + // Continue down the path if the child isn't already blocked. + if !unaffectable[childIdx] { + unaffectable[childIdx] = true + stack = append(stack, childIdx) + } + } + } + } + + continue + } + + // Otherwise, add to the intro-specific affected list and continue + affectedFromIntro[curr] = true + updatedIdx = append(updatedIdx, curr) + if curr < len(r.commitGraph) { + queue = append(queue, r.commitGraph[curr]...) + } + } + + // Add the final affected list of this introduced commit to the global set + // We only look at the index that are updated in this loop + for _, commitIdx := range updatedIdx { + if affectedFromIntro[commitIdx] { + affectedMap[commitIdx] = true + } + } + } + + // Return the affected commit details + affectedCommits := make([]*Commit, 0) + for idx, affected := range affectedMap { + if affected { + affectedCommits = append(affectedCommits, r.commits[idx]) + } + } + + logger.InfoContext(ctx, "Affected commit walking completed", slog.Duration("duration", time.Since(start))) + + return affectedCommits +} + +// Limit walks and returns the commits that are strictly between introduced (inclusive) and limit (exclusive) +func (r *Repository) Limit(ctx context.Context, se *SeparatedEvents, cherrypickIntro, cherrypickLimit bool) []*Commit { + introduced := r.parseHashes(ctx, se.Introduced) + limit := r.parseHashes(ctx, se.Limit) + + if cherrypickIntro { + introduced = r.expandByCherrypick(introduced) + } + if cherrypickLimit { + limit = r.expandByCherrypick(limit) + } + + var affectedCommits []*Commit + + introMap := make([]bool, len(r.commits)) + for _, idx := range introduced { + introMap[idx] = true + } + + // DFS to walk from limit(s) to introduced (follow first parent) + stack := make([]int, 0, len(limit)) + // Start from limits' parents + for _, idx := range limit { + commit := r.commits[idx] + if len(commit.Parents) > 0 { + stack = append(stack, commit.Parents[0]) + } + } + + visited := make([]bool, len(r.commits)) + + for len(stack) > 0 { + curr := stack[len(stack)-1] + stack = stack[:len(stack)-1] + + if visited[curr] { + continue + } + visited[curr] = true + + // Add current node to affected commits + commit := r.commits[curr] + affectedCommits = append(affectedCommits, commit) + + // If commit is in introduced, we can stop the traversal after adding it to affected + if introMap[curr] { + continue + } + + // In git merge, first parent is the HEAD commit at the time of merge (on the branch that gets merged into) + if len(commit.Parents) > 0 { + stack = append(stack, commit.Parents[0]) + } + } + + return affectedCommits } diff --git a/go/cmd/gitter/repository_test.go b/go/cmd/gitter/repository_test.go index 9ec916c30fc..adbf702ebe2 100644 --- a/go/cmd/gitter/repository_test.go +++ b/go/cmd/gitter/repository_test.go @@ -2,10 +2,15 @@ package main import ( "context" + "encoding/hex" "os" "os/exec" "path/filepath" + "strings" "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) // A very simple test repository with 3 commits and 2 tags. @@ -69,12 +74,13 @@ func TestBuildCommitGraph(t *testing.T) { t.Errorf("expected 3 new commits, got %d", len(newCommits)) } - if len(r.commitDetails) != 3 { - t.Errorf("expected 3 commits with details, got %d", len(r.commitDetails)) + if len(r.commits) != 3 { + t.Errorf("expected 3 commits, got %d", len(r.commits)) } - if len(r.tagToCommit) != 2 { - t.Errorf("expected 2 tags, got %d", len(r.tagToCommit)) + // 2 tags + main branch + if len(r.refToCommit) != 3 { + t.Errorf("expected 3 refs, got %d", len(r.refToCommit)) } } @@ -94,10 +100,10 @@ func TestCalculatePatchIDs(t *testing.T) { } // Verify all commits have patch IDs - for _, hash := range newCommits { - details := r.commitDetails[hash] - if details.PatchID == [20]byte{} { - t.Errorf("missing patch ID for commit %x", hash) + for _, idx := range newCommits { + commit := r.commits[idx] + if commit.PatchID == [20]byte{} { + t.Errorf("missing patch ID for commit %s", printSHA1(commit.Hash)) } } } @@ -125,9 +131,753 @@ func TestLoadRepository(t *testing.T) { } // Check that the two sets of Patch IDs are the same - for hash, details := range r1.commitDetails { - if details.PatchID != r2.commitDetails[hash].PatchID { - t.Errorf("patch ID mismatch for commit %x", hash) + for idx, commit := range r1.commits { + if commit.PatchID != r2.commits[idx].PatchID { + t.Errorf("patch ID mismatch for commit %s", printSHA1(commit.Hash)) + } + } +} + +// For test setup +func (r *Repository) addEdgeForTest(parent, child SHA1) { + pIdx := r.getOrCreateIndex(parent) + cIdx := r.getOrCreateIndex(child) + r.commitGraph[pIdx] = append(r.commitGraph[pIdx], cIdx) + r.commits[cIdx].Parents = append(r.commits[cIdx].Parents, pIdx) +} + +// Helper to decode string into SHA1 +func decodeSHA1(s string) SHA1 { + var hash SHA1 + // Pad with zeros because the test strings are shorter than 40 char + padded := strings.Repeat("0", 40-len(s)) + s + b, err := hex.DecodeString(padded) + if err != nil { + panic(err) + } + copy(hash[:], b) + + return hash +} + +// Helper to encode SHA1 into string +func encodeSHA1(hash SHA1) string { + return hex.EncodeToString(hash[:]) +} + +// Helper to pretty print SHA1 as string (leading 0's removed) +func printSHA1(hash SHA1) string { + // Remove padding zeros for a cleaner results + str := hex.EncodeToString(hash[:]) + + return strings.TrimLeft(str, "0") +} + +// cmpSHA1Opts are applied to the cmp.Diff function to make the output more readable +// 1. Transform SHA1s to pretty strings +// 2. Sorts slices to ensure deterministic comparisons +var cmpSHA1Opts = []cmp.Option{ + cmp.Transformer("SHA1s", func(in []SHA1) []string { + out := make([]string, len(in)) + for i, h := range in { + out[i] = printSHA1(h) } + + return out + }), + cmpopts.SortSlices(func(a, b string) bool { + return a < b + }), +} + +func TestExpandByCherrypick(t *testing.T) { + repo := NewRepository("/repo") + + // Commit hashes + h1 := decodeSHA1("aaaa") + h2 := decodeSHA1("bbbb") + h3 := decodeSHA1("cccc") + + // Patch ID + p1 := decodeSHA1("1111") + + // Setup commit details + idx1 := repo.getOrCreateIndex(h1) + idx2 := repo.getOrCreateIndex(h2) + idx3 := repo.getOrCreateIndex(h3) + + repo.commits[idx1].PatchID = p1 + repo.commits[idx3].PatchID = p1 // h3 has the same patch ID as h1 should be cherry picked + + // Setup patch ID map + repo.patchIDToCommits[p1] = []int{idx1, idx3} + + tests := []struct { + name string + input []int + expected []SHA1 + }{ + { + name: "Expand single commit with cherry-pick", + input: []int{idx1}, + expected: []SHA1{h1, h3}, + }, + { + name: "No expansion for commit without cherry-pick", + input: []int{idx2}, + expected: []SHA1{h2}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIdxs := repo.expandByCherrypick(tt.input) + var got []SHA1 + for _, idx := range gotIdxs { + got = append(got, repo.commits[idx].Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("expandByCherrypick() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// Testing cases with introduced and fixed only. +func TestAffected_Introduced_Fixed(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // -> F -> G + // / + // A -> B -> C -> D -> E + // \ / + // -> H -> + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.addEdgeForTest(hA, hB) + repo.addEdgeForTest(hB, hC) + repo.addEdgeForTest(hB, hH) + repo.addEdgeForTest(hC, hD) + repo.addEdgeForTest(hC, hF) + repo.addEdgeForTest(hD, hE) + repo.addEdgeForTest(hF, hG) + repo.addEdgeForTest(hH, hD) + repo.rootCommits = []int{0} // Root commit is A + + tests := []struct { + name string + se *SeparatedEvents + expected []SHA1 + }{ + { + name: "Linear: A introduced, B fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hB)}, + }, + expected: []SHA1{hA}, + }, + { + name: "Branch propagation: A introduced, C fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hC)}, + }, + expected: []SHA1{hA, hB, hH}, + }, + { + name: "Re-introduced: (A,C) introduced, (B,D,G) fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA), encodeSHA1(hC)}, + Fixed: []string{encodeSHA1(hB), encodeSHA1(hD), encodeSHA1(hG)}, + }, + expected: []SHA1{hA, hC, hF}, + }, + { + name: "Merge intro: H introduced, E fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hH)}, + Fixed: []string{encodeSHA1(hE)}, + }, + expected: []SHA1{hH, hD}, + }, + { + name: "Merge fix: A introduced, H fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hH)}, + }, + expected: []SHA1{hA, hB, hC, hF, hG}, + }, + { + name: "Merge intro and fix (different branches): C introduced, H fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hC)}, + Fixed: []string{encodeSHA1(hH)}, + }, + expected: []SHA1{hC, hD, hE, hF, hG}, + }, + { + name: "Introduced = 0: C fixed", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + Fixed: []string{encodeSHA1(hC)}, + }, + expected: []SHA1{hA, hB, hH}, + }, + { + name: "Introduced = 0: no fix", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + }, + expected: []SHA1{hA, hB, hC, hD, hE, hF, hG, hH}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Affected(t.Context(), tt.se, false, false) + + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("TestAffected_Introduced_Fixed() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestAffected_Introduced_LastAffected(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // -> F -> G + // / + // A -> B -> C -> D -> E + // \ / + // -> H -> + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.addEdgeForTest(hA, hB) + repo.addEdgeForTest(hB, hC) + repo.addEdgeForTest(hB, hH) + repo.addEdgeForTest(hC, hD) + repo.addEdgeForTest(hC, hF) + repo.addEdgeForTest(hD, hE) + repo.addEdgeForTest(hF, hG) + repo.addEdgeForTest(hH, hD) + repo.rootCommits = []int{0} // Root commit is A + + tests := []struct { + name string + se *SeparatedEvents + expected []SHA1 + }{ + { + name: "Linear: D introduced, E lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hD)}, + LastAffected: []string{encodeSHA1(hE)}, + }, + expected: []SHA1{hD, hE}, + }, + { + name: "Branch propagation: A introduced, C lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + LastAffected: []string{encodeSHA1(hC)}, + }, + expected: []SHA1{hA, hB, hC, hH}, + }, + { + name: "Re-introduced: (A,D) introduced, (B,E) lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA), encodeSHA1(hD)}, + LastAffected: []string{encodeSHA1(hB), encodeSHA1(hE)}, + }, + expected: []SHA1{hA, hB, hD, hE}, + }, + { + name: "Merge intro: H introduced, D lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hH)}, + LastAffected: []string{encodeSHA1(hD)}, + }, + expected: []SHA1{hH, hD}, + }, + { + name: "Merge lastAffected: A introduced, H lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + LastAffected: []string{encodeSHA1(hH)}, + }, + expected: []SHA1{hA, hB, hC, hF, hG, hH}, + }, + { + name: "Merge intro and lastAffected (different branches): C introduced, H lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hC)}, + LastAffected: []string{encodeSHA1(hH)}, + }, + expected: []SHA1{hC, hF, hG}, + }, + { + name: "Introduced = 0: C lastAffected", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + LastAffected: []string{encodeSHA1(hC)}, + }, + expected: []SHA1{hA, hB, hC, hH}, + }, + { + name: "Introduced = 0: no fix", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + }, + expected: []SHA1{hA, hB, hC, hD, hE, hF, hG, hH}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Affected(t.Context(), tt.se, false, false) + + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("TestAffected_Introduced_LastAffected() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +// Testing with both fixed and lastAffected +func TestAffected_Combined(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // -> F -> G + // / + // A -> B -> C -> D -> E + // \ / + // -> H -> + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.addEdgeForTest(hA, hB) + repo.addEdgeForTest(hB, hC) + repo.addEdgeForTest(hB, hH) + repo.addEdgeForTest(hC, hD) + repo.addEdgeForTest(hC, hF) + repo.addEdgeForTest(hD, hE) + repo.addEdgeForTest(hF, hG) + repo.addEdgeForTest(hH, hD) + + tests := []struct { + name string + se *SeparatedEvents + expected []SHA1 + }{ + { + name: "Branching out: C introduced, G fixed, D lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hC)}, + Fixed: []string{encodeSHA1(hG)}, + LastAffected: []string{encodeSHA1(hD)}, + }, + expected: []SHA1{hC, hD, hF}, + }, + { + name: "Redundant Blocking: A introduced, B fixed, E lastAffected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hB)}, + LastAffected: []string{encodeSHA1(hE)}, + }, + expected: []SHA1{hA}, + }, + { + name: "Introduced=Fixed: No affected commit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hB)}, + Fixed: []string{encodeSHA1(hB)}, + }, + expected: []SHA1{}, + }, + { + name: "Introduced=lastAffected: Only current commit affected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hB)}, + LastAffected: []string{encodeSHA1(hB)}, + }, + expected: []SHA1{hB}, + }, + { + name: "Fixed=lastAffected: Stop at fix, lastAffected no effect", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hB)}, + LastAffected: []string{encodeSHA1(hB)}, + }, + expected: []SHA1{hA}, + }, + { + // This is the current behaviour as we treat child of lastAffected commit as a fixed commit + name: "Intro=lastAffected+1: commit not affected", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA), encodeSHA1(hC)}, // C is the child of B + LastAffected: []string{encodeSHA1(hB)}, + }, + expected: []SHA1{hA, hB}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Affected(t.Context(), tt.se, false, false) + + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("TestAffected_Combined() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestAffected_Cherrypick(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // A -> B -> C -> D + // | | + // | (cherrypick) + // | | + // E -> F -> G -> H + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + c1 := decodeSHA1("c1") + c2 := decodeSHA1("c2") + + // Setup graph (Parent -> Children) + repo.addEdgeForTest(hA, hB) + repo.addEdgeForTest(hB, hC) + repo.addEdgeForTest(hC, hD) + repo.addEdgeForTest(hE, hF) + repo.addEdgeForTest(hF, hG) + repo.addEdgeForTest(hG, hH) + repo.rootCommits = []int{0} + + // Setup PatchID map for cherrypicking + idxA := repo.getOrCreateIndex(hA) + idxE := repo.getOrCreateIndex(hE) + repo.patchIDToCommits[c1] = []int{idxA, idxE} + idxC := repo.getOrCreateIndex(hC) + idxG := repo.getOrCreateIndex(hG) + repo.patchIDToCommits[c2] = []int{idxC, idxG} + + repo.commits[idxA].PatchID = c1 + repo.commits[idxE].PatchID = c1 + repo.commits[idxC].PatchID = c2 + repo.commits[idxG].PatchID = c2 + + tests := []struct { + name string + se *SeparatedEvents + cherrypickIntro bool + cherrypickFixed bool + expected []SHA1 + }{ + { + name: "Cherrypick Introduced Only: A introduced, G fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: true, + cherrypickFixed: false, + expected: []SHA1{hA, hB, hC, hD, hE, hF}, + }, + { + name: "Cherrypick Fixed Only: A introduced, G fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: false, + cherrypickFixed: true, + expected: []SHA1{hA, hB}, + }, + { + name: "Cherrypick Introduced and Fixed: A introduced, G fixed", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Fixed: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: true, + cherrypickFixed: true, + expected: []SHA1{hA, hB, hE, hF}, + }, + { + name: "Cherrypick Introduced=0: G fixed", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + Fixed: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: true, + cherrypickFixed: true, + expected: []SHA1{hA, hB, hE, hF}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Affected(t.Context(), tt.se, tt.cherrypickIntro, tt.cherrypickFixed) + + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("TestAffected_Cherrypick() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLimit(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // A -> B -> C -> D -> E + // \ + // -> F -> G -> H + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + // Setup graph (Parent -> Children) + repo.addEdgeForTest(hA, hB) + repo.addEdgeForTest(hB, hC) + repo.addEdgeForTest(hB, hF) + repo.addEdgeForTest(hC, hD) + repo.addEdgeForTest(hD, hE) + repo.addEdgeForTest(hF, hG) + repo.addEdgeForTest(hG, hH) + repo.rootCommits = []int{0} // A is root commit + + tests := []struct { + name string + se *SeparatedEvents + expected []SHA1 + }{ + { + name: "One branch: A introduced, D limit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Limit: []string{encodeSHA1(hD)}, + }, + expected: []SHA1{hA, hB, hC}, + }, + { + name: "Side branch: A introduced, G limit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Limit: []string{encodeSHA1(hG)}, + }, + expected: []SHA1{hA, hB, hF}, + }, + { + name: "Two branches: A introduced, (D,G) limit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Limit: []string{encodeSHA1(hD), encodeSHA1(hG)}, + }, + expected: []SHA1{hA, hB, hC, hF}, + }, + { + name: "Introduced=0, G limit", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + Limit: []string{encodeSHA1(hG)}, + }, + expected: []SHA1{hA, hB, hF}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Limit(t.Context(), tt.se, false, false) + + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("TestLimit() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestLimit_Cherrypick(t *testing.T) { + repo := NewRepository("/repo") + + // Graph: (Parent -> Child) + // A -> B -> C -> D + // | | + // (cherrypick) + // | | + // E -> F -> G -> H + + hA := decodeSHA1("aaaa") + hB := decodeSHA1("bbbb") + hC := decodeSHA1("cccc") + hD := decodeSHA1("dddd") + hE := decodeSHA1("eeee") + hF := decodeSHA1("ffff") + hG := decodeSHA1("abab") + hH := decodeSHA1("acac") + + c1 := decodeSHA1("c1") + c2 := decodeSHA1("c2") + + // Setup graph (Parent -> Children) + repo.addEdgeForTest(hA, hB) + repo.addEdgeForTest(hB, hC) + repo.addEdgeForTest(hC, hD) + repo.addEdgeForTest(hE, hF) + repo.addEdgeForTest(hF, hG) + repo.addEdgeForTest(hG, hH) + repo.rootCommits = []int{0} + + // Setup PatchID map for cherrypicking + idxB := repo.getOrCreateIndex(hB) + idxF := repo.getOrCreateIndex(hF) + repo.patchIDToCommits[c1] = []int{idxB, idxF} + idxC := repo.getOrCreateIndex(hC) + idxG := repo.getOrCreateIndex(hG) + repo.patchIDToCommits[c2] = []int{idxC, idxG} + + repo.commits[idxB].PatchID = c1 + repo.commits[idxF].PatchID = c1 + repo.commits[idxC].PatchID = c2 + repo.commits[idxG].PatchID = c2 + + tests := []struct { + name string + se *SeparatedEvents + cherrypickIntro bool + cherrypickLimit bool + expected []SHA1 + }{ + { + name: "Cherrypick Introduced Only: B introduced, G limit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hB)}, + Limit: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: true, + cherrypickLimit: false, + expected: []SHA1{hF}, + }, + { + name: "Cherrypick Limit Only: B introduced, G limit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hB)}, + Limit: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: false, + cherrypickLimit: true, + expected: []SHA1{hB, hE, hF}, + }, + { + name: "Cherrypick Introduced and Limit: A introduced, G limit", + se: &SeparatedEvents{ + Introduced: []string{encodeSHA1(hA)}, + Limit: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: true, + cherrypickLimit: true, + expected: []SHA1{hA, hB, hE, hF}, + }, + { + name: "Cherrypick Introduced=0: G limit", + se: &SeparatedEvents{ + Introduced: []string{"0"}, + Limit: []string{encodeSHA1(hG)}, + }, + cherrypickIntro: true, + cherrypickLimit: true, + expected: []SHA1{hA, hB, hE, hF}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCommits := repo.Limit(t.Context(), tt.se, tt.cherrypickIntro, tt.cherrypickLimit) + + var got []SHA1 + for _, c := range gotCommits { + got = append(got, c.Hash) + } + + if diff := cmp.Diff(tt.expected, got, cmpSHA1Opts...); diff != "" { + t.Errorf("TestLimit_Cherrypick() mismatch (-want +got):\n%s", diff) + } + }) } } diff --git a/go/go.mod b/go/go.mod index 02989daec1f..9e94788661a 100644 --- a/go/go.mod +++ b/go/go.mod @@ -35,6 +35,8 @@ require ( github.com/charmbracelet/x/windows v0.2.2 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect + github.com/dgraph-io/ristretto/v2 v2.4.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/muesli/cancelreader v0.2.2 // indirect ) diff --git a/go/go.sum b/go/go.sum index f018ad10bef..e53582bfc14 100644 --- a/go/go.sum +++ b/go/go.sum @@ -77,6 +77,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= +github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=