From 96918a4bd574539974e05985c1c53cbcf0c69b50 Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Tue, 11 Mar 2025 15:27:35 +0000 Subject: [PATCH 1/6] Prevent pattern type macro invocations from having trailing tokens --- compiler/rustc_builtin_macros/src/pattern_type.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/src/pattern_type.rs b/compiler/rustc_builtin_macros/src/pattern_type.rs index a55c7e962d098..a0b84c4ccf53a 100644 --- a/compiler/rustc_builtin_macros/src/pattern_type.rs +++ b/compiler/rustc_builtin_macros/src/pattern_type.rs @@ -1,6 +1,6 @@ use rustc_ast::ptr::P; use rustc_ast::tokenstream::TokenStream; -use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast}; +use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token}; use rustc_errors::PResult; use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult}; use rustc_parse::exp; @@ -40,5 +40,9 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P let pat = P(TyPat { id: pat.id, kind, span: pat.span, tokens: pat.tokens }); + if parser.token != token::Eof { + parser.unexpected()?; + } + Ok((ty, pat)) } From 6338e364f007a0fe4b0bf7dfb425000890336d8a Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Thu, 27 Feb 2025 09:46:46 +0000 Subject: [PATCH 2/6] Pull ast pattern type parsing out into a separate function --- .../rustc_builtin_macros/src/pattern_type.rs | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/pattern_type.rs b/compiler/rustc_builtin_macros/src/pattern_type.rs index a0b84c4ccf53a..e9d666939125f 100644 --- a/compiler/rustc_builtin_macros/src/pattern_type.rs +++ b/compiler/rustc_builtin_macros/src/pattern_type.rs @@ -26,9 +26,27 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P let ty = parser.parse_ty()?; parser.expect_keyword(exp!(Is))?; + + let start = parser.token.span; let pat = parser.parse_pat_no_top_alt(None, None)?.into_inner(); + let kind = pat_to_ty_pat(cx, pat); + + let span = start.to(parser.token.span); + let pat = ty_pat(kind, span); + + if parser.token != token::Eof { + parser.unexpected()?; + } + + Ok((ty, pat)) +} - let kind = match pat.kind { +fn ty_pat(kind: TyPatKind, span: Span) -> P { + P(TyPat { id: DUMMY_NODE_ID, kind, span, tokens: None }) +} + +fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> TyPatKind { + match pat.kind { ast::PatKind::Range(start, end, include_end) => TyPatKind::Range( start.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })), end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })), @@ -36,13 +54,5 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P ), ast::PatKind::Err(guar) => TyPatKind::Err(guar), _ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")), - }; - - let pat = P(TyPat { id: pat.id, kind, span: pat.span, tokens: pat.tokens }); - - if parser.token != token::Eof { - parser.unexpected()?; } - - Ok((ty, pat)) } From 2134793792f0e53504076f037265ed6aefe12fd6 Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Thu, 27 Feb 2025 14:23:55 +0000 Subject: [PATCH 3/6] Directly generate TyPat instead of TyPatKind --- compiler/rustc_builtin_macros/src/pattern_type.rs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/pattern_type.rs b/compiler/rustc_builtin_macros/src/pattern_type.rs index e9d666939125f..aaf5b2336518f 100644 --- a/compiler/rustc_builtin_macros/src/pattern_type.rs +++ b/compiler/rustc_builtin_macros/src/pattern_type.rs @@ -27,12 +27,7 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P let ty = parser.parse_ty()?; parser.expect_keyword(exp!(Is))?; - let start = parser.token.span; - let pat = parser.parse_pat_no_top_alt(None, None)?.into_inner(); - let kind = pat_to_ty_pat(cx, pat); - - let span = start.to(parser.token.span); - let pat = ty_pat(kind, span); + let pat = pat_to_ty_pat(cx, parser.parse_pat_no_top_alt(None, None)?.into_inner()); if parser.token != token::Eof { parser.unexpected()?; @@ -45,8 +40,8 @@ fn ty_pat(kind: TyPatKind, span: Span) -> P { P(TyPat { id: DUMMY_NODE_ID, kind, span, tokens: None }) } -fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> TyPatKind { - match pat.kind { +fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> P { + let kind = match pat.kind { ast::PatKind::Range(start, end, include_end) => TyPatKind::Range( start.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })), end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })), @@ -54,5 +49,6 @@ fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> TyPatKind { ), ast::PatKind::Err(guar) => TyPatKind::Err(guar), _ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")), - } + }; + ty_pat(kind, pat.span) } From 6008592fca37057a11d596da4bd74a944f383f35 Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Fri, 28 Feb 2025 14:48:40 +0000 Subject: [PATCH 4/6] Separate pattern lowering from pattern type lowering --- .../src/hir_ty_lowering/mod.rs | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs index 22162b8b3649a..6e1e650a8177f 100644 --- a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs +++ b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs @@ -2715,30 +2715,9 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ { hir::TyKind::Pat(ty, pat) => { let ty_span = ty.span; let ty = self.lower_ty(ty); - let pat_ty = match pat.kind { - hir::TyPatKind::Range(start, end) => { - let (ty, start, end) = match ty.kind() { - // Keep this list of types in sync with the list of types that - // the `RangePattern` trait is implemented for. - ty::Int(_) | ty::Uint(_) | ty::Char => { - let start = self.lower_const_arg(start, FeedConstTy::No); - let end = self.lower_const_arg(end, FeedConstTy::No); - (ty, start, end) - } - _ => { - let guar = self.dcx().span_delayed_bug( - ty_span, - "invalid base type for range pattern", - ); - let errc = ty::Const::new_error(tcx, guar); - (Ty::new_error(tcx, guar), errc, errc) - } - }; - - let pat = tcx.mk_pat(ty::PatternKind::Range { start, end }); - Ty::new_pat(tcx, ty, pat) - } - hir::TyPatKind::Err(e) => Ty::new_error(tcx, e), + let pat_ty = match self.lower_pat_ty_pat(ty, ty_span, pat) { + Ok(kind) => Ty::new_pat(tcx, ty, tcx.mk_pat(kind)), + Err(guar) => Ty::new_error(tcx, guar), }; self.record_ty(pat.hir_id, ty, pat.span); pat_ty @@ -2750,6 +2729,31 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ { result_ty } + fn lower_pat_ty_pat( + &self, + ty: Ty<'tcx>, + ty_span: Span, + pat: &hir::TyPat<'tcx>, + ) -> Result, ErrorGuaranteed> { + match pat.kind { + hir::TyPatKind::Range(start, end) => { + match ty.kind() { + // Keep this list of types in sync with the list of types that + // the `RangePattern` trait is implemented for. + ty::Int(_) | ty::Uint(_) | ty::Char => { + let start = self.lower_const_arg(start, FeedConstTy::No); + let end = self.lower_const_arg(end, FeedConstTy::No); + Ok(ty::PatternKind::Range { start, end }) + } + _ => Err(self + .dcx() + .span_delayed_bug(ty_span, "invalid base type for range pattern")), + } + } + hir::TyPatKind::Err(e) => Err(e), + } + } + /// Lower an opaque type (i.e., an existential impl-Trait type) from the HIR. #[instrument(level = "debug", skip(self), ret)] fn lower_opaque_ty(&self, def_id: LocalDefId, in_trait: bool) -> Ty<'tcx> { From cb6d3715a5447da8c291aa5d83133376e5811751 Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Thu, 27 Feb 2025 09:46:46 +0000 Subject: [PATCH 5/6] Split out various pattern type matches into their own function --- .../src/variance/constraints.rs | 21 +++-- compiler/rustc_lint/src/types.rs | 36 +++++---- compiler/rustc_middle/src/ty/relate.rs | 3 +- compiler/rustc_symbol_mangling/src/v0.rs | 33 ++++---- .../rustc_trait_selection/src/traits/wf.rs | 79 ++++++++++--------- compiler/rustc_type_ir/src/flags.rs | 6 +- compiler/rustc_type_ir/src/walk.rs | 16 ++-- 7 files changed, 115 insertions(+), 79 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/variance/constraints.rs b/compiler/rustc_hir_analysis/src/variance/constraints.rs index 23223de918cfc..8123326a47f39 100644 --- a/compiler/rustc_hir_analysis/src/variance/constraints.rs +++ b/compiler/rustc_hir_analysis/src/variance/constraints.rs @@ -251,12 +251,7 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> { } ty::Pat(typ, pat) => { - match *pat { - ty::PatternKind::Range { start, end } => { - self.add_constraints_from_const(current, start, variance); - self.add_constraints_from_const(current, end, variance); - } - } + self.add_constraints_from_pat(current, variance, pat); self.add_constraints_from_ty(current, typ, variance); } @@ -334,6 +329,20 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> { } } + fn add_constraints_from_pat( + &mut self, + current: &CurrentItem, + variance: VarianceTermPtr<'a>, + pat: ty::Pattern<'tcx>, + ) { + match *pat { + ty::PatternKind::Range { start, end } => { + self.add_constraints_from_const(current, start, variance); + self.add_constraints_from_const(current, end, variance); + } + } + } + /// Adds constraints appropriate for a nominal type (enum, struct, /// object, etc) appearing in a context with ambient variance `variance` fn add_constraints_from_args( diff --git a/compiler/rustc_lint/src/types.rs b/compiler/rustc_lint/src/types.rs index f9dce5a5198d0..2511d3f6b163d 100644 --- a/compiler/rustc_lint/src/types.rs +++ b/compiler/rustc_lint/src/types.rs @@ -878,25 +878,33 @@ fn ty_is_known_nonnull<'tcx>( } ty::Pat(base, pat) => { ty_is_known_nonnull(tcx, typing_env, *base, mode) - || Option::unwrap_or_default( - try { - match **pat { - ty::PatternKind::Range { start, end } => { - let start = start.try_to_value()?.try_to_bits(tcx, typing_env)?; - let end = end.try_to_value()?.try_to_bits(tcx, typing_env)?; - - // This also works for negative numbers, as we just need - // to ensure we aren't wrapping over zero. - start > 0 && end >= start - } - } - }, - ) + || pat_ty_is_known_nonnull(tcx, typing_env, *pat) } _ => false, } } +fn pat_ty_is_known_nonnull<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + pat: ty::Pattern<'tcx>, +) -> bool { + Option::unwrap_or_default( + try { + match *pat { + ty::PatternKind::Range { start, end } => { + let start = start.try_to_value()?.try_to_bits(tcx, typing_env)?; + let end = end.try_to_value()?.try_to_bits(tcx, typing_env)?; + + // This also works for negative numbers, as we just need + // to ensure we aren't wrapping over zero. + start > 0 && end >= start + } + } + }, + ) +} + /// Given a non-null scalar (or transparent) type `ty`, return the nullable version of that type. /// If the type passed in was not scalar, returns None. fn get_nullable_type<'tcx>( diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs index b1dfcb80bde57..c3ee72bcaed33 100644 --- a/compiler/rustc_middle/src/ty/relate.rs +++ b/compiler/rustc_middle/src/ty/relate.rs @@ -49,6 +49,7 @@ impl<'tcx> Relate> for ty::Pattern<'tcx> { a: Self, b: Self, ) -> RelateResult<'tcx, Self> { + let tcx = relation.cx(); match (&*a, &*b) { ( &ty::PatternKind::Range { start: start_a, end: end_a }, @@ -56,7 +57,7 @@ impl<'tcx> Relate> for ty::Pattern<'tcx> { ) => { let start = relation.relate(start_a, start_b)?; let end = relation.relate(end_a, end_b)?; - Ok(relation.cx().mk_pat(ty::PatternKind::Range { start, end })) + Ok(tcx.mk_pat(ty::PatternKind::Range { start, end })) } } } diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs index f310aa6550025..6cde28d0ee90d 100644 --- a/compiler/rustc_symbol_mangling/src/v0.rs +++ b/compiler/rustc_symbol_mangling/src/v0.rs @@ -247,6 +247,17 @@ impl<'tcx> SymbolMangler<'tcx> { Ok(()) } + + fn print_pat(&mut self, pat: ty::Pattern<'tcx>) -> Result<(), std::fmt::Error> { + Ok(match *pat { + ty::PatternKind::Range { start, end } => { + let consts = [start, end]; + for ct in consts { + Ty::new_array_with_const_len(self.tcx, self.tcx.types.unit, ct).print(self)?; + } + } + }) + } } impl<'tcx> Printer<'tcx> for SymbolMangler<'tcx> { @@ -463,20 +474,14 @@ impl<'tcx> Printer<'tcx> for SymbolMangler<'tcx> { ty.print(self)?; } - ty::Pat(ty, pat) => match *pat { - ty::PatternKind::Range { start, end } => { - let consts = [start, end]; - // HACK: Represent as tuple until we have something better. - // HACK: constants are used in arrays, even if the types don't match. - self.push("T"); - ty.print(self)?; - for ct in consts { - Ty::new_array_with_const_len(self.tcx, self.tcx.types.unit, ct) - .print(self)?; - } - self.push("E"); - } - }, + ty::Pat(ty, pat) => { + // HACK: Represent as tuple until we have something better. + // HACK: constants are used in arrays, even if the types don't match. + self.push("T"); + ty.print(self)?; + self.print_pat(pat)?; + self.push("E"); + } ty::Array(ty, len) => { self.push("A"); diff --git a/compiler/rustc_trait_selection/src/traits/wf.rs b/compiler/rustc_trait_selection/src/traits/wf.rs index 62bd8e1af05de..0a6b3b2999085 100644 --- a/compiler/rustc_trait_selection/src/traits/wf.rs +++ b/compiler/rustc_trait_selection/src/traits/wf.rs @@ -658,6 +658,45 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> { // ``` } } + + fn add_wf_preds_for_pat_ty(&mut self, base_ty: Ty<'tcx>, pat: ty::Pattern<'tcx>) { + let tcx = self.tcx(); + match *pat { + ty::PatternKind::Range { start, end } => { + let mut check = |c| { + let cause = self.cause(ObligationCauseCode::Misc); + self.out.push(traits::Obligation::with_depth( + tcx, + cause.clone(), + self.recursion_depth, + self.param_env, + ty::Binder::dummy(ty::PredicateKind::Clause( + ty::ClauseKind::ConstArgHasType(c, base_ty), + )), + )); + if !tcx.features().generic_pattern_types() { + if c.has_param() { + if self.span.is_dummy() { + self.tcx() + .dcx() + .delayed_bug("feature error should be reported elsewhere, too"); + } else { + feature_err( + &self.tcx().sess, + sym::generic_pattern_types, + self.span, + "wraparound pattern type ranges cause monomorphization time errors", + ) + .emit(); + } + } + } + }; + check(start); + check(end); + } + } + } } impl<'a, 'tcx> TypeVisitor> for WfPredicates<'a, 'tcx> { @@ -710,43 +749,9 @@ impl<'a, 'tcx> TypeVisitor> for WfPredicates<'a, 'tcx> { )); } - ty::Pat(subty, pat) => { - self.require_sized(subty, ObligationCauseCode::Misc); - match *pat { - ty::PatternKind::Range { start, end } => { - let mut check = |c| { - let cause = self.cause(ObligationCauseCode::Misc); - self.out.push(traits::Obligation::with_depth( - tcx, - cause.clone(), - self.recursion_depth, - self.param_env, - ty::Binder::dummy(ty::PredicateKind::Clause( - ty::ClauseKind::ConstArgHasType(c, subty), - )), - )); - if !tcx.features().generic_pattern_types() { - if c.has_param() { - if self.span.is_dummy() { - self.tcx().dcx().delayed_bug( - "feature error should be reported elsewhere, too", - ); - } else { - feature_err( - &self.tcx().sess, - sym::generic_pattern_types, - self.span, - "wraparound pattern type ranges cause monomorphization time errors", - ) - .emit(); - } - } - } - }; - check(start); - check(end); - } - } + ty::Pat(base_ty, pat) => { + self.require_sized(base_ty, ObligationCauseCode::Misc); + self.add_wf_preds_for_pat_ty(base_ty, pat); } ty::Tuple(tys) => { diff --git a/compiler/rustc_type_ir/src/flags.rs b/compiler/rustc_type_ir/src/flags.rs index 74fb148a7cc30..f4ad68734a529 100644 --- a/compiler/rustc_type_ir/src/flags.rs +++ b/compiler/rustc_type_ir/src/flags.rs @@ -304,7 +304,7 @@ impl FlagComputation { ty::Pat(ty, pat) => { self.add_ty(ty); - self.add_flags(pat.flags()); + self.add_ty_pat(pat); } ty::Slice(tt) => self.add_ty(tt), @@ -338,6 +338,10 @@ impl FlagComputation { } } + fn add_ty_pat(&mut self, pat: ::Pat) { + self.add_flags(pat.flags()); + } + fn add_predicate(&mut self, binder: ty::Binder>) { self.bound_computation(binder, |computation, atom| computation.add_predicate_atom(atom)); } diff --git a/compiler/rustc_type_ir/src/walk.rs b/compiler/rustc_type_ir/src/walk.rs index 5683e1f1712c8..ebfb3f786e83a 100644 --- a/compiler/rustc_type_ir/src/walk.rs +++ b/compiler/rustc_type_ir/src/walk.rs @@ -89,12 +89,7 @@ fn push_inner(stack: &mut TypeWalkerStack, parent: I::GenericArg | ty::Foreign(..) => {} ty::Pat(ty, pat) => { - match pat.kind() { - ty::PatternKind::Range { start, end } => { - stack.push(end.into()); - stack.push(start.into()); - } - } + push_ty_pat::(stack, pat); stack.push(ty.into()); } ty::Array(ty, len) => { @@ -171,3 +166,12 @@ fn push_inner(stack: &mut TypeWalkerStack, parent: I::GenericArg }, } } + +fn push_ty_pat(stack: &mut TypeWalkerStack, pat: I::Pat) { + match pat.kind() { + ty::PatternKind::Range { start, end } => { + stack.push(end.into()); + stack.push(start.into()); + } + } +} From b023856f29743a288727d13d0d1044b8e0d3f9f3 Mon Sep 17 00:00:00 2001 From: Oli Scherer Date: Thu, 27 Feb 2025 09:46:46 +0000 Subject: [PATCH 6/6] Add or-patterns to pattern types --- compiler/rustc_ast/src/ast.rs | 2 + compiler/rustc_ast/src/mut_visit.rs | 1 + compiler/rustc_ast/src/visit.rs | 1 + compiler/rustc_ast_lowering/src/pat.rs | 5 + compiler/rustc_ast_pretty/src/pprust/state.rs | 11 ++ .../rustc_builtin_macros/src/pattern_type.rs | 16 ++- .../src/interpret/intrinsics.rs | 15 ++- .../src/interpret/validity.rs | 8 ++ compiler/rustc_hir/src/hir.rs | 3 + compiler/rustc_hir/src/intravisit.rs | 1 + .../rustc_hir_analysis/src/collect/type_of.rs | 8 +- .../src/hir_ty_lowering/mod.rs | 8 ++ .../src/variance/constraints.rs | 5 + compiler/rustc_hir_pretty/src/lib.rs | 13 ++ compiler/rustc_lint/src/types.rs | 25 +++- compiler/rustc_middle/src/ty/codec.rs | 10 ++ compiler/rustc_middle/src/ty/context.rs | 12 ++ compiler/rustc_middle/src/ty/pattern.rs | 27 ++++ compiler/rustc_middle/src/ty/relate.rs | 9 ++ .../rustc_middle/src/ty/structural_impls.rs | 1 + compiler/rustc_resolve/src/late.rs | 5 + .../rustc_smir/src/rustc_smir/convert/ty.rs | 1 + compiler/rustc_symbol_mangling/src/v0.rs | 5 + .../rustc_trait_selection/src/traits/wf.rs | 5 + compiler/rustc_ty_utils/src/layout.rs | 86 +++++++++++- compiler/rustc_type_ir/src/interner.rs | 7 + compiler/rustc_type_ir/src/pattern.rs | 1 + compiler/rustc_type_ir/src/walk.rs | 5 + .../clippy/clippy_utils/src/hir_utils.rs | 5 + src/tools/rustfmt/src/types.rs | 13 ++ tests/ui/type/pattern_types/or_patterns.rs | 45 +++++++ .../ui/type/pattern_types/or_patterns.stderr | 123 ++++++++++++++++++ .../type/pattern_types/or_patterns_invalid.rs | 26 ++++ .../pattern_types/or_patterns_invalid.stderr | 10 ++ 34 files changed, 504 insertions(+), 14 deletions(-) create mode 100644 tests/ui/type/pattern_types/or_patterns.rs create mode 100644 tests/ui/type/pattern_types/or_patterns.stderr create mode 100644 tests/ui/type/pattern_types/or_patterns_invalid.rs create mode 100644 tests/ui/type/pattern_types/or_patterns_invalid.stderr diff --git a/compiler/rustc_ast/src/ast.rs b/compiler/rustc_ast/src/ast.rs index 8986430141be4..9d216ef3dd849 100644 --- a/compiler/rustc_ast/src/ast.rs +++ b/compiler/rustc_ast/src/ast.rs @@ -2469,6 +2469,8 @@ pub enum TyPatKind { /// A range pattern (e.g., `1...2`, `1..2`, `1..`, `..2`, `1..=2`, `..=2`). Range(Option>, Option>, Spanned), + Or(ThinVec>), + /// Placeholder for a pattern that wasn't syntactically well formed in some way. Err(ErrorGuaranteed), } diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 6aae2e481a59f..cd2293423dbb5 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -612,6 +612,7 @@ pub fn walk_ty_pat(vis: &mut T, ty: &mut P) { visit_opt(start, |c| vis.visit_anon_const(c)); visit_opt(end, |c| vis.visit_anon_const(c)); } + TyPatKind::Or(variants) => visit_thin_vec(variants, |p| vis.visit_ty_pat(p)), TyPatKind::Err(_) => {} } visit_lazy_tts(vis, tokens); diff --git a/compiler/rustc_ast/src/visit.rs b/compiler/rustc_ast/src/visit.rs index 79193fcec63a4..69a186c8cf1b7 100644 --- a/compiler/rustc_ast/src/visit.rs +++ b/compiler/rustc_ast/src/visit.rs @@ -608,6 +608,7 @@ pub fn walk_ty_pat<'a, V: Visitor<'a>>(visitor: &mut V, tp: &'a TyPat) -> V::Res visit_opt!(visitor, visit_anon_const, start); visit_opt!(visitor, visit_anon_const, end); } + TyPatKind::Or(variants) => walk_list!(visitor, visit_ty_pat, variants), TyPatKind::Err(_) => {} } V::Result::output() diff --git a/compiler/rustc_ast_lowering/src/pat.rs b/compiler/rustc_ast_lowering/src/pat.rs index f94d788a9b0e6..4a6929ef0117a 100644 --- a/compiler/rustc_ast_lowering/src/pat.rs +++ b/compiler/rustc_ast_lowering/src/pat.rs @@ -464,6 +464,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> { ) }), ), + TyPatKind::Or(variants) => { + hir::TyPatKind::Or(self.arena.alloc_from_iter( + variants.iter().map(|pat| self.lower_ty_pat_mut(pat, base_type)), + )) + } TyPatKind::Err(guar) => hir::TyPatKind::Err(*guar), }; diff --git a/compiler/rustc_ast_pretty/src/pprust/state.rs b/compiler/rustc_ast_pretty/src/pprust/state.rs index 62a50c73855fc..9411237421dbc 100644 --- a/compiler/rustc_ast_pretty/src/pprust/state.rs +++ b/compiler/rustc_ast_pretty/src/pprust/state.rs @@ -1162,6 +1162,17 @@ impl<'a> State<'a> { self.print_expr_anon_const(end, &[]); } } + rustc_ast::TyPatKind::Or(variants) => { + let mut first = true; + for pat in variants { + if first { + first = false + } else { + self.word(" | "); + } + self.print_ty_pat(pat); + } + } rustc_ast::TyPatKind::Err(_) => { self.popen(); self.word("/*ERROR*/"); diff --git a/compiler/rustc_builtin_macros/src/pattern_type.rs b/compiler/rustc_builtin_macros/src/pattern_type.rs index aaf5b2336518f..3529e5525fcd2 100644 --- a/compiler/rustc_builtin_macros/src/pattern_type.rs +++ b/compiler/rustc_builtin_macros/src/pattern_type.rs @@ -4,6 +4,7 @@ use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token}; use rustc_errors::PResult; use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult}; use rustc_parse::exp; +use rustc_parse::parser::{CommaRecoveryMode, RecoverColon, RecoverComma}; use rustc_span::Span; pub(crate) fn expand<'cx>( @@ -27,7 +28,17 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P let ty = parser.parse_ty()?; parser.expect_keyword(exp!(Is))?; - let pat = pat_to_ty_pat(cx, parser.parse_pat_no_top_alt(None, None)?.into_inner()); + let pat = pat_to_ty_pat( + cx, + parser + .parse_pat_no_top_guard( + None, + RecoverComma::No, + RecoverColon::No, + CommaRecoveryMode::EitherTupleOrPipe, + )? + .into_inner(), + ); if parser.token != token::Eof { parser.unexpected()?; @@ -47,6 +58,9 @@ fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> P { end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })), include_end, ), + ast::PatKind::Or(variants) => TyPatKind::Or( + variants.into_iter().map(|pat| pat_to_ty_pat(cx, pat.into_inner())).collect(), + ), ast::PatKind::Err(guar) => TyPatKind::Err(guar), _ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")), }; diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 40c63f2b250f5..97d066ffe3fac 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -61,16 +61,21 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>( ensure_monomorphic_enough(tcx, tp_ty)?; ConstValue::from_u128(tcx.type_id_hash(tp_ty).as_u128()) } - sym::variant_count => match tp_ty.kind() { + sym::variant_count => match match tp_ty.kind() { + // Pattern types have the same number of variants as their base type. + // Even if we restrict e.g. which variants are valid, the variants are essentially just uninhabited. + // And `Result<(), !>` still has two variants according to `variant_count`. + ty::Pat(base, _) => *base, + _ => tp_ty, + } + .kind() + { // Correctly handles non-monomorphic calls, so there is no need for ensure_monomorphic_enough. ty::Adt(adt, _) => ConstValue::from_target_usize(adt.variants().len() as u64, &tcx), ty::Alias(..) | ty::Param(_) | ty::Placeholder(_) | ty::Infer(_) => { throw_inval!(TooGeneric) } - ty::Pat(_, pat) => match **pat { - ty::PatternKind::Range { .. } => ConstValue::from_target_usize(0u64, &tcx), - // Future pattern kinds may have more variants - }, + ty::Pat(..) => unreachable!(), ty::Bound(_, _) => bug!("bound ty during ctfe"), ty::Bool | ty::Char diff --git a/compiler/rustc_const_eval/src/interpret/validity.rs b/compiler/rustc_const_eval/src/interpret/validity.rs index fb7ba6d7ef57e..c86af5a9a4b14 100644 --- a/compiler/rustc_const_eval/src/interpret/validity.rs +++ b/compiler/rustc_const_eval/src/interpret/validity.rs @@ -1248,6 +1248,14 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt, // Range patterns are precisely reflected into `valid_range` and thus // handled fully by `visit_scalar` (called below). ty::PatternKind::Range { .. } => {}, + + // FIXME(pattern_types): check that the value is covered by one of the variants. + // For now, we rely on layout computation setting the scalar's `valid_range` to + // match the pattern. However, this cannot always work; the layout may + // pessimistically cover actually illegal ranges and Miri would miss that UB. + // The consolation here is that codegen also will miss that UB, so at least + // we won't see optimizations actually breaking such programs. + ty::PatternKind::Or(_patterns) => {} } } _ => { diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 2f8a853424789..af587ee5bdcdd 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -1813,6 +1813,9 @@ pub enum TyPatKind<'hir> { /// A range pattern (e.g., `1..=2` or `1..2`). Range(&'hir ConstArg<'hir>, &'hir ConstArg<'hir>), + /// A list of patterns where only one needs to be satisfied + Or(&'hir [TyPat<'hir>]), + /// A placeholder for a pattern that wasn't well formed in some way. Err(ErrorGuaranteed), } diff --git a/compiler/rustc_hir/src/intravisit.rs b/compiler/rustc_hir/src/intravisit.rs index 3c2897ef1d953..a60de4b1fc31b 100644 --- a/compiler/rustc_hir/src/intravisit.rs +++ b/compiler/rustc_hir/src/intravisit.rs @@ -710,6 +710,7 @@ pub fn walk_ty_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v TyPat<'v>) try_visit!(visitor.visit_const_arg_unambig(lower_bound)); try_visit!(visitor.visit_const_arg_unambig(upper_bound)); } + TyPatKind::Or(patterns) => walk_list!(visitor, visit_pattern_type_pattern, patterns), TyPatKind::Err(_) => (), } V::Result::output() diff --git a/compiler/rustc_hir_analysis/src/collect/type_of.rs b/compiler/rustc_hir_analysis/src/collect/type_of.rs index 694c12288596f..c20b14df7704b 100644 --- a/compiler/rustc_hir_analysis/src/collect/type_of.rs +++ b/compiler/rustc_hir_analysis/src/collect/type_of.rs @@ -94,10 +94,12 @@ fn const_arg_anon_type_of<'tcx>(icx: &ItemCtxt<'tcx>, arg_hir_id: HirId, span: S } Node::TyPat(pat) => { - let hir::TyKind::Pat(ty, p) = tcx.parent_hir_node(pat.hir_id).expect_ty().kind else { - bug!() + let node = match tcx.parent_hir_node(pat.hir_id) { + // Or patterns can be nested one level deep + Node::TyPat(p) => tcx.parent_hir_node(p.hir_id), + other => other, }; - assert_eq!(p.hir_id, pat.hir_id); + let hir::TyKind::Pat(ty, _) = node.expect_ty().kind else { bug!() }; icx.lower_ty(ty) } diff --git a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs index 6e1e650a8177f..53792c7b09312 100644 --- a/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs +++ b/compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs @@ -2735,6 +2735,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ { ty_span: Span, pat: &hir::TyPat<'tcx>, ) -> Result, ErrorGuaranteed> { + let tcx = self.tcx(); match pat.kind { hir::TyPatKind::Range(start, end) => { match ty.kind() { @@ -2750,6 +2751,13 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ { .span_delayed_bug(ty_span, "invalid base type for range pattern")), } } + hir::TyPatKind::Or(patterns) => { + self.tcx() + .mk_patterns_from_iter(patterns.iter().map(|pat| { + self.lower_pat_ty_pat(ty, ty_span, pat).map(|pat| tcx.mk_pat(pat)) + })) + .map(ty::PatternKind::Or) + } hir::TyPatKind::Err(e) => Err(e), } } diff --git a/compiler/rustc_hir_analysis/src/variance/constraints.rs b/compiler/rustc_hir_analysis/src/variance/constraints.rs index 8123326a47f39..ff73d163f84da 100644 --- a/compiler/rustc_hir_analysis/src/variance/constraints.rs +++ b/compiler/rustc_hir_analysis/src/variance/constraints.rs @@ -340,6 +340,11 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> { self.add_constraints_from_const(current, start, variance); self.add_constraints_from_const(current, end, variance); } + ty::PatternKind::Or(patterns) => { + for pat in patterns { + self.add_constraints_from_pat(current, variance, pat) + } + } } } diff --git a/compiler/rustc_hir_pretty/src/lib.rs b/compiler/rustc_hir_pretty/src/lib.rs index c95d6a277c71f..37c7e613b2c9f 100644 --- a/compiler/rustc_hir_pretty/src/lib.rs +++ b/compiler/rustc_hir_pretty/src/lib.rs @@ -1866,6 +1866,19 @@ impl<'a> State<'a> { self.word("..="); self.print_const_arg(end); } + TyPatKind::Or(patterns) => { + self.popen(); + let mut first = true; + for pat in patterns { + if first { + first = false; + } else { + self.word(" | "); + } + self.print_ty_pat(pat); + } + self.pclose(); + } TyPatKind::Err(_) => { self.popen(); self.word("/*ERROR*/"); diff --git a/compiler/rustc_lint/src/types.rs b/compiler/rustc_lint/src/types.rs index 2511d3f6b163d..42194bac3dc71 100644 --- a/compiler/rustc_lint/src/types.rs +++ b/compiler/rustc_lint/src/types.rs @@ -900,6 +900,9 @@ fn pat_ty_is_known_nonnull<'tcx>( // to ensure we aren't wrapping over zero. start > 0 && end >= start } + ty::PatternKind::Or(patterns) => { + patterns.iter().all(|pat| pat_ty_is_known_nonnull(tcx, typing_env, pat)) + } } }, ) @@ -1046,13 +1049,29 @@ pub(crate) fn repr_nullable_ptr<'tcx>( } None } - ty::Pat(base, pat) => match **pat { - ty::PatternKind::Range { .. } => get_nullable_type(tcx, typing_env, *base), - }, + ty::Pat(base, pat) => get_nullable_type_from_pat(tcx, typing_env, *base, *pat), _ => None, } } +fn get_nullable_type_from_pat<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + base: Ty<'tcx>, + pat: ty::Pattern<'tcx>, +) -> Option> { + match *pat { + ty::PatternKind::Range { .. } => get_nullable_type(tcx, typing_env, base), + ty::PatternKind::Or(patterns) => { + let first = get_nullable_type_from_pat(tcx, typing_env, base, patterns[0])?; + for &pat in &patterns[1..] { + assert_eq!(first, get_nullable_type_from_pat(tcx, typing_env, base, pat)?); + } + Some(first) + } + } +} + impl<'a, 'tcx> ImproperCTypesVisitor<'a, 'tcx> { /// Check if the type is array and emit an unsafe type lint. fn check_for_array_ty(&mut self, sp: Span, ty: Ty<'tcx>) -> bool { diff --git a/compiler/rustc_middle/src/ty/codec.rs b/compiler/rustc_middle/src/ty/codec.rs index 23927c112bcd4..5ff87959a800d 100644 --- a/compiler/rustc_middle/src/ty/codec.rs +++ b/compiler/rustc_middle/src/ty/codec.rs @@ -442,6 +442,15 @@ impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List> RefDecodable<'tcx, D> for ty::List> { + fn decode(decoder: &mut D) -> &'tcx Self { + let len = decoder.read_usize(); + decoder.interner().mk_patterns_from_iter( + (0..len).map::, _>(|_| Decodable::decode(decoder)), + ) + } +} + impl<'tcx, D: TyDecoder<'tcx>> RefDecodable<'tcx, D> for ty::List> { fn decode(decoder: &mut D) -> &'tcx Self { let len = decoder.read_usize(); @@ -503,6 +512,7 @@ impl_decodable_via_ref! { &'tcx mir::Body<'tcx>, &'tcx mir::ConcreteOpaqueTypes<'tcx>, &'tcx ty::List, + &'tcx ty::List>, &'tcx ty::ListWithCachedTypeInfo>, &'tcx ty::List, &'tcx ty::List<(VariantIdx, FieldIdx)>, diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index e8dad1e056cbd..f97e13f13b006 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -136,6 +136,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> { type AllocId = crate::mir::interpret::AllocId; type Pat = Pattern<'tcx>; + type PatList = &'tcx List>; type Safety = hir::Safety; type Abi = ExternAbi; type Const = ty::Const<'tcx>; @@ -843,6 +844,7 @@ pub struct CtxtInterners<'tcx> { captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>, offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>, valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>, + patterns: InternedSet<'tcx, List>>, } impl<'tcx> CtxtInterners<'tcx> { @@ -879,6 +881,7 @@ impl<'tcx> CtxtInterners<'tcx> { captures: InternedSet::with_capacity(N), offset_of: InternedSet::with_capacity(N), valtree: InternedSet::with_capacity(N), + patterns: InternedSet::with_capacity(N), } } @@ -2659,6 +2662,7 @@ slice_interners!( local_def_ids: intern_local_def_ids(LocalDefId), captures: intern_captures(&'tcx ty::CapturedPlace<'tcx>), offset_of: pub mk_offset_of((VariantIdx, FieldIdx)), + patterns: pub mk_patterns(Pattern<'tcx>), ); impl<'tcx> TyCtxt<'tcx> { @@ -2932,6 +2936,14 @@ impl<'tcx> TyCtxt<'tcx> { self.intern_local_def_ids(def_ids) } + pub fn mk_patterns_from_iter(self, iter: I) -> T::Output + where + I: Iterator, + T: CollectAndApply, &'tcx List>>, + { + T::collect_and_apply(iter, |xs| self.mk_patterns(xs)) + } + pub fn mk_local_def_ids_from_iter(self, iter: I) -> T::Output where I: Iterator, diff --git a/compiler/rustc_middle/src/ty/pattern.rs b/compiler/rustc_middle/src/ty/pattern.rs index 758adc42e3ebb..5af9b17dd7777 100644 --- a/compiler/rustc_middle/src/ty/pattern.rs +++ b/compiler/rustc_middle/src/ty/pattern.rs @@ -23,6 +23,13 @@ impl<'tcx> Flags for Pattern<'tcx> { FlagComputation::for_const_kind(&start.kind()).flags | FlagComputation::for_const_kind(&end.kind()).flags } + ty::PatternKind::Or(pats) => { + let mut flags = pats[0].flags(); + for pat in pats[1..].iter() { + flags |= pat.flags(); + } + flags + } } } @@ -31,6 +38,13 @@ impl<'tcx> Flags for Pattern<'tcx> { ty::PatternKind::Range { start, end } => { start.outer_exclusive_binder().max(end.outer_exclusive_binder()) } + ty::PatternKind::Or(pats) => { + let mut idx = pats[0].outer_exclusive_binder(); + for pat in pats[1..].iter() { + idx = idx.max(pat.outer_exclusive_binder()); + } + idx + } } } } @@ -77,6 +91,19 @@ impl<'tcx> IrPrint> for TyCtxt<'tcx> { write!(f, "..={end}") } + PatternKind::Or(patterns) => { + write!(f, "(")?; + let mut first = true; + for pat in patterns { + if first { + first = false + } else { + write!(f, " | ")?; + } + write!(f, "{pat:?}")?; + } + write!(f, ")") + } } } diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs index c3ee72bcaed33..6ad4e5276b253 100644 --- a/compiler/rustc_middle/src/ty/relate.rs +++ b/compiler/rustc_middle/src/ty/relate.rs @@ -59,6 +59,15 @@ impl<'tcx> Relate> for ty::Pattern<'tcx> { let end = relation.relate(end_a, end_b)?; Ok(tcx.mk_pat(ty::PatternKind::Range { start, end })) } + (&ty::PatternKind::Or(a), &ty::PatternKind::Or(b)) => { + if a.len() != b.len() { + return Err(TypeError::Mismatch); + } + let v = iter::zip(a, b).map(|(a, b)| relation.relate(a, b)); + let patterns = tcx.mk_patterns_from_iter(v)?; + Ok(tcx.mk_pat(ty::PatternKind::Or(patterns))) + } + (ty::PatternKind::Range { .. } | ty::PatternKind::Or(_), _) => Err(TypeError::Mismatch), } } } diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index 26861666c1db2..2fcb2a1572aed 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -779,5 +779,6 @@ list_fold! { ty::Clauses<'tcx> : mk_clauses, &'tcx ty::List> : mk_poly_existential_predicates, &'tcx ty::List> : mk_place_elems, + &'tcx ty::List> : mk_patterns, CanonicalVarInfos<'tcx> : mk_canonical_var_infos, } diff --git a/compiler/rustc_resolve/src/late.rs b/compiler/rustc_resolve/src/late.rs index bae2fdeecafcc..faee0e7dd5ff9 100644 --- a/compiler/rustc_resolve/src/late.rs +++ b/compiler/rustc_resolve/src/late.rs @@ -958,6 +958,11 @@ impl<'ra: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'_, 'ast, 'r self.resolve_anon_const(end, AnonConstKind::ConstArg(IsRepeatExpr::No)); } } + TyPatKind::Or(patterns) => { + for pat in patterns { + self.visit_ty_pat(pat) + } + } TyPatKind::Err(_) => {} } } diff --git a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs index c0ed3b90eb416..1c33101da357b 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/ty.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/ty.rs @@ -412,6 +412,7 @@ impl<'tcx> Stable<'tcx> for ty::Pattern<'tcx> { end: Some(end.stable(tables)), include_end: true, }, + ty::PatternKind::Or(_) => todo!(), } } } diff --git a/compiler/rustc_symbol_mangling/src/v0.rs b/compiler/rustc_symbol_mangling/src/v0.rs index 6cde28d0ee90d..cc608736b672b 100644 --- a/compiler/rustc_symbol_mangling/src/v0.rs +++ b/compiler/rustc_symbol_mangling/src/v0.rs @@ -256,6 +256,11 @@ impl<'tcx> SymbolMangler<'tcx> { Ty::new_array_with_const_len(self.tcx, self.tcx.types.unit, ct).print(self)?; } } + ty::PatternKind::Or(patterns) => { + for pat in patterns { + self.print_pat(pat)?; + } + } }) } } diff --git a/compiler/rustc_trait_selection/src/traits/wf.rs b/compiler/rustc_trait_selection/src/traits/wf.rs index 0a6b3b2999085..6e85d3b5b1363 100644 --- a/compiler/rustc_trait_selection/src/traits/wf.rs +++ b/compiler/rustc_trait_selection/src/traits/wf.rs @@ -695,6 +695,11 @@ impl<'a, 'tcx> WfPredicates<'a, 'tcx> { check(start); check(end); } + ty::PatternKind::Or(patterns) => { + for pat in patterns { + self.add_wf_preds_for_pat_ty(base_ty, pat) + } + } } } } diff --git a/compiler/rustc_ty_utils/src/layout.rs b/compiler/rustc_ty_utils/src/layout.rs index 1915ba623cb2f..b962979bec71c 100644 --- a/compiler/rustc_ty_utils/src/layout.rs +++ b/compiler/rustc_ty_utils/src/layout.rs @@ -255,13 +255,95 @@ fn layout_of_uncached<'tcx>( }; layout.largest_niche = Some(niche); - - tcx.mk_layout(layout) } else { bug!("pattern type with range but not scalar layout: {ty:?}, {layout:?}") } } + ty::PatternKind::Or(variants) => match *variants[0] { + ty::PatternKind::Range { .. } => { + if let BackendRepr::Scalar(scalar) = &mut layout.backend_repr { + let variants: Result, _> = variants + .iter() + .map(|pat| match *pat { + ty::PatternKind::Range { start, end } => Ok(( + extract_const_value(cx, ty, start) + .unwrap() + .try_to_bits(tcx, cx.typing_env) + .ok_or_else(|| error(cx, LayoutError::Unknown(ty)))?, + extract_const_value(cx, ty, end) + .unwrap() + .try_to_bits(tcx, cx.typing_env) + .ok_or_else(|| error(cx, LayoutError::Unknown(ty)))?, + )), + ty::PatternKind::Or(_) => { + unreachable!("mixed or patterns are not allowed") + } + }) + .collect(); + let mut variants = variants?; + if !scalar.is_signed() { + let guar = tcx.dcx().err(format!( + "only signed integer base types are allowed for or-pattern pattern types at present" + )); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + variants.sort(); + if variants.len() != 2 { + let guar = tcx + .dcx() + .err(format!("the only or-pattern types allowed are two range patterns that are directly connected at their overflow site")); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + + // first is the one starting at the signed in range min + let mut first = variants[0]; + let mut second = variants[1]; + if second.0 + == layout.size.truncate(layout.size.signed_int_min() as u128) + { + (second, first) = (first, second); + } + + if layout.size.sign_extend(first.1) >= layout.size.sign_extend(second.0) + { + let guar = tcx.dcx().err(format!( + "only non-overlapping pattern type ranges are allowed at present" + )); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + if layout.size.signed_int_max() as u128 != second.1 { + let guar = tcx.dcx().err(format!( + "one pattern needs to end at `{ty}::MAX`, but was {} instead", + second.1 + )); + + return Err(error(cx, LayoutError::ReferencesError(guar))); + } + + // Now generate a wrapping range (which aren't allowed in surface syntax). + scalar.valid_range_mut().start = second.0; + scalar.valid_range_mut().end = first.1; + + let niche = Niche { + offset: Size::ZERO, + value: scalar.primitive(), + valid_range: scalar.valid_range(cx), + }; + + layout.largest_niche = Some(niche); + } else { + bug!( + "pattern type with range but not scalar layout: {ty:?}, {layout:?}" + ) + } + } + ty::PatternKind::Or(..) => bug!("patterns cannot have nested or patterns"), + }, } + tcx.mk_layout(layout) } // Basic scalars. diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 9758cecaf6ac7..6410da1f7409f 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -113,6 +113,13 @@ pub trait Interner: + Relate + Flags + IntoKind>; + type PatList: Copy + + Debug + + Hash + + Default + + Eq + + TypeVisitable + + SliceLike; type Safety: Safety; type Abi: Abi; diff --git a/compiler/rustc_type_ir/src/pattern.rs b/compiler/rustc_type_ir/src/pattern.rs index d74a82da1f92a..7e56565917c67 100644 --- a/compiler/rustc_type_ir/src/pattern.rs +++ b/compiler/rustc_type_ir/src/pattern.rs @@ -13,4 +13,5 @@ use crate::Interner; )] pub enum PatternKind { Range { start: I::Const, end: I::Const }, + Or(I::PatList), } diff --git a/compiler/rustc_type_ir/src/walk.rs b/compiler/rustc_type_ir/src/walk.rs index ebfb3f786e83a..737550eb73e99 100644 --- a/compiler/rustc_type_ir/src/walk.rs +++ b/compiler/rustc_type_ir/src/walk.rs @@ -173,5 +173,10 @@ fn push_ty_pat(stack: &mut TypeWalkerStack, pat: I::Pat) { stack.push(end.into()); stack.push(start.into()); } + ty::PatternKind::Or(pats) => { + for pat in pats.iter() { + push_ty_pat::(stack, pat) + } + } } } diff --git a/src/tools/clippy/clippy_utils/src/hir_utils.rs b/src/tools/clippy/clippy_utils/src/hir_utils.rs index fe1fd70a9fa78..17368a7530d75 100644 --- a/src/tools/clippy/clippy_utils/src/hir_utils.rs +++ b/src/tools/clippy/clippy_utils/src/hir_utils.rs @@ -1117,6 +1117,11 @@ impl<'a, 'tcx> SpanlessHash<'a, 'tcx> { self.hash_const_arg(s); self.hash_const_arg(e); }, + TyPatKind::Or(variants) => { + for variant in variants.iter() { + self.hash_ty_pat(variant) + } + }, TyPatKind::Err(_) => {}, } } diff --git a/src/tools/rustfmt/src/types.rs b/src/tools/rustfmt/src/types.rs index 75a5a8532b84f..7ec1032dcb421 100644 --- a/src/tools/rustfmt/src/types.rs +++ b/src/tools/rustfmt/src/types.rs @@ -1093,6 +1093,19 @@ impl Rewrite for ast::TyPat { ast::TyPatKind::Range(ref lhs, ref rhs, ref end_kind) => { rewrite_range_pat(context, shape, lhs, rhs, end_kind, self.span) } + ast::TyPatKind::Or(ref variants) => { + let mut first = true; + let mut s = String::new(); + for variant in variants { + if first { + first = false + } else { + s.push_str(" | "); + } + s.push_str(&variant.rewrite_result(context, shape)?); + } + Ok(s) + } ast::TyPatKind::Err(_) => Err(RewriteError::Unknown), } } diff --git a/tests/ui/type/pattern_types/or_patterns.rs b/tests/ui/type/pattern_types/or_patterns.rs new file mode 100644 index 0000000000000..25cb1867047aa --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns.rs @@ -0,0 +1,45 @@ +//! Demonstrate some use cases of or patterns + +//@ normalize-stderr: "pref: Align\([1-8] bytes\)" -> "pref: $$SOME_ALIGN" +//@ normalize-stderr: "randomization_seed: \d+" -> "randomization_seed: $$SEED" + +#![feature( + pattern_type_macro, + pattern_types, + rustc_attrs, + const_trait_impl, + pattern_type_range_trait +)] + +use std::pat::pattern_type; + +#[rustc_layout(debug)] +type NonNullI8 = pattern_type!(i8 is ..0 | 1..); +//~^ ERROR: layout_of + +#[rustc_layout(debug)] +type NonNegOneI8 = pattern_type!(i8 is ..-1 | 0..); +//~^ ERROR: layout_of + +fn main() { + let _: NonNullI8 = 42; + let _: NonNullI8 = 1; + let _: NonNullI8 = 0; + //~^ ERROR: mismatched types + let _: NonNullI8 = -1; + //~^ ERROR: cannot apply unary operator + let _: NonNullI8 = -128; + //~^ ERROR: cannot apply unary operator + let _: NonNullI8 = 127; + + let _: NonNegOneI8 = 42; + let _: NonNegOneI8 = 1; + let _: NonNegOneI8 = 0; + let _: NonNegOneI8 = -1; + //~^ ERROR: cannot apply unary operator + let _: NonNegOneI8 = -2; + //~^ ERROR: cannot apply unary operator + let _: NonNegOneI8 = -128; + //~^ ERROR: cannot apply unary operator + let _: NonNegOneI8 = 127; +} diff --git a/tests/ui/type/pattern_types/or_patterns.stderr b/tests/ui/type/pattern_types/or_patterns.stderr new file mode 100644 index 0000000000000..58ca585f4a9a3 --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns.stderr @@ -0,0 +1,123 @@ +error[E0308]: mismatched types + --> $DIR/or_patterns.rs:27:24 + | +LL | let _: NonNullI8 = 0; + | --------- ^ expected `(i8) is (i8::MIN..=-1 | 1..)`, found integer + | | + | expected due to this + | + = note: expected pattern type `(i8) is (i8::MIN..=-1 | 1..)` + found type `{integer}` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-1 | 1..)` + --> $DIR/or_patterns.rs:29:24 + | +LL | let _: NonNullI8 = -1; + | ^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-1 | 1..)` + --> $DIR/or_patterns.rs:31:24 + | +LL | let _: NonNullI8 = -128; + | ^^^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-2 | 0..)` + --> $DIR/or_patterns.rs:38:26 + | +LL | let _: NonNegOneI8 = -1; + | ^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-2 | 0..)` + --> $DIR/or_patterns.rs:40:26 + | +LL | let _: NonNegOneI8 = -2; + | ^^ cannot apply unary operator `-` + +error[E0600]: cannot apply unary operator `-` to type `(i8) is (i8::MIN..=-2 | 0..)` + --> $DIR/or_patterns.rs:42:26 + | +LL | let _: NonNegOneI8 = -128; + | ^^^^ cannot apply unary operator `-` + +error: layout_of((i8) is (i8::MIN..=-1 | 1..)) = Layout { + size: Size(1 bytes), + align: AbiAndPrefAlign { + abi: Align(1 bytes), + pref: $SOME_ALIGN, + }, + backend_repr: Scalar( + Initialized { + value: Int( + I8, + true, + ), + valid_range: 1..=255, + }, + ), + fields: Primitive, + largest_niche: Some( + Niche { + offset: Size(0 bytes), + value: Int( + I8, + true, + ), + valid_range: 1..=255, + }, + ), + uninhabited: false, + variants: Single { + index: 0, + }, + max_repr_align: None, + unadjusted_abi_align: Align(1 bytes), + randomization_seed: $SEED, + } + --> $DIR/or_patterns.rs:17:1 + | +LL | type NonNullI8 = pattern_type!(i8 is ..0 | 1..); + | ^^^^^^^^^^^^^^ + +error: layout_of((i8) is (i8::MIN..=-2 | 0..)) = Layout { + size: Size(1 bytes), + align: AbiAndPrefAlign { + abi: Align(1 bytes), + pref: $SOME_ALIGN, + }, + backend_repr: Scalar( + Initialized { + value: Int( + I8, + true, + ), + valid_range: 0..=254, + }, + ), + fields: Primitive, + largest_niche: Some( + Niche { + offset: Size(0 bytes), + value: Int( + I8, + true, + ), + valid_range: 0..=254, + }, + ), + uninhabited: false, + variants: Single { + index: 0, + }, + max_repr_align: None, + unadjusted_abi_align: Align(1 bytes), + randomization_seed: $SEED, + } + --> $DIR/or_patterns.rs:21:1 + | +LL | type NonNegOneI8 = pattern_type!(i8 is ..-1 | 0..); + | ^^^^^^^^^^^^^^^^ + +error: aborting due to 8 previous errors + +Some errors have detailed explanations: E0308, E0600. +For more information about an error, try `rustc --explain E0308`. diff --git a/tests/ui/type/pattern_types/or_patterns_invalid.rs b/tests/ui/type/pattern_types/or_patterns_invalid.rs new file mode 100644 index 0000000000000..d341927601d50 --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns_invalid.rs @@ -0,0 +1,26 @@ +//! Demonstrate some use cases of or patterns + +#![feature( + pattern_type_macro, + pattern_types, + rustc_attrs, + const_trait_impl, + pattern_type_range_trait +)] + +use std::pat::pattern_type; + +fn main() { + //~? ERROR: only non-overlapping pattern type ranges are allowed at present + let not_adjacent: pattern_type!(i8 is -127..0 | 1..) = unsafe { std::mem::transmute(0) }; + + //~? ERROR: one pattern needs to end at `i8::MAX`, but was 29 instead + let not_wrapping: pattern_type!(i8 is 10..20 | 20..30) = unsafe { std::mem::transmute(0) }; + + //~? ERROR: only signed integer base types are allowed for or-pattern pattern types + let not_signed: pattern_type!(u8 is 10.. | 0..5) = unsafe { std::mem::transmute(0) }; + + //~? ERROR: allowed are two range patterns that are directly connected + let not_simple_enough_for_mvp: pattern_type!(i8 is ..0 | 1..10 | 10..) = + unsafe { std::mem::transmute(0) }; +} diff --git a/tests/ui/type/pattern_types/or_patterns_invalid.stderr b/tests/ui/type/pattern_types/or_patterns_invalid.stderr new file mode 100644 index 0000000000000..6964788a6c245 --- /dev/null +++ b/tests/ui/type/pattern_types/or_patterns_invalid.stderr @@ -0,0 +1,10 @@ +error: only non-overlapping pattern type ranges are allowed at present + +error: one pattern needs to end at `i8::MAX`, but was 29 instead + +error: only signed integer base types are allowed for or-pattern pattern types at present + +error: the only or-pattern types allowed are two range patterns that are directly connected at their overflow site + +error: aborting due to 4 previous errors +