Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/codegen/sdk/core/class_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,29 @@ def parent_class_names(self) -> list[Name | ChainedAttribute]:
return []

@reader
def get_parent_class(self, parent_class_name: str) -> Editable | None:
def get_parent_class(self, parent_class_name: str, optional: bool = False) -> Editable | None:
"""Returns the parent class node with the specified name.

Retrieves a parent class Name or ChainedAttribute node from this class's list of parent class names that matches
the specified name.

Args:
parent_class_name (str): The name of the parent class to find.
optional (bool, optional): Whether to return None if the parent class is not found. Defaults to False.

Returns:
Editable | None: The matching parent class node, or None if no match is found.
"""
return next((p for p in self.parent_class_names if p.source == parent_class_name), None)
parent_class = [p for p in self.parent_class_names if p.source == parent_class_name]
if not parent_class:
if not optional:
msg = f"Parent class {parent_class_name} not found in class {self.name}. Use optional=True to return None instead."
raise ValueError(msg)
return None
if len(parent_class) > 1:
msg = f"Multiple parent classes found with name {parent_class_name} in class {self.name}."
raise ValueError(msg)
return parent_class[0]

@property
@reader
Expand Down Expand Up @@ -233,30 +243,35 @@ def methods(self, *, max_depth: int | None = 0, private: bool = True, magic: boo
return list(result.values())

@reader
def get_nested_class(self, name: str) -> Self | None:
def get_nested_class(self, name: str, optional: bool = False) -> Self | None:
"""Returns a nested class by name from the current class.

Searches through the nested classes defined in the class and returns the first one that matches the given name.

Args:
name (str): The name of the nested class to find.
optional (bool, optional): Whether to return None if the nested class is not found. Defaults to False.

Returns:
Self | None: The nested class if found, None otherwise.
"""
for m in self.nested_classes:
if m.name == name:
return m
if not optional:
msg = f"Nested class {name} not found in class {self.name}. Use optional=True to return None instead."
raise ValueError(msg)
return None

@reader
def get_method(self, name: str) -> TFunction | None:
def get_method(self, name: str, optional: bool = False) -> TFunction | None:
"""Returns a specific method by name from the class or any of its superclasses.

Searches through the class's methods and its superclasses' methods to find a method with the specified name.

Args:
name (str): The name of the method to find.
optional (bool, optional): Whether to return None if the method is not found. Defaults to False.

Returns:
TFunction | None: The method if found, None otherwise.
Expand All @@ -267,6 +282,9 @@ def get_method(self, name: str) -> TFunction | None:
for m in c.methods:
if m.name == name:
return m
if not optional:
msg = f"Method {name} not found in class {self.name}. Use optional=True to return None instead."
raise ValueError(msg)
return None

@proxy_property
Expand All @@ -293,13 +311,14 @@ def attributes(self, *, max_depth: int | None = 0, private: bool = True) -> list
return list(result.values())

@reader
def get_attribute(self, name: str) -> Attribute | None:
def get_attribute(self, name: str, optional: bool = False) -> Attribute | None:
"""Returns a specific attribute by name.

Searches for an attribute with the given name in the current class and its superclasses.

Args:
name (str): The name of the attribute to search for.
optional (bool, optional): Whether to return None if the attribute is not found. Defaults to False.

Returns:
Attribute | None: The matching attribute if found, None otherwise. If multiple attributes with the same name exist in the inheritance hierarchy, returns the first one found.
Expand All @@ -310,6 +329,9 @@ def get_attribute(self, name: str) -> Attribute | None:
for m in c.code_block.get_attributes(name):
if m.name == name:
return m
if not optional:
msg = f"Attribute {name} not found in class {self.name}. Use optional=True to return None instead."
raise ValueError(msg)
return None

####################################################################################################################
Expand Down
8 changes: 4 additions & 4 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def has_file(self, filepath: str, ignore_case: bool = False) -> bool:
@overload
def get_file(self, filepath: str, *, optional: Literal[False] = ..., ignore_case: bool = ...) -> TSourceFile: ...
@overload
def get_file(self, filepath: str, *, optional: Literal[True], ignore_case: bool = ...) -> TSourceFile | None: ...
def get_file(self, filepath: str, *, optional: Literal[True] = ..., ignore_case: bool = ...) -> TSourceFile | None: ...
def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = False) -> TSourceFile | None:
"""Retrieves a file from the codebase by its filepath.

Expand Down Expand Up @@ -637,7 +637,7 @@ def get_symbol(self, symbol_name: str, optional: bool = False) -> TSymbol | None
ValueError: If multiple symbols are found with the same name, or if no symbol is found and optional=False.
"""
symbols = self.get_symbols(symbol_name)
if len(symbols) == 0:
if not symbols:
if not optional:
msg = f"Symbol {symbol_name} not found in codebase. Use optional=True to return None instead."
raise ValueError(msg)
Expand Down Expand Up @@ -677,7 +677,7 @@ def get_class(self, class_name: str, optional: bool = False) -> TClass | None:
ValueError: If the class is not found and optional=False, or if multiple classes with the same name exist.
"""
matches = [c for c in self.classes if c.name == class_name]
if len(matches) == 0:
if not matches:
if not optional:
msg = f"Class {class_name} not found in codebase. Use optional=True to return None instead."
raise ValueError(msg)
Expand Down Expand Up @@ -706,7 +706,7 @@ def get_function(self, function_name: str, optional: bool = False) -> TFunction
ValueError: If function is not found and optional=False, or if multiple matching functions exist.
"""
matches = [f for f in self.functions if f.name == function_name]
if len(matches) == 0:
if not matches:
if not optional:
msg = f"Function {function_name} not found in codebase. Use optional=True to return None instead."
raise ValueError(msg)
Expand Down
14 changes: 12 additions & 2 deletions src/codegen/sdk/core/detached_symbols/code_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,30 @@ def comments(self) -> list[Comment[Parent, Self]]:
return [x for x in self.statements if x.statement_type == StatementType.COMMENT]

@reader
def get_comment(self, comment_src: str) -> Comment[Parent, Self] | None:
def get_comment(self, comment_src: str, optional: bool = False) -> Comment[Parent, Self] | None:
"""Gets the first comment statement containing a specific text string.

Searches through all nested statement levels in the code block to find a comment that contains
the specified text.

Args:
comment_src (str): The text string to search for within comment statements.
optional (bool, optional): If True, returns None instead of raising an error if the comment is not found.

Returns:
Comment[Parent, Self] | None: The first comment statement containing the search text,
or None if no matching comment is found.
"""
return next((x for x in self._get_statements(StatementType.COMMENT) if comment_src in x.source), None)
comment = [x for x in self._get_statements(StatementType.COMMENT) if comment_src in x.source]
if not comment:
if not optional:
msg = f"Comment {comment_src} not found in code block {self.source}. Use optional=True to return None instead."
raise ValueError(msg)
return None
if len(comment) > 1:
msg = f"Multiple comments found with text {comment_src} in code block {self.source}."
raise ValueError(msg)
return comment[0]

@property
@reader
Expand Down
14 changes: 12 additions & 2 deletions src/codegen/sdk/core/detached_symbols/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def find_parameter_by_name(self, name: str) -> Parameter | None:
return param

@reader
def get_arg_by_parameter_name(self, param_name: str) -> Argument | None:
def get_arg_by_parameter_name(self, param_name: str, optional: bool = False) -> Argument | None:
"""Returns an argument by its parameter name.

Searches through the arguments of a function call to find an argument that matches
Expand All @@ -384,6 +384,7 @@ def get_arg_by_parameter_name(self, param_name: str) -> Argument | None:

Args:
param_name (str): The name of the parameter to search for.
optional (bool, optional): If True, returns None instead of raising an error if the parameter is not found.

Returns:
Argument | None: The matching argument if found, None otherwise.
Expand All @@ -402,19 +403,28 @@ def get_arg_by_parameter_name(self, param_name: str) -> Argument | None:
if param.name == param_name:
return arg

if not optional:
msg = f"Parameter {param_name} not found in function call {self.source}. Use optional=True to return None instead."
raise ValueError(msg)
return None

@reader
def get_arg_by_index(self, arg_idx: int) -> Argument | None:
def get_arg_by_index(self, arg_idx: int, optional: bool = False) -> Argument | None:
"""Returns the Argument with the given index from the function call's argument list.

Args:
arg_idx (int): The index of the argument to retrieve.
optional (bool, optional): If True, returns None instead of raising an error if the index is out of bounds.

Returns:
Argument | None: The Argument object at the specified index, or None if the index is out of bounds.
"""
try:
return self.args[arg_idx]
except IndexError:
if not optional:
msg = f"Index {arg_idx} is out of bounds for function call {self.source}. Use optional=True to return None instead."
raise IndexError(msg)
return None

####################################################################################################################
Expand Down
14 changes: 11 additions & 3 deletions src/codegen/sdk/core/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) -
"""Yield files recursively from the directory."""
yield from self.files(*args, extensions="*", **kwargs, recursive=True)

def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None:
def get_file(self, filename: str, ignore_case: bool = False, optional: bool = False) -> TFile | None:
"""Get a file by its name relative to the directory."""
file_path = os.path.join(self.dirpath, filename)
absolute_path = self.ctx.to_absolute(file_path)
Expand All @@ -230,11 +230,19 @@ def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None:
return self.ctx._get_raw_file_from_path(file)
elif not ignore_case and str(absolute_path) == str(file):
return self.ctx._get_raw_file_from_path(file)
if not optional:
msg = f"File {filename} not found in directory {self.dirpath}. Use optional=True to return None instead."
raise ValueError(msg)
return None

def get_subdirectory(self, subdirectory_name: str) -> Self | None:
def get_subdirectory(self, subdirectory_name: str, optional: bool = False) -> Self | None:
"""Get a subdirectory by its name (relative to the directory)."""
return self.ctx.get_directory(os.path.join(self.dirpath, subdirectory_name))
if directory := self.ctx.get_directory(os.path.join(self.dirpath, subdirectory_name)):
return directory
if not optional:
msg = f"Subdirectory {subdirectory_name} not found in directory {self.dirpath}. Use optional=True to return None instead."
raise ValueError(msg)
return None

def update_filepath(self, new_filepath: str) -> None:
"""Update the filepath of the directory and its contained files."""
Expand Down
Loading
Loading