Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 13 additions & 71 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@
from codegen.visualizations.visualization_manager import VisualizationManager

logger = get_logger(__name__)
MAX_LINES = 10000 # Maximum number of lines of text allowed to be logged

MAX_LINES = 10000

TSourceFile = TypeVar("TSourceFile", bound="SourceFile", default=SourceFile)
TDirectory = TypeVar("TDirectory", bound="Directory", default=Directory)
Expand Down Expand Up @@ -179,7 +178,6 @@ def __init__(
io: IO | None = None,
progress: Progress | None = None,
) -> None:
# Sanity check inputs
if repo_path is not None and projects is not None:
msg = "Cannot specify both repo_path and projects"
raise ValueError(msg)
Expand All @@ -192,11 +190,9 @@ def __init__(
msg = "Cannot specify both projects and language. Use ProjectConfig.from_path() to create projects with a custom language."
raise ValueError(msg)

# If projects is a single ProjectConfig, convert it to a list
if isinstance(projects, ProjectConfig):
projects = [projects]

# Initialize project with repo_path if projects is None
if repo_path is not None:
main_project = ProjectConfig.from_path(
repo_path,
Expand All @@ -206,15 +202,12 @@ def __init__(
else:
main_project = projects[0]

# Initialize codebase
self._op = main_project.repo_operator
self.viz = VisualizationManager(op=self._op)
self.repo_path = Path(self._op.repo_path)
self.ctx = CodebaseContext(projects, config=config, secrets=secrets, io=io, progress=progress)
self.console = Console(record=True, soft_wrap=True)

# Assert config assertions
# External import resolution must be enabled if syspath is enabled
if self.ctx.config.py_resolve_syspath:
if not self.ctx.config.allow_external:
msg = "allow_external must be set to True when py_resolve_syspath is enabled"
Expand Down Expand Up @@ -278,7 +271,6 @@ def _symbols(self, symbol_type: SymbolType | None = None) -> list[TSymbol | TCla
matches: list[Symbol] = self.ctx.get_nodes(NodeType.SYMBOL)
return [x for x in matches if x.is_top_level and (symbol_type is None or x.symbol_type == symbol_type)]

# =====[ Node Types ]=====
@overload
def files(self, *, extensions: list[str]) -> list[File]: ...
@overload
Expand All @@ -305,21 +297,17 @@ def files(self, *, extensions: list[str] | Literal["*"] | None = None) -> list[T
list[TSourceFile]: A sorted list of source files in the codebase.
"""
if extensions is None and len(self.ctx.get_nodes(NodeType.FILE)) > 0:
# If extensions is None AND there is at least one file in the codebase (This checks for unsupported languages or parse-off repos),
# Return all source files
files = self.ctx.get_nodes(NodeType.FILE)
elif isinstance(extensions, str) and extensions != "*":
msg = "extensions must be a list of extensions or '*'"
raise ValueError(msg)
else:
files = []
# Get all files with the specified extensions
for filepath, _ in self._op.iter_files(
extensions=None if extensions == "*" else extensions,
ignore_list=GLOBAL_FILE_IGNORE_LIST,
):
files.append(self.get_file(filepath, optional=False))
# Sort files alphabetically
return sort_editables(files, alphabetical=True, dedupe=False)

@cached_property
Expand Down Expand Up @@ -474,41 +462,33 @@ def types(self) -> list[TTypeAlias]:
# EXTERNAL API
####################################################################################################################

def create_file(self, filepath: str, content: str = "", sync: bool = True) -> TSourceFile:
def create_file(self, filepath: str, content: str = "", sync: bool = True, verify: bool = True) -> TSourceFile:
"""Creates a new file in the codebase with specified content.

Args:
filepath (str): The path where the file should be created.
content (str): The content of the file to be created. Defaults to empty string.
sync (bool): Whether to sync the graph after creating the file. Defaults to True.
verify (bool): Whether to verify the syntax of the file content. Defaults to True.

Returns:
File: The newly created file object.

Raises:
ValueError: If the provided content cannot be parsed according to the file extension.
ValueError: If the provided content cannot be parsed according to the file extension and verify=True.
"""
# Check if file already exists
# TODO: These checks break parse tests ???
# Look into this!
# if self.has_file(filepath):
# raise ValueError(f"File {filepath} already exists in codebase.")
# if os.path.exists(filepath):
# raise ValueError(f"File {filepath} already exists on disk.")

file_exts = self.ctx.extensions
# Create file as source file if it has a registered extension
if any(filepath.endswith(ext) for ext in file_exts):
file_cls = self.ctx.node_classes.file_cls
file = file_cls.from_content(filepath, content, self.ctx, sync=sync)
if file is None:
file = file_cls.from_content(filepath, content, self.ctx, sync=sync, verify_syntax=verify)
if file is None and verify:
msg = f"Failed to parse file with content {content}. Please make sure the content syntax is valid with respect to the filepath extension."
raise ValueError(msg)
elif file is None:
file = File.from_content(filepath, content, self.ctx, sync=False)
else:
# Create file as non-source file
file = File.from_content(filepath, content, self.ctx, sync=False)

# This is to make sure we keep track of this file for diff purposes
uncache_all()
return file

Expand Down Expand Up @@ -557,23 +537,19 @@ def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool =
Raises:
ValueError: If file not found and optional=False.
"""
# Try to get the file from the graph first
file = self.ctx.get_file(filepath, ignore_case=ignore_case)
if file is not None:
return file
# If the file is not in the graph, check the filesystem
absolute_path = self.ctx.to_absolute(filepath)
if self.ctx.io.file_exists(absolute_path):
return self.ctx._get_raw_file_from_path(absolute_path)
# If the file is not in the graph, check the filesystem
if absolute_path.parent.exists():
for file in absolute_path.parent.iterdir():
if ignore_case and str(absolute_path).lower() == str(file).lower():
return self.ctx._get_raw_file_from_path(file)
elif not ignore_case and str(absolute_path) == str(file):
return self.ctx._get_raw_file_from_path(file)

# If we get here, the file is not found
if not optional:
msg = f"File {filepath} not found in codebase. Use optional=True to return None instead."
raise ValueError(msg)
Expand Down Expand Up @@ -605,7 +581,6 @@ def get_directory(self, dir_path: str, optional: bool = False, ignore_case: bool
Raises:
ValueError: If directory not found and optional=False.
"""
# Sanitize the path
dir_path = os.path.normpath(dir_path)
dir_path = "" if dir_path == "." else dir_path
directory = self.ctx.get_directory(self.ctx.to_absolute(dir_path), ignore_case=ignore_case)
Expand Down Expand Up @@ -745,21 +720,17 @@ def get_relative_path(self, from_file: str, to_file: str) -> str:
Returns:
str: The relative path from `from_file` to `to_file` (with the extension removed from `to_file`).
"""
# Remove extension from the target file
to_file = self._remove_extension(to_file)

from_parts = from_file.split("/")
to_parts = to_file.split("/")

# Find common prefix
i = 0
while i < len(from_parts) - 1 and i < len(to_parts) and from_parts[i] == to_parts[i]:
i += 1

# Number of '../' we need
up_levels = len(from_parts) - i - 1

# Construct relative path
relative_path = ("../" * up_levels) + "/".join(to_parts[i:])

return relative_path
Expand Down Expand Up @@ -843,7 +814,7 @@ def reset(self, git_reset: bool = False) -> None:
"""
logger.info("Resetting codebase ...")
if git_reset:
self._op.discard_changes() # Discard any changes made to the raw file state
self._op.discard_changes()
self._num_ai_requests = 0
self.reset_logs()
self.ctx.undo_applied_diffs()
Expand Down Expand Up @@ -920,7 +891,7 @@ def get_diffs(self, base: str | None = None) -> list[Diff]:
def get_diff(self, base: str | None = None, stage_files: bool = False) -> str:
"""Produce a single git diff for all files."""
if stage_files:
self._op.git_cli.git.add(A=True) # add all changes to the index so untracked files are included in the diff
self._op.git_cli.git.add(A=True)
if base is None:
diff = self._op.git_cli.git.diff("HEAD", patch=True, full_index=True)
return diff
Expand Down Expand Up @@ -1062,7 +1033,6 @@ def set_find_mode(self, find_mode: bool) -> None:
@noapidoc
def set_active_group(self, group: Group) -> None:
"""Will only fix these flags."""
# TODO - flesh this out more with Group datatype and GroupBy
self.ctx.flags.set_active_group(group)

####################################################################################################################
Expand All @@ -1082,7 +1052,7 @@ def log(self, *args) -> None:
"""
self.ctx.transaction_manager.check_max_preview_time()
if self.console.export_text(clear=False).count("\n") >= MAX_LINES:
return # if max lines has been reached, skip logging
return
for arg in args:
if self.__is_markup_loggable__(arg):
fullName = arg.get_name() if isinstance(arg, HasName) and arg.get_name() else ""
Expand Down Expand Up @@ -1136,7 +1106,6 @@ def _enable_experimental_language_engine(
logger.info("This may take a while for large repos...")
self.ctx.dependency_manager = get_dependency_manager(self.ctx.projects[0].programming_language, self.ctx, enabled=True)
self.ctx.dependency_manager.start(async_start=False)
# Wait for the dependency manager to be ready
self.ctx.dependency_manager.wait_until_ready(ignore_error=False)
logger.info("Dependencies ready")
if not self.ctx.language_engine:
Expand All @@ -1151,7 +1120,6 @@ def _enable_experimental_language_engine(
use_v8=use_v8,
)
self.ctx.language_engine.start(async_start=async_start)
# Wait for the language engine to be ready
self.ctx.language_engine.wait_until_ready(ignore_error=False)
logger.info("Language engine ready")

Expand All @@ -1166,7 +1134,6 @@ def _enable_experimental_language_engine(
@noapidoc
def ai_client(self) -> OpenAI:
"""Enables calling AI/LLM APIs - re-export of the initialized `openai` module"""
# Create a singleton AIHelper instance
if self._ai_helper is None:
if self.ctx.secrets.openai_api_key is None:
msg = "OpenAI key is not set"
Expand Down Expand Up @@ -1199,7 +1166,6 @@ def ai(
Raises:
MaxAIRequestsError: If the maximum number of allowed AI requests (default 150) has been exceeded.
"""
# Check max transactions
logger.info("Creating call to OpenAI...")
self._num_ai_requests += 1
if self.ctx.session_options.max_ai_requests is not None and self._num_ai_requests > self.ctx.session_options.max_ai_requests:
Expand All @@ -1219,22 +1185,17 @@ def ai(
if model.startswith("gpt"):
params["tool_choice"] = "required"

# Make the AI request
response = self.ai_client.chat.completions.create(
model=model,
messages=params["messages"],
tools=params["functions"], # type: ignore
tools=params["functions"],
temperature=params["temperature"],
tool_choice=params["tool_choice"],
)

# Handle finish reasons
# First check if there is a response
if response.choices:
# Check response reason
choice = response.choices[0]
if choice.finish_reason == "tool_calls" or choice.finish_reason == "function_call" or choice.finish_reason == "stop":
# Check if there is a tool call
if choice.message.tool_calls:
tool_call = choice.message.tool_calls[0]
response_answer = json.loads(tool_call.function.arguments)
Expand All @@ -1259,17 +1220,13 @@ def ai(
msg = "No response from AI Provider. (response.choices is empty)"
raise ValueError(msg)

# Agent sometimes fucks up and does \\\\n for some reason.
response_answer = codecs.decode(response_answer, "unicode_escape")
logger.info(f"OpenAI response: {response_answer}")
return response_answer

def set_ai_key(self, key: str) -> None:
"""Sets the OpenAI key for the current Codebase instance."""
# Reset the AI client
self._ai_helper = None

# Set the AI key
self.ctx.secrets.openai_api_key = key

def find_by_span(self, span: Span) -> list[Editable]:
Expand Down Expand Up @@ -1338,35 +1295,29 @@ def from_repo(
"""
logger.info(f"Fetching codebase for {repo_full_name}")

# Parse repo name
if "/" not in repo_full_name:
msg = "repo_name must be in format 'owner/repo'"
raise ValueError(msg)
owner, repo = repo_full_name.split("/")

# Setup temp directory
os.makedirs(tmp_dir, exist_ok=True)
logger.info(f"Using directory: {tmp_dir}")

# Setup repo path and URL
repo_path = os.path.join(tmp_dir, repo)
repo_url = f"https://github.com/{repo_full_name}.git"
logger.info(f"Will clone {repo_url} to {repo_path}")
access_token = secrets.github_token if secrets else None

try:
# Use RepoOperator to fetch the repository
logger.info("Cloning repository...")
if commit is None:
repo_config = RepoConfig.from_repo_path(repo_path)
repo_config.full_name = repo_full_name
repo_operator = RepoOperator.create_from_repo(repo_path=repo_path, url=repo_url, access_token=access_token, full_history=full_history)
else:
# Ensure the operator can handle remote operations
repo_operator = RepoOperator.create_from_commit(repo_path=repo_path, commit=commit, url=repo_url, full_name=repo_full_name, access_token=access_token)
logger.info("Clone completed successfully")

# Initialize and return codebase with proper context
logger.info("Initializing Codebase...")
project = ProjectConfig.from_repo_operator(
repo_operator=repo_operator,
Expand Down Expand Up @@ -1410,11 +1361,9 @@ def from_string(

logger.info("Creating codebase from string")

# Determine language and filename
prog_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language
filename = "test.ts" if prog_lang == ProgrammingLanguage.TYPESCRIPT else "test.py"

# Create codebase using factory
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory

files = {filename: code}
Expand Down Expand Up @@ -1455,7 +1404,6 @@ def from_files(
>>> files = {"index.ts": "console.log('hello')", "utils.tsx": "export const App = () => <div>Hello</div>"}
>>> codebase = Codebase.from_files(files)
"""
# Create codebase using factory
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory

if not files:
Expand All @@ -1464,7 +1412,7 @@ def from_files(

logger.info("Creating codebase from files")

prog_lang = ProgrammingLanguage.PYTHON # Default language
prog_lang = ProgrammingLanguage.PYTHON

if files:
py_extensions = {".py"}
Expand All @@ -1473,8 +1421,6 @@ def from_files(
extensions = {os.path.splitext(f)[1].lower() for f in files}
inferred_lang = None

# all check to ensure that the from_files method is being used for small testing purposes only.
# If parsing an actual repo, it should not be used. Instead do Codebase("path/to/repo")
if all(ext in py_extensions for ext in extensions):
inferred_lang = ProgrammingLanguage.PYTHON
elif all(ext in ts_extensions for ext in extensions):
Expand All @@ -1491,15 +1437,13 @@ def from_files(

prog_lang = inferred_lang
else:
# Default to Python if no files provided
prog_lang = ProgrammingLanguage.PYTHON if language is None else (ProgrammingLanguage(language.upper()) if isinstance(language, str) else language)

logger.info(f"Using language: {prog_lang}")

with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir:
logger.info(f"Using directory: {tmp_dir}")

# Initialize git repo to avoid "not in a git repository" error
import subprocess

subprocess.run(["git", "init"], cwd=tmp_dir, check=True, capture_output=True)
Expand Down Expand Up @@ -1547,8 +1491,6 @@ def create_pr_review_comment(
return self._op.create_pr_review_comment(pr_number, body, commit_sha, path, line, side, start_line)


# The last 2 lines of code are added to the runner. See codegen-backend/cli/generate/utils.py
# Type Aliases
CodebaseType = Codebase[
SourceFile,
Directory,
Expand Down
Loading