diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index a5597351c..7b31107a4 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -19,6 +19,12 @@ if TYPE_CHECKING: from tree_sitter import Node +TYPE_DECLARATIONS = { + "class_declaration": "class", + "interface_declaration": "interface", + "enum_declaration": "enum", +} + logger = logging.getLogger(__name__) @@ -253,23 +259,25 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum". """ - type_declarations = { - "class_declaration": "class", - "interface_declaration": "interface", - "enum_declaration": "enum", - } - - if node.type in type_declarations: - name_node = node.child_by_field_name("name") - if name_node: - node_name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") - if node_name == type_name: - return node, type_declarations[node.type] - - for child in node.children: - result, kind = _find_type_node(child, type_name, source_bytes) - if result: - return result, kind + # Encode the search name once to avoid repeated UTF-8 decodes of slices. + type_name_bytes = type_name.encode("utf8") + + # Use an explicit stack for DFS to avoid recursion overhead. + stack: list[Node] = [node] + + while stack: + current = stack.pop() + if current.type in TYPE_DECLARATIONS: + name_node = current.child_by_field_name("name") + if name_node: + # Compare bytes directly to avoid decoding the slice to str. + if source_bytes[name_node.start_byte : name_node.end_byte] == type_name_bytes: + return current, TYPE_DECLARATIONS[current.type] + + # Push children in reverse order so that the leftmost child is processed first, + # preserving the original recursive traversal order. + for child in reversed(current.children): + stack.append(child) return None, ""