internal: Migrate wrap_return_type assist to use SyntaxEditor

This commit is contained in:
Giga Bowser 2024-11-16 14:42:17 -05:00
parent 32b86a8378
commit 651b43e551
3 changed files with 168 additions and 63 deletions

View file

@ -189,7 +189,7 @@ pub(crate) fn add_turbo_fish(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti
/// This will create a turbofish generic arg list corresponding to the number of arguments /// This will create a turbofish generic arg list corresponding to the number of arguments
fn get_fish_head(make: &SyntaxFactory, number_of_arguments: usize) -> ast::GenericArgList { fn get_fish_head(make: &SyntaxFactory, number_of_arguments: usize) -> ast::GenericArgList {
let args = (0..number_of_arguments).map(|_| make::type_arg(make::ty_placeholder()).into()); let args = (0..number_of_arguments).map(|_| make::type_arg(make::ty_placeholder()).into());
make.turbofish_generic_arg_list(args) make.generic_arg_list(args, true)
} }
#[cfg(test)] #[cfg(test)]

View file

@ -6,10 +6,9 @@ use ide_db::{
famous_defs::FamousDefs, famous_defs::FamousDefs,
syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
}; };
use itertools::Itertools;
use syntax::{ use syntax::{
ast::{self, make, Expr, HasGenericParams}, ast::{self, syntax_factory::SyntaxFactory, Expr, HasGenericArgs, HasGenericParams},
match_ast, ted, AstNode, ToSmolStr, match_ast, AstNode,
}; };
use crate::{AssistContext, AssistId, AssistKind, Assists}; use crate::{AssistContext, AssistId, AssistKind, Assists};
@ -43,11 +42,11 @@ use crate::{AssistContext, AssistId, AssistKind, Assists};
pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let ret_type = ctx.find_node_at_offset::<ast::RetType>()?; let ret_type = ctx.find_node_at_offset::<ast::RetType>()?;
let parent = ret_type.syntax().parent()?; let parent = ret_type.syntax().parent()?;
let body = match_ast! { let body_expr = match_ast! {
match parent { match parent {
ast::Fn(func) => func.body()?, ast::Fn(func) => func.body()?.into(),
ast::ClosureExpr(closure) => match closure.body()? { ast::ClosureExpr(closure) => match closure.body()? {
Expr::BlockExpr(block) => block, Expr::BlockExpr(block) => block.into(),
// closures require a block when a return type is specified // closures require a block when a return type is specified
_ => return None, _ => return None,
}, },
@ -75,56 +74,65 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
kind.assist_id(), kind.assist_id(),
kind.label(), kind.label(),
type_ref.syntax().text_range(), type_ref.syntax().text_range(),
|edit| { |builder| {
let alias = wrapper_alias(ctx, &core_wrapper, type_ref, kind.symbol()); let mut editor = builder.make_editor(&parent);
let new_return_ty = let make = SyntaxFactory::new();
alias.unwrap_or_else(|| kind.wrap_type(type_ref)).clone_for_update(); let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol());
let new_return_ty = alias.unwrap_or_else(|| match kind {
let body = edit.make_mut(ast::Expr::BlockExpr(body.clone())); WrapperKind::Option => make.ty_option(type_ref.clone()),
WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()),
});
let mut exprs_to_wrap = Vec::new(); let mut exprs_to_wrap = Vec::new();
let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e);
walk_expr(&body, &mut |expr| { walk_expr(&body_expr, &mut |expr| {
if let Expr::ReturnExpr(ret_expr) = expr { if let Expr::ReturnExpr(ret_expr) = expr {
if let Some(ret_expr_arg) = &ret_expr.expr() { if let Some(ret_expr_arg) = &ret_expr.expr() {
for_each_tail_expr(ret_expr_arg, tail_cb); for_each_tail_expr(ret_expr_arg, tail_cb);
} }
} }
}); });
for_each_tail_expr(&body, tail_cb); for_each_tail_expr(&body_expr, tail_cb);
for ret_expr_arg in exprs_to_wrap { for ret_expr_arg in exprs_to_wrap {
let happy_wrapped = make::expr_call( let happy_wrapped = make.expr_call(
make::expr_path(make::ext::ident_path(kind.happy_ident())), make.expr_path(make.ident_path(kind.happy_ident())),
make::arg_list(iter::once(ret_expr_arg.clone())), make.arg_list(iter::once(ret_expr_arg.clone())),
) );
.clone_for_update(); editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
ted::replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
} }
let old_return_ty = edit.make_mut(type_ref.clone()); editor.replace(type_ref.syntax(), new_return_ty.syntax());
ted::replace(old_return_ty.syntax(), new_return_ty.syntax());
if let WrapperKind::Result = kind { if let WrapperKind::Result = kind {
// Add a placeholder snippet at the first generic argument that doesn't equal the return type. // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
// This is normally the error type, but that may not be the case when we inserted a type alias. // This is normally the error type, but that may not be the case when we inserted a type alias.
let args = let args = new_return_ty
new_return_ty.syntax().descendants().find_map(ast::GenericArgList::cast); .path()
let error_type_arg = args.and_then(|list| { .unwrap()
list.generic_args().find(|arg| match arg { .segment()
.unwrap()
.generic_arg_list()
.unwrap();
let error_type_arg = args.generic_args().find(|arg| match arg {
ast::GenericArg::TypeArg(_) => { ast::GenericArg::TypeArg(_) => {
arg.syntax().text() != type_ref.syntax().text() arg.syntax().text() != type_ref.syntax().text()
} }
ast::GenericArg::LifetimeArg(_) => false, ast::GenericArg::LifetimeArg(_) => false,
_ => true, _ => true,
})
}); });
if let Some(error_type_arg) = error_type_arg { if let Some(error_type_arg) = error_type_arg {
if let Some(cap) = ctx.config.snippet_cap { if let Some(cap) = ctx.config.snippet_cap {
edit.add_placeholder_snippet(cap, error_type_arg); editor.add_annotation(
error_type_arg.syntax(),
builder.make_placeholder_snippet(cap),
);
} }
} }
} }
editor.add_mappings(make.finish_with_mappings());
builder.add_file_edits(ctx.file_id(), editor);
}, },
); );
} }
@ -176,22 +184,16 @@ impl WrapperKind {
WrapperKind::Result => hir::sym::Result.clone(), WrapperKind::Result => hir::sym::Result.clone(),
} }
} }
fn wrap_type(&self, type_ref: &ast::Type) -> ast::Type {
match self {
WrapperKind::Option => make::ext::ty_option(type_ref.clone()),
WrapperKind::Result => make::ext::ty_result(type_ref.clone(), make::ty_placeholder()),
}
}
} }
// Try to find an wrapper type alias in the current scope (shadowing the default). // Try to find an wrapper type alias in the current scope (shadowing the default).
fn wrapper_alias( fn wrapper_alias(
ctx: &AssistContext<'_>, ctx: &AssistContext<'_>,
make: &SyntaxFactory,
core_wrapper: &hir::Enum, core_wrapper: &hir::Enum,
ret_type: &ast::Type, ret_type: &ast::Type,
wrapper: hir::Symbol, wrapper: hir::Symbol,
) -> Option<ast::Type> { ) -> Option<ast::PathType> {
let wrapper_path = hir::ModPath::from_segments( let wrapper_path = hir::ModPath::from_segments(
hir::PathKind::Plain, hir::PathKind::Plain,
iter::once(hir::Name::new_symbol_root(wrapper)), iter::once(hir::Name::new_symbol_root(wrapper)),
@ -207,25 +209,28 @@ fn wrapper_alias(
}) })
.find_map(|alias| { .find_map(|alias| {
let mut inserted_ret_type = false; let mut inserted_ret_type = false;
let generic_params = alias let generic_args =
.source(ctx.db())? alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| {
.value match param {
.generic_param_list()? // Replace the very first type parameter with the function's return type.
.generic_params()
.map(|param| match param {
// Replace the very first type parameter with the functions return type.
ast::GenericParam::TypeParam(_) if !inserted_ret_type => { ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
inserted_ret_type = true; inserted_ret_type = true;
ret_type.to_smolstr() make.type_arg(ret_type.clone()).into()
} }
ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(), ast::GenericParam::LifetimeParam(_) => {
_ => make::ty_placeholder().to_smolstr(), make.lifetime_arg(make.lifetime("'_")).into()
}) }
.join(", "); _ => make.type_arg(make.ty_infer().into()).into(),
}
});
let name = alias.name(ctx.db()); let name = alias.name(ctx.db());
let name = name.as_str(); let generic_arg_list = make.generic_arg_list(generic_args, false);
Some(make::ty(&format!("{name}<{generic_params}>"))) let path = make.path_unqualified(
make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list),
);
Some(make.ty_path(path))
}) })
}) })
} }

View file

@ -1,6 +1,9 @@
//! Wrappers over [`make`] constructors //! Wrappers over [`make`] constructors
use crate::{ use crate::{
ast::{self, make, HasGenericArgs, HasGenericParams, HasName, HasTypeBounds, HasVisibility}, ast::{
self, make, HasArgList, HasGenericArgs, HasGenericParams, HasName, HasTypeBounds,
HasVisibility,
},
syntax_editor::SyntaxMappingBuilder, syntax_editor::SyntaxMappingBuilder,
AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, SyntaxToken,
}; };
@ -16,6 +19,10 @@ impl SyntaxFactory {
make::name_ref(name).clone_for_update() make::name_ref(name).clone_for_update()
} }
pub fn lifetime(&self, text: &str) -> ast::Lifetime {
make::lifetime(text).clone_for_update()
}
pub fn ty(&self, text: &str) -> ast::Type { pub fn ty(&self, text: &str) -> ast::Type {
make::ty(text).clone_for_update() make::ty(text).clone_for_update()
} }
@ -28,6 +35,20 @@ impl SyntaxFactory {
ast ast
} }
pub fn ty_path(&self, path: ast::Path) -> ast::PathType {
let ast::Type::PathType(ast) = make::ty_path(path.clone()).clone_for_update() else {
unreachable!()
};
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
builder.map_node(path.syntax().clone(), ast.path().unwrap().syntax().clone());
builder.finish(&mut mapping);
}
ast
}
pub fn type_param( pub fn type_param(
&self, &self,
name: ast::Name, name: ast::Name,
@ -253,6 +274,37 @@ impl SyntaxFactory {
ast ast
} }
pub fn expr_call(&self, expr: ast::Expr, arg_list: ast::ArgList) -> ast::CallExpr {
// FIXME: `make::expr_call`` should return a `CallExpr`, not just an `Expr`
let ast::Expr::CallExpr(ast) =
make::expr_call(expr.clone(), arg_list.clone()).clone_for_update()
else {
unreachable!()
};
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
builder.map_node(expr.syntax().clone(), ast.expr().unwrap().syntax().clone());
builder.map_node(arg_list.syntax().clone(), ast.arg_list().unwrap().syntax().clone());
builder.finish(&mut mapping);
}
ast
}
pub fn arg_list(&self, args: impl IntoIterator<Item = ast::Expr>) -> ast::ArgList {
let (args, input) = iterator_input(args);
let ast = make::arg_list(args).clone_for_update();
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax.clone());
builder.map_children(input.into_iter(), ast.args().map(|it| it.syntax().clone()));
builder.finish(&mut mapping);
}
ast
}
pub fn expr_ref(&self, expr: ast::Expr, exclusive: bool) -> ast::Expr { pub fn expr_ref(&self, expr: ast::Expr, exclusive: bool) -> ast::Expr {
let ast::Expr::RefExpr(ast) = make::expr_ref(expr.clone(), exclusive).clone_for_update() let ast::Expr::RefExpr(ast) = make::expr_ref(expr.clone(), exclusive).clone_for_update()
else { else {
@ -428,6 +480,30 @@ impl SyntaxFactory {
ast ast
} }
pub fn type_arg(&self, ty: ast::Type) -> ast::TypeArg {
let ast = make::type_arg(ty.clone()).clone_for_update();
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
builder.finish(&mut mapping);
}
ast
}
pub fn lifetime_arg(&self, lifetime: ast::Lifetime) -> ast::LifetimeArg {
let ast = make::lifetime_arg(lifetime.clone()).clone_for_update();
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
builder.map_node(lifetime.syntax().clone(), ast.lifetime().unwrap().syntax().clone());
builder.finish(&mut mapping);
}
ast
}
pub fn item_const( pub fn item_const(
&self, &self,
visibility: Option<ast::Visibility>, visibility: Option<ast::Visibility>,
@ -495,12 +571,17 @@ impl SyntaxFactory {
ast ast
} }
pub fn turbofish_generic_arg_list( pub fn generic_arg_list(
&self, &self,
generic_args: impl IntoIterator<Item = ast::GenericArg>, generic_args: impl IntoIterator<Item = ast::GenericArg>,
is_turbo: bool,
) -> ast::GenericArgList { ) -> ast::GenericArgList {
let (generic_args, input) = iterator_input(generic_args); let (generic_args, input) = iterator_input(generic_args);
let ast = make::turbofish_generic_arg_list(generic_args.clone()).clone_for_update(); let ast = if is_turbo {
make::turbofish_generic_arg_list(generic_args).clone_for_update()
} else {
make::generic_arg_list(generic_args).clone_for_update()
};
if let Some(mut mapping) = self.mappings() { if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
@ -753,12 +834,31 @@ impl SyntaxFactory {
// `ext` constructors // `ext` constructors
impl SyntaxFactory { impl SyntaxFactory {
pub fn ident_path(&self, ident: &str) -> ast::Path {
self.path_unqualified(self.path_segment(self.name_ref(ident)))
}
pub fn expr_unit(&self) -> ast::Expr { pub fn expr_unit(&self) -> ast::Expr {
self.expr_tuple([]).into() self.expr_tuple([]).into()
} }
pub fn ident_path(&self, ident: &str) -> ast::Path { pub fn ty_option(&self, t: ast::Type) -> ast::PathType {
self.path_unqualified(self.path_segment(self.name_ref(ident))) let generic_arg_list = self.generic_arg_list([self.type_arg(t).into()], false);
let path = self.path_unqualified(
self.path_segment_generics(self.name_ref("Option"), generic_arg_list),
);
self.ty_path(path)
}
pub fn ty_result(&self, t: ast::Type, e: ast::Type) -> ast::PathType {
let generic_arg_list =
self.generic_arg_list([self.type_arg(t).into(), self.type_arg(e).into()], false);
let path = self.path_unqualified(
self.path_segment_generics(self.name_ref("Result"), generic_arg_list),
);
self.ty_path(path)
} }
} }