diff --git a/src/compile.rs b/src/compile.rs index 3137b47..2917ccd 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -40,6 +40,31 @@ impl CodeInfo { code.cellvars.extend(cellvar_cache); code.freevars.extend(freevar_cache); + if !code.cellvars.is_empty() { + let total_args = code.arg_count + + code.kwonlyarg_count + + code.flags.contains(bytecode::CodeFlags::HAS_VARARGS) as usize + + code.flags.contains(bytecode::CodeFlags::HAS_VARKEYWORDS) as usize; + let all_args = &code.varnames[..total_args]; + let mut found_cellarg = false; + let cell2arg = code + .cellvars + .iter() + .map(|var| { + for (i, arg) in all_args.iter().enumerate() { + if var == arg { + found_cellarg = true; + return i as isize; + } + } + -1 + }) + .collect::>(); + if found_cellarg { + code.cell2arg = Some(cell2arg); + } + } + for instruction in &mut code.instructions { use Instruction::*; // this is a little bit hacky, as until now the data stored inside Labels in @@ -64,7 +89,7 @@ impl CodeInfo { } #[rustfmt::skip] - Import { .. } | ImportStar | ImportFrom { .. } | LoadFast(_) | LoadLocal(_) + Import { .. } | ImportStar | ImportFrom { .. } | LoadFast(_) | LoadNameAny(_) | LoadGlobal(_) | LoadDeref(_) | LoadClassDeref(_) | StoreFast(_) | StoreLocal(_) | StoreGlobal(_) | StoreDeref(_) | DeleteFast(_) | DeleteLocal(_) | DeleteGlobal(_) | DeleteDeref(_) | LoadClosure(_) | Subscript | StoreSubscript | DeleteSubscript @@ -114,14 +139,14 @@ impl Default for CompileOpts { } } -#[derive(Clone, Copy)] +#[derive(Debug, Clone, Copy)] struct CompileContext { in_loop: bool, in_class: bool, func: FunctionContext, } -#[derive(Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq)] enum FunctionContext { NoFunction, Function, @@ -130,7 +155,7 @@ enum FunctionContext { impl CompileContext { fn in_func(self) -> bool { - !matches!(self.func, FunctionContext::NoFunction) + self.func != FunctionContext::NoFunction } } @@ -408,9 +433,8 @@ impl Compiler { cache = &mut info.varname_cache; NameOpType::Fast } - SymbolScope::Local => NameOpType::Local, SymbolScope::GlobalImplicit if self.ctx.in_func() => NameOpType::Global, - SymbolScope::GlobalImplicit => NameOpType::Local, + SymbolScope::Local | SymbolScope::GlobalImplicit => NameOpType::Local, SymbolScope::GlobalExplicit => NameOpType::Global, SymbolScope::Free => { cache = &mut info.freevar_cache; @@ -423,9 +447,12 @@ impl Compiler { // TODO: is this right? SymbolScope::Unknown => NameOpType::Global, }; - let idx = cache + let mut idx = cache .get_index_of(name) .unwrap_or_else(|| cache.insert_full(name.to_owned()).0); + if let SymbolScope::Free = symbol.scope { + idx += info.cellvar_cache.len(); + } let op = match op_typ { NameOpType::Fast => match usage { NameUsage::Load => Instruction::LoadFast, @@ -438,12 +465,15 @@ impl Compiler { NameUsage::Delete => Instruction::DeleteGlobal, }, NameOpType::Deref => match usage { + NameUsage::Load if !self.ctx.in_func() && self.ctx.in_class => { + Instruction::LoadClassDeref + } NameUsage::Load => Instruction::LoadDeref, NameUsage::Store => Instruction::StoreDeref, NameUsage::Delete => Instruction::DeleteDeref, }, NameOpType::Local => match usage { - NameUsage::Load => Instruction::LoadLocal, + NameUsage::Load => Instruction::LoadNameAny, NameUsage::Store => Instruction::StoreLocal, NameUsage::Delete => Instruction::DeleteLocal, }, @@ -844,17 +874,17 @@ impl Compiler { )); for name in &args.args { - self.name(&name.arg); + self.varname(&name.arg); } for name in &args.kwonlyargs { - self.name(&name.arg); + self.varname(&name.arg); } let mut compile_varargs = |va: &ast::Varargs, flag| match va { ast::Varargs::None | ast::Varargs::Unnamed => {} ast::Varargs::Named(name) => { self.current_code().flags |= flag; - self.name(&name.arg); + self.varname(&name.arg); } }; @@ -1002,6 +1032,10 @@ impl Compiler { is_async: bool, ) -> CompileResult<()> { // Create bytecode for this function: + + self.prepare_decorators(decorator_list)?; + self.enter_function(name, args)?; + // remember to restore self.ctx.in_loop to the original after the function is compiled let prev_ctx = self.ctx; @@ -1019,10 +1053,6 @@ impl Compiler { let old_qualified_path = self.current_qualified_path.take(); self.current_qualified_path = Some(self.create_qualified_name(name, ".")); - self.prepare_decorators(decorator_list)?; - - self.enter_function(name, args)?; - let (body, doc_str) = get_doc(body); self.compile_statements(body)?; @@ -1090,6 +1120,35 @@ impl Compiler { code.flags |= bytecode::CodeFlags::IS_COROUTINE; } + self.build_closure(&code); + + self.emit_constant(bytecode::ConstantData::Code { + code: Box::new(code), + }); + self.emit_constant(bytecode::ConstantData::Str { + value: qualified_name, + }); + + // Turn code object into function object: + self.emit(Instruction::MakeFunction); + + self.emit(Instruction::Duplicate); + self.load_docstring(doc_str); + self.emit(Instruction::Rotate { amount: 2 }); + let doc = self.name("__doc__"); + self.emit(Instruction::StoreAttr { idx: doc }); + + self.current_qualified_path = old_qualified_path; + self.ctx = prev_ctx; + + self.apply_decorators(decorator_list); + + self.store_name(name); + + Ok(()) + } + + fn build_closure(&mut self, code: &CodeObject) { if !code.freevars.is_empty() { for var in &code.freevars { let symbol = self.symbol_table_stack.last().unwrap().lookup(var).unwrap(); @@ -1110,29 +1169,6 @@ impl Compiler { unpack: false, }) } - - self.emit_constant(bytecode::ConstantData::Code { - code: Box::new(code), - }); - self.emit_constant(bytecode::ConstantData::Str { - value: qualified_name, - }); - - // Turn code object into function object: - self.emit(Instruction::MakeFunction); - - self.emit(Instruction::Duplicate); - self.load_docstring(doc_str); - self.emit(Instruction::Rotate { amount: 2 }); - let doc = self.name("__doc__"); - self.emit(Instruction::StoreAttr { idx: doc }); - self.apply_decorators(decorator_list); - - self.store_name(name); - - self.current_qualified_path = old_qualified_path; - self.ctx = prev_ctx; - Ok(()) } fn find_ann(&self, body: &[ast::Statement]) -> bool { @@ -1243,11 +1279,30 @@ impl Compiler { self.emit(Instruction::SetupAnnotation); } self.compile_statements(new_body)?; - self.emit_constant(bytecode::ConstantData::None); + + let classcell_idx = self + .code_stack + .last_mut() + .unwrap() + .cellvar_cache + .iter() + .position(|var| *var == "__class__"); + + if let Some(classcell_idx) = classcell_idx { + self.emit(Instruction::LoadClosure(classcell_idx)); + self.emit(Instruction::Duplicate); + let classcell = self.name("__classcell__"); + self.emit(Instruction::StoreLocal(classcell)); + } else { + self.emit_constant(bytecode::ConstantData::None); + } + self.emit(Instruction::ReturnValue); let code = self.pop_code_object(); + self.build_closure(&code); + self.emit_constant(bytecode::ConstantData::Code { code: Box::new(code), }); @@ -1526,7 +1581,7 @@ impl Compiler { // Store as dict entry in __annotations__ dict: if !self.ctx.in_func() { let annotations = self.name("__annotations__"); - self.emit(Instruction::LoadLocal(annotations)); + self.emit(Instruction::LoadNameAny(annotations)); self.emit_constant(bytecode::ConstantData::Str { value: name.to_owned(), }); @@ -1942,6 +1997,7 @@ impl Compiler { self.compile_expression(body)?; self.emit(Instruction::ReturnValue); let code = self.pop_code_object(); + self.build_closure(&code); self.emit_constant(bytecode::ConstantData::Code { code: Box::new(code), }); @@ -2109,6 +2165,14 @@ impl Compiler { kind: &ast::ComprehensionKind, generators: &[ast::Comprehension], ) -> CompileResult<()> { + let prev_ctx = self.ctx; + + self.ctx = CompileContext { + in_loop: false, + in_class: prev_ctx.in_class, + func: FunctionContext::Function, + }; + // We must have at least one generator: assert!(!generators.is_empty()); @@ -2252,6 +2316,10 @@ impl Compiler { // Fetch code for listcomp function: let code = self.pop_code_object(); + self.ctx = prev_ctx; + + self.build_closure(&code); + // List comprehension code: self.emit_constant(bytecode::ConstantData::Code { code: Box::new(code), diff --git a/src/symboltable.rs b/src/symboltable.rs index 02f15ae..ce70620 100644 --- a/src/symboltable.rs +++ b/src/symboltable.rs @@ -144,6 +144,10 @@ impl Symbol { pub fn is_local(&self) -> bool { matches!(self.scope, SymbolScope::Local) } + + pub fn is_bound(&self) -> bool { + self.is_assigned || self.is_parameter || self.is_imported || self.is_iter + } } #[derive(Debug)] @@ -189,29 +193,89 @@ fn analyze_symbol_table(symbol_table: &mut SymbolTable) -> SymbolTableResult { analyzer.analyze_symbol_table(symbol_table) } +type SymbolMap = IndexMap; + +mod stack { + use std::panic; + use std::ptr::NonNull; + pub struct StackStack { + v: Vec>, + } + impl Default for StackStack { + fn default() -> Self { + Self { v: Vec::new() } + } + } + impl StackStack { + pub fn append(&mut self, x: &mut T, f: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + self.v.push(x.into()); + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| f(self))); + self.v.pop(); + res.unwrap_or_else(|x| panic::resume_unwind(x)) + } + + pub fn iter(&self) -> impl Iterator + DoubleEndedIterator + '_ { + self.as_ref().iter().copied() + } + pub fn iter_mut(&mut self) -> impl Iterator + DoubleEndedIterator + '_ { + self.as_mut().iter_mut().map(|x| &mut **x) + } + // pub fn top(&self) -> Option<&T> { + // self.as_ref().last().copied() + // } + // pub fn top_mut(&mut self) -> Option<&mut T> { + // self.as_mut().last_mut().map(|x| &mut **x) + // } + pub fn len(&self) -> usize { + self.v.len() + } + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn as_ref(&self) -> &[&T] { + unsafe { &*(self.v.as_slice() as *const [NonNull] as *const [&T]) } + } + + pub fn as_mut(&mut self) -> &mut [&mut T] { + unsafe { &mut *(self.v.as_mut_slice() as *mut [NonNull] as *mut [&mut T]) } + } + } +} +use stack::StackStack; + /// Symbol table analysis. Can be used to analyze a fully /// build symbol table structure. It will mark variables /// as local variables for example. #[derive(Default)] -struct SymbolTableAnalyzer<'a> { - tables: Vec<(&'a mut IndexMap, SymbolTableType)>, +#[repr(transparent)] +struct SymbolTableAnalyzer { + tables: StackStack<(SymbolMap, SymbolTableType)>, } -impl<'a> SymbolTableAnalyzer<'a> { - fn analyze_symbol_table(&mut self, symbol_table: &'a mut SymbolTable) -> SymbolTableResult { - let symbols = &mut symbol_table.symbols; - let sub_tables = &mut symbol_table.sub_tables; +impl SymbolTableAnalyzer { + fn analyze_symbol_table(&mut self, symbol_table: &mut SymbolTable) -> SymbolTableResult { + let symbols = std::mem::take(&mut symbol_table.symbols); + let sub_tables = &mut *symbol_table.sub_tables; - self.tables.push((symbols, symbol_table.typ)); - // Analyze sub scopes: - for sub_table in sub_tables { - self.analyze_symbol_table(sub_table)?; - } - let (symbols, st_typ) = self.tables.pop().unwrap(); + let mut info = (symbols, symbol_table.typ); + self.tables.append(&mut info, |list| { + let inner_scope = unsafe { &mut *(list as *mut _ as *mut SymbolTableAnalyzer) }; + // Analyze sub scopes: + for sub_table in sub_tables.iter_mut() { + inner_scope.analyze_symbol_table(sub_table)?; + } + Ok(()) + })?; + + symbol_table.symbols = info.0; // Analyze symbols: - for symbol in symbols.values_mut() { - self.analyze_symbol(symbol, st_typ)?; + for symbol in symbol_table.symbols.values_mut() { + self.analyze_symbol(symbol, symbol_table.typ, sub_tables)?; } Ok(()) } @@ -220,6 +284,7 @@ impl<'a> SymbolTableAnalyzer<'a> { &mut self, symbol: &mut Symbol, curr_st_typ: SymbolTableType, + sub_tables: &mut [SymbolTable], ) -> SymbolTableResult { if symbol.is_assign_namedexpr_in_comprehension && curr_st_typ == SymbolTableType::Comprehension @@ -232,11 +297,11 @@ impl<'a> SymbolTableAnalyzer<'a> { } else { match symbol.scope { SymbolScope::Free => { - let scope_depth = self.tables.len(); - if scope_depth > 0 { + if !self.tables.as_ref().is_empty() { + let scope_depth = self.tables.as_ref().iter().count(); // check if the name is already defined in any outer scope // therefore - if scope_depth < 2 || !self.found_in_outer_scope(symbol) { + if scope_depth < 2 || !self.found_in_outer_scope(&symbol.name) { return Err(SymbolTableError { error: format!("no binding for nonlocal '{}' found", symbol.name), // TODO: accurate location info, somehow @@ -262,29 +327,56 @@ impl<'a> SymbolTableAnalyzer<'a> { } SymbolScope::Unknown => { // Try hard to figure out what the scope of this symbol is. - self.analyze_unknown_symbol(symbol); + self.analyze_unknown_symbol(sub_tables, symbol); } } } Ok(()) } - fn found_in_outer_scope(&self, symbol: &Symbol) -> bool { + fn found_in_outer_scope(&mut self, name: &str) -> bool { // Interesting stuff about the __class__ variable: // https://docs.python.org/3/reference/datamodel.html?highlight=__class__#creating-the-class-object - symbol.name == "__class__" - || self.tables.iter().skip(1).rev().any(|(symbols, typ)| { - *typ != SymbolTableType::Class - && symbols - .get(&symbol.name) - .map_or(false, |sym| sym.is_local() && sym.is_assigned) - }) + if name == "__class__" { + return true; + } + let decl_depth = self.tables.iter().rev().position(|(symbols, typ)| { + !matches!(typ, SymbolTableType::Class | SymbolTableType::Module) + && symbols.get(name).map_or(false, |sym| sym.is_bound()) + }); + + if let Some(decl_depth) = decl_depth { + // decl_depth is the number of tables between the current one and + // the one that declared the cell var + for (table, _) in self.tables.iter_mut().rev().take(decl_depth) { + if !table.contains_key(name) { + let mut symbol = Symbol::new(name); + symbol.scope = SymbolScope::Free; + symbol.is_referenced = true; + table.insert(name.to_owned(), symbol); + } + } + } + + decl_depth.is_some() } - fn analyze_unknown_symbol(&self, symbol: &mut Symbol) { - let scope = if symbol.is_assigned || symbol.is_parameter { - SymbolScope::Local - } else if self.found_in_outer_scope(symbol) { + fn found_in_inner_scope(sub_tables: &mut [SymbolTable], name: &str) -> bool { + sub_tables.iter().any(|x| { + x.symbols + .get(name) + .map_or(false, |sym| matches!(sym.scope, SymbolScope::Free)) + }) + } + + fn analyze_unknown_symbol(&mut self, sub_tables: &mut [SymbolTable], symbol: &mut Symbol) { + let scope = if symbol.is_bound() { + if Self::found_in_inner_scope(sub_tables, &symbol.name) { + SymbolScope::Cell + } else { + SymbolScope::Local + } + } else if self.found_in_outer_scope(&symbol.name) { // Symbol is in some outer scope. SymbolScope::Free } else if self.tables.is_empty() { @@ -305,11 +397,9 @@ impl<'a> SymbolTableAnalyzer<'a> { symbol: &mut Symbol, parent_offset: usize, ) -> SymbolTableResult { - // TODO: quite C-ish way to implement the iteration // when this is called, we expect to be in the direct parent scope of the scope that contains 'symbol' - let offs = self.tables.len() - 1 - parent_offset; - let last = self.tables.get_mut(offs).unwrap(); - let symbols = &mut *last.0; + let last = self.tables.iter_mut().rev().nth(parent_offset).unwrap(); + let symbols = &mut last.0; let table_type = last.1; // it is not allowed to use an iterator variable as assignee in a named expression @@ -532,6 +622,7 @@ impl SymbolTableBuilder { self.register_name("__module__", SymbolUsage::Assigned, location)?; self.register_name("__qualname__", SymbolUsage::Assigned, location)?; self.register_name("__doc__", SymbolUsage::Assigned, location)?; + self.register_name("__class__", SymbolUsage::Assigned, location)?; self.scan_statements(body)?; self.leave_scope(); self.scan_expressions(bases, ExpressionContext::Load)?; @@ -954,7 +1045,7 @@ impl SymbolTableBuilder { let table = self.tables.last_mut().unwrap(); // Some checks for the symbol that present on this scope level: - if let Some(symbol) = table.symbols.get(name) { + let symbol = if let Some(symbol) = table.symbols.get_mut(name) { // Role already set.. match role { SymbolUsage::Global => { @@ -1000,6 +1091,7 @@ impl SymbolTableBuilder { // Ok? } } + symbol } else { // The symbol does not present on this scope level. // Some checks to insert new symbol into symbol table: @@ -1016,11 +1108,10 @@ impl SymbolTableBuilder { } // Insert symbol when required: let symbol = Symbol::new(name); - table.symbols.insert(name.to_owned(), symbol); - } + table.symbols.entry(name.to_owned()).or_insert(symbol) + }; // Set proper flags on symbol: - let symbol = table.symbols.get_mut(name).unwrap(); match role { SymbolUsage::Nonlocal => { symbol.scope = SymbolScope::Free; @@ -1064,7 +1155,6 @@ impl SymbolTableBuilder { } SymbolUsage::Iter => { symbol.is_iter = true; - symbol.scope = SymbolScope::Local; } }