@@ -188,6 +188,7 @@ class ComponentInstance:
188188 parent : Optional [ComponentInstance ]
189189 handles : Table [ResourceHandle | Waitable | WaitableSet | ErrorContext ]
190190 threads : Table [Thread ]
191+ may_enter : bool
191192 may_leave : bool
192193 backpressure : int
193194 exclusive : Optional [Task ]
@@ -199,6 +200,7 @@ def __init__(self, store, parent = None):
199200 self .parent = parent
200201 self .handles = Table ()
201202 self .threads = Table ()
203+ self .may_enter = True
202204 self .may_leave = True
203205 self .backpressure = 0
204206 self .exclusive = None
@@ -212,27 +214,19 @@ def reflexive_ancestors(self) -> set[ComponentInstance]:
212214 inst = inst .parent
213215 return s
214216
215- def is_reflexive_ancestor_of (self , other ):
216- while other is not None :
217- if self is other :
218- return True
219- other = other .parent
220- return False
221-
222- class Supertask :
223- inst : Optional [ComponentInstance ]
224- supertask : Optional [Supertask ]
225-
226- def call_might_be_recursive (caller : Supertask , callee_inst : ComponentInstance ):
227- if caller .inst is None :
228- while caller is not None :
229- if caller .inst and caller .inst .reflexive_ancestors () & callee_inst .reflexive_ancestors ():
230- return True
231- caller = caller .supertask
232- return False
233- else :
234- return (caller .inst .is_reflexive_ancestor_of (callee_inst ) or
235- callee_inst .is_reflexive_ancestor_of (caller .inst ))
217+ def flip_may_enter_to (self , new_value : bool , caller : Optional [ComponentInstance ]):
218+ inst = self
219+ if caller is None :
220+ while inst is not None :
221+ assert (inst .may_enter != new_value )
222+ inst .may_enter = new_value
223+ inst = inst .parent
224+ else :
225+ already_entered = caller .reflexive_ancestors ()
226+ while inst is not None and inst not in already_entered :
227+ trap_if (inst .may_enter == new_value )
228+ inst .may_enter = new_value
229+ inst = inst .parent
236230
237231## Concurrency Primitives
238232
@@ -411,9 +405,9 @@ def yield_to(self, cancellable, other: Thread) -> Cancelled:
411405OnStart = Callable [[], list [any ]]
412406OnResolve = Callable [[Optional [list [any ]]], None ]
413407OnCancel = Callable [[], None ]
414- FuncInst = Callable [[Supertask , OnStart , OnResolve ], OnCancel ]
408+ FuncInst = Callable [[OnStart , OnResolve , Optional [ ComponentInstance ] ], OnCancel ]
415409
416- class Task ( Supertask ) :
410+ class Task :
417411 class State (Enum ):
418412 INITIAL = 1
419413 STARTED = 2
@@ -425,19 +419,17 @@ class State(Enum):
425419 opts : CanonicalOptions
426420 inst : ComponentInstance
427421 ft : FuncType
428- supertask : Supertask
429422 on_start : OnStart
430423 on_resolve : OnResolve
431424 num_borrows : int
432425 waiting_to_enter : Optional [Thread ]
433426 threads : list [Thread ]
434427
435- def __init__ (self , opts , inst , ft , supertask , on_start , on_resolve ):
428+ def __init__ (self , opts , inst , ft , on_start , on_resolve ):
436429 self .state = Task .State .INITIAL
437430 self .opts = opts
438431 self .inst = inst
439432 self .ft = ft
440- self .supertask = supertask
441433 self .on_start = on_start
442434 self .on_resolve = on_resolve
443435 self .num_borrows = 0
@@ -536,29 +528,34 @@ def cancel(self):
536528
537529class Store :
538530 waiting : list [Thread ]
531+ nesting_depth : int
539532
540533 def __init__ (self ):
541534 self .waiting = []
535+ self .nesting_depth = 0
542536
543537 def lift (self , callee , opts : CanonicalOptions , ft : FuncType , inst : ComponentInstance ) -> FuncInst :
544- def func_inst (caller : Supertask , on_start , on_resolve ) -> OnCancel :
545- trap_if (call_might_be_recursive (caller , inst ))
546- task = Task (opts , inst , ft , caller , on_start , on_resolve )
538+ def func_inst (on_start , on_resolve , caller ) -> OnCancel :
539+ self .nesting_depth += 1
540+ inst .flip_may_enter_to (False , caller )
541+ task = Task (opts , inst , ft , on_start , on_resolve )
547542 Thread (task , lambda : canon_lift (callee )).resume ()
543+ inst .flip_may_enter_to (True , caller )
544+ self .nesting_depth -= 1
548545 return task .request_cancellation
549546 return func_inst
550547
551- def invoke (self , f : FuncInst , caller : Optional [Supertask ], on_start , on_resolve ) -> OnCancel :
552- host_caller = Supertask ()
553- host_caller .inst = None
554- host_caller .supertask = caller
555- return f (host_caller , on_start , on_resolve )
548+ def invoke (self , f : FuncInst , on_start , on_resolve ) -> OnCancel :
549+ return f (on_start , on_resolve , caller = None )
556550
557551 def tick (self ):
552+ assert (self .nesting_depth == 0 )
558553 random .shuffle (self .waiting )
559554 for thread in self .waiting :
560555 if thread .ready ():
556+ thread .task .inst .flip_may_enter_to (False , caller = None )
561557 thread .resume ()
558+ thread .task .inst .flip_may_enter_to (True , caller = None )
562559 return
563560
564561## Lifting and Lowering Context
@@ -2186,7 +2183,7 @@ def on_resolve(result):
21862183 nonlocal flat_results
21872184 flat_results = lower_flat_values (cx , max_flat_results , result , ft .result_type (), flat_args )
21882185
2189- subtask .on_cancel = callee (task , on_start , on_resolve )
2186+ subtask .on_cancel = callee (on_start , on_resolve , caller = inst )
21902187 assert (ft .async_ or subtask .state == Subtask .State .RETURNED )
21912188
21922189 if not opts .async_ :
0 commit comments