Skip to content

Commit db10882

Browse files
Add symbol extraction to get_file_contents tool
Adds an optional 'symbol' parameter to get_file_contents that uses tree-sitter to extract a specific named symbol (function, class, type, method, etc.) from a file. Instead of returning the entire file, only the matching symbol's source code is returned. Supports all languages from the structural diff engine: Go, Python, JavaScript, TypeScript, Ruby, Rust, Java, C/C++. For unsupported file types, returns an error suggesting the feature is not available. If the symbol is not found, the error message includes a list of available symbols in the file to help the model self-correct. This pairs well with the structural diff tool — a model can see which symbols changed via compare_file_contents, then fetch specific symbols via get_file_contents to examine them in detail.
1 parent d8ed060 commit db10882

File tree

4 files changed

+219
-1
lines changed

4 files changed

+219
-1
lines changed

pkg/github/__toolsnaps__/get_file_contents.snap

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
"sha": {
2727
"description": "Accepts optional commit SHA. If specified, it will be used instead of ref",
2828
"type": "string"
29+
},
30+
"symbol": {
31+
"description": "Optional: extract a specific symbol (function, class, type, etc.) from the file. For supported languages, returns only the symbol's source code instead of the entire file. If the symbol is not found, returns a list of available symbols.",
32+
"type": "string"
2933
}
3034
},
3135
"required": [

pkg/github/repositories.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,10 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
652652
Type: "string",
653653
Description: "Accepts optional commit SHA. If specified, it will be used instead of ref",
654654
},
655+
"symbol": {
656+
Type: "string",
657+
Description: "Optional: extract a specific symbol (function, class, type, etc.) from the file. For supported languages, returns only the symbol's source code instead of the entire file. If the symbol is not found, returns a list of available symbols.",
658+
},
655659
},
656660
Required: []string{"owner", "repo"},
657661
},
@@ -684,6 +688,11 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
684688
return utils.NewToolResultError(err.Error()), nil, nil
685689
}
686690

691+
symbol, err := OptionalParam[string](args, "symbol")
692+
if err != nil {
693+
return utils.NewToolResultError(err.Error()), nil, nil
694+
}
695+
687696
client, err := deps.GetClient(ctx)
688697
if err != nil {
689698
return utils.NewToolResultError("failed to get GitHub client"), nil, nil
@@ -769,9 +778,31 @@ func GetFileContents(t translations.TranslationHelperFunc) inventory.ServerTool
769778
strings.HasSuffix(contentType, "+xml")
770779

771780
if isTextContent {
781+
content := string(body)
782+
783+
// If a symbol was requested, extract just that symbol
784+
if symbol != "" {
785+
symbolText, symbolKind, extractErr := ExtractSymbol(path, body, symbol)
786+
if extractErr != nil {
787+
return utils.NewToolResultError(extractErr.Error()), nil, nil
788+
}
789+
content = symbolText
790+
successMsg := fmt.Sprintf("extracted %s %q from %s", symbolKind, symbol, path)
791+
if fileSHA != "" {
792+
successMsg += fmt.Sprintf(" (SHA: %s)", fileSHA)
793+
}
794+
successMsg += successNote
795+
result := &mcp.ResourceContents{
796+
URI: resourceURI,
797+
Text: content,
798+
MIMEType: contentType,
799+
}
800+
return utils.NewToolResultResource(successMsg, result), nil, nil
801+
}
802+
772803
result := &mcp.ResourceContents{
773804
URI: resourceURI,
774-
Text: string(body),
805+
Text: content,
775806
MIMEType: contentType,
776807
}
777808
// Include SHA in the result metadata

pkg/github/symbol_extraction.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package github
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
)
7+
8+
// ExtractSymbol searches source code for a named symbol and returns its text.
9+
// It searches top-level declarations first, then recursively searches nested
10+
// declarations (e.g. methods inside classes). Returns the symbol text and its
11+
// kind, or an error if the symbol is not found or the language is unsupported.
12+
func ExtractSymbol(path string, source []byte, symbolName string) (text string, kind string, err error) {
13+
config := languageForPath(path)
14+
if config == nil {
15+
return "", "", fmt.Errorf("symbol extraction is not supported for this file type")
16+
}
17+
18+
decls, err := extractDeclarations(config, source)
19+
if err != nil {
20+
return "", "", fmt.Errorf("failed to parse file: %w", err)
21+
}
22+
23+
// Search top-level declarations
24+
if text, kind, found := findSymbol(decls, symbolName); found {
25+
return text, kind, nil
26+
}
27+
28+
// Search nested declarations (methods inside classes, etc.)
29+
for _, decl := range decls {
30+
nested := extractChildDeclarationsFromText(config, decl.Text)
31+
if text, kind, found := findSymbol(nested, symbolName); found {
32+
return text, kind, nil
33+
}
34+
}
35+
36+
// Build list of available symbols for the error message
37+
available := listSymbolNames(config, decls)
38+
return "", "", fmt.Errorf("symbol %q not found. Available symbols: %s", symbolName, strings.Join(available, ", "))
39+
}
40+
41+
// findSymbol searches a slice of declarations for a matching name.
42+
func findSymbol(decls []declaration, name string) (string, string, bool) {
43+
for _, d := range decls {
44+
if d.Name == name {
45+
return d.Text, d.Kind, true
46+
}
47+
}
48+
return "", "", false
49+
}
50+
51+
// listSymbolNames returns all symbol names from top-level and one level of
52+
// nested declarations, for use in error messages.
53+
func listSymbolNames(config *languageConfig, decls []declaration) []string {
54+
var names []string
55+
for _, d := range decls {
56+
if !strings.HasPrefix(d.Name, "_") {
57+
names = append(names, d.Name)
58+
}
59+
nested := extractChildDeclarationsFromText(config, d.Text)
60+
for _, n := range nested {
61+
if !strings.HasPrefix(n.Name, "_") {
62+
names = append(names, n.Name)
63+
}
64+
}
65+
}
66+
return names
67+
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package github
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestExtractSymbol(t *testing.T) {
11+
t.Run("Go function", func(t *testing.T) {
12+
source := []byte("package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n\nfunc world() {\n\tfmt.Println(\"world\")\n}\n")
13+
text, kind, err := ExtractSymbol("main.go", source, "hello")
14+
require.NoError(t, err)
15+
assert.Equal(t, "function_declaration", kind)
16+
assert.Contains(t, text, "func hello()")
17+
assert.Contains(t, text, "hello")
18+
assert.NotContains(t, text, "world")
19+
})
20+
21+
t.Run("Go method with receiver", func(t *testing.T) {
22+
source := []byte("package main\n\ntype Server struct{}\n\nfunc (s *Server) Start() {\n\tlog.Println(\"start\")\n}\n\nfunc (s *Server) Stop() {\n\tlog.Println(\"stop\")\n}\n")
23+
text, kind, err := ExtractSymbol("main.go", source, "(*Server).Start")
24+
require.NoError(t, err)
25+
assert.Equal(t, "method_declaration", kind)
26+
assert.Contains(t, text, "Start")
27+
assert.NotContains(t, text, "Stop")
28+
})
29+
30+
t.Run("Go type", func(t *testing.T) {
31+
source := []byte("package main\n\ntype Config struct {\n\tHost string\n\tPort int\n}\n")
32+
text, kind, err := ExtractSymbol("main.go", source, "Config")
33+
require.NoError(t, err)
34+
assert.Equal(t, "type_declaration", kind)
35+
assert.Contains(t, text, "Host string")
36+
})
37+
38+
t.Run("Python function", func(t *testing.T) {
39+
source := []byte("def hello():\n print('hello')\n\ndef world():\n print('world')\n")
40+
text, kind, err := ExtractSymbol("app.py", source, "hello")
41+
require.NoError(t, err)
42+
assert.Equal(t, "function_definition", kind)
43+
assert.Contains(t, text, "print('hello')")
44+
assert.NotContains(t, text, "world")
45+
})
46+
47+
t.Run("Python class method (nested)", func(t *testing.T) {
48+
source := []byte("class Dog:\n def bark(self):\n return 'woof'\n def fetch(self):\n return 'ball'\n")
49+
text, kind, err := ExtractSymbol("app.py", source, "bark")
50+
require.NoError(t, err)
51+
assert.Equal(t, "function_definition", kind)
52+
assert.Contains(t, text, "woof")
53+
assert.NotContains(t, text, "ball")
54+
})
55+
56+
t.Run("TypeScript class", func(t *testing.T) {
57+
source := []byte("class Api {\n get() {\n return fetch('/data');\n }\n}\n\nfunction helper() { return 1; }\n")
58+
text, kind, err := ExtractSymbol("api.ts", source, "Api")
59+
require.NoError(t, err)
60+
assert.Equal(t, "class_declaration", kind)
61+
assert.Contains(t, text, "get()")
62+
assert.NotContains(t, text, "helper")
63+
})
64+
65+
t.Run("TypeScript class method (nested)", func(t *testing.T) {
66+
source := []byte("class Api {\n get() {\n return fetch('/data');\n }\n post() {\n return fetch('/post');\n }\n}\n")
67+
text, kind, err := ExtractSymbol("api.ts", source, "get")
68+
require.NoError(t, err)
69+
assert.Equal(t, "method_definition", kind)
70+
assert.Contains(t, text, "/data")
71+
assert.NotContains(t, text, "/post")
72+
})
73+
74+
t.Run("symbol not found lists available", func(t *testing.T) {
75+
source := []byte("package main\n\nfunc hello() {}\n\nfunc world() {}\n")
76+
_, _, err := ExtractSymbol("main.go", source, "nonexistent")
77+
require.Error(t, err)
78+
assert.Contains(t, err.Error(), "not found")
79+
assert.Contains(t, err.Error(), "hello")
80+
assert.Contains(t, err.Error(), "world")
81+
})
82+
83+
t.Run("unsupported file type", func(t *testing.T) {
84+
source := []byte("some content")
85+
_, _, err := ExtractSymbol("README.md", source, "anything")
86+
require.Error(t, err)
87+
assert.Contains(t, err.Error(), "not supported")
88+
})
89+
90+
t.Run("Java class with methods", func(t *testing.T) {
91+
source := []byte("class Calculator {\n int add(int a, int b) {\n return a + b;\n }\n int multiply(int a, int b) {\n return a * b;\n }\n}\n")
92+
text, kind, err := ExtractSymbol("Calculator.java", source, "add")
93+
require.NoError(t, err)
94+
assert.Equal(t, "method_declaration", kind)
95+
assert.Contains(t, text, "a + b")
96+
assert.NotContains(t, text, "a * b")
97+
})
98+
99+
t.Run("Rust function", func(t *testing.T) {
100+
source := []byte("fn hello() {\n println!(\"hello\");\n}\n\nfn world() {\n println!(\"world\");\n}\n")
101+
text, kind, err := ExtractSymbol("main.rs", source, "hello")
102+
require.NoError(t, err)
103+
assert.Equal(t, "function_item", kind)
104+
assert.Contains(t, text, "hello")
105+
assert.NotContains(t, text, "world")
106+
})
107+
108+
t.Run("Go var declaration", func(t *testing.T) {
109+
source := []byte("package main\n\nvar defaultTimeout = 30\n\nvar maxRetries = 3\n")
110+
text, kind, err := ExtractSymbol("main.go", source, "defaultTimeout")
111+
require.NoError(t, err)
112+
assert.Equal(t, "var_declaration", kind)
113+
assert.Contains(t, text, "30")
114+
assert.NotContains(t, text, "maxRetries")
115+
})
116+
}

0 commit comments

Comments
 (0)