diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 2db4b483b6..1fc18767d6 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -2637,14 +2637,19 @@ impl GenericDef { Either::Right(x) => GenericParam::TypeParam(x), } }); - let lt_params = generics + self.lifetime_params(db).into_iter().chain(ty_params).collect() + } + + pub fn lifetime_params(self, db: &dyn HirDatabase) -> Vec { + let generics = db.generic_params(self.into()); + generics .lifetimes .iter() .map(|(local_id, _)| LifetimeParam { id: LifetimeParamId { parent: self.into(), local_id }, }) - .map(GenericParam::LifetimeParam); - lt_params.chain(ty_params).collect() + .map(GenericParam::LifetimeParam) + .collect() } pub fn type_params(self, db: &dyn HirDatabase) -> Vec { diff --git a/crates/ide-assists/src/handlers/add_missing_impl_members.rs b/crates/ide-assists/src/handlers/add_missing_impl_members.rs index 6340feda45..ae5118e950 100644 --- a/crates/ide-assists/src/handlers/add_missing_impl_members.rs +++ b/crates/ide-assists/src/handlers/add_missing_impl_members.rs @@ -365,6 +365,59 @@ impl Foo for S { ); } + #[test] + fn test_lifetime_substitution() { + check_assist( + add_missing_impl_members, + r#" +pub trait Trait<'a, 'b, A, B, C> { + fn foo(&self, one: &'a A, anoter: &'b B) -> &'a C; +} + +impl<'x, 'y, T, V, U> Trait<'x, 'y, T, V, U> for () {$0}"#, + r#" +pub trait Trait<'a, 'b, A, B, C> { + fn foo(&self, one: &'a A, anoter: &'b B) -> &'a C; +} + +impl<'x, 'y, T, V, U> Trait<'x, 'y, T, V, U> for () { + fn foo(&self, one: &'x T, anoter: &'y V) -> &'x U { + ${0:todo!()} + } +}"#, + ); + } + + #[test] + fn test_lifetime_substitution_with_body() { + check_assist( + add_missing_default_members, + r#" +pub trait Trait<'a, 'b, A, B, C: Default> { + fn foo(&self, _one: &'a A, _anoter: &'b B) -> (C, &'a i32) { + let value: &'a i32 = &0; + (C::default(), value) + } +} + +impl<'x, 'y, T, V, U: Default> Trait<'x, 'y, T, V, U> for () {$0}"#, + r#" +pub trait Trait<'a, 'b, A, B, C: Default> { + fn foo(&self, _one: &'a A, _anoter: &'b B) -> (C, &'a i32) { + let value: &'a i32 = &0; + (C::default(), value) + } +} + +impl<'x, 'y, T, V, U: Default> Trait<'x, 'y, T, V, U> for () { + $0fn foo(&self, _one: &'x T, _anoter: &'y V) -> (U, &'x i32) { + let value: &'x i32 = &0; + (::default(), value) + } +}"#, + ); + } + #[test] fn test_cursor_after_empty_impl_def() { check_assist( diff --git a/crates/ide-db/src/path_transform.rs b/crates/ide-db/src/path_transform.rs index 0ee627a44c..fe8ffc4354 100644 --- a/crates/ide-db/src/path_transform.rs +++ b/crates/ide-db/src/path_transform.rs @@ -9,6 +9,14 @@ use syntax::{ ted, SyntaxNode, }; +#[derive(Default)] +struct Substs { + types: Vec, + lifetimes: Vec, +} + +type LifetimeName = String; + /// `PathTransform` substitutes path in SyntaxNodes in bulk. /// /// This is mostly useful for IDE code generation. If you paste some existing @@ -34,7 +42,7 @@ use syntax::{ /// ``` pub struct PathTransform<'a> { generic_def: Option, - substs: Vec, + substs: Substs, target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>, } @@ -72,7 +80,7 @@ impl<'a> PathTransform<'a> { target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>, ) -> PathTransform<'a> { - PathTransform { source_scope, target_scope, generic_def: None, substs: Vec::new() } + PathTransform { source_scope, target_scope, generic_def: None, substs: Substs::default() } } pub fn apply(&self, syntax: &SyntaxNode) { @@ -91,11 +99,11 @@ impl<'a> PathTransform<'a> { let target_module = self.target_scope.module(); let source_module = self.source_scope.module(); let skip = match self.generic_def { - // this is a trait impl, so we need to skip the first type parameter -- this is a bit hacky + // this is a trait impl, so we need to skip the first type parameter (i.e. Self) -- this is a bit hacky Some(hir::GenericDef::Trait(_)) => 1, _ => 0, }; - let substs_by_param: FxHashMap<_, _> = self + let type_substs: FxHashMap<_, _> = self .generic_def .into_iter() .flat_map(|it| it.type_params(db)) @@ -106,31 +114,35 @@ impl<'a> PathTransform<'a> { // can still hit those trailing values and check if they actually have // a default type. If they do, go for that type from `hir` to `ast` so // the resulting change can be applied correctly. - .zip(self.substs.iter().map(Some).chain(std::iter::repeat(None))) + .zip(self.substs.types.iter().map(Some).chain(std::iter::repeat(None))) .filter_map(|(k, v)| match k.split(db) { - Either::Left(_) => None, + Either::Left(_) => None, // FIXME: map const types too Either::Right(t) => match v { - Some(v) => Some((k, v.clone())), + Some(v) => Some((k, v.ty()?.clone())), None => { let default = t.default(db)?; - Some(( - k, - ast::make::ty( - &default - .display_source_code(db, source_module.into(), false) - .ok()?, - ), - )) + let v = ast::make::ty( + &default.display_source_code(db, source_module.into(), false).ok()?, + ); + Some((k, v)) } }, }) .collect(); - Ctx { substs: substs_by_param, target_module, source_scope: self.source_scope } + let lifetime_substs: FxHashMap<_, _> = self + .generic_def + .into_iter() + .flat_map(|it| it.lifetime_params(db)) + .zip(self.substs.lifetimes.clone()) + .filter_map(|(k, v)| Some((k.name(db).to_string(), v.lifetime()?))) + .collect(); + Ctx { type_substs, lifetime_substs, target_module, source_scope: self.source_scope } } } struct Ctx<'a> { - substs: FxHashMap, + type_substs: FxHashMap, + lifetime_substs: FxHashMap, target_module: hir::Module, source_scope: &'a SemanticsScope<'a>, } @@ -152,7 +164,24 @@ impl<'a> Ctx<'a> { for path in paths { self.transform_path(path); } + + item.preorder() + .filter_map(|event| match event { + syntax::WalkEvent::Enter(_) => None, + syntax::WalkEvent::Leave(node) => Some(node), + }) + .filter_map(ast::Lifetime::cast) + .for_each(|lifetime| { + if let Some(subst) = self.lifetime_substs.get(&lifetime.syntax().text().to_string()) + { + ted::replace( + lifetime.syntax(), + subst.clone_subtree().clone_for_update().syntax(), + ); + } + }); } + fn transform_path(&self, path: ast::Path) -> Option<()> { if path.qualifier().is_some() { return None; @@ -169,7 +198,7 @@ impl<'a> Ctx<'a> { match resolution { hir::PathResolution::TypeParam(tp) => { - if let Some(subst) = self.substs.get(&tp.merge()) { + if let Some(subst) = self.type_substs.get(&tp.merge()) { let parent = path.syntax().parent()?; if let Some(parent) = ast::Path::cast(parent.clone()) { // Path inside path means that there is an associated @@ -250,7 +279,7 @@ impl<'a> Ctx<'a> { // FIXME: It would probably be nicer if we could get this via HIR (i.e. get the // trait ref, and then go from the types in the substs back to the syntax). -fn get_syntactic_substs(impl_def: ast::Impl) -> Option> { +fn get_syntactic_substs(impl_def: ast::Impl) -> Option { let target_trait = impl_def.trait_()?; let path_type = match target_trait { ast::Type::PathType(path) => path, @@ -261,13 +290,13 @@ fn get_syntactic_substs(impl_def: ast::Impl) -> Option> { get_type_args_from_arg_list(generic_arg_list) } -fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option> { - let mut result = Vec::new(); - for generic_arg in generic_arg_list.generic_args() { - if let ast::GenericArg::TypeArg(type_arg) = generic_arg { - result.push(type_arg.ty()?) - } - } +fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option { + let mut result = Substs::default(); + generic_arg_list.generic_args().for_each(|generic_arg| match generic_arg { + ast::GenericArg::TypeArg(type_arg) => result.types.push(type_arg), + ast::GenericArg::LifetimeArg(l_arg) => result.lifetimes.push(l_arg), + _ => (), // FIXME: don't filter out const params + }); Some(result) }