diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index c8eee4489d..c1eec8de7b 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -6,6 +6,7 @@ use std::num::NonZeroU32; use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use bitflags::bitflags; use hashbrown::hash_map::{Keys, RawEntryMut}; use rustc_hash::{FxHashMap, FxHasher}; @@ -81,15 +82,52 @@ impl Scope { } } +#[derive(Debug)] +pub(crate) enum Kind { + FreeVar, + CellVar, + CellVarAssigned, + ExplicitGlobal, + ImplicitGlobal, +} + +bitflags! { + #[derive(Copy,Clone,Debug)] + pub(crate) struct SymbolFlags: u8 { + const IS_USED = 1 << 0; + const IS_DEFINED = 1 << 1; + /// TODO: This flag is not yet set by anything + const MARKED_GLOBAL = 1 << 2; + /// TODO: This flag is not yet set by anything + const MARKED_NONLOCAL = 1 << 3; + } +} + #[derive(Debug)] pub(crate) struct Symbol { name: Name, + flags: SymbolFlags, + // kind: Kind, } impl Symbol { pub(crate) fn name(&self) -> &str { self.name.as_str() } + + /// Is the symbol used in its containing scope? + pub(crate) fn is_used(&self) -> bool { + self.flags.contains(SymbolFlags::IS_USED) + } + + /// Is the symbol defined in its containing scope? + pub(crate) fn is_defined(&self) -> bool { + self.flags.contains(SymbolFlags::IS_DEFINED) + } + + // TODO: implement Symbol.kind 2-pass analysis to categorize as: free-var, cell-var, + // explicit-global, implicit-global and implement Symbol.kind by modifying the preorder + // traversal code } // TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST; @@ -271,7 +309,12 @@ impl SymbolTable { .flat_map(|(sym_id, defs)| defs.iter().map(move |def| (*sym_id, def))) } - fn add_symbol_to_scope(&mut self, scope_id: ScopeId, name: &str) -> SymbolId { + fn add_or_update_symbol( + &mut self, + scope_id: ScopeId, + name: &str, + flags: SymbolFlags, + ) -> SymbolId { let hash = SymbolTable::hash_name(name); let scope = &mut self.scopes_by_id[scope_id]; let name = Name::new(name); @@ -282,9 +325,14 @@ impl SymbolTable { .from_hash(hash, |existing| self.symbols_by_id[*existing].name == name); match entry { - RawEntryMut::Occupied(entry) => *entry.key(), + RawEntryMut::Occupied(entry) => { + if let Some(symbol) = self.symbols_by_id.get_mut(*entry.key()) { + symbol.flags.insert(flags); + }; + *entry.key() + } RawEntryMut::Vacant(entry) => { - let id = self.symbols_by_id.push(Symbol { name }); + let id = self.symbols_by_id.push(Symbol { name, flags }); entry.insert_with_hasher(hash, id, (), |_| hash); id } @@ -392,12 +440,17 @@ struct SymbolTableBuilder { } impl SymbolTableBuilder { - fn add_symbol(&mut self, identifier: &str) -> SymbolId { - self.table.add_symbol_to_scope(self.cur_scope(), identifier) + fn add_or_update_symbol(&mut self, identifier: &str, flags: SymbolFlags) -> SymbolId { + self.table + .add_or_update_symbol(self.cur_scope(), identifier, flags) } - fn add_symbol_with_def(&mut self, identifier: &str, definition: Definition) -> SymbolId { - let symbol_id = self.add_symbol(identifier); + fn add_or_update_symbol_with_def( + &mut self, + identifier: &str, + definition: Definition, + ) -> SymbolId { + let symbol_id = self.add_or_update_symbol(identifier, SymbolFlags::IS_DEFINED); self.table .defs .entry(symbol_id) @@ -439,7 +492,7 @@ impl SymbolTableBuilder { ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name, ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name, }; - self.add_symbol(name); + self.add_or_update_symbol(name, SymbolFlags::IS_DEFINED); } } nested(self); @@ -452,10 +505,16 @@ impl SymbolTableBuilder { impl PreorderVisitor<'_> for SymbolTableBuilder { fn visit_expr(&mut self, expr: &ast::Expr) { if let ast::Expr::Name(ast::ExprName { id, ctx, .. }) = expr { - self.add_symbol(id); - if matches!(ctx, ast::ExprContext::Store | ast::ExprContext::Del) { + let flags = match ctx { + ast::ExprContext::Load => SymbolFlags::IS_USED, + ast::ExprContext::Store => SymbolFlags::IS_DEFINED, + ast::ExprContext::Del => SymbolFlags::IS_DEFINED, + ast::ExprContext::Invalid => SymbolFlags::empty(), + }; + self.add_or_update_symbol(id, flags); + if flags.contains(SymbolFlags::IS_DEFINED) { if let Some(curdef) = self.current_definition.clone() { - self.add_symbol_with_def(id, curdef); + self.add_or_update_symbol_with_def(id, curdef); } } } @@ -467,7 +526,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { match stmt { ast::Stmt::ClassDef(node) => { let def = Definition::ClassDef(TypedNodeKey::from_node(node)); - self.add_symbol_with_def(&node.name, def); + self.add_or_update_symbol_with_def(&node.name, def); self.with_type_params(&node.name, &node.type_params, |builder| { builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Class); ast::visitor::preorder::walk_stmt(builder, stmt); @@ -476,7 +535,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { } ast::Stmt::FunctionDef(node) => { let def = Definition::FunctionDef(TypedNodeKey::from_node(node)); - self.add_symbol_with_def(&node.name, def); + self.add_or_update_symbol_with_def(&node.name, def); self.with_type_params(&node.name, &node.type_params, |builder| { builder.push_scope(builder.cur_scope(), &node.name, ScopeKind::Function); ast::visitor::preorder::walk_stmt(builder, stmt); @@ -496,7 +555,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { let def = Definition::Import(ImportDefinition { module: module.clone(), }); - self.add_symbol_with_def(symbol_name, def); + self.add_or_update_symbol_with_def(symbol_name, def); self.table.dependencies.push(Dependency::Module(module)); } } @@ -519,7 +578,7 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { name: Name::new(&alias.name.id), level: *level, }); - self.add_symbol_with_def(symbol_name, def); + self.add_or_update_symbol_with_def(symbol_name, def); } let dependency = if let Some(module) = module { @@ -578,7 +637,7 @@ mod tests { use crate::parse::Parsed; use crate::symbols::ScopeKind; - use super::{SymbolId, SymbolIterator, SymbolTable}; + use super::{SymbolFlags, SymbolId, SymbolIterator, SymbolTable}; mod from_ast { use super::*; @@ -662,6 +721,13 @@ mod tests { .len(), 1 ); + assert!( + table.root_symbol_id_by_name("foo").is_some_and(|sid| { + let s = sid.symbol(&table); + s.is_defined() || !s.is_used() + }), + "symbols that are defined get the defined flag" + ); } #[test] @@ -675,6 +741,13 @@ mod tests { .len(), 1 ); + assert!( + table.root_symbol_id_by_name("foo").is_some_and(|sid| { + let s = sid.symbol(&table); + !s.is_defined() && s.is_used() + }), + "a symbol used but not defined in a scope should have only the used flag" + ); } #[test] @@ -800,6 +873,12 @@ mod tests { assert_eq!(ann_scope.kind(), ScopeKind::Annotation); assert_eq!(ann_scope.name(), "C"); assert_eq!(names(table.symbols_for_scope(ann_scope_id)), vec!["T"]); + assert!( + table + .symbol_by_name(ann_scope_id, "T") + .is_some_and(|s| s.is_defined() && !s.is_used()), + "type parameters are defined by the scope that introduces them" + ); let scopes = table.child_scope_ids_of(ann_scope_id); assert_eq!(scopes.len(), 1); let func_scope_id = scopes[0]; @@ -814,17 +893,19 @@ mod tests { fn insert_same_name_symbol_twice() { let mut table = SymbolTable::new(); let root_scope_id = SymbolTable::root_scope_id(); - let symbol_id_1 = table.add_symbol_to_scope(root_scope_id, "foo"); - let symbol_id_2 = table.add_symbol_to_scope(root_scope_id, "foo"); + let symbol_id_1 = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::IS_DEFINED); + let symbol_id_2 = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::IS_USED); assert_eq!(symbol_id_1, symbol_id_2); + assert!(symbol_id_1.symbol(&table).is_used(), "flags must merge"); + assert!(symbol_id_1.symbol(&table).is_defined(), "flags must merge"); } #[test] fn insert_different_named_symbols() { let mut table = SymbolTable::new(); let root_scope_id = SymbolTable::root_scope_id(); - let symbol_id_1 = table.add_symbol_to_scope(root_scope_id, "foo"); - let symbol_id_2 = table.add_symbol_to_scope(root_scope_id, "bar"); + let symbol_id_1 = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); + let symbol_id_2 = table.add_or_update_symbol(root_scope_id, "bar", SymbolFlags::empty()); assert_ne!(symbol_id_1, symbol_id_2); } @@ -832,9 +913,9 @@ mod tests { fn add_child_scope_with_symbol() { let mut table = SymbolTable::new(); let root_scope_id = SymbolTable::root_scope_id(); - let foo_symbol_top = table.add_symbol_to_scope(root_scope_id, "foo"); + let foo_symbol_top = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); let c_scope = table.add_child_scope(root_scope_id, "C", ScopeKind::Class); - let foo_symbol_inner = table.add_symbol_to_scope(c_scope, "foo"); + let foo_symbol_inner = table.add_or_update_symbol(c_scope, "foo", SymbolFlags::empty()); assert_ne!(foo_symbol_top, foo_symbol_inner); } @@ -851,7 +932,7 @@ mod tests { fn symbol_from_id() { let mut table = SymbolTable::new(); let root_scope_id = SymbolTable::root_scope_id(); - let foo_symbol_id = table.add_symbol_to_scope(root_scope_id, "foo"); + let foo_symbol_id = table.add_or_update_symbol(root_scope_id, "foo", SymbolFlags::empty()); let symbol = foo_symbol_id.symbol(&table); assert_eq!(symbol.name.as_str(), "foo"); }