Skip to content

Commit 79966ae

Browse files
committed
Auto merge of #147493 - cjgillot:single-pin, r=oli-obk
StateTransform: Only load pin field once. The current implementation starts by transforming all instances of `_1` into `(*_1)`, and then traverses the body again to transform `(*_1)` into `(*(_1.0))`, and again for `Derefer`. This PR changes the implementation to only traverse the body once. As `_1.0` cannot be not modified inside the body (we just changed its type!), we have no risk of loading from the wrong pointer.
2 parents 04ff05c + ed85b96 commit 79966ae

12 files changed

+203
-270
lines changed

compiler/rustc_mir_transform/src/coroutine.rs

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ struct SelfArgVisitor<'tcx> {
132132
}
133133

134134
impl<'tcx> SelfArgVisitor<'tcx> {
135-
fn new(tcx: TyCtxt<'tcx>, elem: ProjectionElem<Local, Ty<'tcx>>) -> Self {
136-
Self { tcx, new_base: Place { local: SELF_ARG, projection: tcx.mk_place_elems(&[elem]) } }
135+
fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
136+
Self { tcx, new_base }
137137
}
138138
}
139139

@@ -146,16 +146,14 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
146146
assert_ne!(*local, SELF_ARG);
147147
}
148148

149-
fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
149+
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
150150
if place.local == SELF_ARG {
151151
replace_base(place, self.new_base, self.tcx);
152-
} else {
153-
self.visit_local(&mut place.local, context, location);
152+
}
154153

155-
for elem in place.projection.iter() {
156-
if let PlaceElem::Index(local) = elem {
157-
assert_ne!(local, SELF_ARG);
158-
}
154+
for elem in place.projection.iter() {
155+
if let PlaceElem::Index(local) = elem {
156+
assert_ne!(local, SELF_ARG);
159157
}
160158
}
161159
}
@@ -515,32 +513,56 @@ fn make_aggregate_adt<'tcx>(
515513

516514
#[tracing::instrument(level = "trace", skip(tcx, body))]
517515
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
518-
let coroutine_ty = body.local_decls.raw[1].ty;
516+
let coroutine_ty = body.local_decls[SELF_ARG].ty;
519517

520518
let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
521519

522520
// Replace the by value coroutine argument
523-
body.local_decls.raw[1].ty = ref_coroutine_ty;
521+
body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
524522

525523
// Add a deref to accesses of the coroutine state
526-
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
524+
SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
527525
}
528526

529527
#[tracing::instrument(level = "trace", skip(tcx, body))]
530528
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
531-
let ref_coroutine_ty = body.local_decls.raw[1].ty;
529+
let coroutine_ty = body.local_decls[SELF_ARG].ty;
530+
531+
let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
532532

533533
let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
534534
let pin_adt_ref = tcx.adt_def(pin_did);
535535
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
536536
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
537537

538538
// Replace the by ref coroutine argument
539-
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;
539+
body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
540+
541+
let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
540542

541543
// Add the Pin field access to accesses of the coroutine state
542-
SelfArgVisitor::new(tcx, ProjectionElem::Field(FieldIdx::ZERO, ref_coroutine_ty))
543-
.visit_body(body);
544+
SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
545+
546+
let source_info = SourceInfo::outermost(body.span);
547+
let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);
548+
549+
let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
550+
// Miri requires retags to be the very first thing in the body.
551+
// We insert this assignment just after.
552+
let insert_point = statements
553+
.iter()
554+
.position(|stmt| !matches!(stmt.kind, StatementKind::Retag(..)))
555+
.unwrap_or(statements.len());
556+
statements.insert(
557+
insert_point,
558+
Statement::new(
559+
source_info,
560+
StatementKind::Assign(Box::new((
561+
unpinned_local.into(),
562+
Rvalue::Use(Operand::Copy(pin_field)),
563+
))),
564+
),
565+
);
544566
}
545567

546568
/// Transforms the `body` of the coroutine applying the following transforms:
@@ -1292,8 +1314,6 @@ fn create_coroutine_resume_function<'tcx>(
12921314
let default_block = insert_term_block(body, TerminatorKind::Unreachable);
12931315
insert_switch(body, cases, &transform, default_block);
12941316

1295-
make_coroutine_state_argument_indirect(tcx, body);
1296-
12971317
match transform.coroutine_kind {
12981318
CoroutineKind::Coroutine(_)
12991319
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
@@ -1302,7 +1322,9 @@ fn create_coroutine_resume_function<'tcx>(
13021322
}
13031323
// Iterator::next doesn't accept a pinned argument,
13041324
// unlike for all other coroutine kinds.
1305-
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
1325+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1326+
make_coroutine_state_argument_indirect(tcx, body);
1327+
}
13061328
}
13071329

13081330
// Make sure we remove dead blocks to remove

compiler/rustc_mir_transform/src/coroutine/drop.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -684,12 +684,13 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
684684
let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
685685
body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
686686

687-
make_coroutine_state_argument_indirect(tcx, &mut body);
688-
689687
match transform.coroutine_kind {
690688
// Iterator::next doesn't accept a pinned argument,
691689
// unlike for all other coroutine kinds.
692-
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
690+
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
691+
make_coroutine_state_argument_indirect(tcx, &mut body);
692+
}
693+
693694
_ => {
694695
make_coroutine_state_argument_pinned(tcx, &mut body);
695696
}

tests/mir-opt/async_drop_live_dead.a-{closure#0}.coroutine_drop_async.0.panic-abort.mir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
44
debug _task_context => _2;
5-
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
5+
debug x => ((*_20).0: T);
66
let mut _0: std::task::Poll<()>;
77
let _3: T;
88
let mut _4: impl std::future::Future<Output = ()>;
@@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
2121
let mut _17: isize;
2222
let mut _18: ();
2323
let mut _19: u32;
24+
let mut _20: &mut {async fn body of a<T>()};
2425
scope 1 {
25-
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
26+
debug x => (((*_20) as variant#4).0: T);
2627
}
2728

2829
bb0: {
29-
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
30+
_20 = copy (_1.0: &mut {async fn body of a<T>()});
31+
_19 = discriminant((*_20));
3032
switchInt(move _19) -> [0: bb9, 3: bb12, 4: bb13, otherwise: bb14];
3133
}
3234

@@ -43,13 +45,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
4345

4446
bb3: {
4547
_0 = Poll::<()>::Pending;
46-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
48+
discriminant((*_20)) = 4;
4749
return;
4850
}
4951

5052
bb4: {
5153
StorageLive(_16);
52-
_15 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
54+
_15 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
5355
_16 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _15) -> [return: bb7, unwind unreachable];
5456
}
5557

@@ -81,7 +83,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
8183
}
8284

8385
bb11: {
84-
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb10, unwind unreachable];
86+
drop(((*_20).0: T)) -> [return: bb10, unwind unreachable];
8587
}
8688

8789
bb12: {

tests/mir-opt/async_drop_live_dead.a-{closure#0}.coroutine_drop_async.0.panic-unwind.mir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>) -> Poll<()> {
44
debug _task_context => _2;
5-
debug x => ((*(_1.0: &mut {async fn body of a<T>()})).0: T);
5+
debug x => ((*_20).0: T);
66
let mut _0: std::task::Poll<()>;
77
let _3: T;
88
let mut _4: impl std::future::Future<Output = ()>;
@@ -21,12 +21,14 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
2121
let mut _17: isize;
2222
let mut _18: ();
2323
let mut _19: u32;
24+
let mut _20: &mut {async fn body of a<T>()};
2425
scope 1 {
25-
debug x => (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).0: T);
26+
debug x => (((*_20) as variant#4).0: T);
2627
}
2728

2829
bb0: {
29-
_19 = discriminant((*(_1.0: &mut {async fn body of a<T>()})));
30+
_20 = copy (_1.0: &mut {async fn body of a<T>()});
31+
_19 = discriminant((*_20));
3032
switchInt(move _19) -> [0: bb12, 2: bb18, 3: bb16, 4: bb17, otherwise: bb19];
3133
}
3234

@@ -57,13 +59,13 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
5759

5860
bb6: {
5961
_0 = Poll::<()>::Pending;
60-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 4;
62+
discriminant((*_20)) = 4;
6163
return;
6264
}
6365

6466
bb7: {
6567
StorageLive(_16);
66-
_15 = &mut (((*(_1.0: &mut {async fn body of a<T>()})) as variant#4).1: impl std::future::Future<Output = ()>);
68+
_15 = &mut (((*_20) as variant#4).1: impl std::future::Future<Output = ()>);
6769
_16 = Pin::<&mut impl Future<Output = ()>>::new_unchecked(move _15) -> [return: bb10, unwind: bb15];
6870
}
6971

@@ -95,11 +97,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a<T>()}>, _2: &mut Context<'_>)
9597
}
9698

9799
bb14: {
98-
drop(((*(_1.0: &mut {async fn body of a<T>()})).0: T)) -> [return: bb13, unwind: bb4];
100+
drop(((*_20).0: T)) -> [return: bb13, unwind: bb4];
99101
}
100102

101103
bb15 (cleanup): {
102-
discriminant((*(_1.0: &mut {async fn body of a<T>()}))) = 2;
104+
discriminant((*_20)) = 2;
103105
resume;
104106
}
105107

tests/mir-opt/building/async_await.a-{closure#0}.coroutine_resume.0.mir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
1414
let mut _0: std::task::Poll<()>;
1515
let mut _3: ();
1616
let mut _4: u32;
17+
let mut _5: &mut {async fn body of a()};
1718

1819
bb0: {
19-
_4 = discriminant((*(_1.0: &mut {async fn body of a()})));
20+
_5 = copy (_1.0: &mut {async fn body of a()});
21+
_4 = discriminant((*_5));
2022
switchInt(move _4) -> [0: bb1, 1: bb4, otherwise: bb5];
2123
}
2224

@@ -27,7 +29,7 @@ fn a::{closure#0}(_1: Pin<&mut {async fn body of a()}>, _2: &mut Context<'_>) ->
2729

2830
bb2: {
2931
_0 = Poll::<()>::Ready(move _3);
30-
discriminant((*(_1.0: &mut {async fn body of a()}))) = 1;
32+
discriminant((*_5)) = 1;
3133
return;
3234
}
3335

tests/mir-opt/building/async_await.b-{closure#0}.coroutine_resume.0.mir

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,23 +86,25 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
8686
let mut _36: ();
8787
let mut _37: ();
8888
let mut _38: u32;
89+
let mut _39: &mut {async fn body of b()};
8990
scope 1 {
90-
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
91+
debug __awaitee => (((*_39) as variant#3).0: {async fn body of a()});
9192
let _17: ();
9293
scope 2 {
9394
debug result => _17;
9495
}
9596
}
9697
scope 3 {
97-
debug __awaitee => (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
98+
debug __awaitee => (((*_39) as variant#4).0: {async fn body of a()});
9899
let _33: ();
99100
scope 4 {
100101
debug result => _33;
101102
}
102103
}
103104

104105
bb0: {
105-
_38 = discriminant((*(_1.0: &mut {async fn body of b()})));
106+
_39 = copy (_1.0: &mut {async fn body of b()});
107+
_38 = discriminant((*_39));
106108
switchInt(move _38) -> [0: bb1, 1: bb29, 3: bb27, 4: bb28, otherwise: bb8];
107109
}
108110

@@ -121,7 +123,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
121123
StorageDead(_5);
122124
PlaceMention(_4);
123125
nop;
124-
(((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()}) = move _4;
126+
(((*_39) as variant#3).0: {async fn body of a()}) = move _4;
125127
goto -> bb4;
126128
}
127129

@@ -131,7 +133,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
131133
StorageLive(_10);
132134
StorageLive(_11);
133135
StorageLive(_12);
134-
_12 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()});
136+
_12 = &mut (((*_39) as variant#3).0: {async fn body of a()});
135137
_11 = &mut (*_12);
136138
_10 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _11) -> [return: bb5, unwind unreachable];
137139
}
@@ -178,7 +180,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
178180
StorageDead(_4);
179181
StorageDead(_19);
180182
StorageDead(_20);
181-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 3;
183+
discriminant((*_39)) = 3;
182184
return;
183185
}
184186

@@ -191,7 +193,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
191193
StorageDead(_12);
192194
StorageDead(_9);
193195
StorageDead(_8);
194-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
196+
drop((((*_39) as variant#3).0: {async fn body of a()})) -> [return: bb12, unwind unreachable];
195197
}
196198

197199
bb11: {
@@ -223,7 +225,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
223225
StorageDead(_22);
224226
PlaceMention(_21);
225227
nop;
226-
(((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()}) = move _21;
228+
(((*_39) as variant#4).0: {async fn body of a()}) = move _21;
227229
goto -> bb16;
228230
}
229231

@@ -233,7 +235,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
233235
StorageLive(_26);
234236
StorageLive(_27);
235237
StorageLive(_28);
236-
_28 = &mut (((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()});
238+
_28 = &mut (((*_39) as variant#4).0: {async fn body of a()});
237239
_27 = &mut (*_28);
238240
_26 = Pin::<&mut {async fn body of a()}>::new_unchecked(move _27) -> [return: bb17, unwind unreachable];
239241
}
@@ -275,7 +277,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
275277
StorageDead(_21);
276278
StorageDead(_35);
277279
StorageDead(_36);
278-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 4;
280+
discriminant((*_39)) = 4;
279281
return;
280282
}
281283

@@ -288,7 +290,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
288290
StorageDead(_28);
289291
StorageDead(_25);
290292
StorageDead(_24);
291-
drop((((*(_1.0: &mut {async fn body of b()})) as variant#4).0: {async fn body of a()})) -> [return: bb23, unwind unreachable];
293+
drop((((*_39) as variant#4).0: {async fn body of a()})) -> [return: bb23, unwind unreachable];
292294
}
293295

294296
bb22: {
@@ -311,7 +313,7 @@ fn b::{closure#0}(_1: Pin<&mut {async fn body of b()}>, _2: &mut Context<'_>) ->
311313

312314
bb25: {
313315
_0 = Poll::<()>::Ready(move _37);
314-
discriminant((*(_1.0: &mut {async fn body of b()}))) = 1;
316+
discriminant((*_39)) = 1;
315317
return;
316318
}
317319

0 commit comments

Comments
 (0)