4141@dataclasses .dataclass
4242class _Block :
4343 label : str | None = None
44- header : list [str ] = dataclasses .field (default_factory = list )
44+ # Non-instruction lines like labels, directives, and comments:
45+ noise : list [str ] = dataclasses .field (default_factory = list )
46+ # Instruction lines:
4547 instructions : list [str ] = dataclasses .field (default_factory = list )
48+ # If this block ends in a jump, where to?
4649 target : typing .Self | None = None
50+ # The next block in the linked list:
4751 link : typing .Self | None = None
52+ # Whether control flow can fall through to the linked block above:
4853 fallthrough : bool = True
54+ # Whether this block can eventually reach the next uop (_JIT_CONTINUE):
4955 hot : bool = False
5056
5157 def resolve (self ) -> typing .Self :
@@ -68,7 +74,7 @@ class Optimizer:
6874 _root : _Block = dataclasses .field (init = False , default_factory = _Block )
6975 _labels : dict [str , _Block ] = dataclasses .field (init = False , default_factory = dict )
7076 # No groups:
71- _re_header : typing .ClassVar [re .Pattern [str ]] = re .compile (r"\s*(?:\.|#|//|$)" )
77+ _re_noise : typing .ClassVar [re .Pattern [str ]] = re .compile (r"\s*(?:\.|#|//|$)" )
7278 # One group (label):
7379 _re_label : typing .ClassVar [re .Pattern [str ]] = re .compile (
7480 r'\s*(?P<label>[\w."$?@]+):'
@@ -84,32 +90,44 @@ class Optimizer:
8490 _re_return : typing .ClassVar [re .Pattern [str ]] = _RE_NEVER_MATCH
8591
8692 def __post_init__ (self ) -> None :
93+ # Split the code into a linked list of basic blocks. A basic block is an
94+ # optional label, followed by zero or more non-instruction ("noise")
95+ # lines, followed by one or more instruction lines (only the last of
96+ # which may be a branch, jump, or return).
8797 text = self ._preprocess (self .path .read_text ())
8898 block = self ._root
8999 for line in text .splitlines ():
100+ # See if we need to start a new block:
90101 if match := self ._re_label .match (line ):
102+ # Label. New block:
91103 block .link = block = self ._lookup_label (match ["label" ])
92- block .header .append (line )
104+ block .noise .append (line )
93105 continue
94- if self ._re_header .match (line ):
106+ if self ._re_noise .match (line ):
95107 if block .instructions :
108+ # Noise lines. New block:
96109 block .link = block = _Block ()
97- block .header .append (line )
110+ block .noise .append (line )
98111 continue
99112 if block .target or not block .fallthrough :
113+ # Current block ends with a branch, jump, or return. New block:
100114 block .link = block = _Block ()
101115 block .instructions .append (line )
102116 if match := self ._re_branch .match (line ):
117+ # A block ending in a branch has a target, and fallthrough:
103118 block .target = self ._lookup_label (match ["target" ])
104119 assert block .fallthrough
105120 elif match := self ._re_jump .match (line ):
121+ # A block ending in a jump has a target, and no fallthrough:
106122 block .target = self ._lookup_label (match ["target" ])
107123 block .fallthrough = False
108124 elif self ._re_return .match (line ):
125+ # A block ending in a return has no target, and fallthrough:
109126 assert not block .target
110127 block .fallthrough = False
111128
112129 def _preprocess (self , text : str ) -> str :
130+ # Override this method to do preprocessing of the textual assembly:
113131 return text
114132
115133 @classmethod
@@ -120,13 +138,21 @@ def _invert_branch(cls, line: str, target: str) -> str | None:
120138 if not inverted :
121139 return None
122140 (a , b ), (c , d ) = match .span ("instruction" ), match .span ("target" )
141+ # Before:
142+ # je FOO
143+ # After:
144+ # jne BAR
123145 return "" .join ([line [:a ], inverted , line [b :c ], target , line [d :]])
124146
125147 @classmethod
126148 def _update_jump (cls , line : str , target : str ) -> str :
127149 match = cls ._re_jump .match (line )
128150 assert match
129151 a , b = match .span ("target" )
152+ # Before:
153+ # jmp FOO
154+ # After:
155+ # jmp BAR
130156 return "" .join ([line [:a ], target , line [b :]])
131157
132158 def _lookup_label (self , label : str ) -> _Block :
@@ -146,30 +172,41 @@ def _body(self) -> str:
146172 for block in self ._blocks ():
147173 if hot != block .hot :
148174 hot = block .hot
175+ # Make it easy to tell at a glance where cold code is:
149176 lines .append (f"# JIT: { 'HOT' if hot else 'COLD' } " .ljust (80 , "#" ))
150- lines .extend (block .header )
177+ lines .extend (block .noise )
151178 lines .extend (block .instructions )
152179 return "\n " .join (lines )
153180
154181 def _predecessors (self , block : _Block ) -> typing .Generator [_Block , None , None ]:
182+ # This is inefficient, but it's never wrong:
155183 for predecessor in self ._blocks ():
156184 if predecessor .target is block or (
157185 predecessor .fallthrough and predecessor .link is block
158186 ):
159187 yield predecessor
160188
161189 def _insert_continue_label (self ) -> None :
190+ # Find the block with the last instruction:
162191 for end in reversed (list (self ._blocks ())):
163192 if end .instructions :
164193 break
194+ # Before:
195+ # jmp FOO
196+ # After:
197+ # jmp FOO
198+ # .balign 8
199+ # _JIT_CONTINUE:
165200 align = _Block ()
166- align .header .append (f"\t .balign\t { self ._alignment } " )
201+ align .noise .append (f"\t .balign\t { self ._alignment } " )
167202 continuation = self ._lookup_label (f"{ self .prefix } _JIT_CONTINUE" )
168203 assert continuation .label
169- continuation .header .append (f"{ continuation .label } :" )
204+ continuation .noise .append (f"{ continuation .label } :" )
170205 end .link , align .link , continuation .link = align , continuation , end .link
171206
172207 def _mark_hot_blocks (self ) -> None :
208+ # Start with the last block, and perform a DFS to find all blocks that
209+ # can eventually reach it:
173210 todo = list (self ._blocks ())[- 1 :]
174211 while todo :
175212 block = todo .pop ()
@@ -181,17 +218,17 @@ def _mark_hot_blocks(self) -> None:
181218 )
182219
183220 def _invert_hot_branches (self ) -> None :
184- # Before:
185- # branch <hot>
186- # jump <cold>
187- # After:
188- # opposite-branch <cold>
189- # jump <hot>
190221 for branch in self ._blocks ():
191222 link = branch .link
192223 if link is None :
193224 continue
194225 jump = link .resolve ()
226+ # Before:
227+ # je HOT
228+ # jmp COLD
229+ # After:
230+ # jne COLD
231+ # jmp HOT
195232 if (
196233 # block ends with a branch to hot code...
197234 branch .target
@@ -209,6 +246,7 @@ def _invert_hot_branches(self) -> None:
209246 inverted = self ._invert_branch (
210247 branch .instructions [- 1 ], jump .target .label
211248 )
249+ # Check to see if the branch can even be inverted:
212250 if inverted is None :
213251 continue
214252 branch .instructions [- 1 ] = inverted
@@ -219,7 +257,14 @@ def _invert_hot_branches(self) -> None:
219257 jump .hot = True
220258
221259 def _remove_redundant_jumps (self ) -> None :
260+ # Zero-length jumps can be introduced by _insert_continue_label and
261+ # _invert_hot_branches:
222262 for block in self ._blocks ():
263+ # Before:
264+ # jmp FOO
265+ # FOO:
266+ # After:
267+ # FOO:
223268 if (
224269 block .target
225270 and block .link
0 commit comments