Make insert_use return a SyntaxRewriter

This commit is contained in:
Lukas Wirth 2020-11-02 21:40:52 +01:00
parent 245e1b533b
commit cd349dbbc4
5 changed files with 129 additions and 107 deletions

View file

@ -99,7 +99,6 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
let range = ctx.sema.original_range(import_assets.syntax_under_caret()).range; let range = ctx.sema.original_range(import_assets.syntax_under_caret()).range;
let group = import_group_message(import_assets.import_candidate()); let group = import_group_message(import_assets.import_candidate());
let scope = ImportScope::find_insert_use_container(import_assets.syntax_under_caret(), ctx)?; let scope = ImportScope::find_insert_use_container(import_assets.syntax_under_caret(), ctx)?;
let syntax = scope.as_syntax_node();
for (import, _) in proposed_imports { for (import, _) in proposed_imports {
acc.add_group( acc.add_group(
&group, &group,
@ -107,9 +106,9 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
format!("Import `{}`", &import), format!("Import `{}`", &import),
range, range,
|builder| { |builder| {
let new_syntax = let rewriter =
insert_use(&scope, mod_path_to_ast(&import), ctx.config.insert_use.merge); insert_use(&scope, mod_path_to_ast(&import), ctx.config.insert_use.merge);
builder.replace(syntax.text_range(), new_syntax.to_string()) builder.rewrite(rewriter);
}, },
); );
} }

View file

@ -1,16 +1,14 @@
use hir::{EnumVariant, Module, ModuleDef, Name}; use hir::{AsName, EnumVariant, Module, ModuleDef, Name};
use ide_db::base_db::FileId;
use ide_db::{defs::Definition, search::Reference, RootDatabase}; use ide_db::{defs::Definition, search::Reference, RootDatabase};
use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::FxHashSet;
use syntax::{ use syntax::{
algo::find_node_at_offset, algo::find_node_at_offset,
ast::{self, edit::IndentLevel, ArgListOwner, AstNode, NameOwner, VisibilityOwner}, algo::SyntaxRewriter,
SourceFile, TextRange, TextSize, ast::{self, edit::IndentLevel, make, ArgListOwner, AstNode, NameOwner, VisibilityOwner},
SourceFile, SyntaxElement,
}; };
use crate::{ use crate::{
assist_context::AssistBuilder,
utils::{insert_use, mod_path_to_ast, ImportScope}, utils::{insert_use, mod_path_to_ast, ImportScope},
AssistContext, AssistId, AssistKind, Assists, AssistContext, AssistId, AssistKind, Assists,
}; };
@ -43,7 +41,7 @@ pub(crate) fn extract_struct_from_enum_variant(
return None; return None;
} }
let variant_name = variant.name()?.to_string(); let variant_name = variant.name()?;
let variant_hir = ctx.sema.to_def(&variant)?; let variant_hir = ctx.sema.to_def(&variant)?;
if existing_struct_def(ctx.db(), &variant_name, &variant_hir) { if existing_struct_def(ctx.db(), &variant_name, &variant_hir) {
return None; return None;
@ -62,14 +60,18 @@ pub(crate) fn extract_struct_from_enum_variant(
|builder| { |builder| {
let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir));
let res = definition.usages(&ctx.sema).all(); let res = definition.usages(&ctx.sema).all();
let start_offset = variant.parent_enum().syntax().text_range().start();
let mut visited_modules_set = FxHashSet::default(); let mut visited_modules_set = FxHashSet::default();
visited_modules_set.insert(current_module); visited_modules_set.insert(current_module);
let mut rewriters = FxHashMap::default();
for reference in res { for reference in res {
let rewriter = rewriters
.entry(reference.file_range.file_id)
.or_insert_with(SyntaxRewriter::default);
let source_file = ctx.sema.parse(reference.file_range.file_id); let source_file = ctx.sema.parse(reference.file_range.file_id);
update_reference( update_reference(
ctx, ctx,
builder, rewriter,
reference, reference,
&source_file, &source_file,
&enum_module_def, &enum_module_def,
@ -77,34 +79,39 @@ pub(crate) fn extract_struct_from_enum_variant(
&mut visited_modules_set, &mut visited_modules_set,
); );
} }
let mut rewriter =
rewriters.remove(&ctx.frange.file_id).unwrap_or_else(SyntaxRewriter::default);
for (file_id, rewriter) in rewriters {
builder.edit_file(file_id);
builder.rewrite(rewriter);
}
builder.edit_file(ctx.frange.file_id);
update_variant(&mut rewriter, &variant_name, &field_list);
extract_struct_def( extract_struct_def(
builder, &mut rewriter,
&enum_ast, &enum_ast,
&variant_name, variant_name.clone(),
&field_list.to_string(), &field_list,
start_offset, &variant.parent_enum().syntax().clone().into(),
ctx.frange.file_id, visibility,
&visibility,
); );
let list_range = field_list.syntax().text_range(); builder.rewrite(rewriter);
update_variant(builder, &variant_name, ctx.frange.file_id, list_range);
}, },
) )
} }
fn existing_struct_def(db: &RootDatabase, variant_name: &str, variant: &EnumVariant) -> bool { fn existing_struct_def(db: &RootDatabase, variant_name: &ast::Name, variant: &EnumVariant) -> bool {
variant variant
.parent_enum(db) .parent_enum(db)
.module(db) .module(db)
.scope(db, None) .scope(db, None)
.into_iter() .into_iter()
.any(|(name, _)| name.to_string() == variant_name) .any(|(name, _)| name == variant_name.as_name())
} }
#[allow(dead_code)]
fn insert_import( fn insert_import(
ctx: &AssistContext, ctx: &AssistContext,
builder: &mut AssistBuilder, rewriter: &mut SyntaxRewriter,
path: &ast::PathExpr, path: &ast::PathExpr,
module: &Module, module: &Module,
enum_module_def: &ModuleDef, enum_module_def: &ModuleDef,
@ -116,69 +123,59 @@ fn insert_import(
mod_path.segments.pop(); mod_path.segments.pop();
mod_path.segments.push(variant_hir_name.clone()); mod_path.segments.push(variant_hir_name.clone());
let scope = ImportScope::find_insert_use_container(path.syntax(), ctx)?; let scope = ImportScope::find_insert_use_container(path.syntax(), ctx)?;
let syntax = scope.as_syntax_node();
let new_syntax = *rewriter += insert_use(&scope, mod_path_to_ast(&mod_path), ctx.config.insert_use.merge);
insert_use(&scope, mod_path_to_ast(&mod_path), ctx.config.insert_use.merge);
// FIXME: this will currently panic as multiple imports will have overlapping text ranges
builder.replace(syntax.text_range(), new_syntax.to_string())
} }
Some(()) Some(())
} }
// FIXME: this should use strongly-typed `make`, rather than string manipulation.
fn extract_struct_def( fn extract_struct_def(
builder: &mut AssistBuilder, rewriter: &mut SyntaxRewriter,
enum_: &ast::Enum, enum_: &ast::Enum,
variant_name: &str, variant_name: ast::Name,
variant_list: &str, variant_list: &ast::TupleFieldList,
start_offset: TextSize, start_offset: &SyntaxElement,
file_id: FileId, visibility: Option<ast::Visibility>,
visibility: &Option<ast::Visibility>,
) -> Option<()> { ) -> Option<()> {
let visibility_string = if let Some(visibility) = visibility { let variant_list = make::tuple_field_list(
format!("{} ", visibility.to_string()) variant_list
} else { .fields()
"".to_string() .flat_map(|field| Some(make::tuple_field(Some(make::visibility_pub()), field.ty()?))),
};
let indent = IndentLevel::from_node(enum_.syntax());
let struct_def = format!(
r#"{}struct {}{};
{}"#,
visibility_string,
variant_name,
list_with_visibility(variant_list),
indent
); );
builder.edit_file(file_id);
builder.insert(start_offset, struct_def); rewriter.insert_before(
start_offset,
make::struct_(visibility, variant_name, None, variant_list.into()).syntax(),
);
rewriter.insert_before(start_offset, &make::tokens::blank_line());
if let indent_level @ 1..=usize::MAX = IndentLevel::from_node(enum_.syntax()).0 as usize {
rewriter
.insert_before(start_offset, &make::tokens::whitespace(&" ".repeat(4 * indent_level)));
}
Some(()) Some(())
} }
fn update_variant( fn update_variant(
builder: &mut AssistBuilder, rewriter: &mut SyntaxRewriter,
variant_name: &str, variant_name: &ast::Name,
file_id: FileId, field_list: &ast::TupleFieldList,
list_range: TextRange,
) -> Option<()> { ) -> Option<()> {
let inside_variant_range = TextRange::new( let (l, r): (SyntaxElement, SyntaxElement) =
list_range.start().checked_add(TextSize::from(1))?, (field_list.l_paren_token()?.into(), field_list.r_paren_token()?.into());
list_range.end().checked_sub(TextSize::from(1))?, let replacement = vec![l, variant_name.syntax().clone().into(), r];
); rewriter.replace_with_many(field_list.syntax(), replacement);
builder.edit_file(file_id);
builder.replace(inside_variant_range, variant_name);
Some(()) Some(())
} }
fn update_reference( fn update_reference(
ctx: &AssistContext, ctx: &AssistContext,
builder: &mut AssistBuilder, rewriter: &mut SyntaxRewriter,
reference: Reference, reference: Reference,
source_file: &SourceFile, source_file: &SourceFile,
_enum_module_def: &ModuleDef, enum_module_def: &ModuleDef,
_variant_hir_name: &Name, variant_hir_name: &Name,
_visited_modules_set: &mut FxHashSet<Module>, visited_modules_set: &mut FxHashSet<Module>,
) -> Option<()> { ) -> Option<()> {
let path_expr: ast::PathExpr = find_node_at_offset::<ast::PathExpr>( let path_expr: ast::PathExpr = find_node_at_offset::<ast::PathExpr>(
source_file.syntax(), source_file.syntax(),
@ -187,35 +184,21 @@ fn update_reference(
let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?;
let list = call.arg_list()?; let list = call.arg_list()?;
let segment = path_expr.path()?.segment()?; let segment = path_expr.path()?.segment()?;
let _module = ctx.sema.scope(&path_expr.syntax()).module()?; let module = ctx.sema.scope(&path_expr.syntax()).module()?;
let list_range = list.syntax().text_range();
let inside_list_range = TextRange::new(
list_range.start().checked_add(TextSize::from(1))?,
list_range.end().checked_sub(TextSize::from(1))?,
);
builder.edit_file(reference.file_range.file_id);
/* FIXME: this most likely requires AST-based editing, see `insert_import`
if !visited_modules_set.contains(&module) { if !visited_modules_set.contains(&module) {
if insert_import(ctx, builder, &path_expr, &module, enum_module_def, variant_hir_name) if insert_import(ctx, rewriter, &path_expr, &module, enum_module_def, variant_hir_name)
.is_some() .is_some()
{ {
visited_modules_set.insert(module); visited_modules_set.insert(module);
} }
} }
*/
builder.replace(inside_list_range, format!("{}{}", segment, list));
Some(())
}
fn list_with_visibility(list: &str) -> String { let lparen = syntax::SyntaxElement::from(list.l_paren_token()?);
list.split(',') let rparen = syntax::SyntaxElement::from(list.r_paren_token()?);
.map(|part| { rewriter.insert_after(&lparen, segment.syntax());
let index = if part.chars().next().unwrap() == '(' { 1usize } else { 0 }; rewriter.insert_after(&lparen, &lparen);
let mut mod_part = part.trim().to_string(); rewriter.insert_before(&rparen, &rparen);
mod_part.insert_str(index, "pub "); Some(())
mod_part
})
.join(", ")
} }
#[cfg(test)] #[cfg(test)]
@ -250,7 +233,6 @@ pub enum A { One(One) }"#,
} }
#[test] #[test]
#[ignore] // FIXME: this currently panics if `insert_import` is used
fn test_extract_struct_with_complex_imports() { fn test_extract_struct_with_complex_imports() {
check_assist( check_assist(
extract_struct_from_enum_variant, extract_struct_from_enum_variant,

View file

@ -45,10 +45,9 @@ pub(crate) fn replace_qualified_name_with_use(
// affected (that is, all paths inside the node we added the `use` to). // affected (that is, all paths inside the node we added the `use` to).
let mut rewriter = SyntaxRewriter::default(); let mut rewriter = SyntaxRewriter::default();
shorten_paths(&mut rewriter, syntax.clone(), &path); shorten_paths(&mut rewriter, syntax.clone(), &path);
let rewritten_syntax = rewriter.rewrite(&syntax); if let Some(ref import_scope) = ImportScope::from(syntax.clone()) {
if let Some(ref import_scope) = ImportScope::from(rewritten_syntax) { rewriter += insert_use(import_scope, path, ctx.config.insert_use.merge);
let new_syntax = insert_use(import_scope, path, ctx.config.insert_use.merge); builder.rewrite(rewriter);
builder.replace(syntax.text_range(), new_syntax.to_string())
} }
}, },
) )

View file

@ -1,12 +1,9 @@
//! Handle syntactic aspects of inserting a new `use`. //! Handle syntactic aspects of inserting a new `use`.
use std::{ use std::{cmp::Ordering, iter::successors};
cmp::Ordering,
iter::{self, successors},
};
use itertools::{EitherOrBoth, Itertools}; use itertools::{EitherOrBoth, Itertools};
use syntax::{ use syntax::{
algo, algo::SyntaxRewriter,
ast::{ ast::{
self, self,
edit::{AstNodeEdit, IndentLevel}, edit::{AstNodeEdit, IndentLevel},
@ -88,20 +85,19 @@ impl ImportScope {
} }
/// Insert an import path into the given file/node. A `merge` value of none indicates that no import merging is allowed to occur. /// Insert an import path into the given file/node. A `merge` value of none indicates that no import merging is allowed to occur.
pub(crate) fn insert_use( pub(crate) fn insert_use<'a>(
scope: &ImportScope, scope: &ImportScope,
path: ast::Path, path: ast::Path,
merge: Option<MergeBehaviour>, merge: Option<MergeBehaviour>,
) -> SyntaxNode { ) -> SyntaxRewriter<'a> {
let mut rewriter = SyntaxRewriter::default();
let use_item = make::use_(make::use_tree(path.clone(), None, None, false)); let use_item = make::use_(make::use_tree(path.clone(), None, None, false));
// merge into existing imports if possible // merge into existing imports if possible
if let Some(mb) = merge { if let Some(mb) = merge {
for existing_use in scope.as_syntax_node().children().filter_map(ast::Use::cast) { for existing_use in scope.as_syntax_node().children().filter_map(ast::Use::cast) {
if let Some(merged) = try_merge_imports(&existing_use, &use_item, mb) { if let Some(merged) = try_merge_imports(&existing_use, &use_item, mb) {
let to_delete: SyntaxElement = existing_use.syntax().clone().into(); rewriter.replace(existing_use.syntax(), merged.syntax());
let to_delete = to_delete.clone()..=to_delete; return rewriter;
let to_insert = iter::once(merged.syntax().clone().into());
return algo::replace_children(scope.as_syntax_node(), to_delete, to_insert);
} }
} }
} }
@ -157,7 +153,15 @@ pub(crate) fn insert_use(
buf buf
}; };
algo::insert_children(scope.as_syntax_node(), insert_position, to_insert) match insert_position {
InsertPosition::First => {
rewriter.insert_many_as_first_children(scope.as_syntax_node(), to_insert)
}
InsertPosition::Last => return rewriter, // actually unreachable
InsertPosition::Before(anchor) => rewriter.insert_many_before(&anchor, to_insert),
InsertPosition::After(anchor) => rewriter.insert_many_after(&anchor, to_insert),
}
rewriter
} }
fn eq_visibility(vis0: Option<ast::Visibility>, vis1: Option<ast::Visibility>) -> bool { fn eq_visibility(vis0: Option<ast::Visibility>, vis1: Option<ast::Visibility>) -> bool {
@ -1101,7 +1105,8 @@ use foo::bar::baz::Qux;",
.find_map(ast::Path::cast) .find_map(ast::Path::cast)
.unwrap(); .unwrap();
let result = insert_use(&file, path, mb).to_string(); let rewriter = insert_use(&file, path, mb);
let result = rewriter.rewrite(file.as_syntax_node()).to_string();
assert_eq_text!(&result, ra_fixture_after); assert_eq_text!(&result, ra_fixture_after);
} }

View file

@ -351,6 +351,23 @@ pub fn visibility_pub_crate() -> ast::Visibility {
ast_from_text("pub(crate) struct S") ast_from_text("pub(crate) struct S")
} }
pub fn visibility_pub() -> ast::Visibility {
ast_from_text("pub struct S")
}
pub fn tuple_field_list(fields: impl IntoIterator<Item = ast::TupleField>) -> ast::TupleFieldList {
let fields = fields.into_iter().join(", ");
ast_from_text(&format!("struct f({});", fields))
}
pub fn tuple_field(visibility: Option<ast::Visibility>, ty: ast::Type) -> ast::TupleField {
let visibility = match visibility {
None => String::new(),
Some(it) => format!("{} ", it),
};
ast_from_text(&format!("struct f({}{});", visibility, ty))
}
pub fn fn_( pub fn fn_(
visibility: Option<ast::Visibility>, visibility: Option<ast::Visibility>,
fn_name: ast::Name, fn_name: ast::Name,
@ -373,6 +390,26 @@ pub fn fn_(
)) ))
} }
pub fn struct_(
visibility: Option<ast::Visibility>,
strukt_name: ast::Name,
type_params: Option<ast::GenericParamList>,
field_list: ast::FieldList,
) -> ast::Struct {
let semicolon = if matches!(field_list, ast::FieldList::TupleFieldList(_)) { ";" } else { "" };
let type_params =
if let Some(type_params) = type_params { format!("<{}>", type_params) } else { "".into() };
let visibility = match visibility {
None => String::new(),
Some(it) => format!("{} ", it),
};
ast_from_text(&format!(
"{}struct {}{}{}{}",
visibility, strukt_name, type_params, field_list, semicolon
))
}
fn ast_from_text<N: AstNode>(text: &str) -> N { fn ast_from_text<N: AstNode>(text: &str) -> N {
let parse = SourceFile::parse(text); let parse = SourceFile::parse(text);
let node = match parse.tree().syntax().descendants().find_map(N::cast) { let node = match parse.tree().syntax().descendants().find_map(N::cast) {