diff --git a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs index 5ba045d3c8..3e33c62144 100644 --- a/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs +++ b/crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs @@ -1,10 +1,13 @@ use hir::ModuleDef; -use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast}; +use ide_db::helpers::{ + get_path_at_cursor_in_tt, import_assets::NameToImport, mod_path_to_ast, + parse_tt_as_comma_sep_paths, +}; use ide_db::items_locator; use itertools::Itertools; use syntax::{ - ast::{self, make, AstNode, HasName}, - SyntaxKind::{IDENT, WHITESPACE}, + ast::{self, AstNode, AstToken, HasName}, + SyntaxKind::WHITESPACE, }; use crate::{ @@ -52,9 +55,8 @@ pub(crate) fn replace_derive_with_manual_impl( return None; } - let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?; - let trait_name = trait_token.text(); - + let ident = args.syntax().token_at_offset(ctx.offset()).find_map(ast::Ident::cast)?; + let trait_path = get_path_at_cursor_in_tt(&ident)?; let adt = attr.syntax().parent().and_then(ast::Adt::cast)?; let current_module = ctx.sema.scope(adt.syntax()).module()?; @@ -63,7 +65,7 @@ pub(crate) fn replace_derive_with_manual_impl( let found_traits = items_locator::items_with_name( &ctx.sema, current_crate, - NameToImport::Exact(trait_name.to_string()), + NameToImport::Exact(trait_path.segments().last()?.to_string()), items_locator::AssocItemSearch::Exclude, Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()), ) @@ -80,12 +82,23 @@ pub(crate) fn replace_derive_with_manual_impl( }); let mut no_traits_found = true; - for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) { - add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?; + let current_derives = parse_tt_as_comma_sep_paths(args.clone())?; + let current_derives = current_derives.as_slice(); + for (replace_trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) { + add_assist( + acc, + ctx, + &attr, + ¤t_derives, + &args, + &trait_path, + &replace_trait_path, + Some(trait_), + &adt, + )?; } if no_traits_found { - let trait_path = make::ext::ident_path(trait_name); - add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?; + add_assist(acc, ctx, &attr, ¤t_derives, &args, &trait_path, &trait_path, None, &adt)?; } Some(()) } @@ -94,15 +107,16 @@ fn add_assist( acc: &mut Assists, ctx: &AssistContext, attr: &ast::Attr, - input: &ast::TokenTree, - trait_path: &ast::Path, + old_derives: &[ast::Path], + old_tree: &ast::TokenTree, + old_trait_path: &ast::Path, + replace_trait_path: &ast::Path, trait_: Option, adt: &ast::Adt, ) -> Option<()> { let target = attr.syntax().text_range(); let annotated_name = adt.name()?; - let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name); - let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?; + let label = format!("Convert to manual `impl {} for {}`", replace_trait_path, annotated_name); acc.add( AssistId("replace_derive_with_manual_impl", AssistKind::Refactor), @@ -111,9 +125,9 @@ fn add_assist( |builder| { let insert_pos = adt.syntax().text_range().end(); let impl_def_with_items = - impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path); - update_attribute(builder, input, &trait_name, attr); - let trait_path = format!("{}", trait_path); + impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, replace_trait_path); + update_attribute(builder, old_derives, old_tree, old_trait_path, attr); + let trait_path = format!("{}", replace_trait_path); match (ctx.config.snippet_cap, impl_def_with_items) { (None, _) => { builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, "")) @@ -192,23 +206,20 @@ fn impl_def_from_trait( fn update_attribute( builder: &mut AssistBuilder, - input: &ast::TokenTree, - trait_name: &ast::NameRef, + old_derives: &[ast::Path], + old_tree: &ast::TokenTree, + old_trait_path: &ast::Path, attr: &ast::Attr, ) { - let trait_name = trait_name.text(); - let new_attr_input = input - .syntax() - .descendants_with_tokens() - .filter(|t| t.kind() == IDENT) - .filter_map(|t| t.into_token().map(|t| t.text().to_string())) - .filter(|t| t != &trait_name) + let new_derives = old_derives + .iter() + .filter(|t| t.to_string() != old_trait_path.to_string()) .collect::>(); - let has_more_derives = !new_attr_input.is_empty(); + let has_more_derives = !new_derives.is_empty(); if has_more_derives { - let new_attr_input = format!("({})", new_attr_input.iter().format(", ")); - builder.replace(input.syntax().text_range(), new_attr_input); + let new_derives = format!("({})", new_derives.iter().format(", ")); + builder.replace(old_tree.syntax().text_range(), new_derives); } else { let attr_range = attr.syntax().text_range(); builder.delete(attr_range); @@ -1165,4 +1176,48 @@ struct S; "#, ); } + + #[test] + fn add_custom_impl_keep_path() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: clone +#[derive(std::fmt::Debug, Clo$0ne)] +pub struct Foo; +"#, + r#" +#[derive(std::fmt::Debug)] +pub struct Foo; + +impl Clone for Foo { + $0fn clone(&self) -> Self { + Self { } + } +} +"#, + ) + } + + #[test] + fn add_custom_impl_replace_path() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: fmt +#[derive(core::fmt::Deb$0ug, Clone)] +pub struct Foo; +"#, + r#" +#[derive(Clone)] +pub struct Foo; + +impl core::fmt::Debug for Foo { + $0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Foo").finish() + } +} +"#, + ) + } } diff --git a/crates/ide_completion/src/completions/attribute.rs b/crates/ide_completion/src/completions/attribute.rs index e86f38aaa4..8740db326b 100644 --- a/crates/ide_completion/src/completions/attribute.rs +++ b/crates/ide_completion/src/completions/attribute.rs @@ -4,7 +4,10 @@ //! for built-in attributes. use hir::HasAttrs; -use ide_db::helpers::generated_lints::{CLIPPY_LINTS, DEFAULT_LINTS, FEATURES, RUSTDOC_LINTS}; +use ide_db::helpers::{ + generated_lints::{CLIPPY_LINTS, DEFAULT_LINTS, FEATURES, RUSTDOC_LINTS}, + parse_tt_as_comma_sep_paths, +}; use itertools::Itertools; use once_cell::sync::Lazy; use rustc_hash::FxHashMap; @@ -30,12 +33,14 @@ pub(crate) fn complete_attribute(acc: &mut Completions, ctx: &CompletionContext) match (name_ref, attribute.token_tree()) { (Some(path), Some(token_tree)) => match path.text().as_str() { "repr" => repr::complete_repr(acc, ctx, token_tree), - "derive" => derive::complete_derive(acc, ctx, &parse_comma_sep_paths(token_tree)?), + "derive" => { + derive::complete_derive(acc, ctx, &parse_tt_as_comma_sep_paths(token_tree)?) + } "feature" => { - lint::complete_lint(acc, ctx, &parse_comma_sep_paths(token_tree)?, FEATURES) + lint::complete_lint(acc, ctx, &parse_tt_as_comma_sep_paths(token_tree)?, FEATURES) } "allow" | "warn" | "deny" | "forbid" => { - let existing_lints = parse_comma_sep_paths(token_tree)?; + let existing_lints = parse_tt_as_comma_sep_paths(token_tree)?; lint::complete_lint(acc, ctx, &existing_lints, DEFAULT_LINTS); lint::complete_lint(acc, ctx, &existing_lints, CLIPPY_LINTS); lint::complete_lint(acc, ctx, &existing_lints, RUSTDOC_LINTS); @@ -307,23 +312,6 @@ const ATTRIBUTES: &[AttrCompletion] = &[ .prefer_inner(), ]; -fn parse_comma_sep_paths(input: ast::TokenTree) -> Option> { - let r_paren = input.r_paren_token()?; - let tokens = input - .syntax() - .children_with_tokens() - .skip(1) - .take_while(|it| it.as_token() != Some(&r_paren)); - let input_expressions = tokens.into_iter().group_by(|tok| tok.kind() == T![,]); - Some( - input_expressions - .into_iter() - .filter_map(|(is_sep, group)| (!is_sep).then(|| group)) - .filter_map(|mut tokens| ast::Path::parse(&tokens.join("")).ok()) - .collect::>(), - ) -} - fn parse_comma_sep_expr(input: ast::TokenTree) -> Option> { let r_paren = input.r_paren_token()?; let tokens = input diff --git a/crates/ide_db/src/helpers.rs b/crates/ide_db/src/helpers.rs index 97aff0970a..1b9cb7ff51 100644 --- a/crates/ide_db/src/helpers.rs +++ b/crates/ide_db/src/helpers.rs @@ -39,10 +39,9 @@ pub fn get_path_in_derive_attr( attr: &ast::Attr, cursor: &Ident, ) -> Option { - let cursor = cursor.syntax(); let path = attr.path()?; let tt = attr.token_tree()?; - if !tt.syntax().text_range().contains_range(cursor.text_range()) { + if !tt.syntax().text_range().contains_range(cursor.syntax().text_range()) { return None; } let scope = sema.scope(attr.syntax()); @@ -51,7 +50,12 @@ pub fn get_path_in_derive_attr( if PathResolution::Macro(derive) != resolved_attr { return None; } + get_path_at_cursor_in_tt(cursor) +} +/// Parses the path the identifier is part of inside a token tree. +pub fn get_path_at_cursor_in_tt(cursor: &Ident) -> Option { + let cursor = cursor.syntax(); let first = cursor .siblings_with_tokens(Direction::Prev) .filter_map(SyntaxElement::into_token) @@ -300,3 +304,21 @@ pub fn lint_eq_or_in_group(lint: &str, lint_is: &str) -> bool { false } } + +/// Parses the input token tree as comma separated paths. +pub fn parse_tt_as_comma_sep_paths(input: ast::TokenTree) -> Option> { + let r_paren = input.r_paren_token()?; + let tokens = input + .syntax() + .children_with_tokens() + .skip(1) + .take_while(|it| it.as_token() != Some(&r_paren)); + let input_expressions = tokens.into_iter().group_by(|tok| tok.kind() == T![,]); + Some( + input_expressions + .into_iter() + .filter_map(|(is_sep, group)| (!is_sep).then(|| group)) + .filter_map(|mut tokens| ast::Path::parse(&tokens.join("")).ok()) + .collect::>(), + ) +}