internal: replace AstTransformer with mutable syntax trees

This commit is contained in:
Aleksey Kladov 2021-05-18 14:42:41 +03:00
parent 3cfe2d0a5d
commit 47d7434dde
6 changed files with 123 additions and 208 deletions

View file

@ -1,31 +1,12 @@
//! `AstTransformer`s are functions that replace nodes in an AST and can be easily combined. //! `AstTransformer`s are functions that replace nodes in an AST and can be easily combined.
use hir::{HirDisplay, PathResolution, SemanticsScope}; use hir::{HirDisplay, SemanticsScope};
use ide_db::helpers::mod_path_to_ast; use ide_db::helpers::mod_path_to_ast;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use syntax::{ use syntax::{
ast::{self, AstNode}, ast::{self, AstNode},
ted, SyntaxNode, ted,
}; };
pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
let mut skip_to = None;
for event in node.syntax().preorder() {
match event {
syntax::WalkEvent::Enter(node) if skip_to.is_none() => {
skip_to = transformer.get_substitution(&node, transformer).zip(Some(node));
}
syntax::WalkEvent::Enter(_) => (),
syntax::WalkEvent::Leave(node) => match &skip_to {
Some((replacement, skip_target)) if *skip_target == node => {
ted::replace(node, replacement.clone_for_update());
skip_to.take();
}
_ => (),
},
}
}
}
/// `AstTransform` helps with applying bulk transformations to syntax nodes. /// `AstTransform` helps with applying bulk transformations to syntax nodes.
/// ///
/// This is mostly useful for IDE code generation. If you paste some existing /// This is mostly useful for IDE code generation. If you paste some existing
@ -35,8 +16,8 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
/// ///
/// ``` /// ```
/// mod x { /// mod x {
/// pub struct A; /// pub struct A<V>;
/// pub trait T<U> { fn foo(&self, _: U) -> A; } /// pub trait T<U> { fn foo(&self, _: U) -> A<U>; }
/// } /// }
/// ///
/// mod y { /// mod y {
@ -45,7 +26,7 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
/// impl T<()> for () { /// impl T<()> for () {
/// // If we invoke **Add Missing Members** here, we want to copy-paste `foo`. /// // If we invoke **Add Missing Members** here, we want to copy-paste `foo`.
/// // But we want a slightly-modified version of it: /// // But we want a slightly-modified version of it:
/// fn foo(&self, _: ()) -> x::A {} /// fn foo(&self, _: ()) -> x::A<()> {}
/// } /// }
/// } /// }
/// ``` /// ```
@ -54,49 +35,27 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
/// `SyntaxNode`. Note that the API here is a bit too high-order and high-brow. /// `SyntaxNode`. Note that the API here is a bit too high-order and high-brow.
/// We'd want to somehow express this concept simpler, but so far nobody got to /// We'd want to somehow express this concept simpler, but so far nobody got to
/// simplifying this! /// simplifying this!
pub trait AstTransform<'a> { pub(crate) struct AstTransform<'a> {
fn get_substitution( pub(crate) subst: (hir::Trait, ast::Impl),
&self, pub(crate) target_scope: &'a SemanticsScope<'a>,
node: &SyntaxNode, pub(crate) source_scope: &'a SemanticsScope<'a>,
recur: &dyn AstTransform<'a>, }
) -> Option<SyntaxNode>;
fn or<T: AstTransform<'a> + 'a>(self, other: T) -> Box<dyn AstTransform<'a> + 'a> impl<'a> AstTransform<'a> {
where pub(crate) fn apply(&self, item: ast::AssocItem) {
Self: Sized + 'a, if let Some(ctx) = self.build_ctx() {
{ ctx.apply(item)
Box::new(Or(Box::new(self), Box::new(other))) }
} }
} fn build_ctx(&self) -> Option<Ctx<'a>> {
let db = self.source_scope.db;
let target_module = self.target_scope.module()?;
let source_module = self.source_scope.module()?;
struct Or<'a>(Box<dyn AstTransform<'a> + 'a>, Box<dyn AstTransform<'a> + 'a>); let substs = get_syntactic_substs(self.subst.1.clone()).unwrap_or_default();
let generic_def: hir::GenericDef = self.subst.0.into();
impl<'a> AstTransform<'a> for Or<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
self.0.get_substitution(node, recur).or_else(|| self.1.get_substitution(node, recur))
}
}
pub struct SubstituteTypeParams<'a> {
source_scope: &'a SemanticsScope<'a>,
substs: FxHashMap<hir::TypeParam, ast::Type>,
}
impl<'a> SubstituteTypeParams<'a> {
pub fn for_trait_impl(
source_scope: &'a SemanticsScope<'a>,
// FIXME: there's implicit invariant that `trait_` and `source_scope` match...
trait_: hir::Trait,
impl_def: ast::Impl,
) -> SubstituteTypeParams<'a> {
let substs = get_syntactic_substs(impl_def).unwrap_or_default();
let generic_def: hir::GenericDef = trait_.into();
let substs_by_param: FxHashMap<_, _> = generic_def let substs_by_param: FxHashMap<_, _> = generic_def
.type_params(source_scope.db) .type_params(db)
.into_iter() .into_iter()
// this is a trait impl, so we need to skip the first type parameter -- this is a bit hacky // this is a trait impl, so we need to skip the first type parameter -- this is a bit hacky
.skip(1) .skip(1)
@ -110,109 +69,96 @@ impl<'a> SubstituteTypeParams<'a> {
.filter_map(|(k, v)| match v { .filter_map(|(k, v)| match v {
Some(v) => Some((k, v)), Some(v) => Some((k, v)),
None => { None => {
let default = k.default(source_scope.db)?; let default = k.default(db)?;
Some(( Some((
k, k,
ast::make::ty( ast::make::ty(&default.display_source_code(db, source_module.into()).ok()?),
&default
.display_source_code(source_scope.db, source_scope.module()?.into())
.ok()?,
),
)) ))
} }
}) })
.collect(); .collect();
return SubstituteTypeParams { source_scope, substs: substs_by_param };
// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the let res = Ctx { substs: substs_by_param, target_module, source_scope: self.source_scope };
// trait ref, and then go from the types in the substs back to the syntax). Some(res)
fn get_syntactic_substs(impl_def: ast::Impl) -> Option<Vec<ast::Type>> {
let target_trait = impl_def.trait_()?;
let path_type = match target_trait {
ast::Type::PathType(path) => path,
_ => return None,
};
let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
let mut result = Vec::new();
for generic_arg in generic_arg_list.generic_args() {
match generic_arg {
ast::GenericArg::TypeArg(type_arg) => result.push(type_arg.ty()?),
ast::GenericArg::AssocTypeArg(_)
| ast::GenericArg::LifetimeArg(_)
| ast::GenericArg::ConstArg(_) => (),
}
}
Some(result)
}
} }
} }
impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> { struct Ctx<'a> {
fn get_substitution( substs: FxHashMap<hir::TypeParam, ast::Type>,
&self, target_module: hir::Module,
node: &SyntaxNode,
_recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
let type_ref = ast::Type::cast(node.clone())?;
let path = match &type_ref {
ast::Type::PathType(path_type) => path_type.path()?,
_ => return None,
};
let resolution = self.source_scope.speculative_resolve(&path)?;
match resolution {
hir::PathResolution::TypeParam(tp) => Some(self.substs.get(&tp)?.syntax().clone()),
_ => None,
}
}
}
pub struct QualifyPaths<'a> {
target_scope: &'a SemanticsScope<'a>,
source_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>,
} }
impl<'a> QualifyPaths<'a> { impl<'a> Ctx<'a> {
pub fn new(target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>) -> Self { fn apply(&self, item: ast::AssocItem) {
Self { target_scope, source_scope } for event in item.syntax().preorder() {
let node = match event {
syntax::WalkEvent::Enter(_) => continue,
syntax::WalkEvent::Leave(it) => it,
};
if let Some(path) = ast::Path::cast(node.clone()) {
self.transform_path(path);
}
}
} }
} fn transform_path(&self, path: ast::Path) -> Option<()> {
if path.qualifier().is_some() {
impl<'a> AstTransform<'a> for QualifyPaths<'a> { return None;
fn get_substitution( }
&self, if path.segment().and_then(|s| s.param_list()).is_some() {
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
// FIXME handle value ns?
let from = self.target_scope.module()?;
let p = ast::Path::cast(node.clone())?;
if p.segment().and_then(|s| s.param_list()).is_some() {
// don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway // don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway
return None; return None;
} }
let resolution = self.source_scope.speculative_resolve(&p)?;
let resolution = self.source_scope.speculative_resolve(&path)?;
match resolution { match resolution {
PathResolution::Def(def) => { hir::PathResolution::TypeParam(tp) => {
let found_path = from.find_use_path(self.source_scope.db.upcast(), def)?; if let Some(subst) = self.substs.get(&tp) {
let mut path = mod_path_to_ast(&found_path); ted::replace(path.syntax(), subst.clone_subtree().clone_for_update().syntax())
let type_args = p.segment().and_then(|s| s.generic_arg_list());
if let Some(type_args) = type_args {
apply(recur, &type_args);
let last_segment = path.segment().unwrap();
path = path.with_segment(last_segment.with_generic_args(type_args))
} }
Some(path.syntax().clone())
} }
PathResolution::Local(_) hir::PathResolution::Def(def) => {
| PathResolution::TypeParam(_) let found_path =
| PathResolution::SelfType(_) self.target_module.find_use_path(self.source_scope.db.upcast(), def)?;
| PathResolution::ConstParam(_) => None, let res = mod_path_to_ast(&found_path).clone_for_update();
PathResolution::Macro(_) => None, if let Some(args) = path.segment().and_then(|it| it.generic_arg_list()) {
PathResolution::AssocItem(_) => None, if let Some(segment) = res.segment() {
let old = segment.get_or_create_generic_arg_list();
ted::replace(old.syntax(), args.clone_subtree().syntax().clone_for_update())
}
}
ted::replace(path.syntax(), res.syntax())
}
hir::PathResolution::Local(_)
| hir::PathResolution::ConstParam(_)
| hir::PathResolution::SelfType(_)
| hir::PathResolution::Macro(_)
| hir::PathResolution::AssocItem(_) => (),
} }
Some(())
} }
} }
// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
// trait ref, and then go from the types in the substs back to the syntax).
fn get_syntactic_substs(impl_def: ast::Impl) -> Option<Vec<ast::Type>> {
let target_trait = impl_def.trait_()?;
let path_type = match target_trait {
ast::Type::PathType(path) => path,
_ => return None,
};
let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
let mut result = Vec::new();
for generic_arg in generic_arg_list.generic_args() {
match generic_arg {
ast::GenericArg::TypeArg(type_arg) => result.push(type_arg.ty()?),
ast::GenericArg::AssocTypeArg(_)
| ast::GenericArg::LifetimeArg(_)
| ast::GenericArg::ConstArg(_) => (),
}
}
Some(result)
}

View file

@ -24,7 +24,7 @@ use syntax::{
use crate::{ use crate::{
assist_context::{AssistBuilder, AssistContext}, assist_context::{AssistBuilder, AssistContext},
ast_transform::{self, AstTransform, QualifyPaths, SubstituteTypeParams}, ast_transform::AstTransform,
}; };
pub(crate) fn unwrap_trivial_block(block: ast::BlockExpr) -> ast::Expr { pub(crate) fn unwrap_trivial_block(block: ast::BlockExpr) -> ast::Expr {
@ -132,14 +132,18 @@ pub fn add_trait_assoc_items_to_impl(
target_scope: hir::SemanticsScope, target_scope: hir::SemanticsScope,
) -> (ast::Impl, ast::AssocItem) { ) -> (ast::Impl, ast::AssocItem) {
let source_scope = sema.scope_for_def(trait_); let source_scope = sema.scope_for_def(trait_);
let ast_transform = QualifyPaths::new(&target_scope, &source_scope)
.or(SubstituteTypeParams::for_trait_impl(&source_scope, trait_, impl_.clone()));
let items = items let transform = AstTransform {
.into_iter() subst: (trait_, impl_.clone()),
.map(|it| it.clone_for_update()) source_scope: &source_scope,
.inspect(|it| ast_transform::apply(&*ast_transform, it)) target_scope: &target_scope,
.map(|it| edit::remove_attrs_and_docs(&it).clone_subtree().clone_for_update()); };
let items = items.into_iter().map(|assoc_item| {
let assoc_item = assoc_item.clone_for_update();
transform.apply(assoc_item.clone());
edit::remove_attrs_and_docs(&assoc_item).clone_subtree().clone_for_update()
});
let res = impl_.clone_for_update(); let res = impl_.clone_for_update();

View file

@ -6,14 +6,12 @@ use std::{
ops::{self, RangeInclusive}, ops::{self, RangeInclusive},
}; };
use arrayvec::ArrayVec;
use crate::{ use crate::{
algo, algo,
ast::{self, make, AstNode}, ast::{self, make, AstNode},
ted, AstToken, InsertPosition, NodeOrToken, SyntaxElement, SyntaxKind, ted, AstToken, NodeOrToken, SyntaxElement, SyntaxKind,
SyntaxKind::{ATTR, COMMENT, WHITESPACE}, SyntaxKind::{ATTR, COMMENT, WHITESPACE},
SyntaxNode, SyntaxToken, T, SyntaxNode, SyntaxToken,
}; };
impl ast::BinExpr { impl ast::BinExpr {
@ -25,46 +23,6 @@ impl ast::BinExpr {
} }
} }
impl ast::Path {
#[must_use]
pub fn with_segment(&self, segment: ast::PathSegment) -> ast::Path {
if let Some(old) = self.segment() {
return self.replace_children(
single_node(old.syntax().clone()),
iter::once(segment.syntax().clone().into()),
);
}
self.clone()
}
}
impl ast::PathSegment {
#[must_use]
pub fn with_generic_args(&self, type_args: ast::GenericArgList) -> ast::PathSegment {
self._with_generic_args(type_args, false)
}
#[must_use]
pub fn with_turbo_fish(&self, type_args: ast::GenericArgList) -> ast::PathSegment {
self._with_generic_args(type_args, true)
}
fn _with_generic_args(&self, type_args: ast::GenericArgList, turbo: bool) -> ast::PathSegment {
if let Some(old) = self.generic_arg_list() {
return self.replace_children(
single_node(old.syntax().clone()),
iter::once(type_args.syntax().clone().into()),
);
}
let mut to_insert: ArrayVec<SyntaxElement, 2> = ArrayVec::new();
if turbo {
to_insert.push(make::token(T![::]).into());
}
to_insert.push(type_args.syntax().clone().into());
self.insert_children(InsertPosition::Last, to_insert)
}
}
impl ast::UseTree { impl ast::UseTree {
/// Splits off the given prefix, making it the path component of the use tree, appending the rest of the path to all UseTreeList items. /// Splits off the given prefix, making it the path component of the use tree, appending the rest of the path to all UseTreeList items.
#[must_use] #[must_use]
@ -233,16 +191,6 @@ fn prev_tokens(token: SyntaxToken) -> impl Iterator<Item = SyntaxToken> {
} }
pub trait AstNodeEdit: AstNode + Clone + Sized { pub trait AstNodeEdit: AstNode + Clone + Sized {
#[must_use]
fn insert_children(
&self,
position: InsertPosition<SyntaxElement>,
to_insert: impl IntoIterator<Item = SyntaxElement>,
) -> Self {
let new_syntax = algo::insert_children(self.syntax(), position, to_insert);
Self::cast(new_syntax).unwrap()
}
#[must_use] #[must_use]
fn replace_children( fn replace_children(
&self, &self,

View file

@ -239,6 +239,16 @@ impl ast::TypeBoundList {
} }
} }
impl ast::PathSegment {
pub fn get_or_create_generic_arg_list(&self) -> ast::GenericArgList {
if self.generic_arg_list().is_none() {
let arg_list = make::generic_arg_list().clone_for_update();
ted::append_child(self.syntax(), arg_list.syntax())
}
self.generic_arg_list().unwrap()
}
}
impl ast::UseTree { impl ast::UseTree {
pub fn remove(&self) { pub fn remove(&self) {
for &dir in [Direction::Next, Direction::Prev].iter() { for &dir in [Direction::Next, Direction::Prev].iter() {

View file

@ -106,6 +106,10 @@ pub fn impl_trait(trait_: ast::Path, ty: ast::Path) -> ast::Impl {
ast_from_text(&format!("impl {} for {} {{}}", trait_, ty)) ast_from_text(&format!("impl {} for {} {{}}", trait_, ty))
} }
pub(crate) fn generic_arg_list() -> ast::GenericArgList {
ast_from_text("const S: T<> = ();")
}
pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment { pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment {
ast_from_text(&format!("use {};", name_ref)) ast_from_text(&format!("use {};", name_ref))
} }

View file

@ -184,6 +184,9 @@ fn ws_between(left: &SyntaxElement, right: &SyntaxElement) -> Option<SyntaxToken
if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME { if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME {
return None; return None;
} }
if right.kind() == SyntaxKind::GENERIC_ARG_LIST {
return None;
}
if right.kind() == SyntaxKind::USE { if right.kind() == SyntaxKind::USE {
let mut indent = IndentLevel::from_element(left); let mut indent = IndentLevel::from_element(left);