diff --git a/src/codegen/sdk/core/class_definition.py b/src/codegen/sdk/core/class_definition.py index bbf2682ab..73208bf37 100644 --- a/src/codegen/sdk/core/class_definition.py +++ b/src/codegen/sdk/core/class_definition.py @@ -110,7 +110,7 @@ def parent_class_names(self) -> list[Name | ChainedAttribute]: return [] @reader - def get_parent_class(self, parent_class_name: str) -> Editable | None: + def get_parent_class(self, parent_class_name: str, optional: bool = False) -> Editable | None: """Returns the parent class node with the specified name. Retrieves a parent class Name or ChainedAttribute node from this class's list of parent class names that matches @@ -118,11 +118,21 @@ def get_parent_class(self, parent_class_name: str) -> Editable | None: Args: parent_class_name (str): The name of the parent class to find. + optional (bool, optional): Whether to return None if the parent class is not found. Defaults to False. Returns: Editable | None: The matching parent class node, or None if no match is found. """ - return next((p for p in self.parent_class_names if p.source == parent_class_name), None) + parent_class = [p for p in self.parent_class_names if p.source == parent_class_name] + if not parent_class: + if not optional: + msg = f"Parent class {parent_class_name} not found in class {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(parent_class) > 1: + msg = f"Multiple parent classes found with name {parent_class_name} in class {self.name}." + raise ValueError(msg) + return parent_class[0] @property @reader @@ -233,13 +243,14 @@ def methods(self, *, max_depth: int | None = 0, private: bool = True, magic: boo return list(result.values()) @reader - def get_nested_class(self, name: str) -> Self | None: + def get_nested_class(self, name: str, optional: bool = False) -> Self | None: """Returns a nested class by name from the current class. Searches through the nested classes defined in the class and returns the first one that matches the given name. Args: name (str): The name of the nested class to find. + optional (bool, optional): Whether to return None if the nested class is not found. Defaults to False. Returns: Self | None: The nested class if found, None otherwise. @@ -247,16 +258,20 @@ def get_nested_class(self, name: str) -> Self | None: for m in self.nested_classes: if m.name == name: return m + if not optional: + msg = f"Nested class {name} not found in class {self.name}. Use optional=True to return None instead." + raise ValueError(msg) return None @reader - def get_method(self, name: str) -> TFunction | None: + def get_method(self, name: str, optional: bool = False) -> TFunction | None: """Returns a specific method by name from the class or any of its superclasses. Searches through the class's methods and its superclasses' methods to find a method with the specified name. Args: name (str): The name of the method to find. + optional (bool, optional): Whether to return None if the method is not found. Defaults to False. Returns: TFunction | None: The method if found, None otherwise. @@ -267,6 +282,9 @@ def get_method(self, name: str) -> TFunction | None: for m in c.methods: if m.name == name: return m + if not optional: + msg = f"Method {name} not found in class {self.name}. Use optional=True to return None instead." + raise ValueError(msg) return None @proxy_property @@ -293,13 +311,14 @@ def attributes(self, *, max_depth: int | None = 0, private: bool = True) -> list return list(result.values()) @reader - def get_attribute(self, name: str) -> Attribute | None: + def get_attribute(self, name: str, optional: bool = False) -> Attribute | None: """Returns a specific attribute by name. Searches for an attribute with the given name in the current class and its superclasses. Args: name (str): The name of the attribute to search for. + optional (bool, optional): Whether to return None if the attribute is not found. Defaults to False. Returns: Attribute | None: The matching attribute if found, None otherwise. If multiple attributes with the same name exist in the inheritance hierarchy, returns the first one found. @@ -310,6 +329,9 @@ def get_attribute(self, name: str) -> Attribute | None: for m in c.code_block.get_attributes(name): if m.name == name: return m + if not optional: + msg = f"Attribute {name} not found in class {self.name}. Use optional=True to return None instead." + raise ValueError(msg) return None #################################################################################################################### diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index d9a6e6960..7de198a2d 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -533,7 +533,7 @@ def has_file(self, filepath: str, ignore_case: bool = False) -> bool: @overload def get_file(self, filepath: str, *, optional: Literal[False] = ..., ignore_case: bool = ...) -> TSourceFile: ... @overload - def get_file(self, filepath: str, *, optional: Literal[True], ignore_case: bool = ...) -> TSourceFile | None: ... + def get_file(self, filepath: str, *, optional: Literal[True] = ..., ignore_case: bool = ...) -> TSourceFile | None: ... def get_file(self, filepath: str, *, optional: bool = False, ignore_case: bool = False) -> TSourceFile | None: """Retrieves a file from the codebase by its filepath. @@ -637,7 +637,7 @@ def get_symbol(self, symbol_name: str, optional: bool = False) -> TSymbol | None ValueError: If multiple symbols are found with the same name, or if no symbol is found and optional=False. """ symbols = self.get_symbols(symbol_name) - if len(symbols) == 0: + if not symbols: if not optional: msg = f"Symbol {symbol_name} not found in codebase. Use optional=True to return None instead." raise ValueError(msg) @@ -677,7 +677,7 @@ def get_class(self, class_name: str, optional: bool = False) -> TClass | None: ValueError: If the class is not found and optional=False, or if multiple classes with the same name exist. """ matches = [c for c in self.classes if c.name == class_name] - if len(matches) == 0: + if not matches: if not optional: msg = f"Class {class_name} not found in codebase. Use optional=True to return None instead." raise ValueError(msg) @@ -706,7 +706,7 @@ def get_function(self, function_name: str, optional: bool = False) -> TFunction ValueError: If function is not found and optional=False, or if multiple matching functions exist. """ matches = [f for f in self.functions if f.name == function_name] - if len(matches) == 0: + if not matches: if not optional: msg = f"Function {function_name} not found in codebase. Use optional=True to return None instead." raise ValueError(msg) diff --git a/src/codegen/sdk/core/detached_symbols/code_block.py b/src/codegen/sdk/core/detached_symbols/code_block.py index a5fde4d62..59f788bdf 100644 --- a/src/codegen/sdk/core/detached_symbols/code_block.py +++ b/src/codegen/sdk/core/detached_symbols/code_block.py @@ -140,7 +140,7 @@ def comments(self) -> list[Comment[Parent, Self]]: return [x for x in self.statements if x.statement_type == StatementType.COMMENT] @reader - def get_comment(self, comment_src: str) -> Comment[Parent, Self] | None: + def get_comment(self, comment_src: str, optional: bool = False) -> Comment[Parent, Self] | None: """Gets the first comment statement containing a specific text string. Searches through all nested statement levels in the code block to find a comment that contains @@ -148,12 +148,22 @@ def get_comment(self, comment_src: str) -> Comment[Parent, Self] | None: Args: comment_src (str): The text string to search for within comment statements. + optional (bool, optional): If True, returns None instead of raising an error if the comment is not found. Returns: Comment[Parent, Self] | None: The first comment statement containing the search text, or None if no matching comment is found. """ - return next((x for x in self._get_statements(StatementType.COMMENT) if comment_src in x.source), None) + comment = [x for x in self._get_statements(StatementType.COMMENT) if comment_src in x.source] + if not comment: + if not optional: + msg = f"Comment {comment_src} not found in code block {self.source}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(comment) > 1: + msg = f"Multiple comments found with text {comment_src} in code block {self.source}." + raise ValueError(msg) + return comment[0] @property @reader diff --git a/src/codegen/sdk/core/detached_symbols/function_call.py b/src/codegen/sdk/core/detached_symbols/function_call.py index 4507e65ff..70ac75777 100644 --- a/src/codegen/sdk/core/detached_symbols/function_call.py +++ b/src/codegen/sdk/core/detached_symbols/function_call.py @@ -374,7 +374,7 @@ def find_parameter_by_name(self, name: str) -> Parameter | None: return param @reader - def get_arg_by_parameter_name(self, param_name: str) -> Argument | None: + def get_arg_by_parameter_name(self, param_name: str, optional: bool = False) -> Argument | None: """Returns an argument by its parameter name. Searches through the arguments of a function call to find an argument that matches @@ -384,6 +384,7 @@ def get_arg_by_parameter_name(self, param_name: str) -> Argument | None: Args: param_name (str): The name of the parameter to search for. + optional (bool, optional): If True, returns None instead of raising an error if the parameter is not found. Returns: Argument | None: The matching argument if found, None otherwise. @@ -402,12 +403,18 @@ def get_arg_by_parameter_name(self, param_name: str) -> Argument | None: if param.name == param_name: return arg + if not optional: + msg = f"Parameter {param_name} not found in function call {self.source}. Use optional=True to return None instead." + raise ValueError(msg) + return None + @reader - def get_arg_by_index(self, arg_idx: int) -> Argument | None: + def get_arg_by_index(self, arg_idx: int, optional: bool = False) -> Argument | None: """Returns the Argument with the given index from the function call's argument list. Args: arg_idx (int): The index of the argument to retrieve. + optional (bool, optional): If True, returns None instead of raising an error if the index is out of bounds. Returns: Argument | None: The Argument object at the specified index, or None if the index is out of bounds. @@ -415,6 +422,9 @@ def get_arg_by_index(self, arg_idx: int) -> Argument | None: try: return self.args[arg_idx] except IndexError: + if not optional: + msg = f"Index {arg_idx} is out of bounds for function call {self.source}. Use optional=True to return None instead." + raise IndexError(msg) return None #################################################################################################################### diff --git a/src/codegen/sdk/core/directory.py b/src/codegen/sdk/core/directory.py index 504d608e0..e95e88987 100644 --- a/src/codegen/sdk/core/directory.py +++ b/src/codegen/sdk/core/directory.py @@ -216,7 +216,7 @@ def files_generator(self, *args: FilesParam.args, **kwargs: FilesParam.kwargs) - """Yield files recursively from the directory.""" yield from self.files(*args, extensions="*", **kwargs, recursive=True) - def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None: + def get_file(self, filename: str, ignore_case: bool = False, optional: bool = False) -> TFile | None: """Get a file by its name relative to the directory.""" file_path = os.path.join(self.dirpath, filename) absolute_path = self.ctx.to_absolute(file_path) @@ -230,11 +230,19 @@ def get_file(self, filename: str, ignore_case: bool = False) -> TFile | None: return self.ctx._get_raw_file_from_path(file) elif not ignore_case and str(absolute_path) == str(file): return self.ctx._get_raw_file_from_path(file) + if not optional: + msg = f"File {filename} not found in directory {self.dirpath}. Use optional=True to return None instead." + raise ValueError(msg) return None - def get_subdirectory(self, subdirectory_name: str) -> Self | None: + def get_subdirectory(self, subdirectory_name: str, optional: bool = False) -> Self | None: """Get a subdirectory by its name (relative to the directory).""" - return self.ctx.get_directory(os.path.join(self.dirpath, subdirectory_name)) + if directory := self.ctx.get_directory(os.path.join(self.dirpath, subdirectory_name)): + return directory + if not optional: + msg = f"Subdirectory {subdirectory_name} not found in directory {self.dirpath}. Use optional=True to return None instead." + raise ValueError(msg) + return None def update_filepath(self, new_filepath: str) -> None: """Update the filepath of the directory and its contained files.""" diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 8ad9e1385..e5db03c31 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -20,6 +20,7 @@ from codegen.sdk.core.class_definition import Class from codegen.sdk.core.dataclasses.usage import UsageType from codegen.sdk.core.directory import Directory +from codegen.sdk.core.function import Function from codegen.sdk.core.import_resolution import Import, WildcardImport from codegen.sdk.core.interfaces.editable import Editable from codegen.sdk.core.interfaces.has_attribute import HasAttribute @@ -41,7 +42,6 @@ if TYPE_CHECKING: from codegen.sdk.core.assignment import Assignment from codegen.sdk.core.detached_symbols.code_block import CodeBlock - from codegen.sdk.core.function import Function from codegen.sdk.core.interface import Interface logger = get_logger(__name__) @@ -694,16 +694,22 @@ def has_import(self, symbol_alias: str) -> bool: return any(a.source == symbol_alias for a in aliases) @reader - def get_import(self, symbol_alias: str) -> TImport | None: + def get_import(self, symbol_alias: str, optional: bool = False) -> TImport | None: """Returns the import with matching alias. Returns None if not found. Args: symbol_alias (str): The alias name to search for. This can match either the direct import name or the aliased name. + optional (bool, optional): Whether to return None if the import is not found. Defaults to False. Returns: TImport | None: The import statement with the matching alias if found, None otherwise. """ - return next((x for x in self.imports if x.alias is not None and x.alias.source == symbol_alias), None) + if import_ := next((x for x in self.imports if x.alias is not None and x.alias.source == symbol_alias), None): + return import_ + if not optional: + msg = f"Import with alias {symbol_alias} not found in file {self.filepath}. Use optional=True to return None instead." + raise ValueError(msg) + return None @proxy_property def symbols(self, nested: bool = False) -> list[Symbol | TClass | TFunction | TGlobalVar | TInterface]: @@ -732,7 +738,7 @@ def get_nodes(self, *, sort_by_id: bool = False, sort: bool = True) -> Sequence[ return ret @reader - def get_symbol(self, name: str) -> Symbol | None: + def get_symbol(self, name: str, optional: bool = False) -> Symbol | None: """Gets a symbol by its name from the file. Attempts to resolve the symbol by name using name resolution rules first. If that fails, @@ -740,6 +746,7 @@ def get_symbol(self, name: str) -> Symbol | None: Args: name (str): The name of the symbol to find. + optional (bool, optional): Whether to return None if the symbol is not found. Defaults to False. Returns: Symbol | None: The found symbol, or None if not found. @@ -747,7 +754,16 @@ def get_symbol(self, name: str) -> Symbol | None: if symbol := self.resolve_name(name, self.end_byte): if isinstance(symbol, Symbol): return symbol - return next((x for x in self.symbols if x.name == name), None) + symbol = [x for x in self.symbols if x.name == name] + if len(symbol) == 0: + if not optional: + msg = f"Symbol {name} not found in file {self.filepath}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(symbol) > 1: + msg = f"Multiple symbols found with name {name} in file {self.filepath}." + raise ValueError(msg) + return symbol[0] @property @reader(cache=False) @@ -783,16 +799,26 @@ def global_vars(self) -> list[TGlobalVar]: return [s for s in self.symbols if s.symbol_type == SymbolType.GlobalVar] @reader - def get_global_var(self, name: str) -> TGlobalVar | None: + def get_global_var(self, name: str, optional: bool = False) -> TGlobalVar | None: """Returns a specific global var by name. Returns None if not found. Args: name (str): The name of the global variable to find. + optional (bool, optional): Whether to return None if the global var is not found. Defaults to False. Returns: TGlobalVar | None: The global variable if found, None otherwise. """ - return next((x for x in self.global_vars if x.name == name), None) + global_var = [x for x in self.global_vars if x.name == name] + if not global_var: + if not optional: + msg = f"Global var {name} not found in file {self.filepath}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(global_var) > 1: + msg = f"Multiple global vars found with name {name} in file {self.filepath}." + raise ValueError(msg) + return global_var[0] @property @reader(cache=False) @@ -808,13 +834,14 @@ def classes(self) -> list[TClass]: return [s for s in self.symbols if s.symbol_type == SymbolType.Class] @reader - def get_class(self, name: str) -> TClass | None: + def get_class(self, name: str, optional: bool = False) -> TClass | None: """Returns a specific Class by full name. Returns None if not found. Searches for a class in the file with the specified name. Similar to get_symbol, but specifically for Class types. Args: name (str): The full name of the class to search for. + optional (bool, optional): Whether to return None if the class is not found. Defaults to False. Returns: TClass | None: The matching Class object if found, None otherwise. @@ -822,6 +849,16 @@ def get_class(self, name: str) -> TClass | None: if symbol := self.resolve_name(name, self.end_byte): if isinstance(symbol, Class): return symbol + class_ = [x for x in self.classes if x.name == name] + if not class_: + if not optional: + msg = f"Class {name} not found in file {self.filepath}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(class_) > 1: + msg = f"Multiple classes found with name {name} in file {self.filepath}." + raise ValueError(msg) + return class_[0] @property @reader(cache=False) @@ -837,22 +874,35 @@ def functions(self) -> list[TFunction]: return [s for s in self.symbols if s.symbol_type == SymbolType.Function] @reader - def get_function(self, name: str) -> TFunction | None: + def get_function(self, name: str, optional: bool = False) -> TFunction | None: """Returns a specific Function by name. Gets a Function object from the file by searching for a function with the given name. Args: name (str): The name of the function to find. + optional (bool, optional): Whether to return None if the function is not found. Defaults to False. Returns: TFunction | None: The matching Function object if found, None otherwise. """ - return next((x for x in self.functions if x.name == name), None) + if symbol := self.resolve_name(name, self.end_byte): + if isinstance(symbol, Function): + return symbol + function = [x for x in self.functions if x.name == name] + if not function: + if not optional: + msg = f"Function {name} not found in file {self.filepath}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(function) > 1: + msg = f"Multiple functions found with name {name} in file {self.filepath}." + raise ValueError(msg) + return function[0] @noapidoc @reader - def get_node_by_name(self, name: str) -> Symbol | TImport | None: + def get_node_by_name(self, name: str, optional: bool = False) -> Symbol | TImport | None: """Returns something defined in this file by name. Used during import resolution @@ -863,6 +913,9 @@ def get_node_by_name(self, name: str) -> Symbol | TImport | None: imp = self.get_import(name) if imp is not None: return imp + if not optional: + msg = f"Symbol {name} not found in file {self.filepath}. Use optional=True to return None instead." + raise ValueError(msg) return None @cached_property diff --git a/src/codegen/sdk/core/interface.py b/src/codegen/sdk/core/interface.py index 2c605d694..258b459d2 100644 --- a/src/codegen/sdk/core/interface.py +++ b/src/codegen/sdk/core/interface.py @@ -52,12 +52,21 @@ def attributes(self) -> list[TAttribute]: raise NotImplementedError(msg) @reader - def get_attribute(self, name: str) -> TAttribute | None: + def get_attribute(self, name: str, optional: bool = False) -> TAttribute | None: """Returns the attribute with the given name, if it exists. Otherwise, returns None. """ - return next((x for x in self.attributes if x.name == name), None) + attribute = [x for x in self.attributes if x.name == name] + if not attribute: + if not optional: + msg = f"Attribute {name} not found in interface {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(attribute) > 1: + msg = f"Multiple attributes found with name {name} in interface {self.name}." + raise ValueError(msg) + return attribute[0] @reader def extends(self, parent_interface: str | Interface, max_depth: int | None = None) -> bool: diff --git a/src/codegen/sdk/core/interfaces/callable.py b/src/codegen/sdk/core/interfaces/callable.py index 83a2db7b9..b10ae671d 100644 --- a/src/codegen/sdk/core/interfaces/callable.py +++ b/src/codegen/sdk/core/interfaces/callable.py @@ -87,44 +87,74 @@ def parameters(self) -> SymbolGroup[TParameter, Self] | list[TParameter]: return self._parameters @reader - def get_parameter(self, name: str) -> TParameter | None: + def get_parameter(self, name: str, optional: bool = False) -> TParameter | None: """Gets a specific parameter from the callable's parameters list by name. Args: name (str): The name of the parameter to retrieve. + optional (bool, optional): If True, returns None instead of raising an error if the parameter is not found. Returns: TParameter | None: The parameter with the specified name, or None if no parameter with that name exists or if there are no parameters. """ - return next((x for x in self._parameters if x.name == name), None) + parameter = [x for x in self._parameters if x.name == name] + if not parameter: + if not optional: + msg = f"Parameter {name} not found in callable {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(parameter) > 1: + msg = f"Multiple parameters with name {name} found in callable {self.name}. Use get_parameter_by_type to resolve." + raise ValueError(msg) + return parameter[0] @reader - def get_parameter_by_index(self, index: int) -> TParameter | None: + def get_parameter_by_index(self, index: int, optional: bool = False) -> TParameter | None: """Returns the parameter at the given index. Retrieves a parameter from the callable's parameter list based on its positional index. Args: index (int): The index of the parameter to retrieve. + optional (bool, optional): If True, returns None instead of raising an error if the parameter is not found. Returns: TParameter | None: The parameter at the specified index, or None if the parameter list is empty or the index does not exist. """ - return next((x for x in self._parameters if x.index == index), None) + parameter = [x for x in self._parameters if x.index == index] + if not parameter: + if not optional: + msg = f"Parameter at index {index} not found in callable {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(parameter) > 1: + msg = f"Multiple parameters at index {index} found in callable {self.name}. Use get_parameter_by_type to resolve." + raise ValueError(msg) + return parameter[0] @reader - def get_parameter_by_type(self, type: "Symbol") -> TParameter | None: + def get_parameter_by_type(self, type: "Symbol", optional: bool = False) -> TParameter | None: """Retrieves a parameter from the callable by its type. Searches through the callable's parameters to find a parameter with the specified type. Args: type (Symbol): The type to search for. + optional (bool, optional): If True, returns None instead of raising an error if the parameter is not found. Returns: TParameter | None: The parameter with the specified type, or None if no parameter is found or if the callable has no parameters. """ if self._parameters is None: return None - return next((x for x in self._parameters if x.type == type), None) + parameter = [x for x in self._parameters if x.type == type] + if not parameter: + if not optional: + msg = f"Parameter of type {type} not found in callable {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(parameter) > 1: + msg = f"Multiple parameters of type {type} found in callable {self.name}. Use get_parameter_by_name to resolve." + raise ValueError(msg) + return parameter[0] diff --git a/src/codegen/sdk/core/interfaces/has_symbols.py b/src/codegen/sdk/core/interfaces/has_symbols.py index 2c8bbe445..42dbca140 100644 --- a/src/codegen/sdk/core/interfaces/has_symbols.py +++ b/src/codegen/sdk/core/interfaces/has_symbols.py @@ -84,34 +84,98 @@ def imports(self) -> list[TImport]: """Get a recursive list of all imports in files container.""" return list(chain.from_iterable(f.imports for f in self.files_generator())) - def get_symbol(self, name: str) -> TSymbol | None: + def get_symbol(self, name: str, optional: bool = False) -> TSymbol | None: """Get a symbol by name in files container.""" - return next((s for s in self.symbols if s.name == name), None) - - def get_import_statement(self, name: str) -> TImportStatement | None: + symbol = next((s for s in self.symbols if s.name == name), None) + if not symbol: + if not optional: + msg = f"Symbol {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(symbol) > 1: + msg = f"Multiple symbols with name {name} found in files container. Use get_symbol_by_type to resolve." + raise ValueError(msg) + return symbol + + def get_import_statement(self, name: str, optional: bool = False) -> TImportStatement | None: """Get an import statement by name in files container.""" - return next((s for s in self.import_statements if s.name == name), None) - - def get_global_var(self, name: str) -> TGlobalVar | None: + import_statement = next((s for s in self.import_statements if s.name == name), None) + if not import_statement: + if not optional: + msg = f"Import statement {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(import_statement) > 1: + msg = f"Multiple import statements with name {name} found in files container. Use get_import_statement_by_type to resolve." + raise ValueError(msg) + return import_statement + + def get_global_var(self, name: str, optional: bool = False) -> TGlobalVar | None: """Get a global variable by name in files container.""" - return next((s for s in self.global_vars if s.name == name), None) - - def get_class(self, name: str) -> TClass | None: + global_var = next((s for s in self.global_vars if s.name == name), None) + if not global_var: + if not optional: + msg = f"Global variable {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(global_var) > 1: + msg = f"Multiple global variables with name {name} found in files container. Use get_global_var_by_type to resolve." + raise ValueError(msg) + return global_var + + def get_class(self, name: str, optional: bool = False) -> TClass | None: """Get a class by name in files container.""" - return next((s for s in self.classes if s.name == name), None) - - def get_function(self, name: str) -> TFunction | None: + class_ = next((s for s in self.classes if s.name == name), None) + if not class_: + if not optional: + msg = f"Class {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(class_) > 1: + msg = f"Multiple classes with name {name} found in files container. Use get_class_by_type to resolve." + raise ValueError(msg) + return class_ + + def get_function(self, name: str, optional: bool = False) -> TFunction | None: """Get a function by name in files container.""" - return next((s for s in self.functions if s.name == name), None) + function = next((s for s in self.functions if s.name == name), None) + if not function: + if not optional: + msg = f"Function {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(function) > 1: + msg = f"Multiple functions with name {name} found in files container. Use get_function_by_type to resolve." + raise ValueError(msg) + return function @py_noapidoc def get_export( self: "HasSymbols[TSFile, TSSymbol, TSImportStatement, TSGlobalVar, TSClass, TSFunction, TSImport]", name: str, + optional: bool = False, ) -> "TSExport | None": """Get an export by name in files container (supports only typescript).""" - return next((s for s in self.exports if s.name == name), None) - - def get_import(self, name: str) -> TImport | None: + export = next((s for s in self.exports if s.name == name), None) + if not export: + if not optional: + msg = f"Export {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(export) > 1: + msg = f"Multiple exports with name {name} found in files container. Use get_export_by_type to resolve." + raise ValueError(msg) + return export + + def get_import(self, name: str, optional: bool = False) -> TImport | None: """Get an import by name in files container.""" - return next((s for s in self.imports if s.name == name), None) + import_ = next((s for s in self.imports if s.name == name), None) + if not import_: + if not optional: + msg = f"Import {name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(import_) > 1: + msg = f"Multiple imports with name {name} found in files container. Use get_import_by_type to resolve." + raise ValueError(msg) + return import_ diff --git a/src/codegen/sdk/core/type_alias.py b/src/codegen/sdk/core/type_alias.py index 17171155c..676e6510a 100644 --- a/src/codegen/sdk/core/type_alias.py +++ b/src/codegen/sdk/core/type_alias.py @@ -60,9 +60,18 @@ def attributes(self) -> list[TAttribute]: """List of expressions defined in this Type object.""" @reader - def get_attribute(self, name: str) -> TAttribute | None: + def get_attribute(self, name: str, optional: bool = False) -> TAttribute | None: """Get attribute by name.""" - return next((x for x in self.attributes if x.name == name), None) + attribute = [x for x in self.attributes if x.name == name] + if not attribute: + if not optional: + msg = f"Attribute {name} not found in type alias {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(attribute) > 1: + msg = f"Multiple attributes found with name {name} in type alias {self.name}." + raise ValueError(msg) + return attribute[0] @noapidoc @reader diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index 3b1fc9f93..6e67595ab 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -222,7 +222,7 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P return super().valid_import_names @noapidoc - def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None: + def get_node_from_wildcard_chain(self, symbol_name: str, optional: bool = False) -> PySymbol | None: """Recursively searches for a symbol through wildcard import chains. Attempts to find a symbol by name in the current file, and if not found, recursively searches @@ -230,6 +230,7 @@ def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None: Args: symbol_name (str): The name of the symbol to search for. + optional (bool, optional): Whether to return None if the symbol is not found. Defaults to False. Returns: PySymbol | None: The found symbol if it exists in this file or any of its wildcard @@ -244,10 +245,14 @@ def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None: if imp_resolution := wildcard_import.resolve_import(): node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name) + if not node and not optional: + msg = f"Symbol {symbol_name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return node @noapidoc - def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbol | None: + def get_node_wildcard_resolves_for(self, symbol_name: str, optional: bool = False) -> PyImport | PySymbol | None: """Finds the wildcard import that resolves a given symbol name. Searches for a symbol by name, first in the current file, then through wildcard imports. @@ -256,6 +261,7 @@ def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbo Args: symbol_name (str): The name of the symbol to search for. + optional (bool, optional): Whether to return None if the symbol is not found. Defaults to False. Returns: PyImport | PySymbol | None: @@ -273,4 +279,8 @@ def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbo if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name): node = wildcard_import + if not node and not optional: + msg = f"Symbol {symbol_name} not found in files container. Use optional=True to return None instead." + raise ValueError(msg) + return node diff --git a/src/codegen/sdk/typescript/enum_definition.py b/src/codegen/sdk/typescript/enum_definition.py index faacc6e32..a4c4d375e 100644 --- a/src/codegen/sdk/typescript/enum_definition.py +++ b/src/codegen/sdk/typescript/enum_definition.py @@ -64,16 +64,26 @@ def attributes(self) -> list[TSAttribute[Self, None]]: return self.code_block.attributes @reader - def get_attribute(self, name: str) -> TSAttribute | None: + def get_attribute(self, name: str, optional: bool = False) -> TSAttribute | None: """Returns an attribute from the TypeScript enum by its name. Args: name (str): The name of the attribute to retrieve. + optional (bool, optional): Whether to return None if the attribute is not found. Defaults to False. Returns: TSAttribute | None: The attribute with the given name if it exists, None otherwise. """ - return next((x for x in self.attributes if x.name == name), None) + attribute = [x for x in self.attributes if x.name == name] + if not attribute: + if not optional: + msg = f"Attribute {name} not found in enum {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(attribute) > 1: + msg = f"Multiple attributes found with name {name} in enum {self.name}." + raise ValueError(msg) + return attribute[0] @noapidoc @commiter diff --git a/src/codegen/sdk/typescript/file.py b/src/codegen/sdk/typescript/file.py index 4c937292d..57c8a37e9 100644 --- a/src/codegen/sdk/typescript/file.py +++ b/src/codegen/sdk/typescript/file.py @@ -105,18 +105,28 @@ def named_exports(self) -> list[TSExport]: return [x for x in self.exports if not x.is_default_export()] @reader - def get_export(self, export_name: str) -> TSExport | None: + def get_export(self, export_name: str, optional: bool = False) -> TSExport | None: """Returns an export object with the specified name from the file. This method searches for an export with the given name in the file. Args: export_name (str): The name of the export to find. + optional (bool, optional): Whether to return None if the export is not found. Defaults to False. Returns: TSExport | None: The export object if found, None otherwise. """ - return next((x for x in self.exports if x.name == export_name), None) + export = [x for x in self.exports if x.name == export_name] + if not export: + if not optional: + msg = f"Export {export_name} not found in file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(export) > 1: + msg = f"Multiple exports found with name {export_name} in file {self.file_path}." + raise ValueError(msg) + return export[0] @property @reader @@ -134,16 +144,26 @@ def interfaces(self) -> list[TSInterface]: return [s for s in self.symbols if s.symbol_type == SymbolType.Interface] @reader - def get_interface(self, name: str) -> TSInterface | None: + def get_interface(self, name: str, optional: bool = False) -> TSInterface | None: """Retrieves a specific interface from the file by its name. Args: name (str): The name of the interface to find. + optional (bool, optional): Whether to return None if the interface is not found. Defaults to False. Returns: TSInterface | None: The interface with the specified name if found, None otherwise. """ - return next((x for x in self.interfaces if x.name == name), None) + interface = [x for x in self.interfaces if x.name == name] + if not interface: + if not optional: + msg = f"Interface {name} not found in file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(interface) > 1: + msg = f"Multiple interfaces found with name {name} in file {self.file_path}." + raise ValueError(msg) + return interface[0] @property @reader @@ -158,18 +178,28 @@ def types(self) -> list[TSTypeAlias]: return [s for s in self.symbols if s.symbol_type == SymbolType.Type] @reader - def get_type(self, name: str) -> TSTypeAlias | None: + def get_type(self, name: str, optional: bool = False) -> TSTypeAlias | None: """Returns a specific Type by name from the file's types. Retrieves a TypeScript type alias by its name from the file's collection of types. Args: name (str): The name of the type alias to retrieve. + optional (bool, optional): Whether to return None if the type alias is not found. Defaults to False. Returns: TSTypeAlias | None: The TypeScript type alias with the matching name, or None if not found. """ - return next((x for x in self.types if x.name == name), None) + type_ = [x for x in self.types if x.name == name] + if not type_: + if not optional: + msg = f"Type {name} not found in file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(type_) > 1: + msg = f"Multiple types found with name {name} in file {self.file_path}." + raise ValueError(msg) + return type_[0] @staticmethod def get_extensions() -> list[str]: @@ -320,12 +350,13 @@ def has_export_statement_for_path(self, relative_path: str, export_type: str = " #################################################################################################################### @reader - def get_export_statement_for_path(self, relative_path: str, export_type: str = "EXPORT") -> ExportStatement | None: + def get_export_statement_for_path(self, relative_path: str, export_type: str = "EXPORT", optional: bool = False) -> ExportStatement | None: """Gets the first export of specified type that contains the given path in single or double quotes. Args: relative_path (str): The path to check for in export statements export_type (str): Type of export to get - "WILDCARD", "TYPE", or "EXPORT" (default) + optional (bool, optional): Whether to return None if the export statement is not found. Defaults to False. Returns: TSExport | None: The first matching export if found, None otherwise. @@ -342,6 +373,9 @@ def get_export_statement_for_path(self, relative_path: str, export_type: str = " if condition(exp): return exp + if not optional: + msg = f"Export statement for path {relative_path} not found in file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) return None @noapidoc @@ -424,16 +458,26 @@ def update_filepath(self, new_filepath: str) -> None: imp.set_import_module(f"'{new_module_name}'") @reader - def get_namespace(self, name: str) -> TSNamespace | None: + def get_namespace(self, name: str, optional: bool = False) -> TSNamespace | None: """Returns a specific namespace by name from the file's namespaces. Args: name (str): The name of the namespace to find. + optional (bool, optional): Whether to return None if the namespace is not found. Defaults to False. Returns: TSNamespace | None: The namespace with the specified name if found, None otherwise. """ - return next((x for x in self.symbols if isinstance(x, TSNamespace) and x.name == name), None) + namespace = [x for x in self.symbols if isinstance(x, TSNamespace) and x.name == name] + if not namespace: + if not optional: + msg = f"Namespace {name} not found in file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(namespace) > 1: + msg = f"Multiple namespaces found with name {name} in file {self.file_path}." + raise ValueError(msg) + return namespace[0] @property @reader diff --git a/src/codegen/sdk/typescript/interfaces/has_block.py b/src/codegen/sdk/typescript/interfaces/has_block.py index be8bb68c4..7e4eaf75b 100644 --- a/src/codegen/sdk/typescript/interfaces/has_block.py +++ b/src/codegen/sdk/typescript/interfaces/has_block.py @@ -87,7 +87,7 @@ def jsx_elements(self) -> list[JSXElement[Self]]: return jsx_elements @reader - def get_component(self, component_name: str) -> JSXElement[Self] | None: + def get_component(self, component_name: str, optional: bool = False) -> JSXElement[Self] | None: """Returns a specific JSX element from within this symbol's JSX elements. Searches through all JSX elements in this symbol's code block and returns the first one that matches @@ -95,6 +95,7 @@ def get_component(self, component_name: str) -> JSXElement[Self] | None: Args: component_name (str): The name of the JSX component to find. + optional (bool, optional): If True, return None if the component is not found. Defaults to False. Returns: JSXElement[Self] | None: The matching JSX element if found, None otherwise. @@ -102,6 +103,9 @@ def get_component(self, component_name: str) -> JSXElement[Self] | None: for component in self.jsx_elements: if component.name == component_name: return component + if not optional: + msg = f"Component {component_name} not found in symbol {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) return None @cached_property diff --git a/src/codegen/sdk/typescript/namespace.py b/src/codegen/sdk/typescript/namespace.py index 4d1e3f7db..086ed1608 100644 --- a/src/codegen/sdk/typescript/namespace.py +++ b/src/codegen/sdk/typescript/namespace.py @@ -73,12 +73,13 @@ def symbols(self) -> list[Symbol]: all_symbols.append(stmt) return all_symbols - def get_symbol(self, name: str, recursive: bool = True) -> Symbol | None: + def get_symbol(self, name: str, recursive: bool = True, optional: bool = False) -> Symbol | None: """Get a symbol by name from this namespace. Args: name: Name of the symbol to find recursive: If True, also search in nested namespaces + optional: If True, return None if the symbol is not found Returns: Symbol | None: The found symbol, or None if not found @@ -93,6 +94,10 @@ def get_symbol(self, name: str, recursive: bool = True) -> Symbol | None: nested_symbol = symbol.get_symbol(name, recursive=True) return nested_symbol + if not optional: + msg = f"Symbol {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None @cached_property @@ -104,13 +109,14 @@ def functions(self) -> list[TSFunction]: """ return [symbol for symbol in self.symbols if isinstance(symbol, TSFunction)] - def get_function(self, name: str, recursive: bool = True, use_full_name: bool = False) -> TSFunction | None: + def get_function(self, name: str, recursive: bool = True, use_full_name: bool = False, optional: bool = False) -> TSFunction | None: """Get a function by name from this namespace. Args: name: Name of the function to find (can be fully qualified like 'Outer.Inner.func') recursive: If True, also search in nested namespaces use_full_name: If True, match against the full qualified name + optional: If True, return None if the function is not found Returns: TSFunction | None: The found function, or None if not found @@ -120,8 +126,13 @@ def get_function(self, name: str, recursive: bool = True, use_full_name: bool = target_ns = self.get_namespace(namespace_path) return target_ns.get_function(func_name, recursive=False) if target_ns else None - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSFunction) else None + if (symbol := self.get_symbol(name, recursive=recursive)) and isinstance(symbol, TSFunction): + return symbol + + if not optional: + msg = f"Function {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None @cached_property def classes(self) -> list[TSClass]: @@ -132,52 +143,85 @@ def classes(self) -> list[TSClass]: """ return [symbol for symbol in self.symbols if isinstance(symbol, TSClass)] - def get_class(self, name: str, recursive: bool = True) -> TSClass | None: + def get_class(self, name: str, recursive: bool = True, optional: bool = False) -> TSClass | None: """Get a class by name from this namespace. Args: name: Name of the class to find recursive: If True, also search in nested namespaces + optional: If True, return None if the class is not found + + Returns: + TSClass | None: The found class, or None if not found """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSClass) else None + if (symbol := self.get_symbol(name, recursive=recursive)) and isinstance(symbol, TSClass): + return symbol + if not optional: + msg = f"Class {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None - def get_interface(self, name: str, recursive: bool = True) -> TSInterface | None: + def get_interface(self, name: str, recursive: bool = True, optional: bool = False) -> TSInterface | None: """Get an interface by name from this namespace. Args: name: Name of the interface to find recursive: If True, also search in nested namespaces + optional: If True, return None if the interface is not found + + Returns: + TSInterface | None: The found interface, or None if not found """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSInterface) else None + if (symbol := self.get_symbol(name, recursive=recursive)) and isinstance(symbol, TSInterface): + return symbol + if not optional: + msg = f"Interface {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None - def get_type(self, name: str, recursive: bool = True) -> TSTypeAlias | None: + def get_type(self, name: str, recursive: bool = True, optional: bool = False) -> TSTypeAlias | None: """Get a type alias by name from this namespace. Args: name: Name of the type to find recursive: If True, also search in nested namespaces + optional: If True, return None if the type is not found + + Returns: + TSTypeAlias | None: The found type alias, or None if not found """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSTypeAlias) else None + if (symbol := self.get_symbol(name, recursive=recursive)) and isinstance(symbol, TSTypeAlias): + return symbol + if not optional: + msg = f"Type alias {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None - def get_enum(self, name: str, recursive: bool = True) -> TSEnum | None: + def get_enum(self, name: str, recursive: bool = True, optional: bool = False) -> TSEnum | None: """Get an enum by name from this namespace. Args: name: Name of the enum to find recursive: If True, also search in nested namespaces + optional: If True, return None if the enum is not found + + Returns: + TSEnum | None: The found enum, or None if not found """ - symbol = self.get_symbol(name, recursive=recursive) - return symbol if isinstance(symbol, TSEnum) else None + if (symbol := self.get_symbol(name, recursive=recursive)) and isinstance(symbol, TSEnum): + return symbol + if not optional: + msg = f"Enum {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) + return None - def get_namespace(self, name: str, recursive: bool = True) -> TSNamespace | None: + def get_namespace(self, name: str, recursive: bool = True, optional: bool = False) -> TSNamespace | None: """Get a namespace by name from this namespace. Args: name: Name of the namespace to find recursive: If True, also search in nested namespaces + optional: If True, return None if the namespace is not found Returns: TSNamespace | None: The found namespace, or None if not found @@ -192,6 +236,9 @@ def get_namespace(self, name: str, recursive: bool = True) -> TSNamespace | None nested_namespace = symbol.get_namespace(name, recursive=True) return nested_namespace + if not optional: + msg = f"Namespace {name} not found in namespace {self.name} of file {self.file_path}. Use optional=True to return None instead." + raise ValueError(msg) return None def get_nested_namespaces(self) -> list[TSNamespace]: diff --git a/src/codegen/sdk/typescript/type_alias.py b/src/codegen/sdk/typescript/type_alias.py index d4d671909..fdda75f46 100644 --- a/src/codegen/sdk/typescript/type_alias.py +++ b/src/codegen/sdk/typescript/type_alias.py @@ -61,7 +61,7 @@ def attributes(self) -> list[TSAttribute]: return self.code_block.attributes @reader - def get_attribute(self, name: str) -> TSAttribute | None: + def get_attribute(self, name: str, optional: bool = False) -> TSAttribute | None: """Retrieves a specific attribute from a TypeScript type alias by its name. Args: @@ -70,4 +70,13 @@ def get_attribute(self, name: str) -> TSAttribute | None: Returns: TSAttribute[TSTypeAlias, None] | None: The attribute with the specified name if found, None otherwise. """ - return next((x for x in self.attributes if x.name == name), None) + attribute = [x for x in self.attributes if x.name == name] + if not attribute: + if not optional: + msg = f"Attribute {name} not found in type alias {self.name}. Use optional=True to return None instead." + raise ValueError(msg) + return None + if len(attribute) > 1: + msg = f"Multiple attributes found with name {name} in type alias {self.name}." + raise ValueError(msg) + return attribute[0]