@@ -460,6 +460,104 @@ fn replace_local<'tcx>(
460460 new_local
461461}
462462
463+ /// Transforms the `body` of the generator applying the following transforms:
464+ ///
465+ /// - Eliminates all the `get_context` calls that async lowering created.
466+ /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
467+ ///
468+ /// The `Local`s that have their types replaced are:
469+ /// - The `resume` argument itself.
470+ /// - The argument to `get_context`.
471+ /// - The yielded value of a `yield`.
472+ ///
473+ /// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
474+ /// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
475+ ///
476+ /// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
477+ /// but rather directly use `&mut Context<'_>`, however that would currently
478+ /// lead to higher-kinded lifetime errors.
479+ /// See <https://github.com/rust-lang/rust/issues/105501>.
480+ ///
481+ /// The async lowering step and the type / lifetime inference / checking are
482+ /// still using the `ResumeTy` indirection for the time being, and that indirection
483+ /// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
484+ fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
485+ let context_mut_ref = tcx.mk_task_context();
486+
487+ // replace the type of the `resume` argument
488+ replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);
489+
490+ let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
491+
492+ for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
493+ let bb_data = &body[bb];
494+ if bb_data.is_cleanup {
495+ continue;
496+ }
497+
498+ match &bb_data.terminator().kind {
499+ TerminatorKind::Call { func, .. } => {
500+ let func_ty = func.ty(body, tcx);
501+ if let ty::FnDef(def_id, _) = *func_ty.kind() {
502+ if def_id == get_context_def_id {
503+ let local = eliminate_get_context_call(&mut body[bb]);
504+ replace_resume_ty_local(tcx, body, local, context_mut_ref);
505+ }
506+ } else {
507+ continue;
508+ }
509+ }
510+ TerminatorKind::Yield { resume_arg, .. } => {
511+ replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
512+ }
513+ _ => {}
514+ }
515+ }
516+ }
517+
518+ fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
519+ let terminator = bb_data.terminator.take().unwrap();
520+ if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
521+ let arg = args.pop().unwrap();
522+ let local = arg.place().unwrap().local;
523+
524+ let arg = Rvalue::Use(arg);
525+ let assign = Statement {
526+ source_info: terminator.source_info,
527+ kind: StatementKind::Assign(Box::new((destination, arg))),
528+ };
529+ bb_data.statements.push(assign);
530+ bb_data.terminator = Some(Terminator {
531+ source_info: terminator.source_info,
532+ kind: TerminatorKind::Goto { target: target.unwrap() },
533+ });
534+ local
535+ } else {
536+ bug!();
537+ }
538+ }
539+
540+ #[cfg_attr(not(debug_assertions), allow(unused))]
541+ fn replace_resume_ty_local<'tcx>(
542+ tcx: TyCtxt<'tcx>,
543+ body: &mut Body<'tcx>,
544+ local: Local,
545+ context_mut_ref: Ty<'tcx>,
546+ ) {
547+ let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
548+ // We have to replace the `ResumeTy` that is used for type and borrow checking
549+ // with `&mut Context<'_>` in MIR.
550+ #[cfg(debug_assertions)]
551+ {
552+ if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
553+ let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
554+ assert_eq!(*resume_ty_adt, expected_adt);
555+ } else {
556+ panic!("expected `ResumeTy`, found `{:?}`", local_ty);
557+ };
558+ }
559+ }
560+
463561struct LivenessInfo {
464562 /// Which locals are live across any suspension point.
465563 saved_locals: GeneratorSavedLocals,
@@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
12831381 }
12841382 };
12851383
1286- let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen ;
1384+ let is_async_kind = matches!( body.generator_kind(), Some( GeneratorKind::Async(_))) ;
12871385 let (state_adt_ref, state_substs) = if is_async_kind {
12881386 // Compute Poll<return_ty>
1289- let state_did = tcx.require_lang_item(LangItem::Poll, None);
1290- let state_adt_ref = tcx.adt_def(state_did );
1291- let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
1292- (state_adt_ref, state_substs )
1387+ let poll_did = tcx.require_lang_item(LangItem::Poll, None);
1388+ let poll_adt_ref = tcx.adt_def(poll_did );
1389+ let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
1390+ (poll_adt_ref, poll_substs )
12931391 } else {
12941392 // Compute GeneratorState<yield_ty, return_ty>
12951393 let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
@@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
13031401 // RETURN_PLACE then is a fresh unused local with type ret_ty.
13041402 let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
13051403
1404+ // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1405+ if is_async_kind {
1406+ transform_async_context(tcx, body);
1407+ }
1408+
13061409 // We also replace the resume argument and insert an `Assign`.
13071410 // This is needed because the resume argument `_2` might be live across a `yield`, in which
13081411 // case there is no `Assign` to it that the transform can turn into a store to the generator
13091412 // state. After the yield the slot in the generator state would then be uninitialized.
13101413 let resume_local = Local::new(2);
1311- let new_resume_local =
1312- replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);
1414+ let resume_ty =
1415+ if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
1416+ let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
13131417
13141418 // When first entering the generator, move the resume argument into its new local.
13151419 let source_info = SourceInfo::outermost(body.span);
0 commit comments