fix: derive path handling

This commit is contained in:
rainy-me 2021-11-25 00:21:29 +09:00
parent 3e4ac8a2c9
commit 0bb08ccb8f
3 changed files with 118 additions and 53 deletions

View file

@ -1,10 +1,13 @@
use hir::ModuleDef; use hir::ModuleDef;
use ide_db::helpers::{import_assets::NameToImport, mod_path_to_ast}; use ide_db::helpers::{
get_path_at_cursor_in_tt, import_assets::NameToImport, mod_path_to_ast,
parse_tt_as_comma_sep_paths,
};
use ide_db::items_locator; use ide_db::items_locator;
use itertools::Itertools; use itertools::Itertools;
use syntax::{ use syntax::{
ast::{self, make, AstNode, HasName}, ast::{self, AstNode, AstToken, HasName},
SyntaxKind::{IDENT, WHITESPACE}, SyntaxKind::WHITESPACE,
}; };
use crate::{ use crate::{
@ -52,9 +55,8 @@ pub(crate) fn replace_derive_with_manual_impl(
return None; return None;
} }
let trait_token = args.syntax().token_at_offset(ctx.offset()).find(|t| t.kind() == IDENT)?; let ident = args.syntax().token_at_offset(ctx.offset()).find_map(ast::Ident::cast)?;
let trait_name = trait_token.text(); let trait_path = get_path_at_cursor_in_tt(&ident)?;
let adt = attr.syntax().parent().and_then(ast::Adt::cast)?; let adt = attr.syntax().parent().and_then(ast::Adt::cast)?;
let current_module = ctx.sema.scope(adt.syntax()).module()?; let current_module = ctx.sema.scope(adt.syntax()).module()?;
@ -63,7 +65,7 @@ pub(crate) fn replace_derive_with_manual_impl(
let found_traits = items_locator::items_with_name( let found_traits = items_locator::items_with_name(
&ctx.sema, &ctx.sema,
current_crate, current_crate,
NameToImport::Exact(trait_name.to_string()), NameToImport::Exact(trait_path.segments().last()?.to_string()),
items_locator::AssocItemSearch::Exclude, items_locator::AssocItemSearch::Exclude,
Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()), Some(items_locator::DEFAULT_QUERY_SEARCH_LIMIT.inner()),
) )
@ -80,12 +82,23 @@ pub(crate) fn replace_derive_with_manual_impl(
}); });
let mut no_traits_found = true; let mut no_traits_found = true;
for (trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) { let current_derives = parse_tt_as_comma_sep_paths(args.clone())?;
add_assist(acc, ctx, &attr, &args, &trait_path, Some(trait_), &adt)?; let current_derives = current_derives.as_slice();
for (replace_trait_path, trait_) in found_traits.inspect(|_| no_traits_found = false) {
add_assist(
acc,
ctx,
&attr,
&current_derives,
&args,
&trait_path,
&replace_trait_path,
Some(trait_),
&adt,
)?;
} }
if no_traits_found { if no_traits_found {
let trait_path = make::ext::ident_path(trait_name); add_assist(acc, ctx, &attr, &current_derives, &args, &trait_path, &trait_path, None, &adt)?;
add_assist(acc, ctx, &attr, &args, &trait_path, None, &adt)?;
} }
Some(()) Some(())
} }
@ -94,15 +107,16 @@ fn add_assist(
acc: &mut Assists, acc: &mut Assists,
ctx: &AssistContext, ctx: &AssistContext,
attr: &ast::Attr, attr: &ast::Attr,
input: &ast::TokenTree, old_derives: &[ast::Path],
trait_path: &ast::Path, old_tree: &ast::TokenTree,
old_trait_path: &ast::Path,
replace_trait_path: &ast::Path,
trait_: Option<hir::Trait>, trait_: Option<hir::Trait>,
adt: &ast::Adt, adt: &ast::Adt,
) -> Option<()> { ) -> Option<()> {
let target = attr.syntax().text_range(); let target = attr.syntax().text_range();
let annotated_name = adt.name()?; let annotated_name = adt.name()?;
let label = format!("Convert to manual `impl {} for {}`", trait_path, annotated_name); let label = format!("Convert to manual `impl {} for {}`", replace_trait_path, annotated_name);
let trait_name = trait_path.segment().and_then(|seg| seg.name_ref())?;
acc.add( acc.add(
AssistId("replace_derive_with_manual_impl", AssistKind::Refactor), AssistId("replace_derive_with_manual_impl", AssistKind::Refactor),
@ -111,9 +125,9 @@ fn add_assist(
|builder| { |builder| {
let insert_pos = adt.syntax().text_range().end(); let insert_pos = adt.syntax().text_range().end();
let impl_def_with_items = let impl_def_with_items =
impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, trait_path); impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, replace_trait_path);
update_attribute(builder, input, &trait_name, attr); update_attribute(builder, old_derives, old_tree, old_trait_path, attr);
let trait_path = format!("{}", trait_path); let trait_path = format!("{}", replace_trait_path);
match (ctx.config.snippet_cap, impl_def_with_items) { match (ctx.config.snippet_cap, impl_def_with_items) {
(None, _) => { (None, _) => {
builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, "")) builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, ""))
@ -192,23 +206,20 @@ fn impl_def_from_trait(
fn update_attribute( fn update_attribute(
builder: &mut AssistBuilder, builder: &mut AssistBuilder,
input: &ast::TokenTree, old_derives: &[ast::Path],
trait_name: &ast::NameRef, old_tree: &ast::TokenTree,
old_trait_path: &ast::Path,
attr: &ast::Attr, attr: &ast::Attr,
) { ) {
let trait_name = trait_name.text(); let new_derives = old_derives
let new_attr_input = input .iter()
.syntax() .filter(|t| t.to_string() != old_trait_path.to_string())
.descendants_with_tokens()
.filter(|t| t.kind() == IDENT)
.filter_map(|t| t.into_token().map(|t| t.text().to_string()))
.filter(|t| t != &trait_name)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let has_more_derives = !new_attr_input.is_empty(); let has_more_derives = !new_derives.is_empty();
if has_more_derives { if has_more_derives {
let new_attr_input = format!("({})", new_attr_input.iter().format(", ")); let new_derives = format!("({})", new_derives.iter().format(", "));
builder.replace(input.syntax().text_range(), new_attr_input); builder.replace(old_tree.syntax().text_range(), new_derives);
} else { } else {
let attr_range = attr.syntax().text_range(); let attr_range = attr.syntax().text_range();
builder.delete(attr_range); builder.delete(attr_range);
@ -1165,4 +1176,48 @@ struct S;
"#, "#,
); );
} }
#[test]
fn add_custom_impl_keep_path() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: clone
#[derive(std::fmt::Debug, Clo$0ne)]
pub struct Foo;
"#,
r#"
#[derive(std::fmt::Debug)]
pub struct Foo;
impl Clone for Foo {
$0fn clone(&self) -> Self {
Self { }
}
}
"#,
)
}
#[test]
fn add_custom_impl_replace_path() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: fmt
#[derive(core::fmt::Deb$0ug, Clone)]
pub struct Foo;
"#,
r#"
#[derive(Clone)]
pub struct Foo;
impl core::fmt::Debug for Foo {
$0fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Foo").finish()
}
}
"#,
)
}
} }

View file

@ -4,7 +4,10 @@
//! for built-in attributes. //! for built-in attributes.
use hir::HasAttrs; use hir::HasAttrs;
use ide_db::helpers::generated_lints::{CLIPPY_LINTS, DEFAULT_LINTS, FEATURES, RUSTDOC_LINTS}; use ide_db::helpers::{
generated_lints::{CLIPPY_LINTS, DEFAULT_LINTS, FEATURES, RUSTDOC_LINTS},
parse_tt_as_comma_sep_paths,
};
use itertools::Itertools; use itertools::Itertools;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
@ -30,12 +33,14 @@ pub(crate) fn complete_attribute(acc: &mut Completions, ctx: &CompletionContext)
match (name_ref, attribute.token_tree()) { match (name_ref, attribute.token_tree()) {
(Some(path), Some(token_tree)) => match path.text().as_str() { (Some(path), Some(token_tree)) => match path.text().as_str() {
"repr" => repr::complete_repr(acc, ctx, token_tree), "repr" => repr::complete_repr(acc, ctx, token_tree),
"derive" => derive::complete_derive(acc, ctx, &parse_comma_sep_paths(token_tree)?), "derive" => {
derive::complete_derive(acc, ctx, &parse_tt_as_comma_sep_paths(token_tree)?)
}
"feature" => { "feature" => {
lint::complete_lint(acc, ctx, &parse_comma_sep_paths(token_tree)?, FEATURES) lint::complete_lint(acc, ctx, &parse_tt_as_comma_sep_paths(token_tree)?, FEATURES)
} }
"allow" | "warn" | "deny" | "forbid" => { "allow" | "warn" | "deny" | "forbid" => {
let existing_lints = parse_comma_sep_paths(token_tree)?; let existing_lints = parse_tt_as_comma_sep_paths(token_tree)?;
lint::complete_lint(acc, ctx, &existing_lints, DEFAULT_LINTS); lint::complete_lint(acc, ctx, &existing_lints, DEFAULT_LINTS);
lint::complete_lint(acc, ctx, &existing_lints, CLIPPY_LINTS); lint::complete_lint(acc, ctx, &existing_lints, CLIPPY_LINTS);
lint::complete_lint(acc, ctx, &existing_lints, RUSTDOC_LINTS); lint::complete_lint(acc, ctx, &existing_lints, RUSTDOC_LINTS);
@ -307,23 +312,6 @@ const ATTRIBUTES: &[AttrCompletion] = &[
.prefer_inner(), .prefer_inner(),
]; ];
fn parse_comma_sep_paths(input: ast::TokenTree) -> Option<Vec<ast::Path>> {
let r_paren = input.r_paren_token()?;
let tokens = input
.syntax()
.children_with_tokens()
.skip(1)
.take_while(|it| it.as_token() != Some(&r_paren));
let input_expressions = tokens.into_iter().group_by(|tok| tok.kind() == T![,]);
Some(
input_expressions
.into_iter()
.filter_map(|(is_sep, group)| (!is_sep).then(|| group))
.filter_map(|mut tokens| ast::Path::parse(&tokens.join("")).ok())
.collect::<Vec<ast::Path>>(),
)
}
fn parse_comma_sep_expr(input: ast::TokenTree) -> Option<Vec<ast::Expr>> { fn parse_comma_sep_expr(input: ast::TokenTree) -> Option<Vec<ast::Expr>> {
let r_paren = input.r_paren_token()?; let r_paren = input.r_paren_token()?;
let tokens = input let tokens = input

View file

@ -39,10 +39,9 @@ pub fn get_path_in_derive_attr(
attr: &ast::Attr, attr: &ast::Attr,
cursor: &Ident, cursor: &Ident,
) -> Option<ast::Path> { ) -> Option<ast::Path> {
let cursor = cursor.syntax();
let path = attr.path()?; let path = attr.path()?;
let tt = attr.token_tree()?; let tt = attr.token_tree()?;
if !tt.syntax().text_range().contains_range(cursor.text_range()) { if !tt.syntax().text_range().contains_range(cursor.syntax().text_range()) {
return None; return None;
} }
let scope = sema.scope(attr.syntax()); let scope = sema.scope(attr.syntax());
@ -51,7 +50,12 @@ pub fn get_path_in_derive_attr(
if PathResolution::Macro(derive) != resolved_attr { if PathResolution::Macro(derive) != resolved_attr {
return None; return None;
} }
get_path_at_cursor_in_tt(cursor)
}
/// Parses the path the identifier is part of inside a token tree.
pub fn get_path_at_cursor_in_tt(cursor: &Ident) -> Option<ast::Path> {
let cursor = cursor.syntax();
let first = cursor let first = cursor
.siblings_with_tokens(Direction::Prev) .siblings_with_tokens(Direction::Prev)
.filter_map(SyntaxElement::into_token) .filter_map(SyntaxElement::into_token)
@ -300,3 +304,21 @@ pub fn lint_eq_or_in_group(lint: &str, lint_is: &str) -> bool {
false false
} }
} }
/// Parses the input token tree as comma separated paths.
pub fn parse_tt_as_comma_sep_paths(input: ast::TokenTree) -> Option<Vec<ast::Path>> {
let r_paren = input.r_paren_token()?;
let tokens = input
.syntax()
.children_with_tokens()
.skip(1)
.take_while(|it| it.as_token() != Some(&r_paren));
let input_expressions = tokens.into_iter().group_by(|tok| tok.kind() == T![,]);
Some(
input_expressions
.into_iter()
.filter_map(|(is_sep, group)| (!is_sep).then(|| group))
.filter_map(|mut tokens| ast::Path::parse(&tokens.join("")).ok())
.collect::<Vec<ast::Path>>(),
)
}