From f919750e51b2df3cc01aa883ea42dcf03b864a23 Mon Sep 17 00:00:00 2001 From: Marcell Perger Date: Sat, 14 Jun 2025 18:24:01 +0100 Subject: [PATCH 1/6] refactor(ast): Move ast node subclasses into ast_nodes.py --- benchmark.py | 2 +- parser/astgen/ast_node.py | 254 ++-------------------------------- parser/astgen/ast_nodes.py | 230 ++++++++++++++++++++++++++++++ parser/astgen/astgen.py | 3 +- parser/typecheck/typecheck.py | 6 +- 5 files changed, 251 insertions(+), 244 deletions(-) create mode 100644 parser/astgen/ast_nodes.py diff --git a/benchmark.py b/benchmark.py index 460c717..3882013 100644 --- a/benchmark.py +++ b/benchmark.py @@ -2,7 +2,7 @@ import contextlib import time -from parser.astgen.ast_node import AstProgramNode +from parser.astgen.ast_nodes import AstProgramNode from parser.astgen.astgen import AstGen from parser.common.tree_print import tformat from parser.cst.nodes import ProgramNode diff --git a/parser/astgen/ast_node.py b/parser/astgen/ast_node.py index 5994210..9d33785 100644 --- a/parser/astgen/ast_node.py +++ b/parser/astgen/ast_node.py @@ -7,26 +7,27 @@ from util import flatten_force from ..common import HasRegion, StrRegion -__all__ = [ - "AstNode", "AstProgramNode", "VarDeclScope", "VarDeclType", "AstDeclNode", - "AstRepeat", "AstIf", "AstWhile", "AstAssign", "AstAugAssign", "AstDefine", - "AstNumber", "AstString", "AstAnyName", "AstIdent", "AstAttrName", - "AstListLiteral", "AstAttribute", "AstItem", "AstCall", "AstOp", "AstBinOp", - "AstUnaryOp", 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType', - "FilteredWalker" -] - - -class WalkerCallType(Enum): - PRE = 'pre' - POST = 'post' +__all__ = ['AstNode', 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType', + "FilteredWalker"] +VT = TypeVar('VT') +WT = TypeVar('WT', bound='WalkableT') WalkableL0: TypeAlias = 'AstNode | list[AstNode] | tuple[AstNode, ...] | None' WalkableT: TypeAlias = 'WalkableL0 | list[WalkableL0] | tuple[WalkableL0, ...]' -WalkerFnT: TypeAlias = Callable[[WalkableT, WalkerCallType], bool | None] +WalkerFnT: TypeAlias = Callable[[WalkableT, 'WalkerCallType'], bool | None] """Returns True if skip""" +SpecificCbT = Callable[[WT], bool | None] +SpecificCbsDict = dict[type[WT] | type, list[Callable[[WT], bool | None]]] +BothCbT = Callable[[WT, 'WalkerCallType'], bool | None] +BothCbsDict = dict[type[WT] | type, list[Callable[[WT, 'WalkerCallType'], bool | None]]] + + +class WalkerCallType(Enum): + PRE = 'pre' + POST = 'post' + @dataclass class AstNode(HasRegion): @@ -76,15 +77,6 @@ def walk_multiple_objects(cls, fn: WalkerFnT, objs: Iterable[WalkableT]): walk_ast = AstNode.walk_obj -# region -WT = TypeVar('WT', bound=WalkableT) -VT = TypeVar('VT') -SpecificCbT = Callable[[WT], bool | None] -SpecificCbsDict = dict[type[WT] | type, list[Callable[[WT], bool | None]]] -BothCbT = Callable[[WT, WalkerCallType], bool | None] -BothCbsDict = dict[type[WT] | type, list[Callable[[WT, WalkerCallType], bool | None]]] - - class WalkerFilterRegistry: def __init__(self, enter_cbs: SpecificCbsDict = (), exit_cbs: SpecificCbsDict = (), @@ -190,219 +182,3 @@ def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> l """Also looks at superclasses/MRO""" return flatten_force(mapping.get(sub, []) for sub in tp.mro()) # endregion - - -@dataclass -class AstProgramNode(AstNode): - name = 'program' - statements: list[AstNode] - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.statements,)) - - -# region ---- ---- -class VarDeclScope(Enum): - LET = 'let' - GLOBAL = 'global' - - -class VarDeclType(Enum): - VARIABLE = 'variable' - LIST = 'list' - - -@dataclass -class AstDeclNode(AstNode): - name = 'var_decl' - scope: VarDeclScope - type: VarDeclType - ident: AstIdent - value: AstNode | None - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.ident, self.value)) - - -@dataclass -class AstRepeat(AstNode): - name = 'repeat' - count: AstNode - body: list[AstNode] - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.count, self.body)) - - -@dataclass -class AstIf(AstNode): - name = 'if' - cond: AstNode - if_body: list[AstNode] - # elseif = else{if - else_body: list[AstNode] | None = None - # ^ Separate cases for no block and empty block (can be else {} to easily - # add extra blocks in scratch interface) - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.cond, self.if_body, self.else_body)) - - -@dataclass -class AstWhile(AstNode): - name = 'while' - cond: AstNode - body: list[AstNode] - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.cond, self.body)) - - -@dataclass -class AstAssign(AstNode): - name = '=' - target: AstNode - source: AstNode - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.target, self.source)) - - -@dataclass -class AstAugAssign(AstNode): - op: str # maybe attach a StrRegion to the location of the op?? - target: AstNode - source: AstNode - - @property - def name(self): - return self.op - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.target, self.source)) - - -@dataclass -class AstDefine(AstNode): - name = 'def' - - ident: AstIdent - params: list[tuple[AstIdent, AstIdent]] # type, ident - body: list[AstNode] - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.ident, self.params, self.body)) -# endregion ---- ---- - - -# region ---- ---- -@dataclass -class AstNumber(AstNode): - # No real point in storing the string representation (could always StrRegion.resolve()) - value: float | int - - -@dataclass -class AstString(AstNode): - value: str # Values with escapes, etc. resolved - - -@dataclass -class AstAnyName(AstNode): - id: str - - def __post_init__(self): - if type(self) == AstAnyName: - raise TypeError("AstAnyName must not be instantiated directly.") - - -@dataclass -class AstIdent(AstAnyName): - name = 'ident' - - -@dataclass -class AstAttrName(AstAnyName): - name = 'attr' - - -@dataclass -class AstListLiteral(AstNode): - name = 'list' - items: list[AstNode] - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.items,)) - - -@dataclass -class AstAttribute(AstNode): - name = '.' - obj: AstNode - attr: AstAttrName - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.obj, self.attr)) - - -@dataclass -class AstItem(AstNode): - name = 'item' - obj: AstNode - index: AstNode - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.obj, self.index)) - - -@dataclass -class AstCall(AstNode): - name = 'call' - obj: AstNode - args: list[AstNode] - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.obj, self.args)) - - -@dataclass -class AstOp(AstNode): - op: str - - -@dataclass -class AstBinOp(AstOp): - left: AstNode - right: AstNode - - valid_ops = [*'+-*/%', '**', '..', '||', '&&', # ops - '==', '!=', '<', '>', '<=', '>=' # comparisons - ] # type: list[str] - - def __post_init__(self): - assert self.op in self.valid_ops - - @property - def name(self): - return self.op - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.left, self.right)) - - -@dataclass -class AstUnaryOp(AstOp): - operand: AstNode - - valid_ops = ('+', '-', '!') - - def __post_init__(self): - assert self.op in self.valid_ops - - @property - def name(self): - return self.op - - def _walk_members(self, fn: WalkerFnT): - self.walk_multiple_objects(fn, (self.operand,)) -# endregion ---- ---- diff --git a/parser/astgen/ast_nodes.py b/parser/astgen/ast_nodes.py new file mode 100644 index 0000000..fe2a720 --- /dev/null +++ b/parser/astgen/ast_nodes.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from .ast_node import AstNode, WalkerFnT + +__all__ = [ + "AstNode", "AstProgramNode", "VarDeclScope", "VarDeclType", "AstDeclNode", + "AstRepeat", "AstIf", "AstWhile", "AstAssign", "AstAugAssign", "AstDefine", + "AstNumber", "AstString", "AstAnyName", "AstIdent", "AstAttrName", + "AstListLiteral", "AstAttribute", "AstItem", "AstCall", "AstOp", "AstBinOp", + "AstUnaryOp", +] + + +@dataclass +class AstProgramNode(AstNode): + name = 'program' + statements: list[AstNode] + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.statements,)) + + +# region ---- ---- +class VarDeclScope(Enum): + LET = 'let' + GLOBAL = 'global' + + +class VarDeclType(Enum): + VARIABLE = 'variable' + LIST = 'list' + + +@dataclass +class AstDeclNode(AstNode): + name = 'var_decl' + scope: VarDeclScope + type: VarDeclType + ident: AstIdent + value: AstNode | None + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.ident, self.value)) + + +@dataclass +class AstRepeat(AstNode): + name = 'repeat' + count: AstNode + body: list[AstNode] + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.count, self.body)) + + +@dataclass +class AstIf(AstNode): + name = 'if' + cond: AstNode + if_body: list[AstNode] + # elseif = else{if + else_body: list[AstNode] | None = None + # ^ Separate cases for no block and empty block (can be else {} to easily + # add extra blocks in scratch interface) + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.cond, self.if_body, self.else_body)) + + +@dataclass +class AstWhile(AstNode): + name = 'while' + cond: AstNode + body: list[AstNode] + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.cond, self.body)) + + +@dataclass +class AstAssign(AstNode): + name = '=' + target: AstNode + source: AstNode + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.target, self.source)) + + +@dataclass +class AstAugAssign(AstNode): + op: str # maybe attach a StrRegion to the location of the op?? + target: AstNode + source: AstNode + + @property + def name(self): + return self.op + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.target, self.source)) + + +@dataclass +class AstDefine(AstNode): + name = 'def' + + ident: AstIdent + params: list[tuple[AstIdent, AstIdent]] # type, ident + body: list[AstNode] + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.ident, self.params, self.body)) +# endregion ---- ---- + + +# region ---- ---- +@dataclass +class AstNumber(AstNode): + # No real point in storing the string representation (could always StrRegion.resolve()) + value: float | int + + +@dataclass +class AstString(AstNode): + value: str # Values with escapes, etc. resolved + + +@dataclass +class AstAnyName(AstNode): + id: str + + def __post_init__(self): + if type(self) == AstAnyName: + raise TypeError("AstAnyName must not be instantiated directly.") + + +@dataclass +class AstIdent(AstAnyName): + name = 'ident' + + +@dataclass +class AstAttrName(AstAnyName): + name = 'attr' + + +@dataclass +class AstListLiteral(AstNode): + name = 'list' + items: list[AstNode] + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.items,)) + + +@dataclass +class AstAttribute(AstNode): + name = '.' + obj: AstNode + attr: AstAttrName + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.obj, self.attr)) + + +@dataclass +class AstItem(AstNode): + name = 'item' + obj: AstNode + index: AstNode + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.obj, self.index)) + + +@dataclass +class AstCall(AstNode): + name = 'call' + obj: AstNode + args: list[AstNode] + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.obj, self.args)) + + +@dataclass +class AstOp(AstNode): + op: str + + +@dataclass +class AstBinOp(AstOp): + left: AstNode + right: AstNode + + valid_ops = [*'+-*/%', '**', '..', '||', '&&', # ops + '==', '!=', '<', '>', '<=', '>=' # comparisons + ] # type: list[str] + + def __post_init__(self): + assert self.op in self.valid_ops + + @property + def name(self): + return self.op + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.left, self.right)) + + +@dataclass +class AstUnaryOp(AstOp): + operand: AstNode + + valid_ops = ('+', '-', '!') + + def __post_init__(self): + assert self.op in self.valid_ops + + @property + def name(self): + return self.op + + def _walk_members(self, fn: WalkerFnT): + self.walk_multiple_objects(fn, (self.operand,)) +# endregion ---- ---- diff --git a/parser/astgen/astgen.py b/parser/astgen/astgen.py index 1d2f4fe..c8be8c0 100644 --- a/parser/astgen/astgen.py +++ b/parser/astgen/astgen.py @@ -5,7 +5,7 @@ from typing import Callable, overload, TypeVar, TypeAlias from util import flatten_force, is_strict_subclass -from .ast_node import * +from .ast_nodes import * from .eval_literal import eval_number, eval_string from .errors import LocatedAstError from ..common import region_union, RegionUnionArgT, HasRegion, StrRegion @@ -68,6 +68,7 @@ def _detect_autowalk_type_from_annot(fn): bound = sig.bind(0, 1) # simulate call w/ 2 args except TypeError as e: # pragma: no cover raise TypeError("Unable to detect node_type (signature may be incompatible)") from e + # noinspection PyTypeChecker arg2_name: str = (*bound.arguments,)[1] # get name it's bound to param = sig.parameters[arg2_name] # lookup the param by name if param.kind not in (param.POSITIONAL_ONLY, diff --git a/parser/typecheck/typecheck.py b/parser/typecheck/typecheck.py index e989aa1..47323e0 100644 --- a/parser/typecheck/typecheck.py +++ b/parser/typecheck/typecheck.py @@ -3,9 +3,9 @@ from dataclasses import dataclass, field from util.recursive_eq import recursive_eq -from ..astgen.ast_node import ( - AstNode, walk_ast, AstIdent, AstDeclNode, AstDefine, VarDeclType, - VarDeclScope, FilteredWalker) +from ..astgen.ast_node import walk_ast, FilteredWalker +from ..astgen.ast_nodes import ( + AstNode, AstIdent, AstDeclNode, AstDefine, VarDeclType, VarDeclScope) from ..astgen.astgen import AstGen from ..common import BaseLocatedError, StrRegion From 4d499d8f1498399c6d3bc64b86deb2e67006e659 Mon Sep 17 00:00:00 2001 From: Marcell Perger Date: Sat, 14 Jun 2025 18:33:11 +0100 Subject: [PATCH 2/6] refactor(ast): Move FilteredWalker into own module --- parser/astgen/ast_node.py | 120 +---------------------------- parser/astgen/filtered_walker.py | 125 +++++++++++++++++++++++++++++++ parser/typecheck/typecheck.py | 6 +- 3 files changed, 129 insertions(+), 122 deletions(-) create mode 100644 parser/astgen/filtered_walker.py diff --git a/parser/astgen/ast_node.py b/parser/astgen/ast_node.py index 9d33785..4ab5ff4 100644 --- a/parser/astgen/ast_node.py +++ b/parser/astgen/ast_node.py @@ -2,27 +2,18 @@ from dataclasses import dataclass from enum import Enum -from typing import Callable, TypeAlias, Iterable, TypeVar +from typing import Callable, TypeAlias, Iterable -from util import flatten_force from ..common import HasRegion, StrRegion -__all__ = ['AstNode', 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType', - "FilteredWalker"] +__all__ = ['AstNode', 'walk_ast', 'WalkableT', 'WalkerFnT', 'WalkerCallType',] -VT = TypeVar('VT') -WT = TypeVar('WT', bound='WalkableT') WalkableL0: TypeAlias = 'AstNode | list[AstNode] | tuple[AstNode, ...] | None' WalkableT: TypeAlias = 'WalkableL0 | list[WalkableL0] | tuple[WalkableL0, ...]' WalkerFnT: TypeAlias = Callable[[WalkableT, 'WalkerCallType'], bool | None] """Returns True if skip""" -SpecificCbT = Callable[[WT], bool | None] -SpecificCbsDict = dict[type[WT] | type, list[Callable[[WT], bool | None]]] -BothCbT = Callable[[WT, 'WalkerCallType'], bool | None] -BothCbsDict = dict[type[WT] | type, list[Callable[[WT, 'WalkerCallType'], bool | None]]] - class WalkerCallType(Enum): PRE = 'pre' @@ -75,110 +66,3 @@ def walk_multiple_objects(cls, fn: WalkerFnT, objs: Iterable[WalkableT]): walk_ast = AstNode.walk_obj - - -class WalkerFilterRegistry: - def __init__(self, enter_cbs: SpecificCbsDict = (), - exit_cbs: SpecificCbsDict = (), - both_sbc: BothCbsDict = ()): - self.enter_cbs: SpecificCbsDict = dict(enter_cbs) # Copy them, - self.exit_cbs: SpecificCbsDict = dict(exit_cbs) # also converts default () -> {} - self.both_cbs: BothCbsDict = dict(both_sbc) - - def copy(self): - return WalkerFilterRegistry(self.enter_cbs, self.exit_cbs, self.both_cbs) - - def register_both(self, t: type[WT], fn: Callable[[WT, WalkerCallType], bool | None]): - self.both_cbs.setdefault(t, []).append(fn) - return self - - def register_enter(self, t: type[WT], fn: Callable[[WT], bool | None]): - self.enter_cbs.setdefault(t, []).append(fn) - return self - - def register_exit(self, t: type[WT], fn: Callable[[WT], bool | None]): - self.exit_cbs.setdefault(t, []).append(fn) - return self - - def __call__(self, *args, **kwargs): - return self - - def on_enter(self, *tps: type[WT] | type): - """Decorator version of register_enter.""" - def decor(fn: SpecificCbT): - for t in tps: - self.register_enter(t, fn) - return fn - return decor - - def on_exit(self, *tps: type[WT] | type): - """Decorator version of register_exit.""" - def decor(fn: SpecificCbT): - for t in tps: - self.register_exit(t, fn) - return fn - return decor - - def on_both(self, *tps: type[WT] | type): - """Decorator version of register_both.""" - def decor(fn: BothCbT): - for t in tps: - self.register_both(t, fn) - return fn - return decor - - -class FilteredWalker(WalkerFilterRegistry): - def __init__(self): - cls_reg = self.class_registry() - super().__init__(cls_reg.enter_cbs, cls_reg.exit_cbs, cls_reg.both_cbs) - - @classmethod - def class_registry(cls) -> WalkerFilterRegistry: - return WalkerFilterRegistry() - - @classmethod - def create_cls_registry(cls, fn=None): - """Create a class-level registry that can be added to using decorators. - - This can be used in two ways (at the top of your class):: - - # MUST be this name - class_registry = FilteredWalker.create_cls_registry() - - or:: - - @classmethod - @FilteredWalker.create_cls_registry - def class_registry(cls): # MUST be this name - pass - - and when registering methods:: - - @class_registry.on_enter(AstDefine) - def enter_define(self, ...): - ... - - The restrictions on name are because we have no other way of detecting - it (without metaclass dark magic) as we can't refer to the class while - its namespace is being evaluated - """ - if fn is not None and (parent := fn(cls)) is not None: - return WalkerFilterRegistry.copy(parent) - return WalkerFilterRegistry() - - def __call__(self, o: WalkableT, call_type: WalkerCallType): - result = None - # Call more specific ones first - specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs - for fn in self._get_funcs(specific_cbs, type(o)): - result = fn(o) or result - for fn in self._get_funcs(self.both_cbs, type(o)): - result = fn(o, call_type) or result - return result - - @classmethod - def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]: - """Also looks at superclasses/MRO""" - return flatten_force(mapping.get(sub, []) for sub in tp.mro()) -# endregion diff --git a/parser/astgen/filtered_walker.py b/parser/astgen/filtered_walker.py new file mode 100644 index 0000000..e76803f --- /dev/null +++ b/parser/astgen/filtered_walker.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import Callable, TypeVar + +from util import flatten_force +from .ast_node import WalkerCallType, WalkableT, walk_ast + +__all__ = ['WalkerFilterRegistry', 'FilteredWalker', 'walk_ast'] + +VT = TypeVar('VT') +WT = TypeVar('WT', bound=WalkableT) + +SpecificCbT = Callable[[WT], bool | None] +SpecificCbsDict = dict[type[WT] | type, list[SpecificCbT[WT]]] +BothCbT = Callable[[WT, 'WalkerCallType'], bool | None] +BothCbsDict = dict[type[WT] | type, list[BothCbT[WT]]] + + +class WalkerFilterRegistry: + def __init__(self, enter_cbs: SpecificCbsDict = (), + exit_cbs: SpecificCbsDict = (), + both_sbc: BothCbsDict = ()): + self.enter_cbs: SpecificCbsDict = dict(enter_cbs) # Copy them, + self.exit_cbs: SpecificCbsDict = dict(exit_cbs) # also converts default () -> {} + self.both_cbs: BothCbsDict = dict(both_sbc) + + def copy(self): + return WalkerFilterRegistry(self.enter_cbs, self.exit_cbs, self.both_cbs) + + def register_both(self, t: type[WT], fn: Callable[[WT, WalkerCallType], bool | None]): + self.both_cbs.setdefault(t, []).append(fn) + return self + + def register_enter(self, t: type[WT], fn: Callable[[WT], bool | None]): + self.enter_cbs.setdefault(t, []).append(fn) + return self + + def register_exit(self, t: type[WT], fn: Callable[[WT], bool | None]): + self.exit_cbs.setdefault(t, []).append(fn) + return self + + def __call__(self, *args, **kwargs): + return self + + def on_enter(self, *tps: type[WT] | type): + """Decorator version of register_enter.""" + def decor(fn: SpecificCbT): + for t in tps: + self.register_enter(t, fn) + return fn + return decor + + def on_exit(self, *tps: type[WT] | type): + """Decorator version of register_exit.""" + def decor(fn: SpecificCbT): + for t in tps: + self.register_exit(t, fn) + return fn + return decor + + def on_both(self, *tps: type[WT] | type): + """Decorator version of register_both.""" + def decor(fn: BothCbT): + for t in tps: + self.register_both(t, fn) + return fn + return decor + + +class FilteredWalker(WalkerFilterRegistry): + def __init__(self): + cls_reg = self.class_registry() + super().__init__(cls_reg.enter_cbs, cls_reg.exit_cbs, cls_reg.both_cbs) + + @classmethod + def class_registry(cls) -> WalkerFilterRegistry: + return WalkerFilterRegistry() + + @classmethod + def create_cls_registry(cls, fn=None): + """Create a class-level registry that can be added to using decorators. + + This can be used in two ways (at the top of your class):: + + # MUST be this name + class_registry = FilteredWalker.create_cls_registry() + + or:: + + @classmethod + @FilteredWalker.create_cls_registry + def class_registry(cls): # MUST be this name + pass + + and when registering methods:: + + @class_registry.on_enter(AstDefine) + def enter_define(self, ...): + ... + + The restrictions on name are because we have no other way of detecting + it (without metaclass dark magic) as we can't refer to the class while + its namespace is being evaluated + """ + if fn is not None and (parent := fn(cls)) is not None: + return WalkerFilterRegistry.copy(parent) + return WalkerFilterRegistry() + + def walk(self, o: WalkableT): + return walk_ast(o, self) + + def __call__(self, o: WalkableT, call_type: WalkerCallType): + result = None + # Call more specific ones first + specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs + for fn in self._get_funcs(specific_cbs, type(o)): + result = fn(o) or result + for fn in self._get_funcs(self.both_cbs, type(o)): + result = fn(o, call_type) or result + return result + + @classmethod + def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]: + """Also looks at superclasses/MRO""" + return flatten_force(mapping.get(sub, []) for sub in tp.mro()) diff --git a/parser/typecheck/typecheck.py b/parser/typecheck/typecheck.py index 47323e0..5181013 100644 --- a/parser/typecheck/typecheck.py +++ b/parser/typecheck/typecheck.py @@ -3,10 +3,10 @@ from dataclasses import dataclass, field from util.recursive_eq import recursive_eq -from ..astgen.ast_node import walk_ast, FilteredWalker from ..astgen.ast_nodes import ( AstNode, AstIdent, AstDeclNode, AstDefine, VarDeclType, VarDeclScope) from ..astgen.astgen import AstGen +from ..astgen.filtered_walker import FilteredWalker from ..common import BaseLocatedError, StrRegion @@ -182,7 +182,7 @@ def enter_fn_decl(fn: AstDefine): .register_enter(AstIdent, enter_ident) .register_enter(AstDeclNode, enter_decl) .register_enter(AstDefine, enter_fn_decl)) - walk_ast(block, walker) + walker.walk(block) # Walk sub-functions for fn_info, fn_decl in inner_funcs: fn_info.subscope = self.run_on_new_scope( @@ -216,5 +216,3 @@ def _typecheck(self): self.ast.walk(walker) ... - - From 1af40213f936715ef2286588f7c8f0b9f2cdca05 Mon Sep 17 00:00:00 2001 From: Marcell Perger Date: Sat, 14 Jun 2025 18:37:07 +0100 Subject: [PATCH 3/6] feat(ast-walker): Don't call more general walkers if skipped by a specific one --- parser/astgen/filtered_walker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/parser/astgen/filtered_walker.py b/parser/astgen/filtered_walker.py index e76803f..59075cc 100644 --- a/parser/astgen/filtered_walker.py +++ b/parser/astgen/filtered_walker.py @@ -114,9 +114,11 @@ def __call__(self, o: WalkableT, call_type: WalkerCallType): # Call more specific ones first specific_cbs = self.enter_cbs if call_type == WalkerCallType.PRE else self.exit_cbs for fn in self._get_funcs(specific_cbs, type(o)): - result = fn(o) or result + if result := result or fn(o): + return result # Don't call later ones if already skipped for fn in self._get_funcs(self.both_cbs, type(o)): - result = fn(o, call_type) or result + if result := result or fn(o, call_type): + return result return result @classmethod From 664fe3ebdc3482cbc3ace7b8d92640bd3a85e7f5 Mon Sep 17 00:00:00 2001 From: Marcell Perger Date: Sat, 14 Jun 2025 18:51:34 +0100 Subject: [PATCH 4/6] perf(filtered-walker): Replace generator with listcomp, use __mro__ instead of mro() (Improvement of ~15% in name resolution) --- parser/astgen/filtered_walker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/parser/astgen/filtered_walker.py b/parser/astgen/filtered_walker.py index 59075cc..ce6c2ef 100644 --- a/parser/astgen/filtered_walker.py +++ b/parser/astgen/filtered_walker.py @@ -124,4 +124,8 @@ def __call__(self, o: WalkableT, call_type: WalkerCallType): @classmethod def _get_funcs(cls, mapping: dict[type[WT] | type, list[VT]], tp: type[WT]) -> list[VT]: """Also looks at superclasses/MRO""" - return flatten_force(mapping.get(sub, []) for sub in tp.mro()) + return flatten_force([mapping.get(sub, []) for sub in _get_mro(tp)]) + + +def _get_mro(tp: type) -> tuple[type, ...]: # tp.__mro__ but with proper types + return tp.__mro__ # .mro() recalculates it every time, hence is slow From fff6268bc8d392a4f4fd0607edbea0a09133ddda Mon Sep 17 00:00:00 2001 From: Marcell Perger Date: Sat, 14 Jun 2025 23:46:33 +0100 Subject: [PATCH 5/6] feat(typecheck): Start typechecking (add statements this commit) --- parser/astgen/astgen.py | 1 + parser/typecheck/typecheck.py | 134 +++++++++++++++++++++++++++++++--- 2 files changed, 125 insertions(+), 10 deletions(-) diff --git a/parser/astgen/astgen.py b/parser/astgen/astgen.py index c8be8c0..be4097b 100644 --- a/parser/astgen/astgen.py +++ b/parser/astgen/astgen.py @@ -112,6 +112,7 @@ def _walk_smt(self, smt: AnyNode) -> list[AstNode]: elif isinstance(smt, ConditionalBlock): return self._walk_conditional(smt) elif isinstance(smt, AssignNode): # Simple assignment + # TODO: maybe separate SetAttr, SetItem, SetVar nodes? return [AstAssign(smt.region, self._walk_assign_left(smt.target), self._walk_expr(smt.source))] elif isinstance(smt, AssignOpNode): # Other (aug.) assignment diff --git a/parser/typecheck/typecheck.py b/parser/typecheck/typecheck.py index 5181013..655cd94 100644 --- a/parser/typecheck/typecheck.py +++ b/parser/typecheck/typecheck.py @@ -1,13 +1,14 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field +from typing import TypeAlias from util.recursive_eq import recursive_eq -from ..astgen.ast_nodes import ( - AstNode, AstIdent, AstDeclNode, AstDefine, VarDeclType, VarDeclScope) +from ..astgen.ast_nodes import * from ..astgen.astgen import AstGen from ..astgen.filtered_walker import FilteredWalker -from ..common import BaseLocatedError, StrRegion +from ..common import BaseLocatedError, StrRegion, region_union, RegionUnionArgT @dataclass @@ -193,7 +194,18 @@ def err(self, msg: str, region: StrRegion): return NameResolutionError(msg, region, self.src) +class TypecheckError(BaseLocatedError): + """Errors raised by the typechecker""" + + +NodeTypecheckFn: TypeAlias = 'Callable[[Typechecker, AstNode], TypeInfo | None]' + +_typecheck_dispatch: dict[type[AstNode], NodeTypecheckFn] = {} + + class Typechecker: + _curr_scope: Scope + def __init__(self, name_resolver: NameResolver): self.resolver = name_resolver self.src = self.resolver.src @@ -203,16 +215,118 @@ def _init(self): self.resolver.run() self.ast = self.resolver.ast self.top_scope = self.resolver.top_scope + self._curr_scope = self.top_scope def run(self): if self.is_ok is None: return self.is_ok - self._typecheck() - self.is_ok = True + self._init() + self._typecheck(self.ast) + self.is_ok = True # didn't raise any errors return self.is_ok - def _typecheck(self): - walker = FilteredWalker() - - self.ast.walk(walker) - ... + def _node_typechecker(self, tp=None): + if tp is None: + assert callable(self) + tp = self # Called as decor in this class + + def decor(fn: NodeTypecheckFn): + _typecheck_dispatch[tp] = fn + return fn + return decor + + def _typecheck(self, n: AstNode): + try: + fn = _typecheck_dispatch[type(n)] + except KeyError: + fn = type(self)._typecheck_node_fallback + return fn(self, n) + + def _typecheck_node_fallback(self, n: AstNode): + raise NotImplementedError(f"No typechecker function for node " + f"type {type(n).__name__}") + + @_node_typechecker(AstProgramNode) + def _typecheck_program(self, n: AstProgramNode): + self._typecheck_block(n.statements) + + def _typecheck_block(self, block: list[AstNode]): + for smt in block: + if (tp := self._typecheck(smt)) is not None: + self.expect_type(tp, VoidType(), smt) + + @_node_typechecker(AstDeclNode) + def _typecheck_decl(self, n: AstDeclNode): + if not n.value: # Nothing to check + return + expect = self._resolve_scope(n.scope).declared[n.ident.id].tp_info + self.expect_type(self._typecheck(n.value), expect, n) + + @_node_typechecker(AstRepeat) + def _typecheck_repeat(self, n: AstRepeat): + # For now, we don't differentiate between number/string (as sc doesn't) + self.expect_type(self._typecheck(n.count), ValType(), n.count) + self._typecheck_block(n.body) + + @_node_typechecker(AstIf) + def _typecheck_if(self, n: AstIf): + self.expect_type(self._typecheck(n.cond), BoolType(), n.cond) + self._typecheck_block(n.if_body) + if n.else_body is not None: + self._typecheck_block(n.else_body) + + @_node_typechecker(AstWhile) + def _typecheck_while(self, n: AstWhile): + self.expect_type(self._typecheck(n.cond), BoolType(), n.cond) + self._typecheck_block(n.body) + + @_node_typechecker(AstAssign) + def _typecheck_assign(self, n: AstAssign): # super tempted to call this _typecheck_ass + if isinstance(n.target, AstIdent): + target_tp = self._curr_scope.used[n.target.id].tp_info + elif isinstance(n.target, AstItem): # ls[i] = v + target_tp = self._typecheck(n.target) # Also checks that `ls` is a list + elif isinstance(n.target, AstAttribute): + raise self.err("Setting attributes is currently unsupported", n.target) + else: + assert 0, "Unknown simple-assignment type" + if target_tp == ListType(): + raise self.err("Cannot assign directly to list", n) + self.expect_type(self._typecheck(n.source), target_tp, n) + + @_node_typechecker(AstAugAssign) + def _typecheck_aug_assign(self, n: AstAugAssign): + # TODO: change this when desugaring is implemented + # (for now only +=, only on variables) + if n.op != '+=': + raise self.err(f"The '{n.op}' operator is not implemented", n) + if not isinstance(n.target, AstIdent): + raise self.err(f"The '+=' operator is only implemented for variables", n) + target_tp = self._curr_scope.used[n.target.id].tp_info + if target_tp != ValType(): + raise self.err(f"Cannot apply += to {target_tp}", n) + self.expect_type(self._typecheck(n.source), ValType(), n.source) + + @_node_typechecker(AstDefine) + def _typecheck_define(self, n: AstDefine): + # Don't really need to check much here - type is generated from the + # syntax so must be correct. Set _curr_scope and check body + func_info = self._curr_scope.declared[n.ident.id] + assert isinstance(func_info, FuncInfo) + old_scope = self._curr_scope + self._curr_scope = func_info.subscope + try: + self._typecheck_block(n.body) + finally: + self._curr_scope = old_scope + + def _resolve_scope(self, scope_tp: VarDeclScope): + return self.top_scope if scope_tp == VarDeclScope.GLOBAL else self._curr_scope + + def err(self, msg: str, loc: RegionUnionArgT): + return TypecheckError(msg, region_union(loc), self.src) + + def expect_type(self, actual: TypeInfo, exp: TypeInfo, loc: RegionUnionArgT): + if exp != actual: + # TODO: maybe better type formatting + raise self.err(f"Expected type {exp}, got type {actual}", loc) From 5e37ec0ba5d2bfdc2af1e4607ba5886215c195f1 Mon Sep 17 00:00:00 2001 From: Marcell Perger Date: Sun, 15 Jun 2025 21:58:09 +0100 Subject: [PATCH 6/6] feat(typecheck): Finish typechecker (add expressions) --- parser/typecheck/typecheck.py | 78 +++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/parser/typecheck/typecheck.py b/parser/typecheck/typecheck.py index 655cd94..beccb5a 100644 --- a/parser/typecheck/typecheck.py +++ b/parser/typecheck/typecheck.py @@ -203,6 +203,7 @@ class TypecheckError(BaseLocatedError): _typecheck_dispatch: dict[type[AstNode], NodeTypecheckFn] = {} +# TODO: output typed AST class Typechecker: _curr_scope: Scope @@ -320,6 +321,83 @@ def _typecheck_define(self, n: AstDefine): finally: self._curr_scope = old_scope + @_node_typechecker(AstNumber) + def _typecheck_number(self, _n: AstNumber): + return ValType() + + @_node_typechecker(AstString) + def _typecheck_string(self, _n: AstString): + return ValType() + + @_node_typechecker(AstListLiteral) + def _typecheck_list(self, n: AstListLiteral): + for item in n.items: + if self._typecheck(item) != ValType(): + raise self.err("Can only have ValType()s in list", item) + return ListType() + + @_node_typechecker(AstIdent) + def _typecheck_ident(self, n: AstIdent): + return self._curr_scope.used[n.id].tp_info + + @_node_typechecker(AstAttrName) + def _typecheck_attr_name(self, _n: AstAttrName): + assert 0, "AstAttrName has no type, cannot be checked on its own" + + @_node_typechecker(AstAttribute) + def _typecheck_attribute(self, n: AstAttribute): + # TODO: implement this properly, with better types and stuff + raise self.err("Attributes are not implemented yet", n) + + @_node_typechecker(AstItem) + def _typecheck_item(self, n: AstItem): + # TODO: this will require different intrinsics for string vs list getitem + container_tp = self._typecheck(n.obj) + if container_tp not in (ListType(), ValType()): + raise self.err(f"Cannot get item of {container_tp}", n) + self.expect_type(self._typecheck(n.index), ValType(), n.index) + + @_node_typechecker(AstCall) + def _typecheck_call(self, n: AstCall): + called_tp = self._typecheck(n.obj) + if not isinstance(called_tp, FunctionType): + raise self.err(f"Cannot call {called_tp}", n.obj) + if len(called_tp.arg_types) != len(n.args): + if n.args and len(n.args) > len(called_tp.arg_types): + region = n.args[-1].region # Highlight extraneous arg + else: + region = n.region + raise self.err(f"Incorrect number of arguments, expected " + f"{len(called_tp.arg_types)}, got {len(n.args)}", + region) + for decl_t, arg_node in zip(called_tp.arg_types, n.args): + self.expect_type(self._typecheck(arg_node), decl_t, arg_node) + return called_tp.ret_type + + _BINARY_OP_TYPES = dict.fromkeys([ + *'+-*/%', '**', '..', '==', '!=', '<', '>', '<=', '>=' + ], ValType()) | dict.fromkeys([ + '&&', '||' + ], BoolType()) + + _UNARY_OP_TYPES = dict.fromkeys([ + *'+-' + ], ValType()) | dict.fromkeys([ + '!' + ], BoolType()) + + # TODO: allow casting bool to val? - auto-cast or explicit? + @_node_typechecker(AstBinOp) + def _typecheck_bin_op(self, n: AstBinOp): + expect_tp = self._BINARY_OP_TYPES[n.op] + self.expect_type(self._typecheck(n.left), expect_tp, n.left) + self.expect_type(self._typecheck(n.right), expect_tp, n.right) + + @_node_typechecker(AstUnaryOp) + def _typecheck_unary_op(self, n: AstUnaryOp): + expect_tp = self._UNARY_OP_TYPES[n.op] + self.expect_type(self._typecheck(n.operand), expect_tp, n.operand) + def _resolve_scope(self, scope_tp: VarDeclScope): return self.top_scope if scope_tp == VarDeclScope.GLOBAL else self._curr_scope