diff --git a/crates/ide-assists/src/handlers/convert_bool_then.rs b/crates/ide-assists/src/handlers/convert_bool_then.rs index 8d391c64ce61..151c71c0a767 100644 --- a/crates/ide-assists/src/handlers/convert_bool_then.rs +++ b/crates/ide-assists/src/handlers/convert_bool_then.rs @@ -8,12 +8,13 @@ use ide_db::{ }; use itertools::Itertools; use syntax::{ - ast::{self, edit::AstNodeEdit, make, HasArgList}, - ted, AstNode, SyntaxNode, + ast::{self, edit::AstNodeEdit, syntax_factory::SyntaxFactory, HasArgList}, + syntax_editor::SyntaxEditor, + AstNode, SyntaxNode, }; use crate::{ - utils::{invert_boolean_expression_legacy, unwrap_trivial_block}, + utils::{invert_boolean_expression, unwrap_trivial_block}, AssistContext, AssistId, AssistKind, Assists, }; @@ -76,9 +77,9 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> "Convert `if` expression to `bool::then` call", target, |builder| { - let closure_body = closure_body.clone_for_update(); + let closure_body = closure_body.clone_subtree(); + let mut editor = SyntaxEditor::new(closure_body.syntax().clone()); // Rewrite all `Some(e)` in tail position to `e` - let mut replacements = Vec::new(); for_each_tail_expr(&closure_body, &mut |e| { let e = match e { ast::Expr::BreakExpr(e) => e.expr(), @@ -88,12 +89,16 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> if let Some(ast::Expr::CallExpr(call)) = e { if let Some(arg_list) = call.arg_list() { if let Some(arg) = arg_list.args().next() { - replacements.push((call.syntax().clone(), arg.syntax().clone())); + editor.replace(call.syntax(), arg.syntax()); } } } }); - replacements.into_iter().for_each(|(old, new)| ted::replace(old, new)); + let edit = editor.finish(); + let closure_body = ast::Expr::cast(edit.new_root().clone()).unwrap(); + + let mut editor = builder.make_editor(expr.syntax()); + let make = SyntaxFactory::new(); let closure_body = match closure_body { ast::Expr::BlockExpr(block) => unwrap_trivial_block(block), e => e, @@ -119,11 +124,18 @@ pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext<'_> | ast::Expr::WhileExpr(_) | ast::Expr::YieldExpr(_) ); - let cond = if invert_cond { invert_boolean_expression_legacy(cond) } else { cond }; - let cond = if parenthesize { make::expr_paren(cond) } else { cond }; - let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body))); - let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list); - builder.replace(target, mcall.to_string()); + let cond = if invert_cond { + invert_boolean_expression(&make, cond) + } else { + cond.clone_for_update() + }; + let cond = if parenthesize { make.expr_paren(cond).into() } else { cond }; + let arg_list = make.arg_list(Some(make.expr_closure(None, closure_body).into())); + let mcall = make.expr_method_call(cond, make.name_ref("then"), arg_list); + editor.replace(expr.syntax(), mcall.syntax()); + + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.file_id(), editor); }, ) } @@ -173,16 +185,17 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> "Convert `bool::then` call to `if`", target, |builder| { - let closure_body = match closure_body { + let mapless_make = SyntaxFactory::without_mappings(); + let closure_body = match closure_body.reset_indent() { ast::Expr::BlockExpr(block) => block, - e => make::block_expr(None, Some(e)), + e => mapless_make.block_expr(None, Some(e)), }; - let closure_body = closure_body.clone_for_update(); + let closure_body = closure_body.clone_subtree(); + let mut editor = SyntaxEditor::new(closure_body.syntax().clone()); // Wrap all tails in `Some(...)` - let none_path = make::expr_path(make::ext::ident_path("None")); - let some_path = make::expr_path(make::ext::ident_path("Some")); - let mut replacements = Vec::new(); + let none_path = mapless_make.expr_path(mapless_make.ident_path("None")); + let some_path = mapless_make.expr_path(mapless_make.ident_path("Some")); for_each_tail_expr(&ast::Expr::BlockExpr(closure_body.clone()), &mut |e| { let e = match e { ast::Expr::BreakExpr(e) => e.expr(), @@ -190,28 +203,37 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> _ => Some(e.clone()), }; if let Some(expr) = e { - replacements.push(( + editor.replace( expr.syntax().clone(), - make::expr_call(some_path.clone(), make::arg_list(Some(expr))) + mapless_make + .expr_call(some_path.clone(), mapless_make.arg_list(Some(expr))) .syntax() - .clone_for_update(), - )); + .clone(), + ); } }); - replacements.into_iter().for_each(|(old, new)| ted::replace(old, new)); + let edit = editor.finish(); + let closure_body = ast::BlockExpr::cast(edit.new_root().clone()).unwrap(); + + let mut editor = builder.make_editor(mcall.syntax()); + let make = SyntaxFactory::new(); let cond = match &receiver { ast::Expr::ParenExpr(expr) => expr.expr().unwrap_or(receiver), _ => receiver, }; - let if_expr = make::expr_if( - cond, - closure_body.reset_indent(), - Some(ast::ElseBranch::Block(make::block_expr(None, Some(none_path)))), - ) - .indent(mcall.indent_level()); + let if_expr = make + .expr_if( + cond, + closure_body, + Some(ast::ElseBranch::Block(make.block_expr(None, Some(none_path)))), + ) + .indent(mcall.indent_level()) + .clone_for_update(); + editor.replace(mcall.syntax().clone(), if_expr.syntax().clone()); - builder.replace(target, if_expr.to_string()); + editor.add_mappings(make.finish_with_mappings()); + builder.add_file_edits(ctx.file_id(), editor); }, ) } diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index 19c5c64e2184..85393ca5b4ce 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -129,7 +129,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.segments().map(|it| it.syntax().clone())); + builder.map_children(input, ast.segments().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -162,7 +162,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.pats().map(|it| it.syntax().clone())); + builder.map_children(input, ast.pats().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -175,7 +175,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone())); + builder.map_children(input, ast.fields().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -193,7 +193,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); builder.map_node(path.syntax().clone(), ast.path().unwrap().syntax().clone()); - builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone())); + builder.map_children(input, ast.fields().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -230,7 +230,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone())); + builder.map_children(input, ast.fields().map(|it| it.syntax().clone())); if let Some(rest_pat) = rest_pat { builder .map_node(rest_pat.syntax().clone(), ast.rest_pat().unwrap().syntax().clone()); @@ -315,10 +315,7 @@ impl SyntaxFactory { builder.map_node(last_stmt, ast_tail.syntax().clone()); } - builder.map_children( - input.into_iter(), - stmt_list.statements().map(|it| it.syntax().clone()), - ); + builder.map_children(input, stmt_list.statements().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -351,7 +348,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone())); + builder.map_children(input, ast.fields().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -454,7 +451,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone()); - builder.map_children(input.into_iter(), ast.args().map(|it| it.syntax().clone())); + builder.map_children(input, ast.args().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -476,6 +473,31 @@ impl SyntaxFactory { ast.into() } + pub fn expr_closure( + &self, + pats: impl IntoIterator, + expr: ast::Expr, + ) -> ast::ClosureExpr { + let (args, input) = iterator_input(pats); + // FIXME: `make::expr_paren` should return a `ClosureExpr`, not just an `Expr` + let ast::Expr::ClosureExpr(ast) = make::expr_closure(args, expr.clone()).clone_for_update() + else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone()); + builder.map_children( + input, + ast.param_list().unwrap().params().map(|param| param.syntax().clone()), + ); + builder.map_node(expr.syntax().clone(), ast.body().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn expr_return(&self, expr: Option) -> ast::ReturnExpr { let ast::Expr::ReturnExpr(ast) = make::expr_return(expr.clone()).clone_for_update() else { unreachable!() @@ -604,7 +626,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.arms().map(|it| it.syntax().clone())); + builder.map_children(input, ast.arms().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -727,6 +749,19 @@ impl SyntaxFactory { ast } + pub fn param(&self, pat: ast::Pat, ty: ast::Type) -> ast::Param { + let ast = make::param(pat.clone(), ty.clone()); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + builder.map_node(pat.syntax().clone(), ast.pat().unwrap().syntax().clone()); + builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone()); + builder.finish(&mut mapping); + } + + ast + } + pub fn generic_arg_list( &self, generic_args: impl IntoIterator, @@ -741,10 +776,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children( - input.into_iter(), - ast.generic_args().map(|arg| arg.syntax().clone()), - ); + builder.map_children(input, ast.generic_args().map(|arg| arg.syntax().clone())); builder.finish(&mut mapping); } @@ -761,7 +793,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone())); + builder.map_children(input, ast.fields().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -806,7 +838,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.fields().map(|it| it.syntax().clone())); + builder.map_children(input, ast.fields().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -901,7 +933,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children(input.into_iter(), ast.variants().map(|it| it.syntax().clone())); + builder.map_children(input, ast.variants().map(|it| it.syntax().clone())); builder.finish(&mut mapping); } @@ -953,6 +985,69 @@ impl SyntaxFactory { ast } + pub fn fn_( + &self, + visibility: Option, + fn_name: ast::Name, + type_params: Option, + where_clause: Option, + params: ast::ParamList, + body: ast::BlockExpr, + ret_type: Option, + is_async: bool, + is_const: bool, + is_unsafe: bool, + is_gen: bool, + ) -> ast::Fn { + let ast = make::fn_( + visibility.clone(), + fn_name.clone(), + type_params.clone(), + where_clause.clone(), + params.clone(), + body.clone(), + ret_type.clone(), + is_async, + is_const, + is_unsafe, + is_gen, + ); + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + + if let Some(visibility) = visibility { + builder.map_node( + visibility.syntax().clone(), + ast.visibility().unwrap().syntax().clone(), + ); + } + builder.map_node(fn_name.syntax().clone(), ast.name().unwrap().syntax().clone()); + if let Some(type_params) = type_params { + builder.map_node( + type_params.syntax().clone(), + ast.generic_param_list().unwrap().syntax().clone(), + ); + } + if let Some(where_clause) = where_clause { + builder.map_node( + where_clause.syntax().clone(), + ast.where_clause().unwrap().syntax().clone(), + ); + } + builder.map_node(params.syntax().clone(), ast.param_list().unwrap().syntax().clone()); + builder.map_node(body.syntax().clone(), ast.body().unwrap().syntax().clone()); + if let Some(ret_type) = ret_type { + builder + .map_node(ret_type.syntax().clone(), ast.ret_type().unwrap().syntax().clone()); + } + + builder.finish(&mut mapping); + } + + ast + } + pub fn token_tree( &self, delimiter: SyntaxKind, @@ -965,10 +1060,7 @@ impl SyntaxFactory { if let Some(mut mapping) = self.mappings() { let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); - builder.map_children( - input.into_iter(), - ast.token_trees_and_tokens().filter_map(only_nodes), - ); + builder.map_children(input, ast.token_trees_and_tokens().filter_map(only_nodes)); builder.finish(&mut mapping); } diff --git a/crates/syntax/src/syntax_editor/mapping.rs b/crates/syntax/src/syntax_editor/mapping.rs index 16bc55ed2d46..f71925a79558 100644 --- a/crates/syntax/src/syntax_editor/mapping.rs +++ b/crates/syntax/src/syntax_editor/mapping.rs @@ -239,10 +239,10 @@ impl SyntaxMappingBuilder { pub fn map_children( &mut self, - input: impl Iterator, - output: impl Iterator, + input: impl IntoIterator, + output: impl IntoIterator, ) { - for pairs in input.zip_longest(output) { + for pairs in input.into_iter().zip_longest(output) { let (input, output) = match pairs { itertools::EitherOrBoth::Both(l, r) => (l, r), itertools::EitherOrBoth::Left(_) => { diff --git a/docs/book/src/assists_generated.md b/docs/book/src/assists_generated.md index 2d233ca62ad6..72cecc2b02db 100644 --- a/docs/book/src/assists_generated.md +++ b/docs/book/src/assists_generated.md @@ -419,7 +419,7 @@ Converts comments to documentation. ### `convert_bool_then_to_if` -**Source:** [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L131) +**Source:** [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L143) Converts a `bool::then` method call to an equivalent if expression. @@ -443,7 +443,7 @@ fn main() { ### `convert_closure_to_fn` -**Source:** [convert_closure_to_fn.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_closure_to_fn.rs#L25) +**Source:** [convert_closure_to_fn.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_closure_to_fn.rs#L27) This converts a closure to a freestanding function, changing all captures to parameters. @@ -527,7 +527,7 @@ impl TryFrom for Thing { ### `convert_if_to_bool_then` -**Source:** [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L20) +**Source:** [convert_bool_then.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/convert_bool_then.rs#L21) Converts an if expression into a corresponding `bool::then` call. @@ -2258,7 +2258,7 @@ fn bar() { ### `inline_local_variable` -**Source:** [inline_local_variable.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/inline_local_variable.rs#L17) +**Source:** [inline_local_variable.rs](https://github.com/rust-lang/rust-analyzer/blob/master/crates/ide-assists/src/handlers/inline_local_variable.rs#L21) Inlines a local variable.