8201: Fix recursive macro statements expansion r=edwin0cheng a=edwin0cheng

This PR attempts to properly handle macro statement expansion by implementing the following:

1.  Merge macro expanded statements to parent scope statements.
2.  Add a new hir `Expr::MacroStmts` for handle tail expression infer.

PS : The scope of macro expanded statements are so strange that it took more time than I thought to understand and implement it :(

Fixes  #8171



Co-authored-by: Edwin Cheng <edwin0cheng@gmail.com>
This commit is contained in:
bors[bot] 2021-03-27 02:57:02 +00:00 committed by GitHub
commit c8066ebd17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 119 additions and 70 deletions

View file

@ -74,6 +74,7 @@ pub(super) fn lower(
_c: Count::new(), _c: Count::new(),
}, },
expander, expander,
statements_in_scope: Vec::new(),
} }
.collect(params, body) .collect(params, body)
} }
@ -83,6 +84,7 @@ struct ExprCollector<'a> {
expander: Expander, expander: Expander,
body: Body, body: Body,
source_map: BodySourceMap, source_map: BodySourceMap,
statements_in_scope: Vec<Statement>,
} }
impl ExprCollector<'_> { impl ExprCollector<'_> {
@ -533,15 +535,13 @@ impl ExprCollector<'_> {
ids[0] ids[0]
} }
ast::Expr::MacroStmts(e) => { ast::Expr::MacroStmts(e) => {
// FIXME: these statements should be held by some hir containter e.statements().for_each(|s| self.collect_stmt(s));
for stmt in e.statements() { let tail = e
self.collect_stmt(stmt); .expr()
} .map(|e| self.collect_expr(e))
if let Some(expr) = e.expr() { .unwrap_or_else(|| self.alloc_expr(Expr::Missing, syntax_ptr.clone()));
self.collect_expr(expr)
} else { self.alloc_expr(Expr::MacroStmts { tail }, syntax_ptr)
self.alloc_expr(Expr::Missing, syntax_ptr)
}
} }
}) })
} }
@ -618,58 +618,54 @@ impl ExprCollector<'_> {
} }
} }
fn collect_stmt(&mut self, s: ast::Stmt) -> Option<Vec<Statement>> { fn collect_stmt(&mut self, s: ast::Stmt) {
let stmt = match s { match s {
ast::Stmt::LetStmt(stmt) => { ast::Stmt::LetStmt(stmt) => {
self.check_cfg(&stmt)?; if self.check_cfg(&stmt).is_none() {
return;
}
let pat = self.collect_pat_opt(stmt.pat()); let pat = self.collect_pat_opt(stmt.pat());
let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it)); let type_ref = stmt.ty().map(|it| TypeRef::from_ast(&self.ctx(), it));
let initializer = stmt.initializer().map(|e| self.collect_expr(e)); let initializer = stmt.initializer().map(|e| self.collect_expr(e));
vec![Statement::Let { pat, type_ref, initializer }] self.statements_in_scope.push(Statement::Let { pat, type_ref, initializer });
} }
ast::Stmt::ExprStmt(stmt) => { ast::Stmt::ExprStmt(stmt) => {
self.check_cfg(&stmt)?; if self.check_cfg(&stmt).is_none() {
return;
}
// Note that macro could be expended to multiple statements // Note that macro could be expended to multiple statements
if let Some(ast::Expr::MacroCall(m)) = stmt.expr() { if let Some(ast::Expr::MacroCall(m)) = stmt.expr() {
let syntax_ptr = AstPtr::new(&stmt.expr().unwrap()); let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());
let mut stmts = vec![];
self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| { self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| {
match expansion { match expansion {
Some(expansion) => { Some(expansion) => {
let statements: ast::MacroStmts = expansion; let statements: ast::MacroStmts = expansion;
statements.statements().for_each(|stmt| { statements.statements().for_each(|stmt| this.collect_stmt(stmt));
if let Some(mut r) = this.collect_stmt(stmt) {
stmts.append(&mut r);
}
});
if let Some(expr) = statements.expr() { if let Some(expr) = statements.expr() {
stmts.push(Statement::Expr(this.collect_expr(expr))); let expr = this.collect_expr(expr);
this.statements_in_scope.push(Statement::Expr(expr));
} }
} }
None => { None => {
stmts.push(Statement::Expr( let expr = this.alloc_expr(Expr::Missing, syntax_ptr.clone());
this.alloc_expr(Expr::Missing, syntax_ptr.clone()), this.statements_in_scope.push(Statement::Expr(expr));
));
} }
} }
}); });
stmts
} else { } else {
vec![Statement::Expr(self.collect_expr_opt(stmt.expr()))] let expr = self.collect_expr_opt(stmt.expr());
self.statements_in_scope.push(Statement::Expr(expr));
} }
} }
ast::Stmt::Item(item) => { ast::Stmt::Item(item) => {
self.check_cfg(&item)?; if self.check_cfg(&item).is_none() {
return;
return None; }
}
} }
};
Some(stmt)
} }
fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId { fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
@ -685,10 +681,12 @@ impl ExprCollector<'_> {
let module = if has_def_map { def_map.root() } else { self.expander.module }; let module = if has_def_map { def_map.root() } else { self.expander.module };
let prev_def_map = mem::replace(&mut self.expander.def_map, def_map); let prev_def_map = mem::replace(&mut self.expander.def_map, def_map);
let prev_local_module = mem::replace(&mut self.expander.module, module); let prev_local_module = mem::replace(&mut self.expander.module, module);
let prev_statements = std::mem::take(&mut self.statements_in_scope);
block.statements().for_each(|s| self.collect_stmt(s));
let statements =
block.statements().filter_map(|s| self.collect_stmt(s)).flatten().collect();
let tail = block.tail_expr().map(|e| self.collect_expr(e)); let tail = block.tail_expr().map(|e| self.collect_expr(e));
let statements = std::mem::replace(&mut self.statements_in_scope, prev_statements);
let syntax_node_ptr = AstPtr::new(&block.into()); let syntax_node_ptr = AstPtr::new(&block.into());
let expr_id = self.alloc_expr( let expr_id = self.alloc_expr(
Expr::Block { id: block_id, statements, tail, label: None }, Expr::Block { id: block_id, statements, tail, label: None },

View file

@ -171,6 +171,9 @@ pub enum Expr {
Unsafe { Unsafe {
body: ExprId, body: ExprId,
}, },
MacroStmts {
tail: ExprId,
},
Array(Array), Array(Array),
Literal(Literal), Literal(Literal),
} }
@ -357,6 +360,7 @@ impl Expr {
f(*repeat) f(*repeat)
} }
}, },
Expr::MacroStmts { tail } => f(*tail),
Expr::Literal(_) => {} Expr::Literal(_) => {}
} }
} }

View file

@ -110,15 +110,6 @@ impl ItemTree {
// still need to collect inner items. // still need to collect inner items.
ctx.lower_inner_items(e.syntax()) ctx.lower_inner_items(e.syntax())
}, },
ast::ExprStmt(stmt) => {
// Macros can expand to stmt. We return an empty item tree in this case, but
// still need to collect inner items.
ctx.lower_inner_items(stmt.syntax())
},
ast::Item(item) => {
// Macros can expand to stmt and other item, and we add it as top level item
ctx.lower_single_item(item)
},
_ => { _ => {
panic!("cannot create item tree from {:?} {}", syntax, syntax); panic!("cannot create item tree from {:?} {}", syntax, syntax);
}, },

View file

@ -87,14 +87,6 @@ impl Ctx {
self.tree self.tree
} }
pub(super) fn lower_single_item(mut self, item: ast::Item) -> ItemTree {
self.tree.top_level = self
.lower_mod_item(&item, false)
.map(|item| item.0)
.unwrap_or_else(|| Default::default());
self.tree
}
pub(super) fn lower_inner_items(mut self, within: &SyntaxNode) -> ItemTree { pub(super) fn lower_inner_items(mut self, within: &SyntaxNode) -> ItemTree {
self.collect_inner_items(within); self.collect_inner_items(within);
self.tree self.tree

View file

@ -5,7 +5,13 @@ use std::sync::Arc;
use base_db::{salsa, SourceDatabase}; use base_db::{salsa, SourceDatabase};
use mbe::{ExpandError, ExpandResult, MacroRules}; use mbe::{ExpandError, ExpandResult, MacroRules};
use parser::FragmentKind; use parser::FragmentKind;
use syntax::{algo::diff, ast::NameOwner, AstNode, GreenNode, Parse, SyntaxKind::*, SyntaxNode}; use syntax::{
algo::diff,
ast::{MacroStmts, NameOwner},
AstNode, GreenNode, Parse,
SyntaxKind::*,
SyntaxNode,
};
use crate::{ use crate::{
ast_id_map::AstIdMap, hygiene::HygieneFrame, BuiltinDeriveExpander, BuiltinFnLikeExpander, ast_id_map::AstIdMap, hygiene::HygieneFrame, BuiltinDeriveExpander, BuiltinFnLikeExpander,
@ -340,13 +346,19 @@ fn parse_macro_with_arg(
None => return ExpandResult { value: None, err: result.err }, None => return ExpandResult { value: None, err: result.err },
}; };
log::debug!("expanded = {}", tt.as_debug_string());
let fragment_kind = to_fragment_kind(db, macro_call_id); let fragment_kind = to_fragment_kind(db, macro_call_id);
log::debug!("expanded = {}", tt.as_debug_string());
log::debug!("kind = {:?}", fragment_kind);
let (parse, rev_token_map) = match mbe::token_tree_to_syntax_node(&tt, fragment_kind) { let (parse, rev_token_map) = match mbe::token_tree_to_syntax_node(&tt, fragment_kind) {
Ok(it) => it, Ok(it) => it,
Err(err) => { Err(err) => {
log::debug!(
"failed to parse expanstion to {:?} = {}",
fragment_kind,
tt.as_debug_string()
);
return ExpandResult::only_err(err); return ExpandResult::only_err(err);
} }
}; };
@ -362,16 +374,35 @@ fn parse_macro_with_arg(
return ExpandResult::only_err(err); return ExpandResult::only_err(err);
} }
}; };
if is_self_replicating(&node, &call_node.value) {
if !diff(&node, &call_node.value).is_empty() {
ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) }
} else {
return ExpandResult::only_err(err); return ExpandResult::only_err(err);
} else {
ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: Some(err) }
} }
} }
None => ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None }, None => {
log::debug!("parse = {:?}", parse.syntax_node().kind());
ExpandResult { value: Some((parse, Arc::new(rev_token_map))), err: None }
} }
} }
}
fn is_self_replicating(from: &SyntaxNode, to: &SyntaxNode) -> bool {
if diff(from, to).is_empty() {
return true;
}
if let Some(stmts) = MacroStmts::cast(from.clone()) {
if stmts.statements().any(|stmt| diff(stmt.syntax(), to).is_empty()) {
return true;
}
if let Some(expr) = stmts.expr() {
if diff(expr.syntax(), to).is_empty() {
return true;
}
}
}
false
}
fn hygiene_frame(db: &dyn AstDatabase, file_id: HirFileId) -> Arc<HygieneFrame> { fn hygiene_frame(db: &dyn AstDatabase, file_id: HirFileId) -> Arc<HygieneFrame> {
Arc::new(HygieneFrame::new(db, file_id)) Arc::new(HygieneFrame::new(db, file_id))
@ -390,21 +421,15 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind {
let parent = match syn.parent() { let parent = match syn.parent() {
Some(it) => it, Some(it) => it,
None => { None => return FragmentKind::Statements,
// FIXME:
// If it is root, which means the parent HirFile
// MacroKindFile must be non-items
// return expr now.
return FragmentKind::Expr;
}
}; };
match parent.kind() { match parent.kind() {
MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items, MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items,
MACRO_STMTS => FragmentKind::Statement, MACRO_STMTS => FragmentKind::Statements,
ITEM_LIST => FragmentKind::Items, ITEM_LIST => FragmentKind::Items,
LET_STMT => { LET_STMT => {
// FIXME: Handle Pattern // FIXME: Handle LHS Pattern
FragmentKind::Expr FragmentKind::Expr
} }
EXPR_STMT => FragmentKind::Statements, EXPR_STMT => FragmentKind::Statements,

View file

@ -767,6 +767,7 @@ impl<'a> InferenceContext<'a> {
None => self.table.new_float_var(), None => self.table.new_float_var(),
}, },
}, },
Expr::MacroStmts { tail } => self.infer_expr(*tail, expected),
}; };
// use a new type variable if we got unknown here // use a new type variable if we got unknown here
let ty = self.insert_type_vars_shallow(ty); let ty = self.insert_type_vars_shallow(ty);

View file

@ -226,11 +226,48 @@ fn expr_macro_expanded_in_stmts() {
"#, "#,
expect![[r#" expect![[r#"
!0..8 'leta=();': () !0..8 'leta=();': ()
!0..8 'leta=();': ()
!3..4 'a': ()
!5..7 '()': ()
57..84 '{ ...); } }': () 57..84 '{ ...); } }': ()
"#]], "#]],
); );
} }
#[test]
fn recurisve_macro_expanded_in_stmts() {
check_infer(
r#"
macro_rules! ng {
([$($tts:tt)*]) => {
$($tts)*;
};
([$($tts:tt)*] $head:tt $($rest:tt)*) => {
ng! {
[$($tts)* $head] $($rest)*
}
};
}
fn foo() {
ng!([] let a = 3);
let b = a;
}
"#,
expect![[r#"
!0..7 'leta=3;': {unknown}
!0..7 'leta=3;': {unknown}
!0..13 'ng!{[leta=3]}': {unknown}
!0..13 'ng!{[leta=]3}': {unknown}
!0..13 'ng!{[leta]=3}': {unknown}
!3..4 'a': i32
!5..6 '3': i32
196..237 '{ ...= a; }': ()
229..230 'b': i32
233..234 'a': i32
"#]],
);
}
#[test] #[test]
fn recursive_inner_item_macro_rules() { fn recursive_inner_item_macro_rules() {
check_infer( check_infer(
@ -246,7 +283,8 @@ fn recursive_inner_item_macro_rules() {
"#, "#,
expect![[r#" expect![[r#"
!0..1 '1': i32 !0..1 '1': i32
!0..7 'mac!($)': {unknown} !0..26 'macro_...>{1};}': {unknown}
!0..26 'macro_...>{1};}': {unknown}
107..143 '{ ...!(); }': () 107..143 '{ ...!(); }': ()
129..130 'a': i32 129..130 'a': i32
"#]], "#]],