Skip to content

Commit b1c170f

Browse files
Add unwrap_unsafe_binder and wrap_unsafe_binder macro operators
1 parent c1581fc commit b1c170f

File tree

13 files changed

+135
-13
lines changed

13 files changed

+135
-13
lines changed

compiler/rustc_hir_typeck/src/expr.rs

+63-6
Original file line numberDiff line numberDiff line change
@@ -1645,14 +1645,71 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
16451645

16461646
fn check_expr_unsafe_binder_cast(
16471647
&self,
1648-
_kind: hir::UnsafeBinderCastKind,
1648+
kind: hir::UnsafeBinderCastKind,
16491649
expr: &'tcx hir::Expr<'tcx>,
1650-
_hir_ty: Option<&'tcx hir::Ty<'tcx>>,
1651-
_expected: Expectation<'tcx>,
1650+
hir_ty: Option<&'tcx hir::Ty<'tcx>>,
1651+
expected: Expectation<'tcx>,
16521652
) -> Ty<'tcx> {
1653-
let guar =
1654-
self.dcx().struct_span_err(expr.span, "unsafe binders are not yet implemented").emit();
1655-
Ty::new_error(self.tcx, guar)
1653+
match kind {
1654+
hir::UnsafeBinderCastKind::Wrap => {
1655+
let ascribed_ty =
1656+
hir_ty.map(|hir_ty| self.lower_ty_saving_user_provided_ty(hir_ty));
1657+
let expected_ty = expected.only_has_type(self);
1658+
let binder_ty = match (ascribed_ty, expected_ty) {
1659+
(Some(ascribed_ty), Some(expected_ty)) => {
1660+
self.demand_eqtype(expr.span, expected_ty, ascribed_ty);
1661+
expected_ty
1662+
}
1663+
(Some(ty), None) | (None, Some(ty)) => ty,
1664+
(None, None) => self.next_ty_var(expr.span),
1665+
};
1666+
1667+
// Unwrap the binder eagerly if we can use it to guide inference on
1668+
// the inner expr. If not, then we'll error *after* type checking.
1669+
let hint_ty = if let ty::UnsafeBinder(binder) =
1670+
*self.try_structurally_resolve_type(expr.span, binder_ty).kind()
1671+
{
1672+
self.instantiate_binder_with_fresh_vars(
1673+
expr.span,
1674+
infer::BoundRegionConversionTime::HigherRankedType,
1675+
binder.into(),
1676+
)
1677+
} else {
1678+
self.next_ty_var(expr.span)
1679+
};
1680+
1681+
self.check_expr_has_type_or_error(expr, hint_ty, |_| {});
1682+
1683+
let binder_ty = self.structurally_resolve_type(expr.span, binder_ty);
1684+
match *binder_ty.kind() {
1685+
ty::UnsafeBinder(..) => {
1686+
// Ok
1687+
}
1688+
_ => todo!(),
1689+
}
1690+
1691+
binder_ty
1692+
}
1693+
hir::UnsafeBinderCastKind::Unwrap => {
1694+
let ascribed_ty =
1695+
hir_ty.map(|hir_ty| self.lower_ty_saving_user_provided_ty(hir_ty));
1696+
let hint_ty = ascribed_ty.unwrap_or_else(|| self.next_ty_var(expr.span));
1697+
// FIXME(unsafe_binders): coerce here if needed?
1698+
let binder_ty = self.check_expr_has_type_or_error(expr, hint_ty, |_| {});
1699+
1700+
// Unwrap the binder. This will be ambiguous if it's an infer var, and will error
1701+
// if it's not an unsafe binder.
1702+
let binder_ty = self.structurally_resolve_type(expr.span, binder_ty);
1703+
match *binder_ty.kind() {
1704+
ty::UnsafeBinder(binder) => self.instantiate_binder_with_fresh_vars(
1705+
expr.span,
1706+
infer::BoundRegionConversionTime::HigherRankedType,
1707+
binder.into(),
1708+
),
1709+
_ => todo!(),
1710+
}
1711+
}
1712+
}
16561713
}
16571714

16581715
fn check_expr_array(

compiler/rustc_middle/src/thir.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use rustc_abi::{FieldIdx, Integer, Size, VariantIdx};
1616
use rustc_ast::{AsmMacro, InlineAsmOptions, InlineAsmTemplatePiece};
1717
use rustc_hir as hir;
1818
use rustc_hir::def_id::DefId;
19-
use rustc_hir::{BindingMode, ByRef, HirId, MatchSource, RangeEnd};
19+
use rustc_hir::{BindingMode, ByRef, HirId, MatchSource, RangeEnd, UnsafeBinderCastKind};
2020
use rustc_index::{IndexVec, newtype_index};
2121
use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeVisitable};
2222
use rustc_middle::middle::region;
@@ -489,6 +489,16 @@ pub enum ExprKind<'tcx> {
489489
user_ty: UserTy<'tcx>,
490490
user_ty_span: Span,
491491
},
492+
/// An unsafe binder cast on a place, e.g. `unwrap_unsafe_binder!(x)`.
493+
PlaceUnsafeBinderCast {
494+
kind: UnsafeBinderCastKind,
495+
source: ExprId,
496+
},
497+
/// An unsafe binder cast on a value, e.g. `wrap_unsafe_binder!(1; unsafe<> i32)`.
498+
ValueUnsafeBinderCast {
499+
kind: UnsafeBinderCastKind,
500+
source: ExprId,
501+
},
492502
/// A closure definition.
493503
Closure(Box<ClosureExpr<'tcx>>),
494504
/// A literal.

compiler/rustc_middle/src/thir/visit.rs

+3
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ pub fn walk_expr<'thir, 'tcx: 'thir, V: Visitor<'thir, 'tcx>>(
136136
| ValueTypeAscription { source, user_ty: _, user_ty_span: _ } => {
137137
visitor.visit_expr(&visitor.thir()[source])
138138
}
139+
PlaceUnsafeBinderCast { source, kind: _ } | ValueUnsafeBinderCast { source, kind: _ } => {
140+
visitor.visit_expr(&visitor.thir()[source])
141+
}
139142
Closure(box ClosureExpr {
140143
closure_id: _,
141144
args: _,

compiler/rustc_mir_build/src/build/expr/as_place.rs

+16
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
523523
block.and(PlaceBuilder::from(temp))
524524
}
525525

526+
ExprKind::PlaceUnsafeBinderCast { source, kind: _ } => {
527+
let place_builder = unpack!(
528+
block = this.expr_as_place(block, source, mutability, fake_borrow_temps,)
529+
);
530+
// TODO: stick on a projection elem
531+
block.and(place_builder)
532+
}
533+
ExprKind::ValueUnsafeBinderCast { source, kind: _ } => {
534+
let source_expr = &this.thir[source];
535+
let temp = unpack!(
536+
block = this.as_temp(block, source_expr.temp_lifetime, source, mutability)
537+
);
538+
// TODO: stick on a projection elem
539+
block.and(PlaceBuilder::from(temp))
540+
}
541+
526542
ExprKind::Array { .. }
527543
| ExprKind::Tuple { .. }
528544
| ExprKind::Adt { .. }

compiler/rustc_mir_build/src/build/expr/as_rvalue.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
532532
| ExprKind::Become { .. }
533533
| ExprKind::InlineAsm { .. }
534534
| ExprKind::PlaceTypeAscription { .. }
535-
| ExprKind::ValueTypeAscription { .. } => {
535+
| ExprKind::ValueTypeAscription { .. }
536+
| ExprKind::PlaceUnsafeBinderCast { .. }
537+
| ExprKind::ValueUnsafeBinderCast { .. } => {
536538
// these do not have corresponding `Rvalue` variants,
537539
// so make an operand and then return that
538540
debug_assert!(!matches!(

compiler/rustc_mir_build/src/build/expr/category.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ impl Category {
4141
| ExprKind::UpvarRef { .. }
4242
| ExprKind::VarRef { .. }
4343
| ExprKind::PlaceTypeAscription { .. }
44-
| ExprKind::ValueTypeAscription { .. } => Some(Category::Place),
44+
| ExprKind::ValueTypeAscription { .. }
45+
| ExprKind::PlaceUnsafeBinderCast { .. }
46+
| ExprKind::ValueUnsafeBinderCast { .. } => Some(Category::Place),
4547

4648
ExprKind::LogicalOp { .. }
4749
| ExprKind::Match { .. }

compiler/rustc_mir_build/src/build/expr/into.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
548548
ExprKind::VarRef { .. }
549549
| ExprKind::UpvarRef { .. }
550550
| ExprKind::PlaceTypeAscription { .. }
551-
| ExprKind::ValueTypeAscription { .. } => {
551+
| ExprKind::ValueTypeAscription { .. }
552+
| ExprKind::PlaceUnsafeBinderCast { .. }
553+
| ExprKind::ValueUnsafeBinderCast { .. } => {
552554
debug_assert!(Category::of(&expr.kind) == Some(Category::Place));
553555

554556
let place = unpack!(block = this.as_place(block, expr_id));

compiler/rustc_mir_build/src/check_unsafety.rs

+2
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,8 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
440440
| ExprKind::NeverToAny { .. }
441441
| ExprKind::PlaceTypeAscription { .. }
442442
| ExprKind::ValueTypeAscription { .. }
443+
| ExprKind::PlaceUnsafeBinderCast { .. }
444+
| ExprKind::ValueUnsafeBinderCast { .. }
443445
| ExprKind::PointerCoercion { .. }
444446
| ExprKind::Repeat { .. }
445447
| ExprKind::StaticRef { .. }

compiler/rustc_mir_build/src/thir/cx/expr.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -917,8 +917,14 @@ impl<'tcx> Cx<'tcx> {
917917
}
918918
}
919919

920-
hir::ExprKind::UnsafeBinderCast(_kind, _source, _ty) => {
921-
unreachable!("unsafe binders are not yet implemented")
920+
hir::ExprKind::UnsafeBinderCast(kind, source, _ty) => {
921+
// FIXME(unsafe_binders): Take into account the ascribed type, too.
922+
let mirrored = self.mirror_expr(source);
923+
if source.is_syntactic_place_expr() {
924+
ExprKind::PlaceUnsafeBinderCast { source: mirrored, kind }
925+
} else {
926+
ExprKind::ValueUnsafeBinderCast { source: mirrored, kind }
927+
}
922928
}
923929

924930
hir::ExprKind::DropTemps(source) => ExprKind::Use { source: self.mirror_expr(source) },

compiler/rustc_mir_build/src/thir/pattern/check_match.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
326326
| Use { source }
327327
| PointerCoercion { source, .. }
328328
| PlaceTypeAscription { source, .. }
329-
| ValueTypeAscription { source, .. } => {
329+
| ValueTypeAscription { source, .. }
330+
| ExprKind::PlaceUnsafeBinderCast { source, .. }
331+
| ExprKind::ValueUnsafeBinderCast { source, .. } => {
330332
self.is_known_valid_scrutinee(&self.thir()[*source])
331333
}
332334

compiler/rustc_mir_build/src/thir/print.rs

+14
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,20 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
475475
self.print_expr(*source, depth_lvl + 2);
476476
print_indented!(self, "}", depth_lvl);
477477
}
478+
PlaceUnsafeBinderCast { source, kind } => {
479+
print_indented!(self, "PlaceUnsafeBinderCast {", depth_lvl);
480+
print_indented!(self, format!("kind: {kind:?}"), depth_lvl + 1);
481+
print_indented!(self, "source:", depth_lvl + 1);
482+
self.print_expr(*source, depth_lvl + 2);
483+
print_indented!(self, "}", depth_lvl);
484+
}
485+
ValueUnsafeBinderCast { source, kind } => {
486+
print_indented!(self, "ValueUnsafeBinderCast {", depth_lvl);
487+
print_indented!(self, format!("kind: {kind:?}"), depth_lvl + 1);
488+
print_indented!(self, "source:", depth_lvl + 1);
489+
self.print_expr(*source, depth_lvl + 2);
490+
print_indented!(self, "}", depth_lvl);
491+
}
478492
Closure(closure_expr) => {
479493
print_indented!(self, "Closure {", depth_lvl);
480494
print_indented!(self, "closure_expr:", depth_lvl + 1);

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,7 @@ symbols! {
21312131
unwrap,
21322132
unwrap_binder,
21332133
unwrap_or,
2134+
unwrap_unsafe_binder,
21342135
use_extern_macros,
21352136
use_nested_groups,
21362137
used,

compiler/rustc_ty_utils/src/consts.rs

+5
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ fn recurse_build<'tcx>(
116116
| &ExprKind::ValueTypeAscription { source, .. } => {
117117
recurse_build(tcx, body, source, root_span)?
118118
}
119+
&ExprKind::PlaceUnsafeBinderCast { .. } | &ExprKind::ValueUnsafeBinderCast { .. } => {
120+
todo!()
121+
}
119122
&ExprKind::Literal { lit, neg } => {
120123
let sp = node.span;
121124
match tcx.at(sp).lit_to_const(LitToConstInput { lit: &lit.node, ty: node.ty, neg }) {
@@ -353,6 +356,8 @@ impl<'a, 'tcx> IsThirPolymorphic<'a, 'tcx> {
353356
| thir::ExprKind::Adt(_)
354357
| thir::ExprKind::PlaceTypeAscription { .. }
355358
| thir::ExprKind::ValueTypeAscription { .. }
359+
| thir::ExprKind::PlaceUnsafeBinderCast { .. }
360+
| thir::ExprKind::ValueUnsafeBinderCast { .. }
356361
| thir::ExprKind::Closure(_)
357362
| thir::ExprKind::Literal { .. }
358363
| thir::ExprKind::NonHirLiteral { .. }

0 commit comments

Comments
 (0)