Skip to content

Commit 3c4edc5

Browse files
committed
Add or-patterns to pattern types
1 parent cb6d371 commit 3c4edc5

File tree

34 files changed

+502
-14
lines changed

34 files changed

+502
-14
lines changed

compiler/rustc_ast/src/ast.rs

+2
Original file line numberDiff line numberDiff line change
@@ -2469,6 +2469,8 @@ pub enum TyPatKind {
24692469
/// A range pattern (e.g., `1...2`, `1..2`, `1..`, `..2`, `1..=2`, `..=2`).
24702470
Range(Option<P<AnonConst>>, Option<P<AnonConst>>, Spanned<RangeEnd>),
24712471

2472+
Or(ThinVec<P<TyPat>>),
2473+
24722474
/// Placeholder for a pattern that wasn't syntactically well formed in some way.
24732475
Err(ErrorGuaranteed),
24742476
}

compiler/rustc_ast/src/mut_visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ pub fn walk_ty_pat<T: MutVisitor>(vis: &mut T, ty: &mut P<TyPat>) {
612612
visit_opt(start, |c| vis.visit_anon_const(c));
613613
visit_opt(end, |c| vis.visit_anon_const(c));
614614
}
615+
TyPatKind::Or(variants) => visit_thin_vec(variants, |p| vis.visit_ty_pat(p)),
615616
TyPatKind::Err(_) => {}
616617
}
617618
visit_lazy_tts(vis, tokens);

compiler/rustc_ast/src/visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ pub fn walk_ty_pat<'a, V: Visitor<'a>>(visitor: &mut V, tp: &'a TyPat) -> V::Res
608608
visit_opt!(visitor, visit_anon_const, start);
609609
visit_opt!(visitor, visit_anon_const, end);
610610
}
611+
TyPatKind::Or(variants) => walk_list!(visitor, visit_ty_pat, variants),
611612
TyPatKind::Err(_) => {}
612613
}
613614
V::Result::output()

compiler/rustc_ast_lowering/src/pat.rs

+5
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
464464
)
465465
}),
466466
),
467+
TyPatKind::Or(variants) => {
468+
hir::TyPatKind::Or(self.arena.alloc_from_iter(
469+
variants.iter().map(|pat| self.lower_ty_pat_mut(pat, base_type)),
470+
))
471+
}
467472
TyPatKind::Err(guar) => hir::TyPatKind::Err(*guar),
468473
};
469474

compiler/rustc_ast_pretty/src/pprust/state.rs

+11
Original file line numberDiff line numberDiff line change
@@ -1162,6 +1162,17 @@ impl<'a> State<'a> {
11621162
self.print_expr_anon_const(end, &[]);
11631163
}
11641164
}
1165+
rustc_ast::TyPatKind::Or(variants) => {
1166+
let mut first = true;
1167+
for pat in variants {
1168+
if first {
1169+
first = false
1170+
} else {
1171+
self.word(" | ");
1172+
}
1173+
self.print_ty_pat(pat);
1174+
}
1175+
}
11651176
rustc_ast::TyPatKind::Err(_) => {
11661177
self.popen();
11671178
self.word("/*ERROR*/");

compiler/rustc_builtin_macros/src/pattern_type.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token};
44
use rustc_errors::PResult;
55
use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult};
66
use rustc_parse::exp;
7+
use rustc_parse::parser::{CommaRecoveryMode, RecoverColon, RecoverComma};
78
use rustc_span::Span;
89

910
pub(crate) fn expand<'cx>(
@@ -27,7 +28,17 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P
2728
let ty = parser.parse_ty()?;
2829
parser.expect_keyword(exp!(Is))?;
2930

30-
let pat = pat_to_ty_pat(cx, parser.parse_pat_no_top_alt(None, None)?.into_inner());
31+
let pat = pat_to_ty_pat(
32+
cx,
33+
parser
34+
.parse_pat_no_top_guard(
35+
None,
36+
RecoverComma::No,
37+
RecoverColon::No,
38+
CommaRecoveryMode::EitherTupleOrPipe,
39+
)?
40+
.into_inner(),
41+
);
3142

3243
if parser.token != token::Eof {
3344
parser.unexpected()?;
@@ -47,6 +58,9 @@ fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> P<TyPat> {
4758
end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })),
4859
include_end,
4960
),
61+
ast::PatKind::Or(variants) => TyPatKind::Or(
62+
variants.into_iter().map(|pat| pat_to_ty_pat(cx, pat.into_inner())).collect(),
63+
),
5064
ast::PatKind::Err(guar) => TyPatKind::Err(guar),
5165
_ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")),
5266
};

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+8-5
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,19 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>(
6161
ensure_monomorphic_enough(tcx, tp_ty)?;
6262
ConstValue::from_u128(tcx.type_id_hash(tp_ty).as_u128())
6363
}
64-
sym::variant_count => match tp_ty.kind() {
64+
sym::variant_count => match match tp_ty.kind() {
65+
// Pattern types have the same number of variants as their base type.
66+
// Even if we restrict e.g. which variants are valid, the variants are essentially just uninhabited.
67+
// And `Result<(), !>` still has two variants according to `variant_count`.
68+
ty::Pat(base, _) => *base,
69+
_ => tp_ty,
70+
} {
6571
// Correctly handles non-monomorphic calls, so there is no need for ensure_monomorphic_enough.
6672
ty::Adt(adt, _) => ConstValue::from_target_usize(adt.variants().len() as u64, &tcx),
6773
ty::Alias(..) | ty::Param(_) | ty::Placeholder(_) | ty::Infer(_) => {
6874
throw_inval!(TooGeneric)
6975
}
70-
ty::Pat(_, pat) => match **pat {
71-
ty::PatternKind::Range { .. } => ConstValue::from_target_usize(0u64, &tcx),
72-
// Future pattern kinds may have more variants
73-
},
76+
ty::Pat(..) => unreachable!(),
7477
ty::Bound(_, _) => bug!("bound ty during ctfe"),
7578
ty::Bool
7679
| ty::Char

compiler/rustc_const_eval/src/interpret/validity.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,14 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
12481248
// Range patterns are precisely reflected into `valid_range` and thus
12491249
// handled fully by `visit_scalar` (called below).
12501250
ty::PatternKind::Range { .. } => {},
1251+
1252+
// FIXME(pattern_types): check that the value is covered by one of the variants.
1253+
// For now, we rely on layout computation setting the scalar's `valid_range` to
1254+
// match the pattern. However, this cannot always work; the layout may
1255+
// pessimistically cover actually illegal ranges and Miri would miss that UB.
1256+
// The consolation here is that codegen also will miss that UB, so at least
1257+
// we won't see optimizations actually breaking such programs.
1258+
ty::PatternKind::Or(_patterns) => {}
12511259
}
12521260
}
12531261
_ => {

compiler/rustc_hir/src/hir.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1813,6 +1813,9 @@ pub enum TyPatKind<'hir> {
18131813
/// A range pattern (e.g., `1..=2` or `1..2`).
18141814
Range(&'hir ConstArg<'hir>, &'hir ConstArg<'hir>),
18151815

1816+
/// A list of patterns where only one needs to be satisfied
1817+
Or(&'hir [TyPat<'hir>]),
1818+
18161819
/// A placeholder for a pattern that wasn't well formed in some way.
18171820
Err(ErrorGuaranteed),
18181821
}

compiler/rustc_hir/src/intravisit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ pub fn walk_ty_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v TyPat<'v>)
710710
try_visit!(visitor.visit_const_arg_unambig(lower_bound));
711711
try_visit!(visitor.visit_const_arg_unambig(upper_bound));
712712
}
713+
TyPatKind::Or(patterns) => walk_list!(visitor, visit_pattern_type_pattern, patterns),
713714
TyPatKind::Err(_) => (),
714715
}
715716
V::Result::output()

compiler/rustc_hir_analysis/src/collect/type_of.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,12 @@ fn const_arg_anon_type_of<'tcx>(icx: &ItemCtxt<'tcx>, arg_hir_id: HirId, span: S
9494
}
9595

9696
Node::TyPat(pat) => {
97-
let hir::TyKind::Pat(ty, p) = tcx.parent_hir_node(pat.hir_id).expect_ty().kind else {
98-
bug!()
97+
let node = match tcx.parent_hir_node(pat.hir_id) {
98+
// Or patterns can be nested one level deep
99+
Node::TyPat(p) => tcx.parent_hir_node(p.hir_id),
100+
other => other,
99101
};
100-
assert_eq!(p.hir_id, pat.hir_id);
102+
let hir::TyKind::Pat(ty, _) = node.expect_ty().kind else { bug!() };
101103
icx.lower_ty(ty)
102104
}
103105

compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -2735,6 +2735,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
27352735
ty_span: Span,
27362736
pat: &hir::TyPat<'tcx>,
27372737
) -> Result<ty::PatternKind<'tcx>, ErrorGuaranteed> {
2738+
let tcx = self.tcx();
27382739
match pat.kind {
27392740
hir::TyPatKind::Range(start, end) => {
27402741
match ty.kind() {
@@ -2750,6 +2751,13 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
27502751
.span_delayed_bug(ty_span, "invalid base type for range pattern")),
27512752
}
27522753
}
2754+
hir::TyPatKind::Or(patterns) => {
2755+
self.tcx()
2756+
.mk_patterns_from_iter(patterns.iter().map(|pat| {
2757+
self.lower_pat_ty_pat(ty, ty_span, pat).map(|pat| tcx.mk_pat(pat))
2758+
}))
2759+
.map(ty::PatternKind::Or)
2760+
}
27532761
hir::TyPatKind::Err(e) => Err(e),
27542762
}
27552763
}

compiler/rustc_hir_analysis/src/variance/constraints.rs

+5
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,11 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> {
340340
self.add_constraints_from_const(current, start, variance);
341341
self.add_constraints_from_const(current, end, variance);
342342
}
343+
ty::PatternKind::Or(patterns) => {
344+
for pat in patterns {
345+
self.add_constraints_from_pat(current, variance, pat)
346+
}
347+
}
343348
}
344349
}
345350

compiler/rustc_hir_pretty/src/lib.rs

+13
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,19 @@ impl<'a> State<'a> {
18661866
self.word("..=");
18671867
self.print_const_arg(end);
18681868
}
1869+
TyPatKind::Or(patterns) => {
1870+
self.popen();
1871+
let mut first = true;
1872+
for pat in patterns {
1873+
if first {
1874+
first = false;
1875+
} else {
1876+
self.word(" | ");
1877+
}
1878+
self.print_ty_pat(pat);
1879+
}
1880+
self.pclose();
1881+
}
18691882
TyPatKind::Err(_) => {
18701883
self.popen();
18711884
self.word("/*ERROR*/");

compiler/rustc_lint/src/types.rs

+22-3
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,9 @@ fn pat_ty_is_known_nonnull<'tcx>(
900900
// to ensure we aren't wrapping over zero.
901901
start > 0 && end >= start
902902
}
903+
ty::PatternKind::Or(patterns) => {
904+
patterns.iter().all(|pat| pat_ty_is_known_nonnull(tcx, typing_env, pat))
905+
}
903906
}
904907
},
905908
)
@@ -1046,13 +1049,29 @@ pub(crate) fn repr_nullable_ptr<'tcx>(
10461049
}
10471050
None
10481051
}
1049-
ty::Pat(base, pat) => match **pat {
1050-
ty::PatternKind::Range { .. } => get_nullable_type(tcx, typing_env, *base),
1051-
},
1052+
ty::Pat(base, pat) => get_nullable_type_from_pat(tcx, typing_env, *base, *pat),
10521053
_ => None,
10531054
}
10541055
}
10551056

1057+
fn get_nullable_type_from_pat<'tcx>(
1058+
tcx: TyCtxt<'tcx>,
1059+
typing_env: ty::TypingEnv<'tcx>,
1060+
base: Ty<'tcx>,
1061+
pat: ty::Pattern<'tcx>,
1062+
) -> Option<Ty<'tcx>> {
1063+
match *pat {
1064+
ty::PatternKind::Range { .. } => get_nullable_type(tcx, typing_env, base),
1065+
ty::PatternKind::Or(patterns) => {
1066+
let first = get_nullable_type_from_pat(tcx, typing_env, base, patterns[0])?;
1067+
for &pat in &patterns[1..] {
1068+
assert_eq!(first, get_nullable_type_from_pat(tcx, typing_env, base, pat)?);
1069+
}
1070+
Some(first)
1071+
}
1072+
}
1073+
}
1074+
10561075
impl<'a, 'tcx> ImproperCTypesVisitor<'a, 'tcx> {
10571076
/// Check if the type is array and emit an unsafe type lint.
10581077
fn check_for_array_ty(&mut self, sp: Span, ty: Ty<'tcx>) -> bool {

compiler/rustc_middle/src/ty/codec.rs

+10
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,15 @@ impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::BoundVaria
442442
}
443443
}
444444

445+
impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::Pattern<'tcx>> {
446+
fn decode(decoder: &mut D) -> &'tcx Self {
447+
let len = decoder.read_usize();
448+
decoder.interner().mk_patterns_from_iter(
449+
(0..len).map::<ty::Pattern<'tcx>, _>(|_| Decodable::decode(decoder)),
450+
)
451+
}
452+
}
453+
445454
impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List<ty::Const<'tcx>> {
446455
fn decode(decoder: &mut D) -> &'tcx Self {
447456
let len = decoder.read_usize();
@@ -503,6 +512,7 @@ impl_decodable_via_ref! {
503512
&'tcx mir::Body<'tcx>,
504513
&'tcx mir::ConcreteOpaqueTypes<'tcx>,
505514
&'tcx ty::List<ty::BoundVariableKind>,
515+
&'tcx ty::List<ty::Pattern<'tcx>>,
506516
&'tcx ty::ListWithCachedTypeInfo<ty::Clause<'tcx>>,
507517
&'tcx ty::List<FieldIdx>,
508518
&'tcx ty::List<(VariantIdx, FieldIdx)>,

compiler/rustc_middle/src/ty/context.rs

+12
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
136136

137137
type AllocId = crate::mir::interpret::AllocId;
138138
type Pat = Pattern<'tcx>;
139+
type PatList = &'tcx List<Pattern<'tcx>>;
139140
type Safety = hir::Safety;
140141
type Abi = ExternAbi;
141142
type Const = ty::Const<'tcx>;
@@ -843,6 +844,7 @@ pub struct CtxtInterners<'tcx> {
843844
captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>,
844845
offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>,
845846
valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>,
847+
patterns: InternedSet<'tcx, List<ty::Pattern<'tcx>>>,
846848
}
847849

848850
impl<'tcx> CtxtInterners<'tcx> {
@@ -879,6 +881,7 @@ impl<'tcx> CtxtInterners<'tcx> {
879881
captures: InternedSet::with_capacity(N),
880882
offset_of: InternedSet::with_capacity(N),
881883
valtree: InternedSet::with_capacity(N),
884+
patterns: InternedSet::with_capacity(N),
882885
}
883886
}
884887

@@ -2659,6 +2662,7 @@ slice_interners!(
26592662
local_def_ids: intern_local_def_ids(LocalDefId),
26602663
captures: intern_captures(&'tcx ty::CapturedPlace<'tcx>),
26612664
offset_of: pub mk_offset_of((VariantIdx, FieldIdx)),
2665+
patterns: pub mk_patterns(Pattern<'tcx>),
26622666
);
26632667

26642668
impl<'tcx> TyCtxt<'tcx> {
@@ -2932,6 +2936,14 @@ impl<'tcx> TyCtxt<'tcx> {
29322936
self.intern_local_def_ids(def_ids)
29332937
}
29342938

2939+
pub fn mk_patterns_from_iter<I, T>(self, iter: I) -> T::Output
2940+
where
2941+
I: Iterator<Item = T>,
2942+
T: CollectAndApply<ty::Pattern<'tcx>, &'tcx List<ty::Pattern<'tcx>>>,
2943+
{
2944+
T::collect_and_apply(iter, |xs| self.mk_patterns(xs))
2945+
}
2946+
29352947
pub fn mk_local_def_ids_from_iter<I, T>(self, iter: I) -> T::Output
29362948
where
29372949
I: Iterator<Item = T>,

compiler/rustc_middle/src/ty/pattern.rs

+27
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ impl<'tcx> Flags for Pattern<'tcx> {
2323
FlagComputation::for_const_kind(&start.kind()).flags
2424
| FlagComputation::for_const_kind(&end.kind()).flags
2525
}
26+
ty::PatternKind::Or(pats) => {
27+
let mut flags = pats[0].flags();
28+
for pat in pats[1..].iter() {
29+
flags |= pat.flags();
30+
}
31+
flags
32+
}
2633
}
2734
}
2835

@@ -31,6 +38,13 @@ impl<'tcx> Flags for Pattern<'tcx> {
3138
ty::PatternKind::Range { start, end } => {
3239
start.outer_exclusive_binder().max(end.outer_exclusive_binder())
3340
}
41+
ty::PatternKind::Or(pats) => {
42+
let mut idx = pats[0].outer_exclusive_binder();
43+
for pat in pats[1..].iter() {
44+
idx = idx.max(pat.outer_exclusive_binder());
45+
}
46+
idx
47+
}
3448
}
3549
}
3650
}
@@ -77,6 +91,19 @@ impl<'tcx> IrPrint<PatternKind<'tcx>> for TyCtxt<'tcx> {
7791

7892
write!(f, "..={end}")
7993
}
94+
PatternKind::Or(patterns) => {
95+
write!(f, "(")?;
96+
let mut first = true;
97+
for pat in patterns {
98+
if first {
99+
first = false
100+
} else {
101+
write!(f, " | ")?;
102+
}
103+
write!(f, "{pat:?}")?;
104+
}
105+
write!(f, ")")
106+
}
80107
}
81108
}
82109

compiler/rustc_middle/src/ty/relate.rs

+9
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ impl<'tcx> Relate<TyCtxt<'tcx>> for ty::Pattern<'tcx> {
5959
let end = relation.relate(end_a, end_b)?;
6060
Ok(tcx.mk_pat(ty::PatternKind::Range { start, end }))
6161
}
62+
(&ty::PatternKind::Or(a), &ty::PatternKind::Or(b)) => {
63+
if a.len() != b.len() {
64+
return Err(TypeError::Mismatch);
65+
}
66+
let v = iter::zip(a, b).map(|(a, b)| relation.relate(a, b));
67+
let patterns = tcx.mk_patterns_from_iter(v)?;
68+
Ok(tcx.mk_pat(ty::PatternKind::Or(patterns)))
69+
}
70+
(ty::PatternKind::Range { .. } | ty::PatternKind::Or(_), _) => Err(TypeError::Mismatch),
6271
}
6372
}
6473
}

0 commit comments

Comments
 (0)