diff --git a/crates/hir/src/code_model.rs b/crates/hir/src/code_model.rs index 849c8f6d09..a2a166e0ae 100644 --- a/crates/hir/src/code_model.rs +++ b/crates/hir/src/code_model.rs @@ -709,11 +709,23 @@ impl Function { } pub fn params(self, db: &dyn HirDatabase) -> Vec { + let resolver = self.id.resolver(db.upcast()); + let ctx = hir_ty::TyLoweringContext::new(db, &resolver); + let environment = TraitEnvironment::lower(db, &resolver); db.function_data(self.id) .params .iter() .skip(if self.self_param(db).is_some() { 1 } else { 0 }) - .map(|_| Param { _ty: () }) + .map(|type_ref| { + let ty = Type { + krate: self.id.lookup(db.upcast()).container.module(db.upcast()).krate, + ty: InEnvironment { + value: Ty::from_hir_ext(&ctx, type_ref).0, + environment: environment.clone(), + }, + }; + Param { ty } + }) .collect() } @@ -742,15 +754,21 @@ impl From for Access { } } +pub struct Param { + ty: Type, +} + +impl Param { + pub fn ty(&self) -> &Type { + &self.ty + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct SelfParam { func: FunctionId, } -pub struct Param { - _ty: (), -} - impl SelfParam { pub fn access(self, db: &dyn HirDatabase) -> Access { let func_data = db.function_data(self.func); @@ -1276,6 +1294,14 @@ impl Type { ) } + pub fn remove_ref(&self) -> Option { + if let Ty::Apply(ApplicationTy { ctor: TypeCtor::Ref(_), .. }) = self.ty.value { + self.ty.value.substs().map(|substs| self.derived(substs[0].clone())) + } else { + None + } + } + pub fn is_unknown(&self) -> bool { matches!(self.ty.value, Ty::Unknown) } diff --git a/crates/ide/src/completion/completion_context.rs b/crates/ide/src/completion/completion_context.rs index 161f59c1e4..671b13328c 100644 --- a/crates/ide/src/completion/completion_context.rs +++ b/crates/ide/src/completion/completion_context.rs @@ -1,7 +1,7 @@ //! FIXME: write short doc here use base_db::SourceDatabase; -use hir::{Semantics, SemanticsScope, Type}; +use hir::{Local, ScopeDef, Semantics, SemanticsScope, Type}; use ide_db::RootDatabase; use syntax::{ algo::{find_covering_element, find_node_at_offset}, @@ -91,6 +91,7 @@ pub(crate) struct CompletionContext<'a> { pub(super) impl_as_prev_sibling: bool, pub(super) is_match_arm: bool, pub(super) has_item_list_or_source_file_parent: bool, + pub(super) locals: Vec<(String, Local)>, } impl<'a> CompletionContext<'a> { @@ -119,6 +120,12 @@ impl<'a> CompletionContext<'a> { original_file.syntax().token_at_offset(position.offset).left_biased()?; let token = sema.descend_into_macros(original_token.clone()); let scope = sema.scope_at_offset(&token.parent(), position.offset); + let mut locals = vec![]; + scope.process_all_names(&mut |name, scope| { + if let ScopeDef::Local(local) = scope { + locals.push((name.to_string(), local)); + } + }); let mut ctx = CompletionContext { sema, scope, @@ -167,6 +174,7 @@ impl<'a> CompletionContext<'a> { if_is_prev: false, is_match_arm: false, has_item_list_or_source_file_parent: false, + locals, }; let mut original_file = original_file.syntax().clone(); diff --git a/crates/ide/src/completion/presentation.rs b/crates/ide/src/completion/presentation.rs index 24c507f9b2..987cbfa7a8 100644 --- a/crates/ide/src/completion/presentation.rs +++ b/crates/ide/src/completion/presentation.rs @@ -191,6 +191,17 @@ impl Completions { func: hir::Function, local_name: Option, ) { + fn add_arg(arg: &str, ty: &Type, ctx: &CompletionContext) -> String { + if let Some(derefed_ty) = ty.remove_ref() { + for (name, local) in ctx.locals.iter() { + if name == arg && local.ty(ctx.db) == derefed_ty { + return (if ty.is_mutable_reference() { "&mut " } else { "&" }).to_string() + + &arg.to_string(); + } + } + } + arg.to_string() + }; let name = local_name.unwrap_or_else(|| func.name(ctx.db).to_string()); let ast_node = func.source(ctx.db).value; @@ -205,12 +216,20 @@ impl Completions { .set_deprecated(is_deprecated(func, ctx.db)) .detail(function_declaration(&ast_node)); + let params_ty = func.params(ctx.db); let params = ast_node .param_list() .into_iter() .flat_map(|it| it.params()) - .flat_map(|it| it.pat()) - .map(|pat| pat.to_string().trim_start_matches('_').into()) + .zip(params_ty) + .flat_map(|(it, param_ty)| { + if let Some(pat) = it.pat() { + let name = pat.to_string(); + let arg = name.trim_start_matches("mut ").trim_start_matches('_'); + return Some(add_arg(arg, param_ty.ty(), ctx)); + } + None + }) .collect(); builder = builder.add_call_parens(ctx, name, Params::Named(params)); @@ -863,6 +882,106 @@ fn main() { foo(${1:foo}, ${2:bar}, ${3:ho_ge_})$0 } ); } + #[test] + fn insert_ref_when_matching_local_in_scope() { + check_edit( + "ref_arg", + r#" +struct Foo {} +fn ref_arg(x: &Foo) {} +fn main() { + let x = Foo {}; + ref_ar<|> +} +"#, + r#" +struct Foo {} +fn ref_arg(x: &Foo) {} +fn main() { + let x = Foo {}; + ref_arg(${1:&x})$0 +} +"#, + ); + } + + #[test] + fn insert_mut_ref_when_matching_local_in_scope() { + check_edit( + "ref_arg", + r#" +struct Foo {} +fn ref_arg(x: &mut Foo) {} +fn main() { + let x = Foo {}; + ref_ar<|> +} +"#, + r#" +struct Foo {} +fn ref_arg(x: &mut Foo) {} +fn main() { + let x = Foo {}; + ref_arg(${1:&mut x})$0 +} +"#, + ); + } + + #[test] + fn insert_ref_when_matching_local_in_scope_for_method() { + check_edit( + "apply_foo", + r#" +struct Foo {} +struct Bar {} +impl Bar { + fn apply_foo(&self, x: &Foo) {} +} + +fn main() { + let x = Foo {}; + let y = Bar {}; + y.<|> +} +"#, + r#" +struct Foo {} +struct Bar {} +impl Bar { + fn apply_foo(&self, x: &Foo) {} +} + +fn main() { + let x = Foo {}; + let y = Bar {}; + y.apply_foo(${1:&x})$0 +} +"#, + ); + } + + #[test] + fn trim_mut_keyword_in_func_completion() { + check_edit( + "take_mutably", + r#" +fn take_mutably(mut x: &i32) {} + +fn main() { + take_m<|> +} +"#, + r#" +fn take_mutably(mut x: &i32) {} + +fn main() { + take_mutably(${1:x})$0 +} +"#, + ); + } + #[test] fn inserts_parens_for_tuple_enums() { mark::check!(inserts_parens_for_tuple_enums);