Add mbe expand limit and poision macro set

This commit is contained in:
Edwin Cheng 2019-04-22 15:33:55 +08:00
parent bbc5c1d24e
commit b177813f3b
5 changed files with 216 additions and 18 deletions

View file

@ -94,6 +94,13 @@ fn parse_macro(
let macro_rules = db.macro_def(loc.def).ok_or("Fail to find macro definition")?; let macro_rules = db.macro_def(loc.def).ok_or("Fail to find macro definition")?;
let tt = macro_rules.expand(&macro_arg).map_err(|err| format!("{:?}", err))?; let tt = macro_rules.expand(&macro_arg).map_err(|err| format!("{:?}", err))?;
// Set a hard limit for the expanded tt
let count = tt.count();
if count > 65536 {
return Err(format!("Total tokens count exceed limit : count = {}", count));
}
Ok(mbe::token_tree_to_ast_item_list(&tt)) Ok(mbe::token_tree_to_ast_item_list(&tt))
} }

View file

@ -55,7 +55,7 @@ mod tests;
use std::sync::Arc; use std::sync::Arc;
use rustc_hash::FxHashMap; use rustc_hash::{FxHashMap, FxHashSet};
use ra_arena::{Arena, RawId, impl_arena_id}; use ra_arena::{Arena, RawId, impl_arena_id};
use ra_db::{FileId, Edition}; use ra_db::{FileId, Edition};
use test_utils::tested_by; use test_utils::tested_by;
@ -91,6 +91,19 @@ pub struct CrateDefMap {
root: CrateModuleId, root: CrateModuleId,
modules: Arena<CrateModuleId, ModuleData>, modules: Arena<CrateModuleId, ModuleData>,
public_macros: FxHashMap<Name, MacroDefId>, public_macros: FxHashMap<Name, MacroDefId>,
/// Some macros are not well-behavior, which leads to infinite loop
/// e.g. macro_rules! foo { ($ty:ty) => { foo!($ty); } }
/// We mark it down and skip it in collector
///
/// FIXME:
/// Right now it only handle a poison macro in a single crate,
/// such that if other crate try to call that macro,
/// the whole process will do again until it became poisoned in that crate.
/// We should handle this macro set globally
/// However, do we want to put it as a global variable?
poison_macros: FxHashSet<MacroDefId>,
diagnostics: Vec<DefDiagnostic>, diagnostics: Vec<DefDiagnostic>,
} }
@ -195,6 +208,7 @@ impl CrateDefMap {
root, root,
modules, modules,
public_macros: FxHashMap::default(), public_macros: FxHashMap::default(),
poison_macros: FxHashSet::default(),
diagnostics: Vec::new(), diagnostics: Vec::new(),
} }
}; };

View file

@ -42,14 +42,40 @@ pub(super) fn collect_defs(db: &impl DefDatabase, mut def_map: CrateDefMap) -> C
unresolved_imports: Vec::new(), unresolved_imports: Vec::new(),
unexpanded_macros: Vec::new(), unexpanded_macros: Vec::new(),
global_macro_scope: FxHashMap::default(), global_macro_scope: FxHashMap::default(),
marco_stack_count: 0, macro_stack_monitor: SimpleMacroStackMonitor::default(),
}; };
collector.collect(); collector.collect();
collector.finish() collector.finish()
} }
trait MacroStackMonitor {
fn increase(&mut self, macro_def_id: MacroDefId);
fn decrease(&mut self, macro_def_id: MacroDefId);
fn is_poison(&self, macro_def_id: MacroDefId) -> bool;
}
#[derive(Default)]
struct SimpleMacroStackMonitor {
counts: FxHashMap<MacroDefId, u32>,
}
impl MacroStackMonitor for SimpleMacroStackMonitor {
fn increase(&mut self, macro_def_id: MacroDefId) {
*self.counts.entry(macro_def_id).or_default() += 1;
}
fn decrease(&mut self, macro_def_id: MacroDefId) {
*self.counts.entry(macro_def_id).or_default() -= 1;
}
fn is_poison(&self, macro_def_id: MacroDefId) -> bool {
*self.counts.get(&macro_def_id).unwrap_or(&0) > 100
}
}
/// Walks the tree of module recursively /// Walks the tree of module recursively
struct DefCollector<DB> { struct DefCollector<DB, M> {
db: DB, db: DB,
def_map: CrateDefMap, def_map: CrateDefMap,
glob_imports: FxHashMap<CrateModuleId, Vec<(CrateModuleId, raw::ImportId)>>, glob_imports: FxHashMap<CrateModuleId, Vec<(CrateModuleId, raw::ImportId)>>,
@ -59,12 +85,13 @@ struct DefCollector<DB> {
/// Some macro use `$tt:tt which mean we have to handle the macro perfectly /// Some macro use `$tt:tt which mean we have to handle the macro perfectly
/// To prevent stackoverflow, we add a deep counter here for prevent that. /// To prevent stackoverflow, we add a deep counter here for prevent that.
marco_stack_count: u32, macro_stack_monitor: M,
} }
impl<'a, DB> DefCollector<&'a DB> impl<'a, DB, M> DefCollector<&'a DB, M>
where where
DB: DefDatabase, DB: DefDatabase,
M: MacroStackMonitor,
{ {
fn collect(&mut self) { fn collect(&mut self) {
let crate_graph = self.db.crate_graph(); let crate_graph = self.db.crate_graph();
@ -317,30 +344,40 @@ where
let def_map = self.db.crate_def_map(krate); let def_map = self.db.crate_def_map(krate);
if let Some(macro_id) = def_map.public_macros.get(&path.segments[1].name).cloned() { if let Some(macro_id) = def_map.public_macros.get(&path.segments[1].name).cloned() {
let call_id = MacroCallLoc { def: macro_id, ast_id: *ast_id }.id(self.db); let call_id = MacroCallLoc { def: macro_id, ast_id: *ast_id }.id(self.db);
resolved.push((*module_id, call_id)); resolved.push((*module_id, call_id, macro_id));
} }
false false
}); });
for (module_id, macro_call_id) in resolved { for (module_id, macro_call_id, macro_def_id) in resolved {
self.collect_macro_expansion(module_id, macro_call_id); self.collect_macro_expansion(module_id, macro_call_id, macro_def_id);
} }
res res
} }
fn collect_macro_expansion(&mut self, module_id: CrateModuleId, macro_call_id: MacroCallId) { fn collect_macro_expansion(
self.marco_stack_count += 1; &mut self,
module_id: CrateModuleId,
macro_call_id: MacroCallId,
macro_def_id: MacroDefId,
) {
if self.def_map.poison_macros.contains(&macro_def_id) {
return;
}
if self.marco_stack_count < 300 { self.macro_stack_monitor.increase(macro_def_id);
if !self.macro_stack_monitor.is_poison(macro_def_id) {
let file_id: HirFileId = macro_call_id.into(); let file_id: HirFileId = macro_call_id.into();
let raw_items = self.db.raw_items(file_id); let raw_items = self.db.raw_items(file_id);
ModCollector { def_collector: &mut *self, file_id, module_id, raw_items: &raw_items } ModCollector { def_collector: &mut *self, file_id, module_id, raw_items: &raw_items }
.collect(raw_items.items()) .collect(raw_items.items());
} else { } else {
log::error!("Too deep macro expansion: {}", macro_call_id.debug_dump(self.db)); log::error!("Too deep macro expansion: {}", macro_call_id.debug_dump(self.db));
self.def_map.poison_macros.insert(macro_def_id);
} }
self.marco_stack_count -= 1; self.macro_stack_monitor.decrease(macro_def_id);
} }
fn finish(self) -> CrateDefMap { fn finish(self) -> CrateDefMap {
@ -356,9 +393,10 @@ struct ModCollector<'a, D> {
raw_items: &'a raw::RawItems, raw_items: &'a raw::RawItems,
} }
impl<DB> ModCollector<'_, &'_ mut DefCollector<&'_ DB>> impl<DB, M> ModCollector<'_, &'_ mut DefCollector<&'_ DB, M>>
where where
DB: DefDatabase, DB: DefDatabase,
M: MacroStackMonitor,
{ {
fn collect(&mut self, items: &[raw::RawItem]) { fn collect(&mut self, items: &[raw::RawItem]) {
for item in items { for item in items {
@ -484,7 +522,7 @@ where
{ {
let macro_call_id = MacroCallLoc { def: macro_id, ast_id }.id(self.def_collector.db); let macro_call_id = MacroCallLoc { def: macro_id, ast_id }.id(self.def_collector.db);
self.def_collector.collect_macro_expansion(self.module_id, macro_call_id); self.def_collector.collect_macro_expansion(self.module_id, macro_call_id, macro_id);
return; return;
} }
@ -530,3 +568,123 @@ fn resolve_submodule(
None => Err(if is_dir_owner { file_mod } else { file_dir_mod }), None => Err(if is_dir_owner { file_mod } else { file_dir_mod }),
} }
} }
#[cfg(test)]
mod tests {
use ra_db::SourceDatabase;
use crate::{Crate, mock::MockDatabase, DefDatabase};
use ra_arena::{Arena};
use super::*;
use rustc_hash::FxHashSet;
struct LimitedMacroStackMonitor {
count: u32,
limit: u32,
poison_limit: u32,
}
impl MacroStackMonitor for LimitedMacroStackMonitor {
fn increase(&mut self, _: MacroDefId) {
self.count += 1;
assert!(self.count < self.limit);
}
fn decrease(&mut self, _: MacroDefId) {
self.count -= 1;
}
fn is_poison(&self, _: MacroDefId) -> bool {
self.count >= self.poison_limit
}
}
fn do_collect_defs(
db: &impl DefDatabase,
def_map: CrateDefMap,
monitor: impl MacroStackMonitor,
) -> CrateDefMap {
let mut collector = DefCollector {
db,
def_map,
glob_imports: FxHashMap::default(),
unresolved_imports: Vec::new(),
unexpanded_macros: Vec::new(),
global_macro_scope: FxHashMap::default(),
macro_stack_monitor: monitor,
};
collector.collect();
collector.finish()
}
fn do_limited_resolve(code: &str, limit: u32, poison_limit: u32) -> CrateDefMap {
let (db, _source_root, _) = MockDatabase::with_single_file(&code);
let crate_id = db.crate_graph().iter().next().unwrap();
let krate = Crate { crate_id };
let def_map = {
let edition = krate.edition(&db);
let mut modules: Arena<CrateModuleId, ModuleData> = Arena::default();
let root = modules.alloc(ModuleData::default());
CrateDefMap {
krate,
edition,
extern_prelude: FxHashMap::default(),
prelude: None,
root,
modules,
public_macros: FxHashMap::default(),
poison_macros: FxHashSet::default(),
diagnostics: Vec::new(),
}
};
do_collect_defs(&db, def_map, LimitedMacroStackMonitor { count: 0, limit, poison_limit })
}
#[test]
fn test_macro_expand_limit_width() {
do_limited_resolve(
r#"
macro_rules! foo {
($($ty:ty)*) => { foo!($($ty)*, $($ty)*); }
}
foo!(KABOOM);
"#,
16,
1000,
);
}
#[test]
fn test_macro_expand_poisoned() {
let def = do_limited_resolve(
r#"
macro_rules! foo {
($ty:ty) => { foo!($ty); }
}
foo!(KABOOM);
"#,
100,
16,
);
assert_eq!(def.poison_macros.len(), 1);
}
#[test]
fn test_macro_expand_normal() {
let def = do_limited_resolve(
r#"
macro_rules! foo {
($ident:ident) => { struct $ident {} }
}
foo!(Bar);
"#,
16,
16,
);
assert_eq!(def.poison_macros.len(), 0);
}
}

View file

@ -5,6 +5,7 @@ use ra_syntax::{SyntaxKind};
struct OffsetTokenSink { struct OffsetTokenSink {
token_pos: usize, token_pos: usize,
error: bool,
} }
impl TreeSink for OffsetTokenSink { impl TreeSink for OffsetTokenSink {
@ -13,7 +14,9 @@ impl TreeSink for OffsetTokenSink {
} }
fn start_node(&mut self, _kind: SyntaxKind) {} fn start_node(&mut self, _kind: SyntaxKind) {}
fn finish_node(&mut self) {} fn finish_node(&mut self) {}
fn error(&mut self, _error: ra_parser::ParseError) {} fn error(&mut self, _error: ra_parser::ParseError) {
self.error = true;
}
} }
pub(crate) struct Parser<'a> { pub(crate) struct Parser<'a> {
@ -67,11 +70,15 @@ impl<'a> Parser<'a> {
F: FnOnce(&dyn TokenSource, &mut dyn TreeSink), F: FnOnce(&dyn TokenSource, &mut dyn TreeSink),
{ {
let mut src = SubtreeTokenSource::new(&self.subtree.token_trees[*self.cur_pos..]); let mut src = SubtreeTokenSource::new(&self.subtree.token_trees[*self.cur_pos..]);
let mut sink = OffsetTokenSink { token_pos: 0 }; let mut sink = OffsetTokenSink { token_pos: 0, error: false };
f(&src, &mut sink); f(&src, &mut sink);
self.finish(sink.token_pos, &mut src) let r = self.finish(sink.token_pos, &mut src);
if sink.error {
return None;
}
r
} }
fn finish(self, parsed_token: usize, src: &mut SubtreeTokenSource) -> Option<tt::TokenTree> { fn finish(self, parsed_token: usize, src: &mut SubtreeTokenSource) -> Option<tt::TokenTree> {

View file

@ -149,3 +149,15 @@ impl fmt::Display for Punct {
fmt::Display::fmt(&self.char, f) fmt::Display::fmt(&self.char, f)
} }
} }
impl Subtree {
/// Count the number of tokens recursively
pub fn count(&self) -> usize {
self.token_trees.iter().fold(self.token_trees.len(), |acc, c| {
acc + match c {
TokenTree::Subtree(c) => c.count(),
_ => 0,
}
})
}
}