diff --git a/crates/ide_assists/src/handlers/generate_function.rs b/crates/ide_assists/src/handlers/generate_function.rs index 954ad2db08..0255e508b4 100644 --- a/crates/ide_assists/src/handlers/generate_function.rs +++ b/crates/ide_assists/src/handlers/generate_function.rs @@ -1,5 +1,10 @@ -use hir::{HasSource, HirDisplay, Module, TypeInfo}; -use ide_db::{base_db::FileId, helpers::SnippetCap}; +use hir::{HasSource, HirDisplay, Module, ModuleDef, Semantics, TypeInfo}; +use ide_db::{ + base_db::FileId, + defs::{Definition, NameRefClass}, + helpers::SnippetCap, + RootDatabase, +}; use rustc_hash::{FxHashMap, FxHashSet}; use stdx::to_lower_snake_case; use syntax::{ @@ -438,7 +443,7 @@ fn fn_args( let mut arg_names = Vec::new(); let mut arg_types = Vec::new(); for arg in call.arg_list()?.args() { - arg_names.push(fn_arg_name(&arg)); + arg_names.push(fn_arg_name(&ctx.sema, &arg)); arg_types.push(match fn_arg_type(ctx, target_module, &arg) { Some(ty) => { if !ty.is_empty() && ty.starts_with('&') { @@ -503,12 +508,18 @@ fn deduplicate_arg_names(arg_names: &mut Vec) { } } -fn fn_arg_name(arg_expr: &ast::Expr) -> String { +fn fn_arg_name(sema: &Semantics, arg_expr: &ast::Expr) -> String { let name = (|| match arg_expr { - ast::Expr::CastExpr(cast_expr) => Some(fn_arg_name(&cast_expr.expr()?)), + ast::Expr::CastExpr(cast_expr) => Some(fn_arg_name(sema, &cast_expr.expr()?)), expr => { - let s = expr.syntax().descendants().filter_map(ast::NameRef::cast).last()?.to_string(); - Some(to_lower_snake_case(&s)) + let name_ref = expr.syntax().descendants().filter_map(ast::NameRef::cast).last()?; + if let Some(NameRefClass::Definition(Definition::ModuleDef( + ModuleDef::Const(_) | ModuleDef::Static(_), + ))) = NameRefClass::classify(sema, &name_ref) + { + return Some(name_ref.to_string().to_lowercase()); + }; + Some(to_lower_snake_case(&name_ref.to_string())) } })(); match name { @@ -1683,6 +1694,75 @@ fn main() { fn foo(arg0: ()) ${0:-> _} { todo!() } +", + ) + } + + #[test] + fn add_function_with_const_arg() { + check_assist( + generate_function, + r" +const VALUE: usize = 0; +fn main() { + foo$0(VALUE); +} +", + r" +const VALUE: usize = 0; +fn main() { + foo(VALUE); +} + +fn foo(value: usize) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_with_static_arg() { + check_assist( + generate_function, + r" +static VALUE: usize = 0; +fn main() { + foo$0(VALUE); +} +", + r" +static VALUE: usize = 0; +fn main() { + foo(VALUE); +} + +fn foo(value: usize) ${0:-> _} { + todo!() +} +", + ) + } + + #[test] + fn add_function_with_static_mut_arg() { + check_assist( + generate_function, + r" +static mut VALUE: usize = 0; +fn main() { + foo$0(VALUE); +} +", + r" +static mut VALUE: usize = 0; +fn main() { + foo(VALUE); +} + +fn foo(value: usize) ${0:-> _} { + todo!() +} ", ) }