@@ -425,7 +425,7 @@ int pthread_attr_destroy(pthread_attr_t *a)
425425#endif
426426
427427static void
428- hardware_stack_limits (uintptr_t * top , uintptr_t * base )
428+ hardware_stack_limits (uintptr_t * base , uintptr_t * top )
429429{
430430#ifdef WIN32
431431 ULONG_PTR low , high ;
@@ -468,23 +468,86 @@ hardware_stack_limits(uintptr_t *top, uintptr_t *base)
468468#endif
469469}
470470
471- void
472- _Py_InitializeRecursionLimits (PyThreadState * tstate )
471+ static void
472+ tstate_set_stack (PyThreadState * tstate ,
473+ uintptr_t base , uintptr_t top )
473474{
474- uintptr_t top ;
475- uintptr_t base ;
476- hardware_stack_limits ( & top , & base );
475+ assert ( base < top ) ;
476+ assert (( top - base ) >= _PyOS_MIN_STACK_SIZE ) ;
477+
477478#ifdef _Py_THREAD_SANITIZER
478479 // Thread sanitizer crashes if we use more than half the stack.
479480 uintptr_t stacksize = top - base ;
480- base += stacksize / 2 ;
481+ base += stacksize / 2 ;
481482#endif
482483 _PyThreadStateImpl * _tstate = (_PyThreadStateImpl * )tstate ;
483484 _tstate -> c_stack_top = top ;
484485 _tstate -> c_stack_hard_limit = base + _PyOS_STACK_MARGIN_BYTES ;
485486 _tstate -> c_stack_soft_limit = base + _PyOS_STACK_MARGIN_BYTES * 2 ;
487+
488+ #ifndef NDEBUG
489+ // Sanity checks
490+ _PyThreadStateImpl * ts = (_PyThreadStateImpl * )tstate ;
491+ assert (ts -> c_stack_hard_limit <= ts -> c_stack_soft_limit );
492+ assert (ts -> c_stack_soft_limit < ts -> c_stack_top );
493+ #endif
494+ }
495+
496+
497+ void
498+ _Py_InitializeRecursionLimits (PyThreadState * tstate )
499+ {
500+ uintptr_t base , top ;
501+ hardware_stack_limits (& base , & top );
502+ assert (top != 0 );
503+
504+ tstate_set_stack (tstate , base , top );
505+ _PyThreadStateImpl * ts = (_PyThreadStateImpl * )tstate ;
506+ ts -> c_stack_init_base = base ;
507+ ts -> c_stack_init_top = top ;
508+
509+ // Test the stack pointer
510+ #if !defined(NDEBUG ) && !defined(__wasi__ )
511+ uintptr_t here_addr = _Py_get_machine_stack_pointer ();
512+ assert (ts -> c_stack_soft_limit < here_addr );
513+ assert (here_addr < ts -> c_stack_top );
514+ #endif
515+ }
516+
517+
518+ int
519+ PyUnstable_ThreadState_SetStackProtection (PyThreadState * tstate ,
520+ void * stack_start_addr , size_t stack_size )
521+ {
522+ if (stack_size < _PyOS_MIN_STACK_SIZE ) {
523+ PyErr_Format (PyExc_ValueError ,
524+ "stack_size must be at least %zu bytes" ,
525+ _PyOS_MIN_STACK_SIZE );
526+ return -1 ;
527+ }
528+
529+ uintptr_t base = (uintptr_t )stack_start_addr ;
530+ uintptr_t top = base + stack_size ;
531+ tstate_set_stack (tstate , base , top );
532+ return 0 ;
486533}
487534
535+
536+ void
537+ PyUnstable_ThreadState_ResetStackProtection (PyThreadState * tstate )
538+ {
539+ _PyThreadStateImpl * ts = (_PyThreadStateImpl * )tstate ;
540+ if (ts -> c_stack_init_top != 0 ) {
541+ tstate_set_stack (tstate ,
542+ ts -> c_stack_init_base ,
543+ ts -> c_stack_init_top );
544+ return ;
545+ }
546+
547+ _Py_InitializeRecursionLimits (tstate );
548+ }
549+
550+
488551/* The function _Py_EnterRecursiveCallTstate() only calls _Py_CheckRecursiveCall()
489552 if the recursion_depth reaches recursion_limit. */
490553int
0 commit comments