From a0789536e40d5dd52522eeda57fa4af60ecd3a7f Mon Sep 17 00:00:00 2001 From: codegen-bot Date: Wed, 12 Mar 2025 20:52:17 +0000 Subject: [PATCH 1/2] Add verify=False flag to codebase.create_file --- src/codegen/sdk/core/codebase.py | 88 ++++++-------------------------- 1 file changed, 15 insertions(+), 73 deletions(-) diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index f5b853a6e..38d07d974 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -92,8 +92,7 @@ from codegen.visualizations.visualization_manager import VisualizationManager logger = get_logger(__name__) -MAX_LINES = 10000 # Maximum number of lines of text allowed to be logged - +MAX_LINES = 10000 TSourceFile = TypeVar("TSourceFile", bound="SourceFile", default=SourceFile) TDirectory = TypeVar("TDirectory", bound="Directory", default=Directory) @@ -179,7 +178,6 @@ def __init__( io: IO | None = None, progress: Progress | None = None, ) -> None: - # Sanity check inputs if repo_path is not None and projects is not None: msg = "Cannot specify both repo_path and projects" raise ValueError(msg) @@ -192,11 +190,9 @@ def __init__( msg = "Cannot specify both projects and language. Use ProjectConfig.from_path() to create projects with a custom language." raise ValueError(msg) - # If projects is a single ProjectConfig, convert it to a list if isinstance(projects, ProjectConfig): projects = [projects] - # Initialize project with repo_path if projects is None if repo_path is not None: main_project = ProjectConfig.from_path( repo_path, @@ -206,15 +202,12 @@ def __init__( else: main_project = projects[0] - # Initialize codebase self._op = main_project.repo_operator self.viz = VisualizationManager(op=self._op) self.repo_path = Path(self._op.repo_path) self.ctx = CodebaseContext(projects, config=config, secrets=secrets, io=io, progress=progress) self.console = Console(record=True, soft_wrap=True) - # Assert config assertions - # External import resolution must be enabled if syspath is enabled if self.ctx.config.py_resolve_syspath: if not self.ctx.config.allow_external: msg = "allow_external must be set to True when py_resolve_syspath is enabled" @@ -258,7 +251,7 @@ def github(self) -> RepoOperator: #################################################################################################################### # SIMPLE META #################################################################################################################### - + @property def name(self) -> str: """The name of the repository.""" @@ -278,13 +271,13 @@ def _symbols(self, symbol_type: SymbolType | None = None) -> list[TSymbol | TCla matches: list[Symbol] = self.ctx.get_nodes(NodeType.SYMBOL) return [x for x in matches if x.is_top_level and (symbol_type is None or x.symbol_type == symbol_type)] - # =====[ Node Types ]===== @overload def files(self, *, extensions: list[str]) -> list[File]: ... @overload def files(self, *, extensions: Literal["*"]) -> list[File]: ... @overload def files(self, *, extensions: None = ...) -> list[TSourceFile]: ... + @proxy_property def files(self, *, extensions: list[str] | Literal["*"] | None = None) -> list[TSourceFile] | list[File]: """A list property that returns all files in the codebase. @@ -305,21 +298,17 @@ def files(self, *, extensions: list[str] | Literal["*"] | None = None) -> list[T list[TSourceFile]: A sorted list of source files in the codebase. """ if extensions is None and len(self.ctx.get_nodes(NodeType.FILE)) > 0: - # If extensions is None AND there is at least one file in the codebase (This checks for unsupported languages or parse-off repos), - # Return all source files files = self.ctx.get_nodes(NodeType.FILE) elif isinstance(extensions, str) and extensions != "*": msg = "extensions must be a list of extensions or '*'" raise ValueError(msg) else: files = [] - # Get all files with the specified extensions for filepath, _ in self._op.iter_files( extensions=None if extensions == "*" else extensions, ignore_list=GLOBAL_FILE_IGNORE_LIST, ): files.append(self.get_file(filepath, optional=False)) - # Sort files alphabetically return sort_editables(files, alphabetical=True, dedupe=False) @cached_property @@ -474,41 +463,31 @@ def types(self) -> list[TTypeAlias]: # EXTERNAL API #################################################################################################################### - def create_file(self, filepath: str, content: str = "", sync: bool = True) -> TSourceFile: + def create_file(self, filepath: str, content: str = "", sync: bool = True, verify: bool = True) -> TSourceFile: """Creates a new file in the codebase with specified content. Args: filepath (str): The path where the file should be created. content (str): The content of the file to be created. Defaults to empty string. sync (bool): Whether to sync the graph after creating the file. Defaults to True. + verify (bool): Whether to verify the syntax of the file content. Defaults to True. Returns: File: The newly created file object. Raises: - ValueError: If the provided content cannot be parsed according to the file extension. + ValueError: If the provided content cannot be parsed according to the file extension and verify is True. """ - # Check if file already exists - # TODO: These checks break parse tests ??? - # Look into this! - # if self.has_file(filepath): - # raise ValueError(f"File {filepath} already exists in codebase.") - # if os.path.exists(filepath): - # raise ValueError(f"File {filepath} already exists on disk.") - file_exts = self.ctx.extensions - # Create file as source file if it has a registered extension if any(filepath.endswith(ext) for ext in file_exts): file_cls = self.ctx.node_classes.file_cls - file = file_cls.from_content(filepath, content, self.ctx, sync=sync) - if file is None: + file = file_cls.from_content(filepath, content, self.ctx, sync=sync, verify_syntax=verify) + if file is None and verify: msg = f"Failed to parse file with content {content}. Please make sure the content syntax is valid with respect to the filepath extension." raise ValueError(msg) else: - # Create file as non-source file file = File.from_content(filepath, content, self.ctx, sync=False) - # This is to make sure we keep track of this file for diff purposes uncache_all() return file @@ -541,6 +520,7 @@ def has_file(self, filepath: str, ignore_case: bool = False) -> bool: def get_file(self, filepath: str, *, optional: Literal[False] = ..., ignore_case: bool = ...) -> TSourceFile: ... @overload def get_file(self, filepath: str, *, optional: Literal[True], ignore_case: bool = ...) -> TSourceFile | None: ... + def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = False) -> TSourceFile | None: """Retrieves a file from the codebase by its filepath. @@ -557,15 +537,12 @@ def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = Raises: ValueError: If file not found and optional=False. """ - # Try to get the file from the graph first file = self.ctx.get_file(filepath, ignore_case=ignore_case) if file is not None: return file - # If the file is not in the graph, check the filesystem absolute_path = self.ctx.to_absolute(filepath) if self.ctx.io.file_exists(absolute_path): return self.ctx._get_raw_file_from_path(absolute_path) - # If the file is not in the graph, check the filesystem if absolute_path.parent.exists(): for file in absolute_path.parent.iterdir(): if ignore_case and str(absolute_path).lower() == str(file).lower(): @@ -573,7 +550,6 @@ def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = elif not ignore_case and str(absolute_path) == str(file): return self.ctx._get_raw_file_from_path(file) - # If we get here, the file is not found if not optional: msg = f"File {filepath} not found in codebase. Use optional=True to return None instead." raise ValueError(msg) @@ -605,7 +581,6 @@ def get_directory(self, dir_path: str, optional: bool = False, ignore_case: bool Raises: ValueError: If directory not found and optional=False. """ - # Sanitize the path dir_path = os.path.normpath(dir_path) dir_path = "" if dir_path == "." else dir_path directory = self.ctx.get_directory(self.ctx.to_absolute(dir_path), ignore_case=ignore_case) @@ -745,21 +720,17 @@ def get_relative_path(self, from_file: str, to_file: str) -> str: Returns: str: The relative path from `from_file` to `to_file` (with the extension removed from `to_file`). """ - # Remove extension from the target file to_file = self._remove_extension(to_file) from_parts = from_file.split("/") to_parts = to_file.split("/") - # Find common prefix i = 0 while i < len(from_parts) - 1 and i < len(to_parts) and from_parts[i] == to_parts[i]: i += 1 - # Number of '../' we need up_levels = len(from_parts) - i - 1 - # Construct relative path relative_path = ("../" * up_levels) + "/".join(to_parts[i:]) return relative_path @@ -843,7 +814,7 @@ def reset(self, git_reset: bool = False) -> None: """ logger.info("Resetting codebase ...") if git_reset: - self._op.discard_changes() # Discard any changes made to the raw file state + self._op.discard_changes() self._num_ai_requests = 0 self.reset_logs() self.ctx.undo_applied_diffs() @@ -920,7 +891,7 @@ def get_diffs(self, base: str | None = None) -> list[Diff]: def get_diff(self, base: str | None = None, stage_files: bool = False) -> str: """Produce a single git diff for all files.""" if stage_files: - self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff + self._op.git_cli.git.add(A=True) if base is None: diff = self._op.git_cli.git.diff("HEAD", patch=True, full_index=True) return diff @@ -1062,7 +1033,6 @@ def set_find_mode(self, find_mode: bool) -> None: @noapidoc def set_active_group(self, group: Group) -> None: """Will only fix these flags.""" - # TODO - flesh this out more with Group datatype and GroupBy self.ctx.flags.set_active_group(group) #################################################################################################################### @@ -1082,7 +1052,7 @@ def log(self, *args) -> None: """ self.ctx.transaction_manager.check_max_preview_time() if self.console.export_text(clear=False).count("\n") >= MAX_LINES: - return # if max lines has been reached, skip logging + return for arg in args: if self.__is_markup_loggable__(arg): fullName = arg.get_name() if isinstance(arg, HasName) and arg.get_name() else "" @@ -1136,7 +1106,6 @@ def _enable_experimental_language_engine( logger.info("This may take a while for large repos...") self.ctx.dependency_manager = get_dependency_manager(self.ctx.projects[0].programming_language, self.ctx, enabled=True) self.ctx.dependency_manager.start(async_start=False) - # Wait for the dependency manager to be ready self.ctx.dependency_manager.wait_until_ready(ignore_error=False) logger.info("Dependencies ready") if not self.ctx.language_engine: @@ -1151,7 +1120,6 @@ def _enable_experimental_language_engine( use_v8=use_v8, ) self.ctx.language_engine.start(async_start=async_start) - # Wait for the language engine to be ready self.ctx.language_engine.wait_until_ready(ignore_error=False) logger.info("Language engine ready") @@ -1166,7 +1134,6 @@ def _enable_experimental_language_engine( @noapidoc def ai_client(self) -> OpenAI: """Enables calling AI/LLM APIs - re-export of the initialized `openai` module""" - # Create a singleton AIHelper instance if self._ai_helper is None: if self.ctx.secrets.openai_api_key is None: msg = "OpenAI key is not set" @@ -1199,7 +1166,6 @@ def ai( Raises: MaxAIRequestsError: If the maximum number of allowed AI requests (default 150) has been exceeded. """ - # Check max transactions logger.info("Creating call to OpenAI...") self._num_ai_requests += 1 if self.ctx.session_options.max_ai_requests is not None and self._num_ai_requests > self.ctx.session_options.max_ai_requests: @@ -1219,22 +1185,17 @@ def ai( if model.startswith("gpt"): params["tool_choice"] = "required" - # Make the AI request response = self.ai_client.chat.completions.create( model=model, messages=params["messages"], - tools=params["functions"], # type: ignore + tools=params["functions"], temperature=params["temperature"], tool_choice=params["tool_choice"], ) - # Handle finish reasons - # First check if there is a response if response.choices: - # Check response reason choice = response.choices[0] if choice.finish_reason == "tool_calls" or choice.finish_reason == "function_call" or choice.finish_reason == "stop": - # Check if there is a tool call if choice.message.tool_calls: tool_call = choice.message.tool_calls[0] response_answer = json.loads(tool_call.function.arguments) @@ -1259,17 +1220,13 @@ def ai( msg = "No response from AI Provider. (response.choices is empty)" raise ValueError(msg) - # Agent sometimes fucks up and does \\\\n for some reason. response_answer = codecs.decode(response_answer, "unicode_escape") logger.info(f"OpenAI response: {response_answer}") return response_answer def set_ai_key(self, key: str) -> None: """Sets the OpenAI key for the current Codebase instance.""" - # Reset the AI client self._ai_helper = None - - # Set the AI key self.ctx.secrets.openai_api_key = key def find_by_span(self, span: Span) -> list[Editable]: @@ -1338,35 +1295,29 @@ def from_repo( """ logger.info(f"Fetching codebase for {repo_full_name}") - # Parse repo name if "/" not in repo_full_name: msg = "repo_name must be in format 'owner/repo'" raise ValueError(msg) owner, repo = repo_full_name.split("/") - # Setup temp directory os.makedirs(tmp_dir, exist_ok=True) logger.info(f"Using directory: {tmp_dir}") - # Setup repo path and URL repo_path = os.path.join(tmp_dir, repo) repo_url = f"https://github.com/{repo_full_name}.git" logger.info(f"Will clone {repo_url} to {repo_path}") access_token = secrets.github_token if secrets else None try: - # Use RepoOperator to fetch the repository logger.info("Cloning repository...") if commit is None: repo_config = RepoConfig.from_repo_path(repo_path) repo_config.full_name = repo_full_name repo_operator = RepoOperator.create_from_repo(repo_path=repo_path, url=repo_url, access_token=access_token, full_history=full_history) else: - # Ensure the operator can handle remote operations repo_operator = RepoOperator.create_from_commit(repo_path=repo_path, commit=commit, url=repo_url, full_name=repo_full_name, access_token=access_token) logger.info("Clone completed successfully") - # Initialize and return codebase with proper context logger.info("Initializing Codebase...") project = ProjectConfig.from_repo_operator( repo_operator=repo_operator, @@ -1410,11 +1361,9 @@ def from_string( logger.info("Creating codebase from string") - # Determine language and filename prog_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language filename = "test.ts" if prog_lang == ProgrammingLanguage.TYPESCRIPT else "test.py" - # Create codebase using factory from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory files = {filename: code} @@ -1455,7 +1404,6 @@ def from_files( >>> files = {"index.ts": "console.log('hello')", "utils.tsx": "export const App = () =>
Hello
"} >>> codebase = Codebase.from_files(files) """ - # Create codebase using factory from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory if not files: @@ -1464,7 +1412,7 @@ def from_files( logger.info("Creating codebase from files") - prog_lang = ProgrammingLanguage.PYTHON # Default language + prog_lang = ProgrammingLanguage.PYTHON if files: py_extensions = {".py"} @@ -1473,8 +1421,6 @@ def from_files( extensions = {os.path.splitext(f)[1].lower() for f in files} inferred_lang = None - # all check to ensure that the from_files method is being used for small testing purposes only. - # If parsing an actual repo, it should not be used. Instead do Codebase("path/to/repo") if all(ext in py_extensions for ext in extensions): inferred_lang = ProgrammingLanguage.PYTHON elif all(ext in ts_extensions for ext in extensions): @@ -1491,7 +1437,6 @@ def from_files( prog_lang = inferred_lang else: - # Default to Python if no files provided prog_lang = ProgrammingLanguage.PYTHON if language is None else (ProgrammingLanguage(language.upper()) if isinstance(language, str) else language) logger.info(f"Using language: {prog_lang}") @@ -1499,7 +1444,6 @@ def from_files( with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir: logger.info(f"Using directory: {tmp_dir}") - # Initialize git repo to avoid "not in a git repository" error import subprocess subprocess.run(["git", "init"], cwd=tmp_dir, check=True, capture_output=True) @@ -1547,8 +1491,6 @@ def create_pr_review_comment( return self._op.create_pr_review_comment(pr_number, body, commit_sha, path, line, side, start_line) -# The last 2 lines of code are added to the runner. See codegen-backend/cli/generate/utils.py -# Type Aliases CodebaseType = Codebase[ SourceFile, Directory, @@ -1587,4 +1529,4 @@ def create_pr_review_comment( TSTypeAlias, TSParameter, TSCodeBlock, -] +] \ No newline at end of file From d2de1a86662d3c8c21125efb9cbd083ac678b7b4 Mon Sep 17 00:00:00 2001 From: "codegen-sh[bot]" <131295404+codegen-sh[bot]@users.noreply.github.com> Date: Wed, 12 Mar 2025 20:53:01 +0000 Subject: [PATCH 2/2] Automated pre-commit update --- src/codegen/sdk/core/codebase.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 38d07d974..619fca21d 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -92,7 +92,7 @@ from codegen.visualizations.visualization_manager import VisualizationManager logger = get_logger(__name__) -MAX_LINES = 10000 +MAX_LINES = 10000 TSourceFile = TypeVar("TSourceFile", bound="SourceFile", default=SourceFile) TDirectory = TypeVar("TDirectory", bound="Directory", default=Directory) @@ -251,7 +251,7 @@ def github(self) -> RepoOperator: #################################################################################################################### # SIMPLE META #################################################################################################################### - + @property def name(self) -> str: """The name of the repository.""" @@ -277,7 +277,7 @@ def files(self, *, extensions: list[str]) -> list[File]: ... def files(self, *, extensions: Literal["*"]) -> list[File]: ... @overload def files(self, *, extensions: None = ...) -> list[TSourceFile]: ... - + @proxy_property def files(self, *, extensions: list[str] | Literal["*"] | None = None) -> list[TSourceFile] | list[File]: """A list property that returns all files in the codebase. @@ -520,7 +520,7 @@ def has_file(self, filepath: str, ignore_case: bool = False) -> bool: def get_file(self, filepath: str, *, optional: Literal[False] = ..., ignore_case: bool = ...) -> TSourceFile: ... @overload def get_file(self, filepath: str, *, optional: Literal[True], ignore_case: bool = ...) -> TSourceFile | None: ... - + def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = False) -> TSourceFile | None: """Retrieves a file from the codebase by its filepath. @@ -814,7 +814,7 @@ def reset(self, git_reset: bool = False) -> None: """ logger.info("Resetting codebase ...") if git_reset: - self._op.discard_changes() + self._op.discard_changes() self._num_ai_requests = 0 self.reset_logs() self.ctx.undo_applied_diffs() @@ -891,7 +891,7 @@ def get_diffs(self, base: str | None = None) -> list[Diff]: def get_diff(self, base: str | None = None, stage_files: bool = False) -> str: """Produce a single git diff for all files.""" if stage_files: - self._op.git_cli.git.add(A=True) + self._op.git_cli.git.add(A=True) if base is None: diff = self._op.git_cli.git.diff("HEAD", patch=True, full_index=True) return diff @@ -1052,7 +1052,7 @@ def log(self, *args) -> None: """ self.ctx.transaction_manager.check_max_preview_time() if self.console.export_text(clear=False).count("\n") >= MAX_LINES: - return + return for arg in args: if self.__is_markup_loggable__(arg): fullName = arg.get_name() if isinstance(arg, HasName) and arg.get_name() else "" @@ -1412,7 +1412,7 @@ def from_files( logger.info("Creating codebase from files") - prog_lang = ProgrammingLanguage.PYTHON + prog_lang = ProgrammingLanguage.PYTHON if files: py_extensions = {".py"} @@ -1529,4 +1529,4 @@ def create_pr_review_comment( TSTypeAlias, TSParameter, TSCodeBlock, -] \ No newline at end of file +]