From 3767bfaa957e9938394f47b30872b14601e928d3 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 17 Mar 2025 14:21:28 -0700 Subject: [PATCH 1/3] Fix bugs in search tool --- src/codegen/extensions/tools/search.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/codegen/extensions/tools/search.py b/src/codegen/extensions/tools/search.py index 47df419c7..2a347c133 100644 --- a/src/codegen/extensions/tools/search.py +++ b/src/codegen/extensions/tools/search.py @@ -5,10 +5,11 @@ Results are paginated with a default of 10 files per page. """ +import logging import os import re import subprocess -from typing import ClassVar, Optional +from typing import ClassVar from pydantic import Field @@ -16,6 +17,8 @@ from .observation import Observation +logger = logging.getLogger(__name__) + class SearchMatch(Observation): """Information about a single line match.""" @@ -114,7 +117,7 @@ def render(self) -> str: def _search_with_ripgrep( codebase: Codebase, query: str, - file_extensions: Optional[list[str]] = None, + file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False, @@ -136,10 +139,10 @@ def _search_with_ripgrep( for ext in file_extensions: # Remove leading dot if present ext = ext[1:] if ext.startswith(".") else ext - cmd.extend(["--type-add", f"custom:{ext}", "--type", "custom"]) + cmd.extend(["--type-add", f"custom:*.{ext}", "--type", "custom"]) # Add target directories if specified - search_path = codebase.repo_path + search_path = str(codebase.repo_path) # Add the query and path cmd.append(f"{query}") @@ -147,6 +150,7 @@ def _search_with_ripgrep( # Run ripgrep try: + logger.info(f"Running ripgrep command: {' '.join(cmd)}") # Use text mode and UTF-8 encoding result = subprocess.run( cmd, @@ -256,7 +260,7 @@ def _search_with_ripgrep( def _search_with_python( codebase: Codebase, query: str, - file_extensions: Optional[list[str]] = None, + file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False, @@ -353,7 +357,7 @@ def _search_with_python( def search( codebase: Codebase, query: str, - file_extensions: Optional[list[str]] = None, + file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False, From 652cf8c6eeb7198055e077ecd56893fa954d1d69 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 17 Mar 2025 14:22:09 -0700 Subject: [PATCH 2/3] Add global replacement tool --- src/codegen/extensions/langchain/agent.py | 18 ++- src/codegen/extensions/langchain/tools.py | 94 ++++++++++--- src/codegen/extensions/tools/__init__.py | 2 + .../tools/global_replacement_edit.py | 129 ++++++++++++++++++ tests/unit/codegen/extensions/test_tools.py | 66 +++++++++ 5 files changed, 285 insertions(+), 24 deletions(-) create mode 100644 src/codegen/extensions/tools/global_replacement_edit.py diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index 8917daa7f..167aa3128 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -1,6 +1,6 @@ """Demo implementation of an agent with Codegen tools.""" -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from langchain.tools import BaseTool from langchain_core.messages import SystemMessage @@ -13,6 +13,7 @@ from codegen.extensions.langchain.tools import ( CreateFileTool, DeleteFileTool, + GlobalReplacementEditTool, ListDirectoryTool, MoveSymbolTool, ReflectionTool, @@ -20,6 +21,7 @@ RenameFileTool, ReplacementEditTool, RevealSymbolTool, + SearchFilesByNameTool, SearchTool, # SemanticEditTool, ViewFileTool, @@ -38,8 +40,8 @@ def create_codebase_agent( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = False, - additional_tools: Optional[list[BaseTool]] = None, - config: Optional[AgentConfig] = None, + additional_tools: list[BaseTool] | None = None, + config: AgentConfig | None = None, **kwargs, ) -> CompiledGraph: """Create an agent with all codebase tools. @@ -76,6 +78,8 @@ def create_codebase_agent( ReplacementEditTool(codebase), RelaceEditTool(codebase), ReflectionTool(codebase), + SearchFilesByNameTool(codebase), + GlobalReplacementEditTool(codebase), # SemanticSearchTool(codebase), # =====[ Github Integration ]===== # Enable Github integration @@ -101,8 +105,8 @@ def create_chat_agent( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = False, - additional_tools: Optional[list[BaseTool]] = None, - config: Optional[dict[str, Any]] = None, # over here you can pass in the max length of the number of messages + additional_tools: list[BaseTool] | None = None, + config: dict[str, Any] | None = None, # over here you can pass in the max length of the number of messages **kwargs, ) -> CompiledGraph: """Create an agent with all codebase tools. @@ -151,7 +155,7 @@ def create_codebase_inspector_agent( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = True, - config: Optional[dict[str, Any]] = None, + config: dict[str, Any] | None = None, **kwargs, ) -> CompiledGraph: """Create an inspector agent with read-only codebase tools. @@ -189,7 +193,7 @@ def create_agent_with_tools( system_message: SystemMessage = SystemMessage(REASONER_SYSTEM_MESSAGE), memory: bool = True, debug: bool = True, - config: Optional[dict[str, Any]] = None, + config: dict[str, Any] | None = None, **kwargs, ) -> CompiledGraph: """Create an agent with a specific set of tools. diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 3ae1d6ae9..b0d314683 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -1,6 +1,7 @@ """Langchain tools for workspace operations.""" -from typing import Callable, ClassVar, Literal, Optional +from collections.abc import Callable +from typing import ClassVar, Literal from langchain_core.tools.base import BaseTool from pydantic import BaseModel, Field @@ -9,6 +10,7 @@ from codegen.extensions.tools.bash import run_bash_command from codegen.extensions.tools.github.checkout_pr import checkout_pr from codegen.extensions.tools.github.view_pr_checks import view_pr_checks +from codegen.extensions.tools.global_replacement_edit import replacement_edit_global from codegen.extensions.tools.linear.linear import ( linear_comment_on_issue_tool, linear_create_issue_tool, @@ -50,10 +52,10 @@ class ViewFileInput(BaseModel): """Input for viewing a file.""" filepath: str = Field(..., description="Path to the file relative to workspace root") - start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)") - end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)") - max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 250") - line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)") + start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)") + end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)") + max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 250") + line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)") class ViewFileTool(BaseTool): @@ -72,10 +74,10 @@ def __init__(self, codebase: Codebase) -> None: def _run( self, filepath: str, - start_line: Optional[int] = None, - end_line: Optional[int] = None, - max_lines: Optional[int] = None, - line_numbers: Optional[bool] = True, + start_line: int | None = None, + end_line: int | None = None, + max_lines: int | None = None, + line_numbers: bool | None = True, ) -> str: result = view_file( self.codebase, @@ -120,7 +122,7 @@ class SearchInput(BaseModel): description="""The search query to find in the codebase. When ripgrep is available, this will be passed as a ripgrep pattern. For regex searches, set use_regex=True. Ripgrep is the preferred method.""", ) - file_extensions: Optional[list[str]] = Field(default=None, description="Optional list of file extensions to search (e.g. ['.py', '.ts'])") + file_extensions: list[str] | None = Field(default=None, description="Optional list of file extensions to search (e.g. ['.py', '.ts'])") page: int = Field(default=1, description="Page number to return (1-based, default: 1)") files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)") use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)") @@ -137,7 +139,7 @@ class SearchTool(BaseTool): def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) - def _run(self, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str: + def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str: result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex) return result.render() @@ -273,7 +275,7 @@ class RevealSymbolInput(BaseModel): symbol_name: str = Field(..., description="Name of the symbol to analyze") degree: int = Field(default=1, description="How many degrees of separation to traverse") - max_tokens: Optional[int] = Field( + max_tokens: int | None = Field( default=None, description="Optional maximum number of tokens for all source code combined", ) @@ -296,7 +298,7 @@ def _run( self, symbol_name: str, degree: int = 1, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, collect_dependencies: bool = True, collect_usages: bool = True, ) -> str: @@ -849,8 +851,10 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: RenameFileTool(codebase), ReplacementEditTool(codebase), RevealSymbolTool(codebase), + GlobalReplacementEditTool(codebase), RunBashCommandTool(), # Note: This tool doesn't need the codebase SearchTool(codebase), + SearchFilesByNameTool(codebase), # SemanticEditTool(codebase), # SemanticSearchTool(codebase), ViewFileTool(codebase), @@ -872,6 +876,62 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: ] +class GlobalReplacementEditInput(BaseModel): + """Input for replacement editing across the entire codebase.""" + + file_pattern: str = Field( + default="*", + description=("Glob pattern to match files that should be edited. Supports all Python glob syntax including wildcards (*, ?, **)"), + ) + pattern: str = Field( + ..., + description=( + "Regular expression pattern to match text that should be replaced. " + "Supports all Python regex syntax including capture groups (\\1, \\2, etc). " + "The pattern is compiled with re.MULTILINE flag by default." + ), + ) + replacement: str = Field( + ..., + description=( + "Text to replace matched patterns with. Can reference regex capture groups using \\1, \\2, etc. If using regex groups in pattern, make sure to preserve them in replacement if needed." + ), + ) + count: int | None = Field( + default=None, + description=( + "Maximum number of replacements to make. " + "Use None to replace all occurrences (default), or specify a number to limit replacements. " + "Useful when you only want to replace the first N occurrences." + ), + ) + + +class GlobalReplacementEditTool(BaseTool): + """Tool for regex-based replacement editing of files across the entire codebase. + + Use this to make a change across an entire codebase if you have a regex pattern that matches the text you want to replace and are trying to edit a large number of files. + """ + + name: ClassVar[str] = "global_replace" + description: ClassVar[str] = "Replace text in the entire codebase using regex pattern matching." + args_schema: ClassVar[type[BaseModel]] = GlobalReplacementEditInput + codebase: Codebase = Field(exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) + + def _run( + self, + file_pattern: str, + pattern: str, + replacement: str, + count: int | None = None, + ) -> str: + result = replacement_edit_global(self.codebase, file_pattern, pattern, replacement, count) + return result.render() + + class ReplacementEditInput(BaseModel): """Input for replacement editing.""" @@ -905,7 +965,7 @@ class ReplacementEditInput(BaseModel): "Default is -1 (end of file)." ), ) - count: Optional[int] = Field( + count: int | None = Field( default=None, description=( "Maximum number of replacements to make. " @@ -933,7 +993,7 @@ def _run( replacement: str, start: int = 1, end: int = -1, - count: Optional[int] = None, + count: int | None = None, ) -> str: result = replacement_edit( self.codebase, @@ -997,7 +1057,7 @@ class ReflectionInput(BaseModel): context_summary: str = Field(..., description="Summary of the current context and problem being solved") findings_so_far: str = Field(..., description="Key information and insights gathered so far") current_challenges: str = Field(default="", description="Current obstacles or questions that need to be addressed") - reflection_focus: Optional[str] = Field(default=None, description="Optional specific aspect to focus reflection on (e.g., 'architecture', 'performance', 'next steps')") + reflection_focus: str | None = Field(default=None, description="Optional specific aspect to focus reflection on (e.g., 'architecture', 'performance', 'next steps')") class ReflectionTool(BaseTool): @@ -1020,7 +1080,7 @@ def _run( context_summary: str, findings_so_far: str, current_challenges: str = "", - reflection_focus: Optional[str] = None, + reflection_focus: str | None = None, ) -> str: result = perform_reflection(context_summary=context_summary, findings_so_far=findings_so_far, current_challenges=current_challenges, reflection_focus=reflection_focus, codebase=self.codebase) diff --git a/src/codegen/extensions/tools/__init__.py b/src/codegen/extensions/tools/__init__.py index 44305e61a..33e864e39 100644 --- a/src/codegen/extensions/tools/__init__.py +++ b/src/codegen/extensions/tools/__init__.py @@ -8,6 +8,7 @@ from .github.create_pr_comment import create_pr_comment from .github.create_pr_review_comment import create_pr_review_comment from .github.view_pr import view_pr +from .global_replacement_edit import replacement_edit_global from .linear import ( linear_comment_on_issue_tool, linear_get_issue_comments_tool, @@ -49,6 +50,7 @@ "perform_reflection", "rename_file", "replacement_edit", + "replacement_edit_global", "reveal_symbol", "run_codemod", # Search operations diff --git a/src/codegen/extensions/tools/global_replacement_edit.py b/src/codegen/extensions/tools/global_replacement_edit.py new file mode 100644 index 000000000..4717b2f58 --- /dev/null +++ b/src/codegen/extensions/tools/global_replacement_edit.py @@ -0,0 +1,129 @@ +"""Tool for making regex-based replacements in files.""" + +import difflib +import logging +import re +from typing import ClassVar + +from pydantic import Field + +from codegen.extensions.tools.search_files_by_name import search_files_by_name +from codegen.sdk.core.codebase import Codebase + +from .observation import Observation + +logger = logging.getLogger(__name__) + + +class GlobalReplacementEditObservation(Observation): + """Response from making regex-based replacements in a file.""" + + diff: str | None = Field( + default=None, + description="Unified diff showing the changes made. Only the first 5 file's changes are shown.", + ) + message: str | None = Field( + default=None, + description="Message describing the result", + ) + error: str | None = Field( + default=None, + description="Error message if an error occurred", + ) + error_pattern: str | None = Field( + default=None, + description="Regex pattern that failed to compile", + ) + + str_template: ClassVar[str] = "{message}" if "{message}" else "Edited file {filepath}" + + +def generate_diff(original: str, modified: str, path: str) -> str: + """Generate a unified diff between two strings. + + Args: + original: Original content + modified: Modified content + + Returns: + Unified diff as a string + """ + original_lines = original.splitlines(keepends=True) + modified_lines = modified.splitlines(keepends=True) + + diff = difflib.unified_diff( + original_lines, + modified_lines, + fromfile=path, + tofile=path, + lineterm="", + ) + + return "".join(diff) + + +def replacement_edit_global( + codebase: Codebase, + file_pattern: str, + pattern: str, + replacement: str, + count: int | None = None, + flags: re.RegexFlag = re.MULTILINE, +) -> GlobalReplacementEditObservation: + """Replace text in a file using regex pattern matching. + + Args: + codebase: The codebase to operate on + file_pattern: Glob pattern to match files + pattern: Regex pattern to match + replacement: Replacement text (can include regex groups) + count: Maximum number of replacements (None for all) + flags: Regex flags (default: re.MULTILINE) + + Returns: + GlobalReplacementEditObservation containing edit results and status + + Raises: + FileNotFoundError: If file not found + ValueError: If invalid regex pattern + """ + logger.info(f"Replacing text in files matching {file_pattern} using regex pattern {pattern}") + + if count == 0: + count = None + try: + # Compile pattern for better error messages + regex = re.compile(pattern, flags) + except re.error as e: + return GlobalReplacementEditObservation( + status="error", + error=f"Invalid regex pattern: {e!s}", + error_pattern=pattern, + message="Invalid regex pattern", + ) + + diffs = [] + for file in search_files_by_name(codebase, file_pattern).files: + if count is not None and count <= 0: + break + try: + file = codebase.get_file(file) + except ValueError: + msg = f"File not found: {file}" + raise FileNotFoundError(msg) + content = file.content + new_content, n = regex.subn(replacement, content, count=(count or 0)) + if count is not None: + count -= n + if n > 0: + file.edit(new_content) + if new_content != content: + diff = generate_diff(content, new_content, file.filepath) + diffs.append(diff) + diff = "\n".join(diffs[:5]) + codebase.commit() + return GlobalReplacementEditObservation( + status="success", + diff=diff, + message=f"Successfully replaced text in files matching {file_pattern} using regex pattern {pattern}", + ) diff --git a/tests/unit/codegen/extensions/test_tools.py b/tests/unit/codegen/extensions/test_tools.py index 006d4f1e3..aa90e1be4 100644 --- a/tests/unit/codegen/extensions/test_tools.py +++ b/tests/unit/codegen/extensions/test_tools.py @@ -13,6 +13,7 @@ move_symbol, rename_file, replacement_edit, + replacement_edit_global, reveal_symbol, run_codemod, search_files_by_name, @@ -431,6 +432,71 @@ def test_replacement_edit(codebase): assert "No matches found" in str(result) +def test_replacement_edit_global(codebase): + """Test global regex-based replacement editing.""" + # Create additional test file + create_file( + codebase, + filepath="src/other.py", + content=""" +def hello(): + print("Hello, world!") + +def greet(): + print("Hello!") +""", + ) + codebase.commit() # Commit the new file so it can be found + + # List directory to debug + print("Directory contents:") + print(list_directory(codebase, "src")) + + # Test basic global replacement across files + result = replacement_edit_global( + codebase, + file_pattern="src/*.py", + pattern=r'print\("Hello.*?"\)', + replacement='print("Goodbye!")', + ) + print(f"Found files: {search_files_by_name(codebase, 'src/*.py').files}") # Debug print + assert result.status == "success" + assert result.diff # Should have modified both files + assert 'print("Goodbye!")' in result.diff + + # Test with count limit + result = replacement_edit_global( + codebase, + file_pattern="src/*.py", + pattern=r"def", + replacement="async def", + count=1, # Only replace first occurrence in each file + ) + assert result.status == "success" + assert result.diff # Should have modified both files + + # Test invalid regex pattern + result = replacement_edit_global( + codebase, + file_pattern="src/*.py", + pattern=r"[invalid", # Invalid regex pattern + replacement="replacement", + ) + assert result.status == "error" + assert result.error_pattern == "[invalid" + assert "Invalid regex pattern" in result.message + + # Test no matches + result = replacement_edit_global( + codebase, + file_pattern="src/*.py", + pattern=r"nonexistent_pattern", + replacement="replacement", + ) + assert result.status == "success" + assert not result.diff # Should be empty since no files were modified + + def test_run_codemod(codebase): """Test running custom codemods.""" # Test adding type hints From 5c9e50a3d8e7c107f95d21e8f7a02c04aaa4cb1d Mon Sep 17 00:00:00 2001 From: bagel897 Date: Mon, 17 Mar 2025 14:37:35 -0700 Subject: [PATCH 3/3] fix test --- src/codegen/extensions/langchain/tools.py | 6 ++++-- tests/unit/codegen/extensions/test_tools.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index b0d314683..a1697d63e 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -1102,6 +1102,8 @@ class SearchFilesByNameTool(BaseTool): - Find specific file types (e.g., '*.py', '*.tsx') - Locate configuration files (e.g., 'package.json', 'requirements.txt') - Find files with specific names (e.g., 'README.md', 'Dockerfile') + + Uses fd under the hood """ args_schema: ClassVar[type[BaseModel]] = SearchFilesByNameInput codebase: Codebase = Field(exclude=True) @@ -1109,6 +1111,6 @@ class SearchFilesByNameTool(BaseTool): def __init__(self, codebase: Codebase): super().__init__(codebase=codebase) - def _run(self, pattern: str) -> str: + def _run(self, pattern: str, full_path: bool = False) -> str: """Execute the glob pattern search using fd.""" - return search_files_by_name(self.codebase, pattern).render() + return search_files_by_name(self.codebase, pattern, full_path).render() diff --git a/tests/unit/codegen/extensions/test_tools.py b/tests/unit/codegen/extensions/test_tools.py index aa90e1be4..b7c728d74 100644 --- a/tests/unit/codegen/extensions/test_tools.py +++ b/tests/unit/codegen/extensions/test_tools.py @@ -455,11 +455,11 @@ def greet(): # Test basic global replacement across files result = replacement_edit_global( codebase, - file_pattern="src/*.py", + file_pattern="*.py", pattern=r'print\("Hello.*?"\)', replacement='print("Goodbye!")', ) - print(f"Found files: {search_files_by_name(codebase, 'src/*.py').files}") # Debug print + print(f"Found files: {search_files_by_name(codebase, '*.py').files}") # Debug print assert result.status == "success" assert result.diff # Should have modified both files assert 'print("Goodbye!")' in result.diff @@ -467,7 +467,7 @@ def greet(): # Test with count limit result = replacement_edit_global( codebase, - file_pattern="src/*.py", + file_pattern="*.py", pattern=r"def", replacement="async def", count=1, # Only replace first occurrence in each file @@ -478,7 +478,7 @@ def greet(): # Test invalid regex pattern result = replacement_edit_global( codebase, - file_pattern="src/*.py", + file_pattern="*.py", pattern=r"[invalid", # Invalid regex pattern replacement="replacement", ) @@ -489,7 +489,7 @@ def greet(): # Test no matches result = replacement_edit_global( codebase, - file_pattern="src/*.py", + file_pattern="*.py", pattern=r"nonexistent_pattern", replacement="replacement", )