From 9f7ef2050ec8b93b851d62982093a91d127f3882 Mon Sep 17 00:00:00 2001 From: Windel Bouwman Date: Sat, 20 Jul 2019 20:44:38 +0200 Subject: [PATCH] Add location to expressions. Change symboltable to use flags for symbols. --- src/compile.rs | 196 ++++++++++++++++---------------- src/symboltable.rs | 270 +++++++++++++++++++++++++-------------------- 2 files changed, 252 insertions(+), 214 deletions(-) diff --git a/src/compile.rs b/src/compile.rs index 8f3d695..b06262c 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -6,7 +6,7 @@ //! https://github.com/micropython/micropython/blob/master/py/compile.c use crate::error::{CompileError, CompileErrorType}; -use crate::symboltable::{make_symbol_table, statements_to_symbol_table, SymbolRole, SymbolScope}; +use crate::symboltable::{make_symbol_table, statements_to_symbol_table, Symbol, SymbolScope}; use num_complex::Complex64; use rustpython_bytecode::bytecode::{self, CallType, CodeObject, Instruction, Varargs}; use rustpython_parser::{ast, parser}; @@ -75,7 +75,7 @@ pub fn compile_program( /// Compile a single Python expression to bytecode pub fn compile_statement_eval( - statement: Vec, + statement: Vec, source_path: String, optimize: u8, ) -> Result { @@ -180,7 +180,7 @@ impl Compiler { for (i, statement) in program.statements.iter().enumerate() { let is_last = i == program.statements.len() - 1; - if let ast::Statement::Expression { ref expression } = statement.node { + if let ast::StatementType::Expression { ref expression } = statement.node { self.compile_expression(expression)?; if is_last { @@ -209,12 +209,12 @@ impl Compiler { // Compile statement in eval mode: fn compile_statement_eval( &mut self, - statements: &[ast::LocatedStatement], + statements: &[ast::Statement], symbol_table: SymbolScope, ) -> Result<(), CompileError> { self.scope_stack.push(symbol_table); for statement in statements { - if let ast::Statement::Expression { ref expression } = statement.node { + if let ast::StatementType::Expression { ref expression } = statement.node { self.compile_expression(expression)?; } else { return Err(CompileError { @@ -227,10 +227,7 @@ impl Compiler { Ok(()) } - fn compile_statements( - &mut self, - statements: &[ast::LocatedStatement], - ) -> Result<(), CompileError> { + fn compile_statements(&mut self, statements: &[ast::Statement]) -> Result<(), CompileError> { for statement in statements { self.compile_statement(statement)? } @@ -238,11 +235,13 @@ impl Compiler { } fn scope_for_name(&self, name: &str) -> bytecode::NameScope { - let role = self.lookup_name(name); - match role { - SymbolRole::Global => bytecode::NameScope::Global, - SymbolRole::Nonlocal => bytecode::NameScope::NonLocal, - _ => bytecode::NameScope::Local, + let symbol = self.lookup_name(name); + if symbol.is_global { + bytecode::NameScope::Global + } else if symbol.is_nonlocal { + bytecode::NameScope::NonLocal + } else { + bytecode::NameScope::Local } } @@ -262,12 +261,13 @@ impl Compiler { }); } - fn compile_statement(&mut self, statement: &ast::LocatedStatement) -> Result<(), CompileError> { + fn compile_statement(&mut self, statement: &ast::Statement) -> Result<(), CompileError> { trace!("Compiling {:?}", statement); self.set_source_location(&statement.location); + use ast::StatementType::*; match &statement.node { - ast::Statement::Import { names } => { + Import { names } => { // import a, b, c as d for name in names { self.emit(Instruction::Import { @@ -283,7 +283,7 @@ impl Compiler { } } } - ast::Statement::ImportFrom { + ImportFrom { level, module, names, @@ -326,16 +326,16 @@ impl Compiler { self.emit(Instruction::Pop); } } - ast::Statement::Expression { expression } => { + Expression { expression } => { self.compile_expression(expression)?; // Pop result of stack, since we not use it: self.emit(Instruction::Pop); } - ast::Statement::Global { .. } | ast::Statement::Nonlocal { .. } => { + Global { .. } | Nonlocal { .. } => { // Handled during symbol table construction. } - ast::Statement::If { test, body, orelse } => { + If { test, body, orelse } => { let end_label = self.new_label(); match orelse { None => { @@ -358,7 +358,7 @@ impl Compiler { } self.set_label(end_label); } - ast::Statement::While { test, body, orelse } => { + While { test, body, orelse } => { let start_label = self.new_label(); let else_label = self.new_label(); let end_label = self.new_label(); @@ -385,7 +385,7 @@ impl Compiler { } self.set_label(end_label); } - ast::Statement::With { items, body } => { + With { items, body } => { let end_label = self.new_label(); for item in items { self.compile_expression(&item.context_expr)?; @@ -406,16 +406,16 @@ impl Compiler { } self.set_label(end_label); } - ast::Statement::For { + For { target, iter, body, orelse, } => self.compile_for(target, iter, body, orelse)?, - ast::Statement::AsyncFor { .. } => { + AsyncFor { .. } => { unimplemented!("async for"); } - ast::Statement::Raise { exception, cause } => match exception { + Raise { exception, cause } => match exception { Some(value) => { self.compile_expression(value)?; match cause { @@ -432,30 +432,30 @@ impl Compiler { self.emit(Instruction::Raise { argc: 0 }); } }, - ast::Statement::Try { + Try { body, handlers, orelse, finalbody, } => self.compile_try_statement(body, handlers, orelse, finalbody)?, - ast::Statement::FunctionDef { + FunctionDef { name, args, body, decorator_list, returns, } => self.compile_function_def(name, args, body, decorator_list, returns)?, - ast::Statement::AsyncFunctionDef { .. } => { + AsyncFunctionDef { .. } => { unimplemented!("async def"); } - ast::Statement::ClassDef { + ClassDef { name, body, bases, keywords, decorator_list, } => self.compile_class_def(name, body, bases, keywords, decorator_list)?, - ast::Statement::Assert { test, msg } => { + Assert { test, msg } => { // if some flag, ignore all assert statements! if self.optimize == 0 { let end_label = self.new_label(); @@ -481,7 +481,7 @@ impl Compiler { self.set_label(end_label); } } - ast::Statement::Break => { + Break => { if !self.in_loop { return Err(CompileError { error: CompileErrorType::InvalidBreak, @@ -490,7 +490,7 @@ impl Compiler { } self.emit(Instruction::Break); } - ast::Statement::Continue => { + Continue => { if !self.in_loop { return Err(CompileError { error: CompileErrorType::InvalidContinue, @@ -499,7 +499,7 @@ impl Compiler { } self.emit(Instruction::Continue); } - ast::Statement::Return { value } => { + Return { value } => { if !self.in_function_def { return Err(CompileError { error: CompileErrorType::InvalidReturn, @@ -519,7 +519,7 @@ impl Compiler { self.emit(Instruction::ReturnValue); } - ast::Statement::Assign { targets, value } => { + Assign { targets, value } => { self.compile_expression(value)?; for (i, target) in targets.iter().enumerate() { @@ -529,7 +529,7 @@ impl Compiler { self.compile_store(target)?; } } - ast::Statement::AugAssign { target, op, value } => { + AugAssign { target, op, value } => { self.compile_expression(target)?; self.compile_expression(value)?; @@ -537,12 +537,12 @@ impl Compiler { self.compile_op(op, true); self.compile_store(target)?; } - ast::Statement::Delete { targets } => { + Delete { targets } => { for target in targets { self.compile_delete(target)?; } } - ast::Statement::Pass => { + Pass => { self.emit(Instruction::Pass); } } @@ -550,24 +550,24 @@ impl Compiler { } fn compile_delete(&mut self, expression: &ast::Expression) -> Result<(), CompileError> { - match expression { - ast::Expression::Identifier { name } => { + match &expression.node { + ast::ExpressionType::Identifier { name } => { self.emit(Instruction::DeleteName { name: name.to_string(), }); } - ast::Expression::Attribute { value, name } => { + ast::ExpressionType::Attribute { value, name } => { self.compile_expression(value)?; self.emit(Instruction::DeleteAttr { name: name.to_string(), }); } - ast::Expression::Subscript { a, b } => { + ast::ExpressionType::Subscript { a, b } => { self.compile_expression(a)?; self.compile_expression(b)?; self.emit(Instruction::DeleteSubscript); } - ast::Expression::Tuple { elements } => { + ast::ExpressionType::Tuple { elements } => { for element in elements { self.compile_delete(element)?; } @@ -663,10 +663,10 @@ impl Compiler { fn compile_try_statement( &mut self, - body: &[ast::LocatedStatement], + body: &[ast::Statement], handlers: &[ast::ExceptHandler], - orelse: &Option>, - finalbody: &Option>, + orelse: &Option>, + finalbody: &Option>, ) -> Result<(), CompileError> { let mut handler_label = self.new_label(); let finally_label = self.new_label(); @@ -764,7 +764,7 @@ impl Compiler { &mut self, name: &str, args: &ast::Parameters, - body: &[ast::LocatedStatement], + body: &[ast::Statement], decorator_list: &[ast::Expression], returns: &Option, // TODO: use type hint somehow.. ) -> Result<(), CompileError> { @@ -858,7 +858,7 @@ impl Compiler { fn compile_class_def( &mut self, name: &str, - body: &[ast::LocatedStatement], + body: &[ast::Statement], bases: &[ast::Expression], keywords: &[ast::Keyword], decorator_list: &[ast::Expression], @@ -989,8 +989,8 @@ impl Compiler { &mut self, target: &ast::Expression, iter: &ast::Expression, - body: &[ast::LocatedStatement], - orelse: &Option>, + body: &[ast::Statement], + orelse: &Option>, ) -> Result<(), CompileError> { // Start loop let start_label = self.new_label(); @@ -1104,27 +1104,27 @@ impl Compiler { } fn compile_store(&mut self, target: &ast::Expression) -> Result<(), CompileError> { - match target { - ast::Expression::Identifier { name } => { + match &target.node { + ast::ExpressionType::Identifier { name } => { self.store_name(name); } - ast::Expression::Subscript { a, b } => { + ast::ExpressionType::Subscript { a, b } => { self.compile_expression(a)?; self.compile_expression(b)?; self.emit(Instruction::StoreSubscript); } - ast::Expression::Attribute { value, name } => { + ast::ExpressionType::Attribute { value, name } => { self.compile_expression(value)?; self.emit(Instruction::StoreAttr { name: name.to_string(), }); } - ast::Expression::List { elements } | ast::Expression::Tuple { elements } => { + ast::ExpressionType::List { elements } | ast::ExpressionType::Tuple { elements } => { let mut seen_star = false; // Scan for star args: for (i, element) in elements.iter().enumerate() { - if let ast::Expression::Starred { .. } = element { + if let ast::ExpressionType::Starred { .. } = &element.node { if seen_star { return Err(CompileError { error: CompileErrorType::StarArgs, @@ -1147,7 +1147,7 @@ impl Compiler { } for element in elements { - if let ast::Expression::Starred { value } = element { + if let ast::ExpressionType::Starred { value } = &element.node { self.compile_store(value)?; } else { self.compile_store(element)?; @@ -1192,8 +1192,8 @@ impl Compiler { context: EvalContext, ) -> Result<(), CompileError> { // Compile expression for test, and jump to label if false - match expression { - ast::Expression::BoolOp { a, op, b } => match op { + match &expression.node { + ast::ExpressionType::BoolOp { a, op, b } => match op { ast::BooleanOperator::And => { let f = false_label.unwrap_or_else(|| self.new_label()); self.compile_test(a, None, Some(f), context)?; @@ -1246,23 +1246,27 @@ impl Compiler { fn compile_expression(&mut self, expression: &ast::Expression) -> Result<(), CompileError> { trace!("Compiling {:?}", expression); - match expression { - ast::Expression::Call { + use ast::ExpressionType::*; + match &expression.node { + Call { function, args, keywords, } => self.compile_call(function, args, keywords)?, - ast::Expression::BoolOp { .. } => { - self.compile_test(expression, None, None, EvalContext::Expression)? - } - ast::Expression::Binop { a, op, b } => { + BoolOp { .. } => self.compile_test( + expression, + Option::None, + Option::None, + EvalContext::Expression, + )?, + Binop { a, op, b } => { self.compile_expression(a)?; self.compile_expression(b)?; // Perform operation: self.compile_op(op, false); } - ast::Expression::Subscript { a, b } => { + Subscript { a, b } => { self.compile_expression(a)?; self.compile_expression(b)?; self.emit(Instruction::BinaryOperation { @@ -1270,7 +1274,7 @@ impl Compiler { inplace: false, }); } - ast::Expression::Unop { op, a } => { + Unop { op, a } => { self.compile_expression(a)?; // Perform operation: @@ -1283,16 +1287,16 @@ impl Compiler { let i = Instruction::UnaryOperation { op: i }; self.emit(i); } - ast::Expression::Attribute { value, name } => { + Attribute { value, name } => { self.compile_expression(value)?; self.emit(Instruction::LoadAttr { name: name.to_string(), }); } - ast::Expression::Compare { vals, ops } => { + Compare { vals, ops } => { self.compile_chained_comparison(vals, ops)?; } - ast::Expression::Number { value } => { + Number { value } => { let const_value = match value { ast::Number::Integer { value } => bytecode::Constant::Integer { value: value.clone(), @@ -1304,7 +1308,7 @@ impl Compiler { }; self.emit(Instruction::LoadConst { value: const_value }); } - ast::Expression::List { elements } => { + List { elements } => { let size = elements.len(); let must_unpack = self.gather_elements(elements)?; self.emit(Instruction::BuildList { @@ -1312,7 +1316,7 @@ impl Compiler { unpack: must_unpack, }); } - ast::Expression::Tuple { elements } => { + Tuple { elements } => { let size = elements.len(); let must_unpack = self.gather_elements(elements)?; self.emit(Instruction::BuildTuple { @@ -1320,7 +1324,7 @@ impl Compiler { unpack: must_unpack, }); } - ast::Expression::Set { elements } => { + Set { elements } => { let size = elements.len(); let must_unpack = self.gather_elements(elements)?; self.emit(Instruction::BuildSet { @@ -1328,7 +1332,7 @@ impl Compiler { unpack: must_unpack, }); } - ast::Expression::Dict { elements } => { + Dict { elements } => { let size = elements.len(); let has_double_star = elements.iter().any(|e| e.0.is_none()); for (key, value) in elements { @@ -1351,14 +1355,14 @@ impl Compiler { unpack: has_double_star, }); } - ast::Expression::Slice { elements } => { + Slice { elements } => { let size = elements.len(); for element in elements { self.compile_expression(element)?; } self.emit(Instruction::BuildSlice { size }); } - ast::Expression::Yield { value } => { + Yield { value } => { if !self.in_function_def { return Err(CompileError { error: CompileErrorType::InvalidYield, @@ -1368,16 +1372,16 @@ impl Compiler { self.mark_generator(); match value { Some(expression) => self.compile_expression(expression)?, - None => self.emit(Instruction::LoadConst { + Option::None => self.emit(Instruction::LoadConst { value: bytecode::Constant::None, }), }; self.emit(Instruction::YieldValue); } - ast::Expression::Await { .. } => { + Await { .. } => { unimplemented!("await"); } - ast::Expression::YieldFrom { value } => { + YieldFrom { value } => { self.mark_generator(); self.compile_expression(value)?; self.emit(Instruction::GetIter); @@ -1386,40 +1390,40 @@ impl Compiler { }); self.emit(Instruction::YieldFrom); } - ast::Expression::True => { + True => { self.emit(Instruction::LoadConst { value: bytecode::Constant::Boolean { value: true }, }); } - ast::Expression::False => { + False => { self.emit(Instruction::LoadConst { value: bytecode::Constant::Boolean { value: false }, }); } - ast::Expression::None => { + None => { self.emit(Instruction::LoadConst { value: bytecode::Constant::None, }); } - ast::Expression::Ellipsis => { + Ellipsis => { self.emit(Instruction::LoadConst { value: bytecode::Constant::Ellipsis, }); } - ast::Expression::String { value } => { + String { value } => { self.compile_string(value)?; } - ast::Expression::Bytes { value } => { + Bytes { value } => { self.emit(Instruction::LoadConst { value: bytecode::Constant::Bytes { value: value.clone(), }, }); } - ast::Expression::Identifier { name } => { + Identifier { name } => { self.load_name(name); } - ast::Expression::Lambda { args, body } => { + Lambda { args, body } => { let name = "".to_string(); // no need to worry about the self.loop_depth because there are no loops in lambda expressions let flags = self.enter_function(&name, args)?; @@ -1438,18 +1442,18 @@ impl Compiler { // Turn code object into function object: self.emit(Instruction::MakeFunction { flags }); } - ast::Expression::Comprehension { kind, generators } => { + Comprehension { kind, generators } => { self.compile_comprehension(kind, generators)?; } - ast::Expression::Starred { value } => { + Starred { value } => { self.compile_expression(value)?; self.emit(Instruction::Unpack); panic!("We should not just unpack a starred args, since the size is unknown."); } - ast::Expression::IfExpression { test, body, orelse } => { + IfExpression { test, body, orelse } => { let no_label = self.new_label(); let end_label = self.new_label(); - self.compile_test(test, None, None, EvalContext::Expression)?; + self.compile_test(test, Option::None, Option::None, EvalContext::Expression)?; self.emit(Instruction::JumpIfFalse { target: no_label }); // True case self.compile_expression(body)?; @@ -1557,7 +1561,7 @@ impl Compiler { fn gather_elements(&mut self, elements: &[ast::Expression]) -> Result { // First determine if we have starred elements: let has_stars = elements.iter().any(|e| { - if let ast::Expression::Starred { .. } = e { + if let ast::ExpressionType::Starred { .. } = &e.node { true } else { false @@ -1565,7 +1569,7 @@ impl Compiler { }); for element in elements { - if let ast::Expression::Starred { value } = element { + if let ast::ExpressionType::Starred { value } = &element.node { self.compile_expression(value)?; } else { self.compile_expression(element)?; @@ -1792,7 +1796,7 @@ impl Compiler { assert!(scope.sub_scopes.is_empty()); } - fn lookup_name(&self, name: &str) -> &SymbolRole { + fn lookup_name(&self, name: &str) -> &Symbol { // println!("Looking up {:?}", name); let scope = self.scope_stack.last().unwrap(); scope.lookup(name).unwrap() @@ -1846,10 +1850,10 @@ impl Compiler { } } -fn get_doc(body: &[ast::LocatedStatement]) -> (&[ast::LocatedStatement], Option) { +fn get_doc(body: &[ast::Statement]) -> (&[ast::Statement], Option) { if let Some(val) = body.get(0) { - if let ast::Statement::Expression { ref expression } = val.node { - if let ast::Expression::String { ref value } = expression { + if let ast::StatementType::Expression { ref expression } = val.node { + if let ast::ExpressionType::String { value } = &expression.node { if let ast::StringGroup::Constant { ref value } = value { if let Some((_, body_rest)) = body.split_first() { return (body_rest, Some(value.to_string())); diff --git a/src/symboltable.rs b/src/symboltable.rs index 9d4d677..d80755d 100644 --- a/src/symboltable.rs +++ b/src/symboltable.rs @@ -24,7 +24,7 @@ pub fn make_symbol_table(program: &ast::Program) -> Result Result { let mut builder: SymbolTableBuilder = Default::default(); builder.enter_scope(); @@ -36,25 +36,46 @@ pub fn statements_to_symbol_table( Ok(symbol_table) } -#[derive(Debug, Clone)] -pub enum SymbolRole { - Global, - Nonlocal, - Used, - Assigned, -} - /// Captures all symbols in the current scope, and has a list of subscopes in this scope. -#[derive(Clone)] +#[derive(Clone, Default)] pub struct SymbolScope { /// A set of symbols present on this scope level. - pub symbols: IndexMap, + pub symbols: IndexMap, /// A list of subscopes in the order as found in the /// AST nodes. pub sub_scopes: Vec, } +#[derive(Debug, Clone)] +pub struct Symbol { + pub name: String, + pub is_global: bool, + pub is_local: bool, + pub is_nonlocal: bool, + pub is_param: bool, + pub is_referenced: bool, + pub is_assigned: bool, + pub is_parameter: bool, + pub is_free: bool, +} + +impl Symbol { + fn new(name: &str) -> Self { + Symbol { + name: name.to_string(), + is_global: false, + is_local: false, + is_nonlocal: false, + is_param: false, + is_referenced: false, + is_assigned: false, + is_parameter: false, + is_free: false, + } + } +} + #[derive(Debug)] pub struct SymbolTableError { error: String, @@ -73,20 +94,11 @@ impl From for CompileError { type SymbolTableResult = Result<(), SymbolTableError>; impl SymbolScope { - pub fn lookup(&self, name: &str) -> Option<&SymbolRole> { + pub fn lookup(&self, name: &str) -> Option<&Symbol> { self.symbols.get(name) } } -impl Default for SymbolScope { - fn default() -> Self { - SymbolScope { - symbols: Default::default(), - sub_scopes: Default::default(), - } - } -} - impl std::fmt::Debug for SymbolScope { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!( @@ -111,74 +123,73 @@ fn analyze_symbol_table( } // Analyze symbols: - for (symbol_name, symbol_role) in &symbol_scope.symbols { - analyze_symbol(symbol_name, symbol_role, parent_symbol_scope)?; + for symbol in symbol_scope.symbols.values() { + analyze_symbol(symbol, parent_symbol_scope)?; } Ok(()) } -#[allow(clippy::single_match)] -fn analyze_symbol( - symbol_name: &str, - symbol_role: &SymbolRole, - parent_symbol_scope: Option<&SymbolScope>, -) -> SymbolTableResult { - match symbol_role { - SymbolRole::Nonlocal => { - // check if name is defined in parent scope! - if let Some(parent_symbol_scope) = parent_symbol_scope { - if !parent_symbol_scope.symbols.contains_key(symbol_name) { - return Err(SymbolTableError { - error: format!("no binding for nonlocal '{}' found", symbol_name), - location: Default::default(), - }); - } - } else { +fn analyze_symbol(symbol: &Symbol, parent_symbol_scope: Option<&SymbolScope>) -> SymbolTableResult { + if symbol.is_nonlocal { + // check if name is defined in parent scope! + if let Some(parent_symbol_scope) = parent_symbol_scope { + if !parent_symbol_scope.symbols.contains_key(&symbol.name) { return Err(SymbolTableError { - error: format!( - "nonlocal {} defined at place without an enclosing scope", - symbol_name - ), + error: format!("no binding for nonlocal '{}' found", symbol.name), location: Default::default(), }); } + } else { + return Err(SymbolTableError { + error: format!( + "nonlocal {} defined at place without an enclosing scope", + symbol.name + ), + location: Default::default(), + }); } - // TODO: add more checks for globals - _ => {} } + + // TODO: add more checks for globals + Ok(()) } -pub struct SymbolTableBuilder { - // Scope stack. - pub scopes: Vec, +#[derive(Debug, Clone)] +enum SymbolRole { + Global, + Nonlocal, + Used, + Assigned, } -impl Default for SymbolTableBuilder { - fn default() -> Self { - SymbolTableBuilder { scopes: vec![] } - } +#[derive(Default)] +struct SymbolTableBuilder { + // Scope stack. + scopes: Vec, } impl SymbolTableBuilder { - pub fn enter_scope(&mut self) { + fn enter_scope(&mut self) { let scope = Default::default(); self.scopes.push(scope); + // self.work_scopes.push(Default::default()); } fn leave_scope(&mut self) { // Pop scope and add to subscopes of parent scope. + // let work_scope = self.work_scopes.pop().unwrap(); let scope = self.scopes.pop().unwrap(); self.scopes.last_mut().unwrap().sub_scopes.push(scope); } - pub fn scan_program(&mut self, program: &ast::Program) -> SymbolTableResult { + fn scan_program(&mut self, program: &ast::Program) -> SymbolTableResult { self.scan_statements(&program.statements)?; Ok(()) } - pub fn scan_statements(&mut self, statements: &[ast::LocatedStatement]) -> SymbolTableResult { + fn scan_statements(&mut self, statements: &[ast::Statement]) -> SymbolTableResult { for statement in statements { self.scan_statement(statement)?; } @@ -210,26 +221,27 @@ impl SymbolTableBuilder { Ok(()) } - fn scan_statement(&mut self, statement: &ast::LocatedStatement) -> SymbolTableResult { + fn scan_statement(&mut self, statement: &ast::Statement) -> SymbolTableResult { + use ast::StatementType::*; match &statement.node { - ast::Statement::Global { names } => { + Global { names } => { for name in names { self.register_name(name, SymbolRole::Global)?; } } - ast::Statement::Nonlocal { names } => { + Nonlocal { names } => { for name in names { self.register_name(name, SymbolRole::Nonlocal)?; } } - ast::Statement::FunctionDef { + FunctionDef { name, body, args, decorator_list, returns, } - | ast::Statement::AsyncFunctionDef { + | AsyncFunctionDef { name, body, args, @@ -247,7 +259,7 @@ impl SymbolTableBuilder { } self.leave_scope(); } - ast::Statement::ClassDef { + ClassDef { name, body, bases, @@ -264,21 +276,21 @@ impl SymbolTableBuilder { } self.scan_expressions(decorator_list)?; } - ast::Statement::Expression { expression } => self.scan_expression(expression)?, - ast::Statement::If { test, body, orelse } => { + Expression { expression } => self.scan_expression(expression)?, + If { test, body, orelse } => { self.scan_expression(test)?; self.scan_statements(body)?; if let Some(code) = orelse { self.scan_statements(code)?; } } - ast::Statement::For { + For { target, iter, body, orelse, } - | ast::Statement::AsyncFor { + | AsyncFor { target, iter, body, @@ -291,17 +303,17 @@ impl SymbolTableBuilder { self.scan_statements(code)?; } } - ast::Statement::While { test, body, orelse } => { + While { test, body, orelse } => { self.scan_expression(test)?; self.scan_statements(body)?; if let Some(code) = orelse { self.scan_statements(code)?; } } - ast::Statement::Break | ast::Statement::Continue | ast::Statement::Pass => { + Break | Continue | Pass => { // No symbols here. } - ast::Statement::Import { names } | ast::Statement::ImportFrom { names, .. } => { + Import { names } | ImportFrom { names, .. } => { for name in names { if let Some(alias) = &name.alias { // `import mymodule as myalias` @@ -312,29 +324,29 @@ impl SymbolTableBuilder { } } } - ast::Statement::Return { value } => { + Return { value } => { if let Some(expression) = value { self.scan_expression(expression)?; } } - ast::Statement::Assert { test, msg } => { + Assert { test, msg } => { self.scan_expression(test)?; if let Some(expression) = msg { self.scan_expression(expression)?; } } - ast::Statement::Delete { targets } => { + Delete { targets } => { self.scan_expressions(targets)?; } - ast::Statement::Assign { targets, value } => { + Assign { targets, value } => { self.scan_expressions(targets)?; self.scan_expression(value)?; } - ast::Statement::AugAssign { target, value, .. } => { + AugAssign { target, value, .. } => { self.scan_expression(target)?; self.scan_expression(value)?; } - ast::Statement::With { items, body } => { + With { items, body } => { for item in items { self.scan_expression(&item.context_expr)?; if let Some(expression) = &item.optional_vars { @@ -343,7 +355,7 @@ impl SymbolTableBuilder { } self.scan_statements(body)?; } - ast::Statement::Try { + Try { body, handlers, orelse, @@ -366,7 +378,7 @@ impl SymbolTableBuilder { self.scan_statements(code)?; } } - ast::Statement::Raise { exception, cause } => { + Raise { exception, cause } => { if let Some(expression) = exception { self.scan_expression(expression)?; } @@ -386,26 +398,27 @@ impl SymbolTableBuilder { } fn scan_expression(&mut self, expression: &ast::Expression) -> SymbolTableResult { - match expression { - ast::Expression::Binop { a, b, .. } => { + use ast::ExpressionType::*; + match &expression.node { + Binop { a, b, .. } => { self.scan_expression(a)?; self.scan_expression(b)?; } - ast::Expression::BoolOp { a, b, .. } => { + BoolOp { a, b, .. } => { self.scan_expression(a)?; self.scan_expression(b)?; } - ast::Expression::Compare { vals, .. } => { + Compare { vals, .. } => { self.scan_expressions(vals)?; } - ast::Expression::Subscript { a, b } => { + Subscript { a, b } => { self.scan_expression(a)?; self.scan_expression(b)?; } - ast::Expression::Attribute { value, .. } => { + Attribute { value, .. } => { self.scan_expression(value)?; } - ast::Expression::Dict { elements } => { + Dict { elements } => { for (key, value) in elements { if let Some(key) = key { self.scan_expression(key)?; @@ -415,36 +428,30 @@ impl SymbolTableBuilder { self.scan_expression(value)?; } } - ast::Expression::Await { value } => { + Await { value } => { self.scan_expression(value)?; } - ast::Expression::Yield { value } => { + Yield { value } => { if let Some(expression) = value { self.scan_expression(expression)?; } } - ast::Expression::YieldFrom { value } => { + YieldFrom { value } => { self.scan_expression(value)?; } - ast::Expression::Unop { a, .. } => { + Unop { a, .. } => { self.scan_expression(a)?; } - ast::Expression::True - | ast::Expression::False - | ast::Expression::None - | ast::Expression::Ellipsis => {} - ast::Expression::Number { .. } => {} - ast::Expression::Starred { value } => { + True | False | None | Ellipsis => {} + Number { .. } => {} + Starred { value } => { self.scan_expression(value)?; } - ast::Expression::Bytes { .. } => {} - ast::Expression::Tuple { elements } - | ast::Expression::Set { elements } - | ast::Expression::List { elements } - | ast::Expression::Slice { elements } => { + Bytes { .. } => {} + Tuple { elements } | Set { elements } | List { elements } | Slice { elements } => { self.scan_expressions(elements)?; } - ast::Expression::Comprehension { kind, generators } => { + Comprehension { kind, generators } => { match **kind { ast::ComprehensionKind::GeneratorExpression { ref element } | ast::ComprehensionKind::List { ref element } @@ -465,7 +472,7 @@ impl SymbolTableBuilder { } } } - ast::Expression::Call { + Call { function, args, keywords, @@ -476,18 +483,18 @@ impl SymbolTableBuilder { self.scan_expression(&keyword.value)?; } } - ast::Expression::String { value } => { + String { value } => { self.scan_string_group(value)?; } - ast::Expression::Identifier { name } => { + Identifier { name } => { self.register_name(name, SymbolRole::Used)?; } - ast::Expression::Lambda { args, body } => { + Lambda { args, body } => { self.enter_function(args)?; self.scan_expression(body)?; self.leave_scope(); } - ast::Expression::IfExpression { test, body, orelse } => { + IfExpression { test, body, orelse } => { self.scan_expression(test)?; self.scan_expression(body)?; self.scan_expression(orelse)?; @@ -549,6 +556,8 @@ impl SymbolTableBuilder { let scope_depth = self.scopes.len(); let current_scope = self.scopes.last_mut().unwrap(); let location = Default::default(); + + // Some checks: if current_scope.symbols.contains_key(name) { // Role already set.. match role { @@ -568,22 +577,47 @@ impl SymbolTableBuilder { // Ok? } } - } else { - match role { - SymbolRole::Nonlocal => { - if scope_depth < 2 { - return Err(SymbolTableError { - error: format!("cannot define nonlocal '{}' at top level.", name), - location, - }); - } - } - _ => { - // Ok! + } + + // Some more checks: + match role { + SymbolRole::Nonlocal => { + if scope_depth < 2 { + return Err(SymbolTableError { + error: format!("cannot define nonlocal '{}' at top level.", name), + location, + }); } } - current_scope.symbols.insert(name.to_string(), role); + _ => { + // Ok! + } } + + // Insert symbol when required: + if !current_scope.symbols.contains_key(name) { + let symbol = Symbol::new(name); + current_scope.symbols.insert(name.to_string(), symbol); + } + + // Set proper flags on symbol: + let symbol = current_scope.symbols.get_mut(name).unwrap(); + match role { + SymbolRole::Nonlocal => { + symbol.is_nonlocal = true; + } + SymbolRole::Assigned => { + symbol.is_assigned = true; + // symbol.is_local = true; + } + SymbolRole::Global => { + symbol.is_global = true; + } + SymbolRole::Used => { + symbol.is_referenced = true; + } + } + Ok(()) } }