diff --git a/compiler/formatter/src/format.rs b/compiler/formatter/src/format.rs index 4ba025454..cba3180e2 100644 --- a/compiler/formatter/src/format.rs +++ b/compiler/formatter/src/format.rs @@ -742,6 +742,7 @@ pub fn format_cst<'a>( } CstKind::MatchCase { pattern, + condition: _, // TODO: format match case conditions arrow, body, } => { diff --git a/compiler/frontend/src/ast.rs b/compiler/frontend/src/ast.rs index 18da8c865..2fe9e3179 100644 --- a/compiler/frontend/src/ast.rs +++ b/compiler/frontend/src/ast.rs @@ -131,6 +131,7 @@ pub struct Match { #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct MatchCase { pub pattern: Box, + pub condition: Option>, pub body: Vec, } #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -248,7 +249,14 @@ impl FindAst for Match { } impl FindAst for MatchCase { fn find(&self, id: &Id) -> Option<&Ast> { - self.pattern.find(id).or_else(|| self.body.find(id)) + self.pattern + .find(id) + .or_else(|| { + self.condition + .as_ref() + .and_then(|box condition| condition.find(id)) + }) + .or_else(|| self.body.find(id)) } } impl FindAst for OrPattern { @@ -352,8 +360,15 @@ impl CollectErrors for Ast { expression.collect_errors(errors); cases.collect_errors(errors); } - AstKind::MatchCase(MatchCase { pattern, body }) => { + AstKind::MatchCase(MatchCase { + pattern, + condition, + body, + }) => { pattern.collect_errors(errors); + if let Some(box condition) = condition { + condition.collect_errors(errors); + } body.collect_errors(errors); } AstKind::OrPattern(OrPattern(patterns)) => { diff --git a/compiler/frontend/src/ast_to_hir.rs b/compiler/frontend/src/ast_to_hir.rs index 88be6923d..3a78ffdb8 100644 --- a/compiler/frontend/src/ast_to_hir.rs +++ b/compiler/frontend/src/ast_to_hir.rs @@ -385,34 +385,65 @@ impl Context<'_> { cases .iter() .map(|case| match &case.kind { - AstKind::MatchCase(MatchCase { box pattern, body }) => { + AstKind::MatchCase(MatchCase { + box pattern, + condition, + body, + }) => { let (pattern, pattern_identifiers) = scope.lower_pattern(pattern); - let (body, ()) = scope.with_scope(None, |scope| { - for (name, (ast_id, identifier_id)) in - pattern_identifiers.clone() - { - scope.push( - ast_id, - Expression::PatternIdentifierReference(identifier_id), - name.clone(), - ); - } - scope.compile(body.as_ref()); - }); - - (pattern, body) + let (outer_body, (condition_body, body)) = + scope.with_scope(None, |scope| { + for (name, (ast_id, identifier_id)) in + pattern_identifiers.clone() + { + scope.push( + ast_id, + Expression::PatternIdentifierReference( + identifier_id, + ), + name.clone(), + ); + } + + let condition = condition.as_ref().map(|box condition| { + scope + .with_scope(match_id.clone(), |condition_scope| { + condition_scope.compile_single(condition); + }) + .0 + }); + + let (body, ()) = + scope.with_scope(match_id.clone(), |scope| { + scope.compile(body); + }); + + (condition, body) + }); + + hir::MatchCase { + pattern, + identifier_expressions: outer_body, + condition: condition_body, + body, + } } AstKind::Error { errors } => { let pattern = Pattern::Error { errors: errors.clone(), }; - let (body, ()) = scope.with_scope(None, |scope| { - scope.compile(&[]); + let (body, inner_body) = scope.with_scope(None, |scope| { + scope.with_scope(match_id.clone(), |scope| {}).0 }); - (pattern, body) + hir::MatchCase { + pattern, + identifier_expressions: body, + condition: None, + body: inner_body, + } } _ => unreachable!("Expected match case in match cases, got {case:?}."), }) diff --git a/compiler/frontend/src/cst/error.rs b/compiler/frontend/src/cst/error.rs index 73e1eb044..196c03ea7 100644 --- a/compiler/frontend/src/cst/error.rs +++ b/compiler/frontend/src/cst/error.rs @@ -8,6 +8,7 @@ pub enum CstError { ListNotClosed, MatchCaseMissesArrow, MatchCaseMissesBody, + MatchCaseMissesCondition, MatchMissesCases, OpeningParenthesisMissesExpression, OrPatternMissesRight, diff --git a/compiler/frontend/src/cst/is_multiline.rs b/compiler/frontend/src/cst/is_multiline.rs index edebe0466..ada83953c 100644 --- a/compiler/frontend/src/cst/is_multiline.rs +++ b/compiler/frontend/src/cst/is_multiline.rs @@ -107,9 +107,17 @@ impl IsMultiline for CstKind { } => expression.is_multiline() || percent.is_multiline() || cases.is_multiline(), Self::MatchCase { pattern, + condition, arrow, body, - } => pattern.is_multiline() || arrow.is_multiline() || body.is_multiline(), + } => { + pattern.is_multiline() + || condition.as_deref().map_or(false, |(comma, condition)| { + comma.is_multiline() || condition.is_multiline() + }) + || arrow.is_multiline() + || body.is_multiline() + } Self::Function { opening_curly_brace, parameters_and_arrow, diff --git a/compiler/frontend/src/cst/kind.rs b/compiler/frontend/src/cst/kind.rs index 7087edebe..6d64f1ad3 100644 --- a/compiler/frontend/src/cst/kind.rs +++ b/compiler/frontend/src/cst/kind.rs @@ -2,7 +2,7 @@ use super::{Cst, CstData, CstError}; use crate::rich_ir::{RichIrBuilder, ToRichIr, TokenType}; use enumset::EnumSet; use num_bigint::{BigInt, BigUint}; -use std::fmt::{self, Display, Formatter}; +use std::fmt::{self, Display, Formatter, Pointer}; use strum_macros::EnumIs; #[derive(Clone, Debug, EnumIs, Eq, Hash, PartialEq)] @@ -106,6 +106,7 @@ pub enum CstKind { }, MatchCase { pattern: Box>, + condition: Option>, arrow: Box>, body: Vec>, }, @@ -130,6 +131,7 @@ pub enum IntRadix { Binary, Hexadecimal, } +pub type MatchCaseWithComma = Box<(Cst, Cst)>; pub type FunctionParametersAndArrow = (Vec>, Box>); impl CstKind { @@ -291,10 +293,14 @@ impl CstKind { } Self::MatchCase { pattern, + condition, arrow, body, } => { let mut children = vec![pattern.as_ref(), arrow.as_ref()]; + if let Some(box (comma, condition)) = condition { + children.extend([&comma, &condition]); + } children.extend(body); children } @@ -507,11 +513,16 @@ impl Display for CstKind { } Self::MatchCase { pattern, + condition, arrow, body, } => { pattern.fmt(f)?; arrow.fmt(f)?; + if let Some(box (comma, condition)) = condition { + comma.fmt(f)?; + condition.fmt(f)?; + } for expression in body { expression.fmt(f)?; } @@ -904,6 +915,7 @@ where Self::MatchCase { pattern, arrow, + condition, body, } => { builder.push_cst_kind("MatchCase", |builder| { diff --git a/compiler/frontend/src/cst/tree_with_ids.rs b/compiler/frontend/src/cst/tree_with_ids.rs index 5fb98fd97..a518e0260 100644 --- a/compiler/frontend/src/cst/tree_with_ids.rs +++ b/compiler/frontend/src/cst/tree_with_ids.rs @@ -138,10 +138,16 @@ impl TreeWithIds for Cst { .or_else(|| cases.find(id)), CstKind::MatchCase { pattern, + condition, arrow, body, } => pattern .find(id) + .or_else(|| { + condition.as_deref().and_then(|(comma, condition)| { + comma.find(id).or_else(|| condition.find(id)) + }) + }) .or_else(|| arrow.find(id)) .or_else(|| body.find(id)), CstKind::Function { @@ -329,11 +335,19 @@ impl TreeWithIds for Cst { ), CstKind::MatchCase { pattern, + condition, arrow, body, } => ( pattern .find_by_offset(offset) + .or_else(|| { + condition.as_deref().and_then(|(comma, condition)| { + comma + .find_by_offset(offset) + .or_else(|| condition.find_by_offset(offset)) + }) + }) .or_else(|| arrow.find_by_offset(offset)) .or_else(|| body.find_by_offset(offset)), false, diff --git a/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs b/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs index 374c7a497..89742ead8 100644 --- a/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs +++ b/compiler/frontend/src/cst/unwrap_whitespace_and_comment.rs @@ -144,10 +144,17 @@ impl UnwrapWhitespaceAndComment for Cst { }, CstKind::MatchCase { pattern, + condition, arrow, body, } => CstKind::MatchCase { pattern: pattern.unwrap_whitespace_and_comment(), + condition: condition.as_deref().map(|(comma, condition)| { + Box::new(( + comma.unwrap_whitespace_and_comment(), + condition.unwrap_whitespace_and_comment(), + )) + }), arrow: arrow.unwrap_whitespace_and_comment(), body: body.unwrap_whitespace_and_comment(), }, diff --git a/compiler/frontend/src/cst_to_ast.rs b/compiler/frontend/src/cst_to_ast.rs index 96d808f70..00196a8f4 100644 --- a/compiler/frontend/src/cst_to_ast.rs +++ b/compiler/frontend/src/cst_to_ast.rs @@ -584,23 +584,36 @@ impl LoweringContext { } CstKind::MatchCase { pattern, - arrow: _, + condition, + arrow, body, } => { if lowering_type != LoweringType::Expression { return self.create_ast_for_invalid_expression_in_pattern(cst); }; - + let mut errors = vec![]; let pattern = self.lower_cst(pattern, LoweringType::Pattern); - // TODO: handle error in arrow + let condition = condition + .as_ref() + .map(|box (_, condition)| self.lower_cst(condition, LoweringType::Expression)); + + if let CstKind::Error { + unparsable_input: _, + error, + } = arrow.kind + { + errors.push(self.create_error(arrow, error)); + } let body = self.lower_csts(body); - self.create_ast( - cst.data.id, + self.create_errors_or_ast( + cst, + errors, MatchCase { pattern: Box::new(pattern), + condition: condition.map(Box::new), body, }, ) diff --git a/compiler/frontend/src/error.rs b/compiler/frontend/src/error.rs index 36ee54807..295156483 100644 --- a/compiler/frontend/src/error.rs +++ b/compiler/frontend/src/error.rs @@ -63,6 +63,7 @@ impl Display for CompilerErrorPayload { CstError::MatchMissesCases => "This match misses cases to match against.", CstError::MatchCaseMissesArrow => "This match case misses an arrow.", CstError::MatchCaseMissesBody => "This match case misses a body to run.", + CstError::MatchCaseMissesCondition => "This match case condition is empty.", CstError::OpeningParenthesisMissesExpression => { "Here's an opening parenthesis without an expression after it." } diff --git a/compiler/frontend/src/hir.rs b/compiler/frontend/src/hir.rs index d76f0a726..eec5b4248 100644 --- a/compiler/frontend/src/hir.rs +++ b/compiler/frontend/src/hir.rs @@ -46,7 +46,11 @@ fn containing_body_of(db: &dyn HirDb, id: Id) -> Arc { Expression::Match { cases, .. } => { let body = cases .into_iter() - .map(|(_, body)| body) + .flat_map( + |MatchCase { + condition, body, .. + }| condition.into_iter().chain([body]), + ) .find(|body| body.expressions.contains_key(&id)) .unwrap(); Arc::new(body) @@ -91,7 +95,13 @@ impl Expression { Self::PatternIdentifierReference(_) => {} Self::Match { expression, cases } => { ids.push(expression.clone()); - for (_, body) in cases { + for MatchCase { + condition, body, .. + } in cases + { + if let Some(condition) = condition { + condition.collect_all_ids(ids); + } body.collect_all_ids(ids); } } @@ -386,7 +396,7 @@ pub enum Expression { /// Each case consists of the pattern to match against, and the body /// which starts with [PatternIdentifierReference]s for all identifiers /// in the pattern. - cases: Vec<(Pattern, Body)>, + cases: Vec, }, Function(Function), Builtin(BuiltinFunction), @@ -439,6 +449,14 @@ impl ToRichIr for PatternIdentifierId { } } +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct MatchCase { + pub pattern: Pattern, + pub identifier_expressions: Body, + pub condition: Option, + pub body: Body, +} + #[derive(Clone, PartialEq, Eq, Debug)] pub enum Pattern { NewIdentifier(PatternIdentifierId), @@ -660,16 +678,29 @@ impl ToRichIr for Expression { Self::Match { expression, cases } => { expression.build_rich_ir(builder); builder.push(" %", None, EnumSet::empty()); - builder.push_children_custom_multiline(cases, |builder, (pattern, body)| { - pattern.build_rich_ir(builder); - builder.push(" ->", None, EnumSet::empty()); - builder.push_indented_foldable(|builder| { - if !body.expressions.is_empty() { - builder.push_newline(); + builder.push_children_custom_multiline( + cases, + |builder, + MatchCase { + pattern, + condition, + body, + .. + }| { + pattern.build_rich_ir(builder); + if let Some(condition) = condition { + builder.push(", ", None, EnumSet::empty()); + condition.build_rich_ir(builder); } - body.build_rich_ir(builder); - }); - }); + builder.push(" ->", None, EnumSet::empty()); + builder.push_indented_foldable(|builder| { + if !body.expressions.is_empty() { + builder.push_newline(); + } + body.build_rich_ir(builder); + }); + }, + ); } Self::Function(function) => { builder.push( @@ -833,7 +864,16 @@ impl Expression { Self::Destructure { .. } => None, Self::PatternIdentifierReference { .. } => None, // TODO: use binary search - Self::Match { cases, .. } => cases.iter().find_map(|(_, body)| body.find(id)), + Self::Match { cases, .. } => cases.iter().find_map( + |MatchCase { + condition, body, .. + }| { + condition + .as_ref() + .and_then(|condition| condition.find(id)) + .or_else(|| body.find(id)) + }, + ), Self::Function(Function { body, .. }) => body.find(id), Self::Builtin(_) => None, Self::Call { .. } => None, @@ -873,8 +913,17 @@ impl CollectErrors for Expression { | Self::Struct(_) | Self::PatternIdentifierReference { .. } => {} Self::Match { cases, .. } => { - for (pattern, body) in cases { + for MatchCase { + pattern, + identifier_expressions: _, + condition, + body, + } in cases + { pattern.collect_errors(errors); + if let Some(condition) = condition { + condition.collect_errors(errors); + } body.collect_errors(errors); } } @@ -929,3 +978,10 @@ impl CollectErrors for Body { } } } +impl CollectErrors for Option { + fn collect_errors(&self, errors: &mut Vec) { + if let Some(this) = self { + this.collect_errors(errors); + } + } +} diff --git a/compiler/frontend/src/hir_to_mir.rs b/compiler/frontend/src/hir_to_mir.rs index c82595105..ac184ecfa 100644 --- a/compiler/frontend/src/hir_to_mir.rs +++ b/compiler/frontend/src/hir_to_mir.rs @@ -126,30 +126,12 @@ fn generate_needs_function(body: &mut BodyBuilder) -> Id { // Common stuff. let needs_code = body.push_hir_id(needs_id.clone()); let builtin_equals = body.push_builtin(BuiltinFunction::Equals); - let nothing_tag = body.push_nothing(); // Make sure the condition is a bool. - let true_tag = body.push_bool(true); - let false_tag = body.push_bool(false); - let is_condition_true = - body.push_call(builtin_equals, vec![condition, true_tag], needs_code); - let is_condition_bool = body.push_if_else( - &needs_id.child("isConditionTrue"), - is_condition_true, - |body| { - body.push_reference(true_tag); - }, - |body| { - body.push_call(builtin_equals, vec![condition, false_tag], needs_code); - }, - needs_code, - ); - body.push_if_else( + let is_condition_bool = body.push_is_bool(&needs_id, condition, needs_code); + body.push_if_not( &needs_id.child("isConditionBool"), is_condition_bool, - |body| { - body.push_reference(nothing_tag); - }, |body| { let panic_reason = body.push_text("The `condition` must be either `True` or `False`.".to_string()); @@ -167,12 +149,9 @@ fn generate_needs_function(body: &mut BodyBuilder) -> Id { vec![type_of_reason, text_tag], responsible_for_call, ); - body.push_if_else( + body.push_if_not( &needs_id.child("isReasonText"), is_reason_text, - |body| { - body.push_reference(nothing_tag); - }, |body| { let panic_reason = body.push_text("The `reason` must be a text.".to_string()); body.push_panic(panic_reason, responsible_for_call); @@ -181,12 +160,9 @@ fn generate_needs_function(body: &mut BodyBuilder) -> Id { ); // The core logic of the needs. - body.push_if_else( + body.push_if_not( &needs_id.child("condition"), condition, - |body| { - body.push_reference(nothing_tag); - }, |body| { body.push_panic(reason, responsible_for_condition); }, @@ -370,14 +346,10 @@ impl<'a> LoweringContext<'a> { is_trivial: false, }); - let nothing = body.push_nothing(); let is_match = body.push_is_match(pattern_result, responsible); - body.push_if_else( + body.push_if_not( &hir_id.child("isMatch"), is_match, - |body| { - body.push_reference(nothing); - }, |body| { let reason = body.push_text("The value doesn't match the pattern on the left side of the destructuring.".to_string()); body.push_panic(reason, responsible); @@ -391,6 +363,8 @@ impl<'a> LoweringContext<'a> { self.ongoing_destructuring.clone().unwrap(); if is_trivial { + // something % + // foo -> ... body.push_reference(result) } else { let responsible = body.push_hir_id(hir_id.clone()); @@ -629,7 +603,7 @@ impl<'a> LoweringContext<'a> { hir_id: hir::Id, body: &mut BodyBuilder, expression: Id, - cases: &[(hir::Pattern, hir::Body)], + cases: &[hir::MatchCase], responsible_for_needs: Id, responsible_for_match: Id, ) -> Id { @@ -649,7 +623,7 @@ impl<'a> LoweringContext<'a> { hir_id: hir::Id, body: &mut BodyBuilder, expression: Id, - cases: &[(hir::Pattern, hir::Body)], + cases: &[hir::MatchCase], responsible_for_needs: Id, responsible_for_match: Id, case_index: usize, @@ -660,7 +634,12 @@ impl<'a> LoweringContext<'a> { // TODO: concat reasons body.push_panic(reason, responsible_for_match) } - [(case_pattern, case_body), rest @ ..] => { + [hir::MatchCase { + pattern: case_pattern, + identifier_expressions: case_identifiers, + condition: case_condition, + body: case_body, + }, rest @ ..] => { let pattern_result = PatternLoweringContext::check_pattern( body, hir_id.clone(), @@ -668,11 +647,24 @@ impl<'a> LoweringContext<'a> { expression, case_pattern, ); + let builtin_if_else = body.push_builtin(BuiltinFunction::IfElse); - let is_match = body.push_is_match(pattern_result, responsible_for_match); - + let is_pattern_match = body.push_is_match(pattern_result, responsible_for_match); let case_id = hir_id.child(format!("case-{case_index}")); let builtin_if_else = body.push_builtin(BuiltinFunction::IfElse); + + let else_function = body.push_function(case_id.child("didNotMatch"), |body, _| { + self.compile_match_rec( + hir_id, + body, + expression, + rest, + responsible_for_needs, + responsible_for_match, + case_index + 1, + ); + }); + let then_function = body.push_function(case_id.child("matched"), |body, _| { self.ongoing_destructuring = Some(OngoingDestructuring { result: pattern_result, @@ -680,25 +672,82 @@ impl<'a> LoweringContext<'a> { }); self.compile_expressions(body, responsible_for_needs, &case_body.expressions); }); - let else_function = body.push_function(case_id.child("didNotMatch"), |body, _| { - self.compile_match_rec( - hir_id, + + let then_function = body.push_function(case_id.child("patternMatch"), |body, _| { + self.ongoing_destructuring = Some(OngoingDestructuring { + result: pattern_result, + is_trivial: false, + }); + self.compile_expressions( body, - expression, - rest, + responsible_for_needs, + &case_identifiers.expressions, + ); + + self.compile_match_case_body( + &case_id, + body, + case_condition, + case_body, + else_function, responsible_for_needs, responsible_for_match, - case_index + 1, ); }); + body.push_call( builtin_if_else, - vec![is_match, then_function, else_function], + vec![is_pattern_match, then_function, else_function], responsible_for_match, ) } } } + #[allow(clippy::too_many_arguments)] + fn compile_match_case_body( + &mut self, + case_id: &hir::Id, + body: &mut BodyBuilder, + case_condition: &Option, + case_body: &hir::Body, + else_function: Id, + responsible_for_needs: Id, + responsible_for_match: Id, + ) { + let builtin_if_else = body.push_builtin(BuiltinFunction::IfElse); + if let Some(condition) = case_condition { + self.compile_expressions(body, responsible_for_needs, &condition.expressions); + let condition_result = body.current_return_value(); + + let is_boolean = body.push_is_bool(case_id, condition_result, responsible_for_match); + body.push_if_not( + &case_id.child("conditionCheck"), + is_boolean, + |body| { + let reason_parts = [ + body.push_text("Match Condition expected boolean value, got `".to_string()), + body.push_to_debug_text(condition_result, responsible_for_match), + body.push_text("`".to_string()), + ]; + let reason = body.push_text_concatenate(&reason_parts, responsible_for_match); + body.push_panic(reason, responsible_for_match); + }, + responsible_for_match, + ); + + let then_function = body.push_function(case_id.child("conditionMatch"), |body, _| { + self.compile_expressions(body, responsible_for_needs, &case_body.expressions); + }); + + body.push_call( + builtin_if_else, + vec![condition_result, then_function, else_function], + responsible_for_needs, + ); + } else { + self.compile_expressions(body, responsible_for_needs, &case_body.expressions); + }; + } fn push_call( &self, diff --git a/compiler/frontend/src/mir/body.rs b/compiler/frontend/src/mir/body.rs index a080be93e..7a42952b1 100644 --- a/compiler/frontend/src/mir/body.rs +++ b/compiler/frontend/src/mir/body.rs @@ -412,7 +412,53 @@ impl BodyBuilder { responsible, ) } + pub fn push_if_not( + &mut self, + hir_id: &hir::Id, + condition: Id, + else_builder: E, + responsible: Id, + ) -> Id + where + E: FnOnce(&mut Self), + { + self.push_if_else( + hir_id, + condition, + |body| { + body.push_nothing(); + }, + else_builder, + responsible, + ) + } + pub fn push_is_bool(&mut self, hir_id: &hir::Id, value: Id, responsible: Id) -> Id { + let is_condition_true = self.push_equals_value(value, true, responsible); + self.push_if_else( + &hir_id.child("isValueTrue"), + is_condition_true, + |body| { + body.push_reference(is_condition_true); + }, + |body| { + body.push_equals_value(value, false, responsible); + }, + responsible, + ) + } + pub fn push_equals(&mut self, a: Id, b: Id, responsible: Id) -> Id { + let builtin_equals = self.push_builtin(BuiltinFunction::Equals); + self.push_call(builtin_equals, vec![a, b], responsible) + } + pub fn push_equals_value(&mut self, a: Id, b: impl Into, responsible: Id) -> Id { + let b = self.push(b.into()); + self.push_equals(a, b, responsible) + } + pub fn push_to_debug_text(&mut self, value: Id, responsible: Id) -> Id { + let builtin_to_debug_text = self.push_builtin(BuiltinFunction::ToDebugText); + self.push_call(builtin_to_debug_text, vec![value], responsible) + } pub fn push_panic(&mut self, reason: Id, responsible: Id) -> Id { self.push(Expression::Panic { reason, @@ -439,6 +485,19 @@ impl BodyBuilder { ) } + pub fn push_text_concatenate(&mut self, parts: &[Id], responsible: Id) -> Id { + assert!(!parts.is_empty()); + + let builtin_text_concatenate = self.push_builtin(BuiltinFunction::TextConcatenate); + parts + .iter() + .copied() + .reduce(|left, right| { + self.push_call(builtin_text_concatenate, vec![left, right], responsible) + }) + .unwrap() + } + #[must_use] pub fn current_return_value(&self) -> Id { self.body.return_value() diff --git a/compiler/frontend/src/rcst.rs b/compiler/frontend/src/rcst.rs index 1dc20761e..ee8c37284 100644 --- a/compiler/frontend/src/rcst.rs +++ b/compiler/frontend/src/rcst.rs @@ -2,6 +2,7 @@ use crate::{ cst::{Cst, CstKind}, rich_ir::{RichIrBuilder, ToRichIr}, }; +use enumset::EnumSet; pub type Rcst = Cst<()>; @@ -16,7 +17,7 @@ impl From> for Rcst { impl ToRichIr for Rcst { fn build_rich_ir(&self, builder: &mut RichIrBuilder) { - builder.push(format!("{self:#?}"), None, EnumSet::empty()); + self.kind.build_rich_ir(builder); } } diff --git a/compiler/frontend/src/rcst_to_cst.rs b/compiler/frontend/src/rcst_to_cst.rs index 8d493b614..86af35a32 100644 --- a/compiler/frontend/src/rcst_to_cst.rs +++ b/compiler/frontend/src/rcst_to_cst.rs @@ -257,10 +257,12 @@ impl Rcst { }, CstKind::MatchCase { pattern, + condition, arrow, body, } => CstKind::MatchCase { pattern: Box::new(pattern.to_cst(state)), + condition: condition.as_ref().map(|v| Box::new(v.to_cst(state))), arrow: Box::new(arrow.to_cst(state)), body: body.to_csts_helper(state), }, @@ -334,3 +336,10 @@ impl RcstsToCstsHelperExt for Vec { csts } } + +#[extension_trait] +impl ConvertToCst for (Rcst, Rcst) { + fn to_cst(&self, state: &mut State) -> (Cst, Cst) { + (self.0.to_cst(state), self.1.to_cst(state)) + } +} diff --git a/compiler/frontend/src/string_to_rcst/expression.rs b/compiler/frontend/src/string_to_rcst/expression.rs index 51491c843..aee4f46be 100644 --- a/compiler/frontend/src/string_to_rcst/expression.rs +++ b/compiler/frontend/src/string_to_rcst/expression.rs @@ -5,7 +5,7 @@ use super::{ list::list, literal::{ arrow, bar, closing_bracket, closing_curly_brace, closing_parenthesis, colon_equals_sign, - dot, equals_sign, percent, + comma, dot, equals_sign, percent, }, struct_::struct_, text::text, @@ -407,6 +407,31 @@ fn match_case(input: &str, indentation: usize) -> Option<(&str, Rcst)> { let (input, whitespace) = whitespaces_and_newlines(input, indentation, true); let pattern = pattern.wrap_in_whitespace(whitespace); + let (input, condition) = if let Some((input, condition_comma)) = comma(input) { + let (input, whitespace) = whitespaces_and_newlines(input, indentation, true); + let condition_comma = condition_comma.wrap_in_whitespace(whitespace); + if let Some((input, condition_expresion)) = expression( + input, + indentation, + ExpressionParsingOptions { + allow_assignment: false, + allow_call: true, + allow_bar: true, + allow_function: true, + }, + ) { + (input, Some((condition_comma, condition_expresion))) + } else { + let error = CstKind::Error { + unparsable_input: String::new(), + error: CstError::MatchCaseMissesCondition, + }; + (input, Some((condition_comma, error.into()))) + } + } else { + (input, None) + }; + let (input, arrow) = if let Some((input, arrow)) = arrow(input) { let (input, whitespace) = whitespaces_and_newlines(input, indentation, true); (input, arrow.wrap_in_whitespace(whitespace)) @@ -431,6 +456,7 @@ fn match_case(input: &str, indentation: usize) -> Option<(&str, Rcst)> { let case = CstKind::MatchCase { pattern: Box::new(pattern), + condition: condition.map(Box::new), arrow: Box::new(arrow), body, }; diff --git a/compiler/frontend/src/string_to_rcst/utils.rs b/compiler/frontend/src/string_to_rcst/utils.rs index b7fab7fb9..733fbda0f 100644 --- a/compiler/frontend/src/string_to_rcst/utils.rs +++ b/compiler/frontend/src/string_to_rcst/utils.rs @@ -91,7 +91,7 @@ impl ToRichIr for Option<(&str, T)> { } #[cfg(test)] -macro_rules! assert_rich_ir_snapshot { +macro_rules! assert_rich_ir_snapshot { ($value:expr, @$string:literal) => { insta::_assert_snapshot_base!( transform=|it| $crate::rich_ir::ToRichIr::to_rich_ir(it, false).text, diff --git a/compiler/language_server/src/features_candy/folding_ranges.rs b/compiler/language_server/src/features_candy/folding_ranges.rs index ab44e63cf..e19efd786 100644 --- a/compiler/language_server/src/features_candy/folding_ranges.rs +++ b/compiler/language_server/src/features_candy/folding_ranges.rs @@ -137,11 +137,16 @@ where } CstKind::MatchCase { pattern, + condition, arrow, body, } => { self.visit_cst(pattern); + if let Some(box (_, condition)) = condition { + self.visit_cst(condition); + } + let arrow = arrow.unwrap_whitespace_and_comment(); let body_end = body .unwrap_whitespace_and_comment() diff --git a/compiler/language_server/src/features_candy/references.rs b/compiler/language_server/src/features_candy/references.rs index 383bef63a..a2eb8c625 100644 --- a/compiler/language_server/src/features_candy/references.rs +++ b/compiler/language_server/src/features_candy/references.rs @@ -2,7 +2,7 @@ use crate::{features::Reference, utils::LspPositionConversion}; use candy_frontend::{ ast_to_hir::AstToHir, cst::{CstDb, CstKind}, - hir::{self, Body, Expression, Function, HirDb}, + hir::{self, Body, Expression, Function, HirDb, MatchCase}, module::{Module, ModuleDb}, position::{Offset, PositionConversionDb}, }; @@ -173,7 +173,10 @@ where | Expression::Destructure { .. } | Expression::PatternIdentifierReference(_) => {} Expression::Match { cases, .. } => { - for (_, body) in cases { + for MatchCase{condition, body, ..} in cases { + if let Some(condition) = condition { + self.visit_body(condition); + } self.visit_body(body); } } diff --git a/compiler/language_server/src/features_candy/semantic_tokens.rs b/compiler/language_server/src/features_candy/semantic_tokens.rs index c8b1157e5..1b146e750 100644 --- a/compiler/language_server/src/features_candy/semantic_tokens.rs +++ b/compiler/language_server/src/features_candy/semantic_tokens.rs @@ -236,10 +236,15 @@ fn visit_cst( } CstKind::MatchCase { pattern, + condition, arrow, body, } => { visit_cst(builder, pattern, None); + if let Some(box (comma, condition)) = condition { + visit_cst(builder, comma, None); + visit_cst(builder, condition, None); + } visit_cst(builder, arrow, None); visit_csts(builder, body, None); } diff --git a/packages/Examples/match.candy b/packages/Examples/match.candy index 5c79b08d2..a6af975be 100644 --- a/packages/Examples/match.candy +++ b/packages/Examples/match.candy @@ -1,10 +1,18 @@ [ifElse, int] = use "Core" -foo value = +buildEnum value = needs (int.is value) - ifElse (value | int.isLessThan 5) { Ok value } { Error "NOPE" } + ifElse (value | int.isLessThan 10) { Ok value } { Error "NOPE" } -main = foo 2 % - Ok value, value | int.isGreaterThan 5 -> value - Error value -> 10 - _ -> 20 +testFunction value = + needs (int.is value) + buildEnum value % + Ok value, value | int.isLessThan 2 -> value + Ok value, value | int.isGreaterThan 3 -> int.multiply value 2 + Error value -> 10 + _ -> 20 + + +main := { args -> + (testFunction 1, testFunction 2, testFunction 3, testFunction 4, testFunction 40) +}