From 52ff863abcc8a6cb06689b25590616982504d916 Mon Sep 17 00:00:00 2001 From: Jonas Schievink Date: Thu, 19 May 2022 18:53:08 +0200 Subject: [PATCH] Teach `Callable` about closures properly --- crates/hir/src/lib.rs | 71 ++++++++++++++++++++++++-------- crates/ide/src/inlay_hints.rs | 20 ++++++++- crates/ide/src/signature_help.rs | 8 ++-- 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 12e06bf4ac..3f62a2cd33 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -62,9 +62,9 @@ use hir_ty::{ subst_prefix, traits::FnTrait, AliasEq, AliasTy, BoundVar, CallableDefId, CallableSig, Canonical, CanonicalVarKinds, Cast, - DebruijnIndex, GenericArgData, InEnvironment, Interner, ParamKind, QuantifiedWhereClause, - Scalar, Solution, Substitution, TraitEnvironment, TraitRefExt, Ty, TyBuilder, TyDefId, TyExt, - TyKind, TyVariableKind, WhereClause, + ClosureId, DebruijnIndex, GenericArgData, InEnvironment, Interner, ParamKind, + QuantifiedWhereClause, Scalar, Solution, Substitution, TraitEnvironment, TraitRefExt, Ty, + TyBuilder, TyDefId, TyExt, TyKind, TyVariableKind, WhereClause, }; use itertools::Itertools; use nameres::diagnostics::DefDiagnosticKind; @@ -2819,10 +2819,14 @@ impl Type { } pub fn as_callable(&self, db: &dyn HirDatabase) -> Option { - let def = self.ty.callable_def(db); + let callee = match self.ty.kind(Interner) { + TyKind::Closure(id, _) => Callee::Closure(*id), + TyKind::Function(_) => Callee::FnPtr, + _ => Callee::Def(self.ty.callable_def(db)?), + }; let sig = self.ty.callable_sig(db)?; - Some(Callable { ty: self.clone(), sig, def, is_bound_method: false }) + Some(Callable { ty: self.clone(), sig, callee, is_bound_method: false }) } pub fn is_closure(&self) -> bool { @@ -3265,34 +3269,43 @@ impl Type { } } -// FIXME: closures #[derive(Debug)] pub struct Callable { ty: Type, sig: CallableSig, - def: Option, + callee: Callee, pub(crate) is_bound_method: bool, } +#[derive(Debug)] +enum Callee { + Def(CallableDefId), + Closure(ClosureId), + FnPtr, +} + pub enum CallableKind { Function(Function), TupleStruct(Struct), TupleEnumVariant(Variant), Closure, + FnPtr, } impl Callable { pub fn kind(&self) -> CallableKind { - match self.def { - Some(CallableDefId::FunctionId(it)) => CallableKind::Function(it.into()), - Some(CallableDefId::StructId(it)) => CallableKind::TupleStruct(it.into()), - Some(CallableDefId::EnumVariantId(it)) => CallableKind::TupleEnumVariant(it.into()), - None => CallableKind::Closure, + use Callee::*; + match self.callee { + Def(CallableDefId::FunctionId(it)) => CallableKind::Function(it.into()), + Def(CallableDefId::StructId(it)) => CallableKind::TupleStruct(it.into()), + Def(CallableDefId::EnumVariantId(it)) => CallableKind::TupleEnumVariant(it.into()), + Closure(_) => CallableKind::Closure, + FnPtr => CallableKind::FnPtr, } } pub fn receiver_param(&self, db: &dyn HirDatabase) -> Option { - let func = match self.def { - Some(CallableDefId::FunctionId(it)) if self.is_bound_method => it, + let func = match self.callee { + Callee::Def(CallableDefId::FunctionId(it)) if self.is_bound_method => it, _ => return None, }; let src = func.lookup(db.upcast()).source(db.upcast()); @@ -3312,8 +3325,9 @@ impl Callable { .iter() .skip(if self.is_bound_method { 1 } else { 0 }) .map(|ty| self.ty.derived(ty.clone())); - let patterns = match self.def { - Some(CallableDefId::FunctionId(func)) => { + let map_param = |it: ast::Param| it.pat().map(Either::Right); + let patterns = match self.callee { + Callee::Def(CallableDefId::FunctionId(func)) => { let src = func.lookup(db.upcast()).source(db.upcast()); src.value.param_list().map(|param_list| { param_list @@ -3321,9 +3335,20 @@ impl Callable { .map(|it| Some(Either::Left(it))) .filter(|_| !self.is_bound_method) .into_iter() - .chain(param_list.params().map(|it| it.pat().map(Either::Right))) + .chain(param_list.params().map(map_param)) }) } + Callee::Closure(closure_id) => match closure_source(db, closure_id) { + Some(src) => src.param_list().map(|param_list| { + param_list + .self_param() + .map(|it| Some(Either::Left(it))) + .filter(|_| !self.is_bound_method) + .into_iter() + .chain(param_list.params().map(map_param)) + }), + None => None, + }, _ => None, }; patterns.into_iter().flatten().chain(iter::repeat(None)).zip(types).collect() @@ -3333,6 +3358,18 @@ impl Callable { } } +fn closure_source(db: &dyn HirDatabase, closure: ClosureId) -> Option { + let (owner, expr_id) = db.lookup_intern_closure(closure.into()); + let (_, source_map) = db.body_with_source_map(owner); + let ast = source_map.expr_syntax(expr_id).ok()?; + let root = ast.file_syntax(db.upcast()); + let expr = ast.value.to_node(&root); + match expr { + ast::Expr::ClosureExpr(it) => Some(it), + _ => None, + } +} + #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum BindingMode { Move, diff --git a/crates/ide/src/inlay_hints.rs b/crates/ide/src/inlay_hints.rs index 3cb60d9e44..47f1a08b6f 100644 --- a/crates/ide/src/inlay_hints.rs +++ b/crates/ide/src/inlay_hints.rs @@ -1169,6 +1169,23 @@ fn main() { ); } + #[test] + fn param_hints_on_closure() { + check_params( + r#" +fn main() { + let clo = |a: u8, b: u8| a + b; + clo( + 1, + //^ a + 2, + //^ b + ); +} + "#, + ); + } + #[test] fn param_name_similar_to_fn_name_still_hints() { check_params( @@ -2000,7 +2017,8 @@ fn main() { ; - let _: i32 = multiply(1, 2); + let _: i32 = multiply(1, 2); + //^ a ^ b let multiply_ref = &multiply; //^^^^^^^^^^^^ &|i32, i32| -> i32 diff --git a/crates/ide/src/signature_help.rs b/crates/ide/src/signature_help.rs index 32e7c59b2a..cb38f48f32 100644 --- a/crates/ide/src/signature_help.rs +++ b/crates/ide/src/signature_help.rs @@ -149,7 +149,7 @@ fn signature_help_for_call( variant.name(db) ); } - hir::CallableKind::Closure => (), + hir::CallableKind::Closure | hir::CallableKind::FnPtr => (), } res.signature.push('('); @@ -189,7 +189,7 @@ fn signature_help_for_call( hir::CallableKind::Function(func) if callable.return_type().contains_unknown() => { render(func.ret_type(db)) } - hir::CallableKind::Function(_) | hir::CallableKind::Closure => { + hir::CallableKind::Function(_) | hir::CallableKind::Closure | hir::CallableKind::FnPtr => { render(callable.return_type()) } hir::CallableKind::TupleStruct(_) | hir::CallableKind::TupleEnumVariant(_) => {} @@ -914,8 +914,8 @@ fn main() { } "#, expect![[r#" - (S) -> i32 - ^ + (s: S) -> i32 + ^^^^ "#]], ) }