diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 54f6b9f40..42b33d80f 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -7,12 +7,11 @@ before: builds: - env: - - CGO_ENABLED=0 + - CGO_ENABLED=1 ldflags: - -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.Date}} goos: - linux - - windows - darwin main: ./cmd/github-mcp-server diff --git a/Dockerfile b/Dockerfile index 6ff2babb8..c5426e2a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,16 +4,16 @@ ARG VERSION="dev" # Set the working directory WORKDIR /build -# Install git +# Install git and C compiler for CGO (tree-sitter) RUN --mount=type=cache,target=/var/cache/apk \ - apk add git + apk add git gcc musl-dev # Build the server # go build automatically download required module dependencies to /go/pkg/mod RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ --mount=type=bind,target=. \ - CGO_ENABLED=0 go build -ldflags="-s -w -X main.version=${VERSION} -X main.commit=$(git rev-parse HEAD) -X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ + CGO_ENABLED=1 go build -ldflags="-s -w -linkmode external -extldflags '-static' -X main.version=${VERSION} -X main.commit=$(git rev-parse HEAD) -X main.date=$(date -u +%Y-%m-%dT%H:%M:%SZ)" \ -o /bin/github-mcp-server ./cmd/github-mcp-server # Make a stage to run the app diff --git a/go.mod b/go.mod index a02997d6c..a625d62f0 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,8 @@ require ( github.com/stretchr/testify v1.11.1 ) +require github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 + require ( github.com/aymerick/douceur v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/go.sum b/go.sum index d525cb0a1..d2e79a1d7 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 h1:cYCy18SHPKRkv github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8= github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0= github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE= +github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 h1:6C8qej6f1bStuePVkLSFxoU22XBS165D3klxlzRg8F4= +github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82/go.mod h1:xe4pgH49k4SsmkQq5OT8abwhWmnzkhpgnXeekbx2efw= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= diff --git a/pkg/github/compare_file_contents_test.go b/pkg/github/compare_file_contents_test.go index af0b9d5a2..e13f37a7d 100644 --- a/pkg/github/compare_file_contents_test.go +++ b/pkg/github/compare_file_contents_test.go @@ -114,7 +114,7 @@ func Test_CompareFileContents(t *testing.T) { expectDiff: `host: "localhost" → "production.db"`, }, { - name: "unsupported format falls back to unified diff", + name: "Go file uses structural diff", mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ GetReposContentsByOwnerByRepoByPath: mockContentsForRef(map[string]string{ "main": "func main() {}\n", @@ -128,8 +128,8 @@ func Test_CompareFileContents(t *testing.T) { "base": "main", "head": "feature", }, - expectFormat: "unified", - expectDiff: "--- a/main.go", + expectFormat: "structural", + expectDiff: "function_declaration main: modified", }, { name: "missing required parameter - owner", diff --git a/pkg/github/semantic_diff.go b/pkg/github/semantic_diff.go index f48d3625a..e05262e4e 100644 --- a/pkg/github/semantic_diff.go +++ b/pkg/github/semantic_diff.go @@ -80,6 +80,10 @@ func SemanticDiff(path string, base, head []byte) SemanticDiffResult { case ".toml": return semanticDiffTOML(path, base, head) default: + // Try tree-sitter structural diff for code files + if languageForPath(path) != nil { + return structuralDiff(path, base, head) + } return SemanticDiffResult{ Format: DiffFormatUnified, Diff: unifiedDiff(path, base, head), @@ -542,6 +546,9 @@ func DetectDiffFormat(path string) DiffFormat { case ".toml": return DiffFormatTOML default: + if languageForPath(path) != nil { + return DiffFormatStructural + } return DiffFormatUnified } } diff --git a/pkg/github/semantic_diff_test.go b/pkg/github/semantic_diff_test.go index 147547822..248d821d4 100644 --- a/pkg/github/semantic_diff_test.go +++ b/pkg/github/semantic_diff_test.go @@ -302,36 +302,17 @@ func TestSemanticDiffTOML(t *testing.T) { } func TestSemanticDiffUnifiedFallback(t *testing.T) { - tests := []struct { - name string - path string - base string - head string - expectedDiff string - }{ - { - name: "unsupported extension uses unified diff", - path: "main.go", - base: "func main() {\n}\n", - head: "func main() {\n\tfmt.Println(\"hello\")\n}\n", - expectedDiff: "--- a/main.go", - }, - { - name: "no extension uses unified diff", - path: "Makefile", - base: "all:\n\techo hello\n", - head: "all:\n\techo world\n", - expectedDiff: "--- a/Makefile", - }, - } + t.Run("Go file uses structural diff", func(t *testing.T) { + result := SemanticDiff("main.go", []byte("func main() {\n}\n"), []byte("func main() {\n\tfmt.Println(\"hello\")\n}\n")) + assert.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, "function_declaration main: modified") + }) - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - result := SemanticDiff(tc.path, []byte(tc.base), []byte(tc.head)) - assert.Equal(t, DiffFormatUnified, result.Format) - assert.Contains(t, result.Diff, tc.expectedDiff) - }) - } + t.Run("no extension uses unified diff", func(t *testing.T) { + result := SemanticDiff("Makefile", []byte("all:\n\techo hello\n"), []byte("all:\n\techo world\n")) + assert.Equal(t, DiffFormatUnified, result.Format) + assert.Contains(t, result.Diff, "--- a/Makefile") + }) } func TestSemanticDiffFileSizeLimit(t *testing.T) { @@ -373,7 +354,7 @@ func TestSemanticDiffNewAndDeletedFiles(t *testing.T) { t.Run("deleted Go file", func(t *testing.T) { result := SemanticDiff("main.go", []byte("package main\n"), nil) - assert.Equal(t, DiffFormatUnified, result.Format) + assert.Equal(t, DiffFormatStructural, result.Format) assert.Equal(t, "file deleted", result.Diff) }) @@ -394,7 +375,7 @@ func TestDetectDiffFormat(t *testing.T) { {"config.yml", DiffFormatYAML}, {"data.csv", DiffFormatCSV}, {"config.toml", DiffFormatTOML}, - {"main.go", DiffFormatUnified}, + {"main.go", DiffFormatStructural}, {"README.md", DiffFormatUnified}, {"Makefile", DiffFormatUnified}, } diff --git a/pkg/github/structural_diff.go b/pkg/github/structural_diff.go new file mode 100644 index 000000000..9aab7c6e7 --- /dev/null +++ b/pkg/github/structural_diff.go @@ -0,0 +1,661 @@ +package github + +import ( + "context" + "fmt" + "path/filepath" + "sort" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "github.com/smacker/go-tree-sitter/c" + "github.com/smacker/go-tree-sitter/cpp" + "github.com/smacker/go-tree-sitter/golang" + "github.com/smacker/go-tree-sitter/java" + "github.com/smacker/go-tree-sitter/javascript" + "github.com/smacker/go-tree-sitter/python" + "github.com/smacker/go-tree-sitter/ruby" + "github.com/smacker/go-tree-sitter/rust" + "github.com/smacker/go-tree-sitter/typescript/tsx" + "github.com/smacker/go-tree-sitter/typescript/typescript" +) + +// DiffFormatStructural indicates a tree-sitter based structural diff. +const DiffFormatStructural DiffFormat = "structural" + +// maxStructuralDiffDepth limits recursion into nested declarations. +const maxStructuralDiffDepth = 5 + +// declaration represents a named top-level code construct (function, class, etc). +type declaration struct { + Kind string // e.g. "function", "class", "type", "import" + Name string + Text string +} + +// languageConfig maps file extensions to tree-sitter languages and the node +// types that should be treated as top-level declarations. +type languageConfig struct { + language *sitter.Language + declarationKinds map[string]bool + nameExtractor func(node *sitter.Node, source []byte) string + indentationIsSignificant bool +} + +// languageForPath returns the tree-sitter language config for a file path, or nil if unsupported. +func languageForPath(path string) *languageConfig { + ext := strings.ToLower(filepath.Ext(path)) + switch ext { + case ".go": + return goConfig() + case ".py": + return pythonConfig() + case ".js", ".mjs", ".cjs": + return javascriptConfig() + case ".ts": + return typescriptConfig() + case ".tsx", ".jsx": + return tsxConfig() + case ".rb": + return rubyConfig() + case ".rs": + return rustConfig() + case ".java": + return javaConfig() + case ".c", ".h": + return cConfig() + case ".cpp", ".hpp", ".cc", ".cxx": + return cppConfig() + default: + return nil + } +} + +func goConfig() *languageConfig { + return &languageConfig{ + language: golang.GetLanguage(), + declarationKinds: map[string]bool{ + "function_declaration": true, + "method_declaration": true, + "type_declaration": true, + "var_declaration": true, + "const_declaration": true, + "import_declaration": true, + "package_clause": true, + }, + nameExtractor: goNameExtractor, + } +} + +func pythonConfig() *languageConfig { + return &languageConfig{ + language: python.GetLanguage(), + declarationKinds: map[string]bool{ + "function_definition": true, + "class_definition": true, + "import_statement": true, + "import_from_statement": true, + }, + nameExtractor: defaultNameExtractor, + indentationIsSignificant: true, + } +} + +func javascriptConfig() *languageConfig { + return &languageConfig{ + language: javascript.GetLanguage(), + declarationKinds: map[string]bool{ + "function_declaration": true, + "class_declaration": true, + "method_definition": true, + "export_statement": true, + "import_statement": true, + "lexical_declaration": true, + "variable_declaration": true, + }, + nameExtractor: jsNameExtractor, + } +} + +func typescriptConfig() *languageConfig { + return &languageConfig{ + language: typescript.GetLanguage(), + declarationKinds: map[string]bool{ + "function_declaration": true, + "class_declaration": true, + "method_definition": true, + "export_statement": true, + "import_statement": true, + "lexical_declaration": true, + "variable_declaration": true, + "interface_declaration": true, + "type_alias_declaration": true, + "enum_declaration": true, + }, + nameExtractor: jsNameExtractor, + } +} + +func tsxConfig() *languageConfig { + return &languageConfig{ + language: tsx.GetLanguage(), + declarationKinds: map[string]bool{ + "function_declaration": true, + "class_declaration": true, + "method_definition": true, + "export_statement": true, + "import_statement": true, + "lexical_declaration": true, + "variable_declaration": true, + "interface_declaration": true, + "type_alias_declaration": true, + "enum_declaration": true, + }, + nameExtractor: jsNameExtractor, + } +} + +func rubyConfig() *languageConfig { + return &languageConfig{ + language: ruby.GetLanguage(), + declarationKinds: map[string]bool{ + "method": true, + "class": true, + "module": true, + }, + nameExtractor: defaultNameExtractor, + } +} + +func rustConfig() *languageConfig { + return &languageConfig{ + language: rust.GetLanguage(), + declarationKinds: map[string]bool{ + "function_item": true, + "struct_item": true, + "enum_item": true, + "impl_item": true, + "trait_item": true, + "mod_item": true, + "use_declaration": true, + "type_item": true, + "const_item": true, + "static_item": true, + }, + nameExtractor: defaultNameExtractor, + } +} + +func javaConfig() *languageConfig { + return &languageConfig{ + language: java.GetLanguage(), + declarationKinds: map[string]bool{ + "class_declaration": true, + "method_declaration": true, + "interface_declaration": true, + "enum_declaration": true, + "import_declaration": true, + "package_declaration": true, + "constructor_declaration": true, + }, + nameExtractor: defaultNameExtractor, + } +} + +func cConfig() *languageConfig { + return &languageConfig{ + language: c.GetLanguage(), + declarationKinds: map[string]bool{ + "function_definition": true, + "declaration": true, + "preproc_include": true, + "preproc_def": true, + "struct_specifier": true, + "enum_specifier": true, + "type_definition": true, + }, + nameExtractor: cNameExtractor, + } +} + +func cppConfig() *languageConfig { + return &languageConfig{ + language: cpp.GetLanguage(), + declarationKinds: map[string]bool{ + "function_definition": true, + "declaration": true, + "preproc_include": true, + "preproc_def": true, + "struct_specifier": true, + "enum_specifier": true, + "class_specifier": true, + "type_definition": true, + "namespace_definition": true, + "template_declaration": true, + }, + nameExtractor: cNameExtractor, + } +} + +// extractDeclarations parses source code and extracts top-level declarations. +func extractDeclarations(config *languageConfig, source []byte) ([]declaration, error) { + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(config.language) + + tree, err := parser.ParseCtx(context.Background(), nil, source) + if err != nil { + return nil, fmt.Errorf("failed to parse: %w", err) + } + defer tree.Close() + + return extractChildDeclarations(config, tree.RootNode(), source), nil +} + +// extractChildDeclarations extracts declarations from the direct children of a node. +func extractChildDeclarations(config *languageConfig, node *sitter.Node, source []byte) []declaration { + var decls []declaration + + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + nodeType := child.Type() + + if !config.declarationKinds[nodeType] { + continue + } + + name := config.nameExtractor(child, source) + if name == "" { + name = fmt.Sprintf("_%s_%d", nodeType, i) + } + + decls = append(decls, declaration{ + Kind: nodeType, + Name: name, + Text: child.Content(source), + }) + } + + return decls +} + +// defaultNameExtractor finds the first "name" or "identifier" child node. +func defaultNameExtractor(node *sitter.Node, source []byte) string { + nameNode := node.ChildByFieldName("name") + if nameNode != nil { + return nameNode.Content(source) + } + // Try first identifier child + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "identifier" || child.Type() == "type_identifier" { + return child.Content(source) + } + } + return "" +} + +// goNameExtractor handles Go-specific naming (method receivers, type/var/const specs). +func goNameExtractor(node *sitter.Node, source []byte) string { + switch node.Type() { + case "method_declaration": + nameNode := node.ChildByFieldName("name") + if nameNode == nil { + return "" + } + name := nameNode.Content(source) + receiver := node.ChildByFieldName("receiver") + if receiver != nil { + return fmt.Sprintf("(%s).%s", extractReceiverType(receiver, source), name) + } + return name + case "type_declaration", "var_declaration", "const_declaration": + // These contain spec children (type_spec, var_spec, const_spec) with name fields + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + nameNode := child.ChildByFieldName("name") + if nameNode != nil { + return nameNode.Content(source) + } + } + return "" + default: + return defaultNameExtractor(node, source) + } +} + +// extractReceiverType extracts the type name from a Go method receiver. +func extractReceiverType(receiver *sitter.Node, source []byte) string { + for i := 0; i < int(receiver.ChildCount()); i++ { + child := receiver.Child(i) + if child.Type() == "parameter_declaration" { + typeNode := child.ChildByFieldName("type") + if typeNode != nil { + return typeNode.Content(source) + } + } + } + return receiver.Content(source) +} + +// jsNameExtractor handles JS/TS-specific naming (variable declarations, exports). +func jsNameExtractor(node *sitter.Node, source []byte) string { + switch node.Type() { + case "lexical_declaration", "variable_declaration": + // const/let/var x = ... + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + if child.Type() == "variable_declarator" { + nameNode := child.ChildByFieldName("name") + if nameNode != nil { + return nameNode.Content(source) + } + } + } + return "" + case "export_statement": + // export default/named - use the inner declaration's name + decl := node.ChildByFieldName("declaration") + if decl != nil { + return jsNameExtractor(decl, source) + } + return defaultNameExtractor(node, source) + case "import_statement": + return node.Content(source) + default: + return defaultNameExtractor(node, source) + } +} + +// cNameExtractor handles C/C++ naming where the function name is inside the declarator. +func cNameExtractor(node *sitter.Node, source []byte) string { + // function_definition: the name is in the declarator field + declarator := node.ChildByFieldName("declarator") + if declarator != nil { + return findIdentifier(declarator, source) + } + return defaultNameExtractor(node, source) +} + +// findIdentifier recursively searches for the first identifier in a node tree. +func findIdentifier(node *sitter.Node, source []byte) string { + if node.Type() == "identifier" { + return node.Content(source) + } + for i := 0; i < int(node.ChildCount()); i++ { + if name := findIdentifier(node.Child(i), source); name != "" { + return name + } + } + return "" +} + +// structuralDiff produces a structural diff using tree-sitter AST parsing. +func structuralDiff(path string, base, head []byte) SemanticDiffResult { + config := languageForPath(path) + if config == nil { + return SemanticDiffResult{ + Format: DiffFormatUnified, + Diff: unifiedDiff(path, base, head), + } + } + + baseDecls, err := extractDeclarations(config, base) + if err != nil { + return fallbackResult(path, base, head, "failed to parse base file") + } + + headDecls, err := extractDeclarations(config, head) + if err != nil { + return fallbackResult(path, base, head, "failed to parse head file") + } + + changes := diffDeclarations(config, baseDecls, headDecls, "", 0) + if len(changes) == 0 { + return SemanticDiffResult{ + Format: DiffFormatStructural, + Diff: "no structural changes detected", + } + } + + return SemanticDiffResult{ + Format: DiffFormatStructural, + Diff: strings.Join(changes, "\n"), + } +} + +// diffDeclarations compares two sets of declarations and returns change descriptions. +// indent controls visual nesting, depth limits recursion. +func diffDeclarations(config *languageConfig, base, head []declaration, indent string, depth int) []string { + baseMap := indexDeclarations(base) + headMap := indexDeclarations(head) + + // Collect all unique keys + allKeys := make(map[string]bool) + for k := range baseMap { + allKeys[k] = true + } + for k := range headMap { + allKeys[k] = true + } + + sortedKeys := make([]string, 0, len(allKeys)) + for k := range allKeys { + sortedKeys = append(sortedKeys, k) + } + sort.Strings(sortedKeys) + + var changes []string + for _, key := range sortedKeys { + baseDecl, inBase := baseMap[key] + headDecl, inHead := headMap[key] + + switch { + case inBase && !inHead: + changes = append(changes, fmt.Sprintf("%s%s %s: removed", indent, baseDecl.Kind, baseDecl.Name)) + case !inBase && inHead: + changes = append(changes, fmt.Sprintf("%s%s %s: added", indent, headDecl.Kind, headDecl.Name)) + case baseDecl.Text != headDecl.Text: + detail := modifiedDetail(config, baseDecl, headDecl, indent, depth) + changes = append(changes, fmt.Sprintf("%s%s %s: modified\n%s", indent, baseDecl.Kind, baseDecl.Name, detail)) + } + } + + return changes +} + +// indexDeclarations creates a lookup map from declaration key to declaration. +// The key combines kind and name to handle same-name declarations of different kinds. +func indexDeclarations(decls []declaration) map[string]declaration { + result := make(map[string]declaration, len(decls)) + for _, d := range decls { + key := d.Kind + ":" + d.Name + result[key] = d + } + return result +} + +// modifiedDetail produces the detail output for a modified declaration. If the +// declaration contains sub-declarations (e.g. methods in a class) and we haven't +// hit the depth limit, it recurses to show which children changed. Otherwise it +// falls back to a line-level diff of the declaration body. +func modifiedDetail(config *languageConfig, baseDecl, headDecl declaration, indent string, depth int) string { + if depth < maxStructuralDiffDepth { + baseChildren := extractChildDeclarationsFromText(config, baseDecl.Text) + headChildren := extractChildDeclarationsFromText(config, headDecl.Text) + + if len(baseChildren) > 0 || len(headChildren) > 0 { + nested := diffDeclarations(config, baseChildren, headChildren, indent+" ", depth+1) + if len(nested) > 0 { + return strings.Join(nested, "\n") + } + } + } + + return declarationDiff(baseDecl.Text, headDecl.Text, indent+" ", config.indentationIsSignificant) +} + +// extractChildDeclarationsFromText parses a declaration's text and extracts any +// nested declarations (e.g. methods inside a class body). +func extractChildDeclarationsFromText(config *languageConfig, text string) []declaration { + parser := sitter.NewParser() + defer parser.Close() + parser.SetLanguage(config.language) + + src := []byte(text) + tree, err := parser.ParseCtx(context.Background(), nil, src) + if err != nil { + return nil + } + defer tree.Close() + + // Walk all descendants looking for declaration nodes that aren't the root wrapper + root := tree.RootNode() + var decls []declaration + findNestedDeclarations(config, root, src, &decls, 0, true) + return decls +} + +// findNestedDeclarations recursively walks AST nodes to find declarations that +// are nested inside a parent. skipRoot=true on the first call to avoid matching +// the re-parsed wrapper of the parent declaration itself. +func findNestedDeclarations(config *languageConfig, node *sitter.Node, source []byte, decls *[]declaration, depth int, skipRoot bool) { + if depth > maxStructuralDiffDepth { + return + } + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + nodeType := child.Type() + + if !skipRoot && config.declarationKinds[nodeType] { + name := config.nameExtractor(child, source) + if name == "" { + name = fmt.Sprintf("_%s_%d", nodeType, i) + } + *decls = append(*decls, declaration{ + Kind: nodeType, + Name: name, + Text: child.Content(source), + }) + } else { + // Keep walking into non-declaration nodes (e.g. class_body) + findNestedDeclarations(config, child, source, decls, depth+1, false) + } + } +} + +// declarationDiff produces a compact, indented diff showing what changed inside +// a modified declaration. For languages where indentation is not significant, +// lines are compared with leading whitespace normalized so that pure formatting +// changes are collapsed. For whitespace-significant languages like Python, +// indentation differences are preserved as meaningful changes. +func declarationDiff(baseText, headText string, indent string, indentationIsSignificant bool) string { + baseLines := strings.Split(baseText, "\n") + headLines := strings.Split(headText, "\n") + + var baseCmp, headCmp []string + if indentationIsSignificant { + baseCmp = baseLines + headCmp = headLines + } else { + baseCmp = trimLines(baseLines) + headCmp = trimLines(headLines) + } + + // Compute LCS — on trimmed content for brace languages, exact for whitespace-significant + lcsIndices := longestCommonSubsequence(baseCmp, headCmp) + + var buf strings.Builder + bi, hi, li := 0, 0, 0 + + for li < len(lcsIndices) { + bIdx := lcsIndices[li][0] + hIdx := lcsIndices[li][1] + + // Lines removed from base before this common line + for bi < bIdx { + buf.WriteString(indent + "- " + baseLines[bi] + "\n") + bi++ + } + // Lines added in head before this common line + for hi < hIdx { + buf.WriteString(indent + "+ " + headLines[hi] + "\n") + hi++ + } + // Common line (by trimmed content) — skip silently + bi++ + hi++ + li++ + } + + // Remaining lines after LCS is exhausted + for bi < len(baseLines) { + buf.WriteString(indent + "- " + baseLines[bi] + "\n") + bi++ + } + for hi < len(headLines) { + buf.WriteString(indent + "+ " + headLines[hi] + "\n") + hi++ + } + + result := strings.TrimRight(buf.String(), "\n") + if result == "" { + return indent + "(whitespace/formatting changes only)" + } + return result +} + +// trimLines returns a slice with each line's leading/trailing whitespace removed. +func trimLines(lines []string) []string { + trimmed := make([]string, len(lines)) + for i, l := range lines { + trimmed[i] = strings.TrimSpace(l) + } + return trimmed +} + +// longestCommonSubsequence returns index pairs [baseIdx, headIdx] for matching +// lines between a and b. +func longestCommonSubsequence(a, b []string) [][2]int { + m, n := len(a), len(b) + dp := make([][]int, m+1) + for i := range dp { + dp[i] = make([]int, n+1) + } + + for i := 1; i <= m; i++ { + for j := 1; j <= n; j++ { + switch { + case a[i-1] == b[j-1]: + dp[i][j] = dp[i-1][j-1] + 1 + case dp[i-1][j] >= dp[i][j-1]: + dp[i][j] = dp[i-1][j] + default: + dp[i][j] = dp[i][j-1] + } + } + } + + // Backtrack to find index pairs + result := make([][2]int, 0, dp[m][n]) + i, j := m, n + for i > 0 && j > 0 { + switch { + case a[i-1] == b[j-1]: + result = append(result, [2]int{i - 1, j - 1}) + i-- + j-- + case dp[i-1][j] >= dp[i][j-1]: + i-- + default: + j-- + } + } + + // Reverse + for left, right := 0, len(result)-1; left < right; left, right = left+1, right-1 { + result[left], result[right] = result[right], result[left] + } + + return result +} diff --git a/pkg/github/structural_diff_test.go b/pkg/github/structural_diff_test.go new file mode 100644 index 000000000..ccc900058 --- /dev/null +++ b/pkg/github/structural_diff_test.go @@ -0,0 +1,387 @@ +package github + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStructuralDiffGo(t *testing.T) { + tests := []struct { + name string + base string + head string + expectedDiff string + notContains string + }{ + { + name: "no changes", + base: `package main + +func hello() {} +`, + head: `package main + +func hello() {} +`, + expectedDiff: "no structural changes detected", + }, + { + name: "function added", + base: `package main + +func hello() {} +`, + head: `package main + +func hello() {} + +func goodbye() {} +`, + expectedDiff: "function_declaration goodbye: added", + }, + { + name: "function removed", + base: `package main + +func hello() {} + +func goodbye() {} +`, + head: `package main + +func hello() {} +`, + expectedDiff: "function_declaration goodbye: removed", + }, + { + name: "function modified", + base: `package main + +func hello() { + fmt.Println("hello") +} +`, + head: `package main + +func hello() { + fmt.Println("world") +} +`, + expectedDiff: "function_declaration hello: modified", + }, + { + name: "function reorder only", + base: `package main + +func a() {} + +func b() {} +`, + head: `package main + +func b() {} + +func a() {} +`, + expectedDiff: "no structural changes detected", + }, + { + name: "method with receiver", + base: `package main + +type Server struct{} + +func (s *Server) Start() {} +`, + head: `package main + +type Server struct{} + +func (s *Server) Start() { + fmt.Println("starting") +} +`, + expectedDiff: "(*Server).Start: modified", + }, + { + name: "type added", + base: `package main +`, + head: `package main + +type Config struct { + Host string +} +`, + expectedDiff: "type_declaration Config: added", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := SemanticDiff("main.go", []byte(tc.base), []byte(tc.head)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, tc.expectedDiff) + if tc.notContains != "" { + assert.NotContains(t, result.Diff, tc.notContains) + } + }) + } +} + +func TestStructuralDiffPython(t *testing.T) { + tests := []struct { + name string + base string + head string + expectedDiff string + }{ + { + name: "function added", + base: `def hello(): + pass +`, + head: `def hello(): + pass + +def goodbye(): + pass +`, + expectedDiff: "function_definition goodbye: added", + }, + { + name: "class modified", + base: `class Foo: + def bar(self): + return 1 +`, + head: `class Foo: + def bar(self): + return 2 +`, + expectedDiff: "class_definition Foo: modified", + }, + { + name: "function reorder", + base: `def a(): + pass + +def b(): + pass +`, + head: `def b(): + pass + +def a(): + pass +`, + expectedDiff: "no structural changes detected", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := SemanticDiff("app.py", []byte(tc.base), []byte(tc.head)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, tc.expectedDiff) + }) + } +} + +func TestStructuralDiffJavaScript(t *testing.T) { + tests := []struct { + name string + path string + base string + head string + expectedDiff string + }{ + { + name: "function added", + path: "app.js", + base: `function hello() { + console.log("hello"); +} +`, + head: `function hello() { + console.log("hello"); +} + +function goodbye() { + console.log("goodbye"); +} +`, + expectedDiff: "function_declaration goodbye: added", + }, + { + name: "const variable modified", + path: "config.js", + base: `const PORT = 3000; +`, + head: `const PORT = 8080; +`, + expectedDiff: "lexical_declaration PORT: modified", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := SemanticDiff(tc.path, []byte(tc.base), []byte(tc.head)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, tc.expectedDiff) + }) + } +} + +func TestStructuralDiffTypeScript(t *testing.T) { + tests := []struct { + name string + path string + base string + head string + expectedDiff string + }{ + { + name: "interface added", + path: "types.ts", + base: `interface User { + name: string; +} +`, + head: `interface User { + name: string; +} + +interface Admin { + role: string; +} +`, + expectedDiff: "interface_declaration Admin: added", + }, + { + name: "TSX component modified", + path: "App.tsx", + base: `function App() { + return
Hello
; +} +`, + head: `function App() { + return
World
; +} +`, + expectedDiff: "function_declaration App: modified", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := SemanticDiff(tc.path, []byte(tc.base), []byte(tc.head)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, tc.expectedDiff) + }) + } +} + +func TestStructuralDiffRust(t *testing.T) { + result := SemanticDiff("lib.rs", []byte(`fn hello() {} +`), []byte(`fn hello() {} + +fn goodbye() {} +`)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, "function_item goodbye: added") +} + +func TestStructuralDiffJava(t *testing.T) { + result := SemanticDiff("Main.java", + []byte(`public class Main { + public static void main(String[] args) {} +} +`), + []byte(`public class Main { + public static void main(String[] args) { + System.out.println("hello"); + } +} +`)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, "class_declaration Main: modified") +} + +func TestStructuralDiffC(t *testing.T) { + result := SemanticDiff("main.c", + []byte(`#include + +int main() { + return 0; +} +`), + []byte(`#include + +int main() { + printf("hello\n"); + return 0; +} +`)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, "main: modified") +} + +func TestStructuralDiffRuby(t *testing.T) { + result := SemanticDiff("app.rb", + []byte(`def hello + puts "hello" +end +`), + []byte(`def hello + puts "hello" +end + +def goodbye + puts "goodbye" +end +`)) + require.Equal(t, DiffFormatStructural, result.Format) + assert.Contains(t, result.Diff, "method goodbye: added") +} + +func TestStructuralDiffUnsupportedFallback(t *testing.T) { + // .txt files have no tree-sitter grammar, should fall back to unified + result := SemanticDiff("notes.txt", []byte("hello\n"), []byte("world\n")) + assert.Equal(t, DiffFormatUnified, result.Format) + assert.Contains(t, result.Diff, "--- a/notes.txt") +} + +func TestLanguageForPath(t *testing.T) { + supported := []string{ + "main.go", "app.py", "index.js", "index.mjs", + "app.ts", "App.tsx", "App.jsx", + "lib.rs", "Main.java", "main.c", "main.h", + "main.cpp", "main.hpp", "main.cc", + "app.rb", + } + for _, path := range supported { + t.Run(path, func(t *testing.T) { + assert.NotNil(t, languageForPath(path), "expected language config for %s", path) + }) + } + + unsupported := []string{ + "config.json", "data.yaml", "notes.txt", "Makefile", "README.md", + } + for _, path := range unsupported { + t.Run(path, func(t *testing.T) { + assert.Nil(t, languageForPath(path), "expected no language config for %s", path) + }) + } +} + +func TestDetectDiffFormatStructural(t *testing.T) { + assert.Equal(t, DiffFormatStructural, DetectDiffFormat("main.go")) + assert.Equal(t, DiffFormatStructural, DetectDiffFormat("app.py")) + assert.Equal(t, DiffFormatStructural, DetectDiffFormat("index.js")) + assert.Equal(t, DiffFormatJSON, DetectDiffFormat("config.json")) + assert.Equal(t, DiffFormatUnified, DetectDiffFormat("notes.txt")) +}