diff --git a/Cargo.lock b/Cargo.lock index e7c8dcb705..7460c790b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1913,7 +1913,6 @@ dependencies = [ "ruff_text_size", "rustc-hash 2.0.0", "salsa", - "smallvec", "tracing", ] diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index cc273e4ecc..eb66270ff2 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -20,7 +20,6 @@ ruff_text_size = { workspace = true } bitflags = { workspace = true } indexmap = { workspace = true } salsa = { workspace = true } -smallvec = { workspace = true } tracing = { workspace = true } rustc-hash = { workspace = true } hashbrown = { workspace = true } diff --git a/crates/red_knot_python_semantic/src/db.rs b/crates/red_knot_python_semantic/src/db.rs index 11c7a88352..a40dcf7a3b 100644 --- a/crates/red_knot_python_semantic/src/db.rs +++ b/crates/red_knot_python_semantic/src/db.rs @@ -4,7 +4,8 @@ use ruff_db::{Db as SourceDb, Upcast}; use red_knot_module_resolver::Db as ResolverDb; -use crate::semantic_index::symbol::{public_symbols_map, scopes_map, PublicSymbolId, ScopeId}; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId}; use crate::semantic_index::{root_scope, semantic_index, symbol_table}; use crate::types::{infer_types, public_symbol_ty}; @@ -12,8 +13,8 @@ use crate::types::{infer_types, public_symbol_ty}; pub struct Jar( ScopeId<'_>, PublicSymbolId<'_>, + Definition<'_>, symbol_table, - scopes_map, root_scope, semantic_index, infer_types, diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index b85683889b..5e055bd9f7 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -6,13 +6,14 @@ use rustc_hash::FxHashMap; use ruff_db::parsed::parsed_module; use ruff_db::vfs::VfsFile; use ruff_index::{IndexSlice, IndexVec}; -use ruff_python_ast as ast; -use crate::node_key::NodeKey; -use crate::semantic_index::ast_ids::{AstId, AstIds, ScopedClassId, ScopedFunctionId}; +use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; +use crate::semantic_index::ast_ids::AstIds; use crate::semantic_index::builder::SemanticIndexBuilder; +use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef}; use crate::semantic_index::symbol::{ - FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable, + FileScopeId, NodeWithScopeKey, NodeWithScopeRef, PublicSymbolId, Scope, ScopeId, + ScopedSymbolId, SymbolTable, }; use crate::Db; @@ -27,12 +28,12 @@ type SymbolMap = hashbrown::HashMap; /// /// Prefer using [`symbol_table`] when working with symbols from a single scope. #[salsa::tracked(return_ref, no_eq)] -pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex { +pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex<'_> { let _span = tracing::trace_span!("semantic_index", ?file).entered(); let parsed = parsed_module(db.upcast(), file); - SemanticIndexBuilder::new(parsed).build() + SemanticIndexBuilder::new(db, file, parsed).build() } /// Returns the symbol table for a specific `scope`. @@ -41,7 +42,7 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex { /// Salsa can avoid invalidating dependent queries if this scope's symbol table /// is unchanged. #[salsa::tracked] -pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc { +pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc> { let _span = tracing::trace_span!("symbol_table", ?scope).entered(); let index = semantic_index(db, scope.file(db)); @@ -71,9 +72,9 @@ pub fn public_symbol<'db>( /// The symbol tables for an entire file. #[derive(Debug)] -pub struct SemanticIndex { +pub struct SemanticIndex<'db> { /// List of all symbol tables in this file, indexed by scope. - symbol_tables: IndexVec>, + symbol_tables: IndexVec>>, /// List of all scopes in this file. scopes: IndexVec, @@ -81,10 +82,16 @@ pub struct SemanticIndex { /// Maps expressions to their corresponding scope. /// We can't use [`ExpressionId`] here, because the challenge is how to get from /// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope). - scopes_by_expression: FxHashMap, + scopes_by_expression: FxHashMap, - /// Map from the definition that introduce a scope to the scope they define. - scopes_by_definition: FxHashMap, + /// Maps from a node creating a definition node to its definition. + definitions_by_node: FxHashMap>, + + /// Map from nodes that create a scope to the scope they create. + scopes_by_node: FxHashMap, + + /// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`]. + scope_ids_by_scope: IndexVec>, /// Lookup table to map between node ids and ast nodes. /// @@ -93,12 +100,12 @@ pub struct SemanticIndex { ast_ids: IndexVec, } -impl SemanticIndex { +impl<'db> SemanticIndex<'db> { /// Returns the symbol table for a specific scope. /// /// Use the Salsa cached [`symbol_table`] query if you only need the /// symbol table for a single scope. - pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc { + pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc> { self.symbol_tables[scope_id].clone() } @@ -107,19 +114,16 @@ impl SemanticIndex { } /// Returns the ID of the `expression`'s enclosing scope. - pub(crate) fn expression_scope_id<'expr>( + pub(crate) fn expression_scope_id( &self, - expression: impl Into>, + expression: impl Into, ) -> FileScopeId { - self.scopes_by_expression[&NodeKey::from_node(expression.into())] + self.scopes_by_expression[&expression.into()] } /// Returns the [`Scope`] of the `expression`'s enclosing scope. #[allow(unused)] - pub(crate) fn expression_scope<'expr>( - &self, - expression: impl Into>, - ) -> &Scope { + pub(crate) fn expression_scope(&self, expression: impl Into) -> &Scope { &self.scopes[self.expression_scope_id(expression)] } @@ -157,45 +161,18 @@ impl SemanticIndex { AncestorsIter::new(self, scope) } - /// Returns the scope that is created by `node`. - pub(crate) fn node_scope(&self, node: impl Into) -> FileScopeId { - self.scopes_by_definition[&node.into()] + /// Returns the [`Definition`] salsa ingredient for `definition_node`. + pub(crate) fn definition<'def>( + &self, + definition_node: impl Into>, + ) -> Definition<'db> { + self.definitions_by_node[&definition_node.into().key()] } - /// Returns the scope in which `node_with_scope` is defined. - /// - /// The returned scope can be used to lookup the symbol of the definition or its type. - /// - /// * Annotation: Returns the direct parent scope - /// * Function and classes: Returns the parent scope unless they have type parameters in which case - /// the grandparent scope is returned. - pub(crate) fn definition_scope( - &self, - node_with_scope: impl Into, - ) -> FileScopeId { - fn resolve_scope(index: &SemanticIndex, node_with_scope: NodeWithScopeKey) -> FileScopeId { - let scope_id = index.node_scope(node_with_scope); - let scope = index.scope(scope_id); - - match scope.kind() { - ScopeKind::Module => scope_id, - ScopeKind::Annotation => scope.parent.unwrap(), - ScopeKind::Class | ScopeKind::Function => { - let mut ancestors = index.ancestor_scopes(scope_id).skip(1); - - let (mut scope_id, mut scope) = ancestors.next().unwrap(); - if scope.kind() == ScopeKind::Annotation { - (scope_id, scope) = ancestors.next().unwrap(); - } - - debug_assert_ne!(scope.kind(), ScopeKind::Annotation); - - scope_id - } - } - } - - resolve_scope(self, node_with_scope.into()) + /// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which + /// returns the scope in which that definition is defined in. + pub(crate) fn node_scope(&self, node: NodeWithScopeRef) -> FileScopeId { + self.scopes_by_node[&node.node_key()] } } @@ -293,42 +270,6 @@ impl<'a> Iterator for ChildrenIter<'a> { impl FusedIterator for ChildrenIter<'_> {} -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub(crate) enum NodeWithScopeId { - Module, - Class(AstId), - ClassTypeParams(AstId), - Function(AstId), - FunctionTypeParams(AstId), -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub(crate) struct NodeWithScopeKey(NodeKey); - -impl From<&ast::StmtClassDef> for NodeWithScopeKey { - fn from(node: &ast::StmtClassDef) -> Self { - Self(NodeKey::from_node(node)) - } -} - -impl From<&ast::StmtFunctionDef> for NodeWithScopeKey { - fn from(value: &ast::StmtFunctionDef) -> Self { - Self(NodeKey::from_node(value)) - } -} - -impl From<&ast::TypeParams> for NodeWithScopeKey { - fn from(value: &ast::TypeParams) -> Self { - Self(NodeKey::from_node(value)) - } -} - -impl From<&ast::ModModule> for NodeWithScopeKey { - fn from(value: &ast::ModModule) -> Self { - Self(NodeKey::from_node(value)) - } -} - #[cfg(test)] mod tests { use ruff_db::parsed::parsed_module; @@ -355,10 +296,10 @@ mod tests { TestCase { db, file } } - fn names(table: &SymbolTable) -> Vec<&str> { + fn names(table: &SymbolTable) -> Vec { table .symbols() - .map(|symbol| symbol.name().as_str()) + .map(|symbol| symbol.name().to_string()) .collect() } @@ -367,7 +308,9 @@ mod tests { let TestCase { db, file } = test_case(""); let root_table = symbol_table(&db, root_scope(&db, file)); - assert_eq!(names(&root_table), Vec::<&str>::new()); + let root_names = names(&root_table); + + assert_eq!(root_names, Vec::<&str>::new()); } #[test] @@ -474,7 +417,8 @@ y = 2 let (class_scope_id, class_scope) = scopes[0]; assert_eq!(class_scope.kind(), ScopeKind::Class); - assert_eq!(class_scope.name(&db, file), "C"); + + assert_eq!(class_scope_id.to_scope_id(&db, file).name(&db), "C"); let class_table = index.symbol_table(class_scope_id); assert_eq!(names(&class_table), vec!["x"]); @@ -503,7 +447,7 @@ y = 2 let (function_scope_id, function_scope) = scopes[0]; assert_eq!(function_scope.kind(), ScopeKind::Function); - assert_eq!(function_scope.name(&db, file), "func"); + assert_eq!(function_scope_id.to_scope_id(&db, file).name(&db), "func"); let function_table = index.symbol_table(function_scope_id); assert_eq!(names(&function_table), vec!["x"]); @@ -539,9 +483,9 @@ def func(): assert_eq!(func_scope_1.kind(), ScopeKind::Function); - assert_eq!(func_scope_1.name(&db, file), "func"); + assert_eq!(func_scope1_id.to_scope_id(&db, file).name(&db), "func"); assert_eq!(func_scope_2.kind(), ScopeKind::Function); - assert_eq!(func_scope_2.name(&db, file), "func"); + assert_eq!(func_scope2_id.to_scope_id(&db, file).name(&db), "func"); let func1_table = index.symbol_table(func_scope1_id); let func2_table = index.symbol_table(func_scope2_id); @@ -576,7 +520,7 @@ def func[T](): let (ann_scope_id, ann_scope) = scopes[0]; assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!(ann_scope.name(&db, file), "func"); + assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db), "func"); let ann_table = index.symbol_table(ann_scope_id); assert_eq!(names(&ann_table), vec!["T"]); @@ -584,7 +528,7 @@ def func[T](): assert_eq!(scopes.len(), 1); let (func_scope_id, func_scope) = scopes[0]; assert_eq!(func_scope.kind(), ScopeKind::Function); - assert_eq!(func_scope.name(&db, file), "func"); + assert_eq!(func_scope_id.to_scope_id(&db, file).name(&db), "func"); let func_table = index.symbol_table(func_scope_id); assert_eq!(names(&func_table), vec!["x"]); } @@ -608,7 +552,7 @@ class C[T]: assert_eq!(scopes.len(), 1); let (ann_scope_id, ann_scope) = scopes[0]; assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!(ann_scope.name(&db, file), "C"); + assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db), "C"); let ann_table = index.symbol_table(ann_scope_id); assert_eq!(names(&ann_table), vec!["T"]); assert!( @@ -620,11 +564,11 @@ class C[T]: let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect(); assert_eq!(scopes.len(), 1); - let (func_scope_id, class_scope) = scopes[0]; + let (class_scope_id, class_scope) = scopes[0]; assert_eq!(class_scope.kind(), ScopeKind::Class); - assert_eq!(class_scope.name(&db, file), "C"); - assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]); + assert_eq!(class_scope_id.to_scope_id(&db, file).name(&db), "C"); + assert_eq!(names(&index.symbol_table(class_scope_id)), vec!["x"]); } // TODO: After porting the control flow graph. @@ -691,7 +635,7 @@ class C[T]: ) -> Vec<&'a str> { scopes .into_iter() - .map(|(_, scope)| scope.name(db, file)) + .map(|(scope_id, _)| scope_id.to_scope_id(db, file).name(db)) .collect() } diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 892d92fc40..86f17216b8 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -1,16 +1,12 @@ use rustc_hash::FxHashMap; -use ruff_db::parsed::ParsedModule; -use ruff_db::vfs::VfsFile; -use ruff_index::{newtype_index, IndexVec}; +use ruff_index::{newtype_index, Idx}; use ruff_python_ast as ast; -use ruff_python_ast::{AnyNodeRef, ExpressionRef}; +use ruff_python_ast::ExpressionRef; -use crate::ast_node_ref::AstNodeRef; -use crate::node_key::NodeKey; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::semantic_index; -use crate::semantic_index::symbol::{FileScopeId, ScopeId}; +use crate::semantic_index::symbol::ScopeId; use crate::Db; /// AST ids for a single scope. @@ -28,41 +24,18 @@ use crate::Db; /// /// x = foo() /// ``` +#[derive(Debug)] pub(crate) struct AstIds { - /// Maps expression ids to their expressions. - expressions: IndexVec>, - /// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`]. expressions_map: FxHashMap, - - statements: IndexVec>, - - statements_map: FxHashMap, } impl AstIds { - fn statement_id<'a, N>(&self, node: N) -> ScopedStatementId - where - N: Into>, - { - self.statements_map[&NodeKey::from_node(node.into())] - } - fn expression_id(&self, key: impl Into) -> ScopedExpressionId { self.expressions_map[&key.into()] } } -#[allow(clippy::missing_fields_in_debug)] -impl std::fmt::Debug for AstIds { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("AstIds") - .field("expressions", &self.expressions) - .field("statements", &self.statements) - .finish() - } -} - fn ast_ids<'db>(db: &'db dyn Db, scope: ScopeId) -> &'db AstIds { semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db)) } @@ -75,79 +48,7 @@ pub trait HasScopedAstId { fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id; } -/// Node that can be uniquely identified by an id in a [`FileScopeId`]. -pub trait ScopedAstIdNode: HasScopedAstId { - /// Looks up the AST node by its ID. - /// - /// ## Panics - /// May panic if the `id` does not belong to the AST of `scope`. - fn lookup_in_scope<'db>(db: &'db dyn Db, scope: ScopeId<'db>, id: Self::Id) -> &'db Self - where - Self: Sized; -} - -/// Extension trait for AST nodes that can be resolved by an `AstId`. -pub trait AstIdNode { - type ScopeId: Copy; - - /// Resolves the AST id of the node. - /// - /// ## Panics - /// May panic if the node does not belong to `scope`. It may also - /// return an incorrect node if that's the case. - fn ast_id(&self, db: &dyn Db, scope: ScopeId) -> AstId; - - /// Resolves the AST node for `id`. - /// - /// ## Panics - /// May panic if the `id` does not belong to the AST of `file` or it returns an incorrect node. - - fn lookup(db: &dyn Db, file: VfsFile, id: AstId) -> &Self - where - Self: Sized; -} - -impl AstIdNode for T -where - T: ScopedAstIdNode, -{ - type ScopeId = T::Id; - - fn ast_id(&self, db: &dyn Db, scope: ScopeId) -> AstId { - let in_scope_id = self.scoped_ast_id(db, scope); - AstId { - scope: scope.file_scope_id(db), - in_scope_id, - } - } - - fn lookup(db: &dyn Db, file: VfsFile, id: AstId) -> &Self - where - Self: Sized, - { - let scope = id.scope.to_scope_id(db, file); - - Self::lookup_in_scope(db, scope, id.in_scope_id) - } -} - -/// Uniquely identifies an AST node in a file. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct AstId { - /// The node's scope. - scope: FileScopeId, - - /// The ID of the node inside [`Self::scope`]. - in_scope_id: L, -} - -impl AstId { - pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self { - Self { scope, in_scope_id } - } -} - -/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`]. +/// Uniquely identifies an [`ast::Expr`] in a [`crate::semantic_index::symbol::FileScopeId`]. #[newtype_index] pub struct ScopedExpressionId; @@ -207,133 +108,29 @@ impl HasScopedAstId for ast::ExpressionRef<'_> { } } -impl ScopedAstIdNode for ast::Expr { - fn lookup_in_scope<'db>(db: &'db dyn Db, scope: ScopeId<'db>, id: Self::Id) -> &'db Self { - let ast_ids = ast_ids(db, scope); - ast_ids.expressions[id].node() - } -} - -/// Uniquely identifies an [`ast::Stmt`] in a [`FileScopeId`]. -#[newtype_index] -pub struct ScopedStatementId; - -macro_rules! impl_has_scoped_statement_id { - ($ty: ty) => { - impl HasScopedAstId for $ty { - type Id = ScopedStatementId; - - fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { - let ast_ids = ast_ids(db, scope); - ast_ids.statement_id(self) - } - } - }; -} - -impl_has_scoped_statement_id!(ast::Stmt); - -impl ScopedAstIdNode for ast::Stmt { - fn lookup_in_scope<'db>(db: &'db dyn Db, scope: ScopeId<'db>, id: Self::Id) -> &'db Self { - let ast_ids = ast_ids(db, scope); - - ast_ids.statements[id].node() - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopedFunctionId(pub(super) ScopedStatementId); - -impl HasScopedAstId for ast::StmtFunctionDef { - type Id = ScopedFunctionId; - - fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { - let ast_ids = ast_ids(db, scope); - ScopedFunctionId(ast_ids.statement_id(self)) - } -} - -impl ScopedAstIdNode for ast::StmtFunctionDef { - fn lookup_in_scope<'db>(db: &'db dyn Db, scope: ScopeId<'db>, id: Self::Id) -> &'db Self { - ast::Stmt::lookup_in_scope(db, scope, id.0) - .as_function_def_stmt() - .unwrap() - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopedClassId(pub(super) ScopedStatementId); - -impl HasScopedAstId for ast::StmtClassDef { - type Id = ScopedClassId; - - fn scoped_ast_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { - let ast_ids = ast_ids(db, scope); - ScopedClassId(ast_ids.statement_id(self)) - } -} - -impl ScopedAstIdNode for ast::StmtClassDef { - fn lookup_in_scope<'db>(db: &'db dyn Db, scope: ScopeId<'db>, id: Self::Id) -> &'db Self { - let statement = ast::Stmt::lookup_in_scope(db, scope, id.0); - statement.as_class_def_stmt().unwrap() - } -} - -impl_has_scoped_statement_id!(ast::StmtAssign); -impl_has_scoped_statement_id!(ast::StmtAnnAssign); -impl_has_scoped_statement_id!(ast::StmtImport); -impl_has_scoped_statement_id!(ast::StmtImportFrom); - #[derive(Debug)] pub(super) struct AstIdsBuilder { - expressions: IndexVec>, + next_id: ScopedExpressionId, expressions_map: FxHashMap, - statements: IndexVec>, - statements_map: FxHashMap, } impl AstIdsBuilder { pub(super) fn new() -> Self { Self { - expressions: IndexVec::default(), + next_id: ScopedExpressionId::new(0), expressions_map: FxHashMap::default(), - statements: IndexVec::default(), - statements_map: FxHashMap::default(), } } - /// Adds `stmt` to the AST ids map and returns its id. - /// - /// ## Safety - /// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires - /// that `stmt` is a child of `parsed`. - #[allow(unsafe_code)] - pub(super) unsafe fn record_statement( - &mut self, - stmt: &ast::Stmt, - parsed: &ParsedModule, - ) -> ScopedStatementId { - let statement_id = self.statements.push(AstNodeRef::new(parsed.clone(), stmt)); - - self.statements_map - .insert(NodeKey::from_node(stmt), statement_id); - - statement_id - } - /// Adds `expr` to the AST ids map and returns its id. /// /// ## Safety /// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires /// that `expr` is a child of `parsed`. #[allow(unsafe_code)] - pub(super) unsafe fn record_expression( - &mut self, - expr: &ast::Expr, - parsed: &ParsedModule, - ) -> ScopedExpressionId { - let expression_id = self.expressions.push(AstNodeRef::new(parsed.clone(), expr)); + pub(super) fn record_expression(&mut self, expr: &ast::Expr) -> ScopedExpressionId { + let expression_id = self.next_id; + self.next_id = expression_id + 1; self.expressions_map.insert(expr.into(), expression_id); @@ -341,28 +138,22 @@ impl AstIdsBuilder { } pub(super) fn finish(mut self) -> AstIds { - self.expressions.shrink_to_fit(); self.expressions_map.shrink_to_fit(); - self.statements.shrink_to_fit(); - self.statements_map.shrink_to_fit(); AstIds { - expressions: self.expressions, expressions_map: self.expressions_map, - statements: self.statements, - statements_map: self.statements_map, } } } /// Node key that can only be constructed for expressions. -mod node_key { +pub(crate) mod node_key { use ruff_python_ast as ast; use crate::node_key::NodeKey; #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] - pub(super) struct ExpressionNodeKey(NodeKey); + pub(crate) struct ExpressionNodeKey(NodeKey); impl From> for ExpressionNodeKey { fn from(value: ast::ExpressionRef<'_>) -> Self { diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 750f928229..e4a2d60184 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -3,52 +3,64 @@ use std::sync::Arc; use rustc_hash::FxHashMap; use ruff_db::parsed::ParsedModule; +use ruff_db::vfs::VfsFile; use ruff_index::IndexVec; use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor}; -use crate::node_key::NodeKey; -use crate::semantic_index::ast_ids::{AstId, AstIdsBuilder, ScopedClassId, ScopedFunctionId}; -use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; +use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; +use crate::semantic_index::ast_ids::AstIdsBuilder; +use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef}; use crate::semantic_index::symbol::{ - FileScopeId, Scope, ScopeKind, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, + FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolFlags, + SymbolTableBuilder, }; -use crate::semantic_index::{NodeWithScopeId, NodeWithScopeKey, SemanticIndex}; +use crate::semantic_index::SemanticIndex; +use crate::Db; -pub(super) struct SemanticIndexBuilder<'a> { +pub(super) struct SemanticIndexBuilder<'db, 'ast> { // Builder state - module: &'a ParsedModule, + db: &'db dyn Db, + file: VfsFile, + module: &'db ParsedModule, scope_stack: Vec, - /// the definition whose target(s) we are currently walking - current_definition: Option, + /// the target we're currently inferring + current_target: Option>, // Semantic Index fields scopes: IndexVec, - symbol_tables: IndexVec, + scope_ids_by_scope: IndexVec>, + symbol_tables: IndexVec>, ast_ids: IndexVec, - scopes_by_expression: FxHashMap, - scopes_by_definition: FxHashMap, + scopes_by_node: FxHashMap, + scopes_by_expression: FxHashMap, + definitions_by_node: FxHashMap>, } -impl<'a> SemanticIndexBuilder<'a> { - pub(super) fn new(parsed: &'a ParsedModule) -> Self { +impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> +where + 'db: 'ast, +{ + pub(super) fn new(db: &'db dyn Db, file: VfsFile, parsed: &'db ParsedModule) -> Self { let mut builder = Self { + db, + file, module: parsed, scope_stack: Vec::new(), - current_definition: None, + current_target: None, scopes: IndexVec::new(), symbol_tables: IndexVec::new(), ast_ids: IndexVec::new(), + scope_ids_by_scope: IndexVec::new(), + scopes_by_expression: FxHashMap::default(), - scopes_by_definition: FxHashMap::default(), + scopes_by_node: FxHashMap::default(), + definitions_by_node: FxHashMap::default(), }; - builder.push_scope_with_parent( - &NodeWithScope::new(parsed.syntax(), NodeWithScopeId::Module), - None, - ); + builder.push_scope_with_parent(NodeWithScopeRef::Module, None); builder } @@ -60,29 +72,40 @@ impl<'a> SemanticIndexBuilder<'a> { .expect("Always to have a root scope") } - fn push_scope(&mut self, node: &NodeWithScope) { + fn push_scope(&mut self, node: NodeWithScopeRef<'ast>) { let parent = self.current_scope(); self.push_scope_with_parent(node, Some(parent)); } - fn push_scope_with_parent(&mut self, node: &NodeWithScope, parent: Option) { + fn push_scope_with_parent( + &mut self, + node: NodeWithScopeRef<'ast>, + parent: Option, + ) { let children_start = self.scopes.next_index() + 1; let scope = Scope { - node: node.id(), parent, kind: node.scope_kind(), descendents: children_start..children_start, }; - let scope_id = self.scopes.push(scope); + let file_scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::new()); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); - debug_assert_eq!(ast_id_scope, scope_id); + #[allow(unsafe_code)] + // SAFETY: `node` is guaranteed to be a child of `self.module` + let scope_id = ScopeId::new(self.db, self.file, file_scope_id, unsafe { + node.to_kind(self.module.clone()) + }); - self.scope_stack.push(scope_id); - self.scopes_by_definition.insert(node.key(), scope_id); + self.scope_ids_by_scope.push(scope_id); + self.scopes_by_node.insert(node.node_key(), file_scope_id); + + debug_assert_eq!(ast_id_scope, file_scope_id); + + self.scope_stack.push(file_scope_id); } fn pop_scope(&mut self) -> FileScopeId { @@ -93,7 +116,7 @@ impl<'a> SemanticIndexBuilder<'a> { id } - fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder { + fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder<'db> { let scope_id = self.current_scope(); &mut self.symbol_tables[scope_id] } @@ -105,33 +128,64 @@ impl<'a> SemanticIndexBuilder<'a> { fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); - symbol_table.add_or_update_symbol(name, flags, None) + symbol_table.add_or_update_symbol(name, flags) + } + + fn add_definition( + &mut self, + definition_node: impl Into>, + symbol_id: ScopedSymbolId, + ) -> Definition<'db> { + let definition_node = definition_node.into(); + let definition = Definition::new( + self.db, + self.file, + self.current_scope(), + symbol_id, + #[allow(unsafe_code)] + unsafe { + definition_node.into_owned(self.module.clone()) + }, + ); + + self.definitions_by_node + .insert(definition_node.key(), definition); + + definition } fn add_or_update_symbol_with_definition( &mut self, name: Name, - definition: Definition, - ) -> ScopedSymbolId { + definition: impl Into>, + ) -> (ScopedSymbolId, Definition<'db>) { let symbol_table = self.current_symbol_table(); - symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED, Some(definition)) + let id = symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED); + let definition = self.add_definition(definition, id); + self.current_symbol_table().add_definition(id, definition); + (id, definition) } fn with_type_params( &mut self, - with_params: &WithTypeParams, + with_params: &WithTypeParams<'ast>, nested: impl FnOnce(&mut Self) -> FileScopeId, ) -> FileScopeId { let type_params = with_params.type_parameters(); if let Some(type_params) = type_params { - let type_params_id = match with_params { - WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id), - WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id), + let with_scope = match with_params { + WithTypeParams::ClassDef { node, .. } => { + NodeWithScopeRef::ClassTypeParameters(node) + } + WithTypeParams::FunctionDef { node, .. } => { + NodeWithScopeRef::FunctionTypeParameters(node) + } }; - self.push_scope(&NodeWithScope::new(type_params, type_params_id)); + self.push_scope(with_scope); + for type_param in &type_params.type_params { let name = match type_param { ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, @@ -151,7 +205,7 @@ impl<'a> SemanticIndexBuilder<'a> { nested_scope } - pub(super) fn build(mut self) -> SemanticIndex { + pub(super) fn build(mut self) -> SemanticIndex<'db> { let module = self.module; self.visit_body(module.suite()); @@ -159,7 +213,7 @@ impl<'a> SemanticIndexBuilder<'a> { self.pop_scope(); assert!(self.scope_stack.is_empty()); - assert!(self.current_definition.is_none()); + assert!(self.current_target.is_none()); let mut symbol_tables: IndexVec<_, _> = self .symbol_tables @@ -177,53 +231,48 @@ impl<'a> SemanticIndexBuilder<'a> { ast_ids.shrink_to_fit(); symbol_tables.shrink_to_fit(); self.scopes_by_expression.shrink_to_fit(); + self.definitions_by_node.shrink_to_fit(); + + self.scope_ids_by_scope.shrink_to_fit(); + self.scopes_by_node.shrink_to_fit(); SemanticIndex { symbol_tables, scopes: self.scopes, - scopes_by_definition: self.scopes_by_definition, + definitions_by_node: self.definitions_by_node, + scope_ids_by_scope: self.scope_ids_by_scope, ast_ids, scopes_by_expression: self.scopes_by_expression, + scopes_by_node: self.scopes_by_node, } } } -impl Visitor<'_> for SemanticIndexBuilder<'_> { - fn visit_stmt(&mut self, stmt: &ast::Stmt) { - let module = self.module; - #[allow(unsafe_code)] - let statement_id = unsafe { - // SAFETY: The builder only visits nodes that are part of `module`. This guarantees that - // the current statement must be a child of `module`. - self.current_ast_ids().record_statement(stmt, module) - }; +impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db, 'ast> +where + 'db: 'ast, +{ + fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) { match stmt { ast::Stmt::FunctionDef(function_def) => { for decorator in &function_def.decorator_list { self.visit_decorator(decorator); } - let name = &function_def.name.id; - let function_id = ScopedFunctionId(statement_id); - let definition = Definition::FunctionDef(function_id); - let scope = self.current_scope(); - self.add_or_update_symbol_with_definition(name.clone(), definition); + self.add_or_update_symbol_with_definition( + function_def.name.id.clone(), + function_def, + ); self.with_type_params( - &WithTypeParams::FunctionDef { - node: function_def, - id: AstId::new(scope, function_id), - }, + &WithTypeParams::FunctionDef { node: function_def }, |builder| { builder.visit_parameters(&function_def.parameters); for expr in &function_def.returns { builder.visit_annotation(expr); } - builder.push_scope(&NodeWithScope::new( - function_def, - NodeWithScopeId::Function(AstId::new(scope, function_id)), - )); + builder.push_scope(NodeWithScopeRef::Function(function_def)); builder.visit_body(&function_def.body); builder.pop_scope() }, @@ -234,46 +283,28 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { self.visit_decorator(decorator); } - let name = &class.name.id; - let class_id = ScopedClassId(statement_id); - let definition = Definition::ClassDef(class_id); - let scope = self.current_scope(); + self.add_or_update_symbol_with_definition(class.name.id.clone(), class); - self.add_or_update_symbol_with_definition(name.clone(), definition); + self.with_type_params(&WithTypeParams::ClassDef { node: class }, |builder| { + if let Some(arguments) = &class.arguments { + builder.visit_arguments(arguments); + } - self.with_type_params( - &WithTypeParams::ClassDef { - node: class, - id: AstId::new(scope, class_id), - }, - |builder| { - if let Some(arguments) = &class.arguments { - builder.visit_arguments(arguments); - } + builder.push_scope(NodeWithScopeRef::Class(class)); + builder.visit_body(&class.body); - builder.push_scope(&NodeWithScope::new( - class, - NodeWithScopeId::Class(AstId::new(scope, class_id)), - )); - builder.visit_body(&class.body); - - builder.pop_scope() - }, - ); + builder.pop_scope() + }); } ast::Stmt::Import(ast::StmtImport { names, .. }) => { - for (i, alias) in names.iter().enumerate() { + for alias in names { let symbol_name = if let Some(asname) = &alias.asname { asname.id.clone() } else { Name::new(alias.name.id.split('.').next().unwrap()) }; - let def = Definition::Import(ImportDefinition { - import_id: statement_id, - alias: u32::try_from(i).unwrap(), - }); - self.add_or_update_symbol_with_definition(symbol_name, def); + self.add_or_update_symbol_with_definition(symbol_name, alias); } } ast::Stmt::ImportFrom(ast::StmtImportFrom { @@ -282,27 +313,24 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { level: _, .. }) => { - for (i, alias) in names.iter().enumerate() { + for alias in names { let symbol_name = if let Some(asname) = &alias.asname { &asname.id } else { &alias.name.id }; - let def = Definition::ImportFrom(ImportFromDefinition { - import_id: statement_id, - name: u32::try_from(i).unwrap(), - }); - self.add_or_update_symbol_with_definition(symbol_name.clone(), def); + + self.add_or_update_symbol_with_definition(symbol_name.clone(), alias); } } ast::Stmt::Assign(node) => { - debug_assert!(self.current_definition.is_none()); + debug_assert!(self.current_target.is_none()); self.visit_expr(&node.value); - self.current_definition = Some(Definition::Assignment(statement_id)); for target in &node.targets { + self.current_target = Some(CurrentTarget::Expr(target)); self.visit_expr(target); } - self.current_definition = None; + self.current_target = None; } _ => { walk_stmt(self, stmt); @@ -310,17 +338,10 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } } - fn visit_expr(&mut self, expr: &'_ ast::Expr) { - let module = self.module; - #[allow(unsafe_code)] - let expression_id = unsafe { - // SAFETY: The builder only visits nodes that are part of `module`. This guarantees that - // the current expression must be a child of `module`. - self.current_ast_ids().record_expression(expr, module) - }; - + fn visit_expr(&mut self, expr: &'ast ast::Expr) { self.scopes_by_expression - .insert(NodeKey::from_node(expr), self.current_scope()); + .insert(expr.into(), self.current_scope()); + self.current_ast_ids().record_expression(expr); match expr { ast::Expr::Name(ast::ExprName { id, ctx, .. }) => { @@ -330,9 +351,9 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { ast::ExprContext::Del => SymbolFlags::IS_DEFINED, ast::ExprContext::Invalid => SymbolFlags::empty(), }; - match self.current_definition { - Some(definition) if flags.contains(SymbolFlags::IS_DEFINED) => { - self.add_or_update_symbol_with_definition(id.clone(), definition); + match self.current_target { + Some(target) if flags.contains(SymbolFlags::IS_DEFINED) => { + self.add_or_update_symbol_with_definition(id.clone(), target); } _ => { self.add_or_update_symbol(id.clone(), flags); @@ -342,11 +363,11 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { walk_expr(self, expr); } ast::Expr::Named(node) => { - debug_assert!(self.current_definition.is_none()); - self.current_definition = Some(Definition::NamedExpr(expression_id)); + debug_assert!(self.current_target.is_none()); + self.current_target = Some(CurrentTarget::ExprNamed(node)); // TODO walrus in comprehensions is implicitly nonlocal self.visit_expr(&node.target); - self.current_definition = None; + self.current_target = None; self.visit_expr(&node.value); } ast::Expr::If(ast::ExprIf { @@ -382,19 +403,13 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } } -enum WithTypeParams<'a> { - ClassDef { - node: &'a ast::StmtClassDef, - id: AstId, - }, - FunctionDef { - node: &'a ast::StmtFunctionDef, - id: AstId, - }, +enum WithTypeParams<'node> { + ClassDef { node: &'node ast::StmtClassDef }, + FunctionDef { node: &'node ast::StmtFunctionDef }, } -impl<'a> WithTypeParams<'a> { - fn type_parameters(&self) -> Option<&'a ast::TypeParams> { +impl<'node> WithTypeParams<'node> { + fn type_parameters(&self) -> Option<&'node ast::TypeParams> { match self { WithTypeParams::ClassDef { node, .. } => node.type_params.as_deref(), WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(), @@ -402,35 +417,17 @@ impl<'a> WithTypeParams<'a> { } } -struct NodeWithScope { - id: NodeWithScopeId, - key: NodeWithScopeKey, +#[derive(Copy, Clone, Debug)] +enum CurrentTarget<'a> { + Expr(&'a ast::Expr), + ExprNamed(&'a ast::ExprNamed), } -impl NodeWithScope { - fn new(node: impl Into, id: NodeWithScopeId) -> Self { - Self { - id, - key: node.into(), - } - } - - fn id(&self) -> NodeWithScopeId { - self.id - } - - fn key(&self) -> NodeWithScopeKey { - self.key - } - - fn scope_kind(&self) -> ScopeKind { - match self.id { - NodeWithScopeId::Module => ScopeKind::Module, - NodeWithScopeId::Class(_) => ScopeKind::Class, - NodeWithScopeId::Function(_) => ScopeKind::Function, - NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => { - ScopeKind::Annotation - } +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(val: CurrentTarget<'a>) -> Self { + match val { + CurrentTarget::Expr(expression) => DefinitionNodeRef::Target(expression), + CurrentTarget::ExprNamed(named) => DefinitionNodeRef::NamedExpression(named), } } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index f1427ace93..90081435be 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -1,57 +1,103 @@ -use crate::semantic_index::ast_ids::{ - ScopedClassId, ScopedExpressionId, ScopedFunctionId, ScopedStatementId, -}; +use ruff_db::parsed::ParsedModule; +use ruff_db::vfs::VfsFile; +use ruff_python_ast as ast; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum Definition { - Import(ImportDefinition), - ImportFrom(ImportFromDefinition), - ClassDef(ScopedClassId), - FunctionDef(ScopedFunctionId), - Assignment(ScopedStatementId), - AnnotatedAssignment(ScopedStatementId), - NamedExpr(ScopedExpressionId), - /// represents the implicit initial definition of every name as "unbound" - Unbound, - // TODO with statements, except handlers, function args... +use crate::ast_node_ref::AstNodeRef; +use crate::node_key::NodeKey; +use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId}; + +#[salsa::tracked] +pub struct Definition<'db> { + /// The file in which the definition is defined. + #[id] + pub(super) file: VfsFile, + + /// The scope in which the definition is defined. + #[id] + pub(crate) scope: FileScopeId, + + /// The id of the corresponding symbol. Mainly used as ID. + #[id] + symbol_id: ScopedSymbolId, + + #[no_eq] + #[return_ref] + pub(crate) node: DefinitionKind, } -impl From for Definition { - fn from(value: ImportDefinition) -> Self { - Self::Import(value) +#[derive(Copy, Clone, Debug)] +pub(crate) enum DefinitionNodeRef<'a> { + Alias(&'a ast::Alias), + Function(&'a ast::StmtFunctionDef), + Class(&'a ast::StmtClassDef), + NamedExpression(&'a ast::ExprNamed), + Target(&'a ast::Expr), +} + +impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> { + fn from(node: &'a ast::Alias) -> Self { + Self::Alias(node) + } +} +impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { + fn from(node: &'a ast::StmtFunctionDef) -> Self { + Self::Function(node) + } +} +impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> { + fn from(node: &'a ast::StmtClassDef) -> Self { + Self::Class(node) + } +} +impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> { + fn from(node: &'a ast::ExprNamed) -> Self { + Self::NamedExpression(node) } } -impl From for Definition { - fn from(value: ImportFromDefinition) -> Self { - Self::ImportFrom(value) +impl DefinitionNodeRef<'_> { + #[allow(unsafe_code)] + pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind { + match self { + DefinitionNodeRef::Alias(alias) => { + DefinitionKind::Alias(AstNodeRef::new(parsed, alias)) + } + DefinitionNodeRef::Function(function) => { + DefinitionKind::Function(AstNodeRef::new(parsed, function)) + } + DefinitionNodeRef::Class(class) => { + DefinitionKind::Class(AstNodeRef::new(parsed, class)) + } + DefinitionNodeRef::NamedExpression(named) => { + DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named)) + } + DefinitionNodeRef::Target(target) => { + DefinitionKind::Target(AstNodeRef::new(parsed, target)) + } + } } } -impl From for Definition { - fn from(value: ScopedClassId) -> Self { - Self::ClassDef(value) +impl DefinitionNodeRef<'_> { + pub(super) fn key(self) -> DefinitionNodeKey { + match self { + Self::Alias(node) => DefinitionNodeKey(NodeKey::from_node(node)), + Self::Function(node) => DefinitionNodeKey(NodeKey::from_node(node)), + Self::Class(node) => DefinitionNodeKey(NodeKey::from_node(node)), + Self::NamedExpression(node) => DefinitionNodeKey(NodeKey::from_node(node)), + Self::Target(node) => DefinitionNodeKey(NodeKey::from_node(node)), + } } } -impl From for Definition { - fn from(value: ScopedFunctionId) -> Self { - Self::FunctionDef(value) - } +#[derive(Clone, Debug)] +pub enum DefinitionKind { + Alias(AstNodeRef), + Function(AstNodeRef), + Class(AstNodeRef), + NamedExpression(AstNodeRef), + Target(AstNodeRef), } -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct ImportDefinition { - pub(crate) import_id: ScopedStatementId, - - /// Index into [`ruff_python_ast::StmtImport::names`]. - pub(crate) alias: u32, -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] -pub struct ImportFromDefinition { - pub(crate) import_id: ScopedStatementId, - - /// Index into [`ruff_python_ast::StmtImportFrom::names`]. - pub(crate) name: u32, -} +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub(super) struct DefinitionNodeKey(NodeKey); diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index 8c5ebb8c23..dc746081fa 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -3,35 +3,39 @@ use std::ops::Range; use bitflags::bitflags; use hashbrown::hash_map::RawEntryMut; -use rustc_hash::FxHasher; -use smallvec::SmallVec; - +use ruff_db::parsed::ParsedModule; use ruff_db::vfs::VfsFile; use ruff_index::{newtype_index, IndexVec}; use ruff_python_ast::name::Name; +use ruff_python_ast::{self as ast}; +use rustc_hash::FxHasher; +use crate::ast_node_ref::AstNodeRef; +use crate::node_key::NodeKey; use crate::semantic_index::definition::Definition; -use crate::semantic_index::{root_scope, semantic_index, symbol_table, NodeWithScopeId, SymbolMap}; +use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap}; use crate::Db; #[derive(Eq, PartialEq, Debug)] -pub struct Symbol { +pub struct Symbol<'db> { name: Name, flags: SymbolFlags, /// The nodes that define this symbol, in source order. - definitions: SmallVec<[Definition; 4]>, + /// + /// TODO: Use smallvec here, but it creates the same lifetime issues as in [QualifiedName](https://github.com/astral-sh/ruff/blob/5109b50bb3847738eeb209352cf26bda392adf62/crates/ruff_python_ast/src/name.rs#L562-L569) + definitions: Vec>, } -impl Symbol { - fn new(name: Name, definition: Option) -> Self { +impl<'db> Symbol<'db> { + fn new(name: Name) -> Self { Self { name, flags: SymbolFlags::empty(), - definitions: definition.into_iter().collect(), + definitions: Vec::new(), } } - fn push_definition(&mut self, definition: Definition) { + fn push_definition(&mut self, definition: Definition<'db>) { self.definitions.push(definition); } @@ -118,39 +122,6 @@ impl ScopedSymbolId { } } -/// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`]. -#[salsa::tracked(return_ref)] -pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap<'_> { - let _span = tracing::trace_span!("scopes_map", ?file).entered(); - - let index = semantic_index(db, file); - - let scopes: IndexVec<_, _> = index - .scopes - .indices() - .map(|id| ScopeId::new(db, file, id)) - .collect(); - - ScopesMap { scopes } -} - -/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter. -/// -/// The [`SemanticIndex`] uses [`FileScopeId`] on a per-file level to identify scopes -/// because they allow for more efficient storage of associated data -/// (use of an [`IndexVec`] keyed by [`FileScopeId`] over an [`FxHashMap`] keyed by [`ScopeId`]). -#[derive(Eq, PartialEq, Debug)] -pub(crate) struct ScopesMap<'db> { - scopes: IndexVec>, -} - -impl<'db> ScopesMap<'db> { - /// Gets the program-wide unique scope id for the given file specific `scope_id`. - fn get(&self, scope: FileScopeId) -> ScopeId<'db> { - self.scopes[scope] - } -} - #[salsa::tracked(return_ref)] pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap<'_> { let _span = tracing::trace_span!("public_symbols_map", ?file).entered(); @@ -189,6 +160,25 @@ pub struct ScopeId<'db> { pub file: VfsFile, #[id] pub file_scope_id: FileScopeId, + + /// The node that introduces this scope. + #[no_eq] + #[return_ref] + pub node: NodeWithScopeKind, +} + +impl<'db> ScopeId<'db> { + #[cfg(test)] + pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { + match self.node(db) { + NodeWithScopeKind::Module => "", + NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => { + class.name.as_str() + } + NodeWithScopeKind::Function(function) + | NodeWithScopeKind::FunctionTypeParameters(function) => function.name.as_str(), + } + } } /// ID that uniquely identifies a scope inside of a module. @@ -202,42 +192,19 @@ impl FileScopeId { } pub fn to_scope_id(self, db: &dyn Db, file: VfsFile) -> ScopeId<'_> { - scopes_map(db, file).get(self) + let index = semantic_index(db, file); + index.scope_ids_by_scope[self] } } #[derive(Debug, Eq, PartialEq)] pub struct Scope { pub(super) parent: Option, - pub(super) node: NodeWithScopeId, pub(super) kind: ScopeKind, pub(super) descendents: Range, } impl Scope { - #[cfg(test)] - pub(crate) fn name<'db>(&self, db: &'db dyn Db, file: VfsFile) -> &'db str { - use crate::semantic_index::ast_ids::AstIdNode; - use ruff_python_ast as ast; - - match self.node { - NodeWithScopeId::Module => "", - NodeWithScopeId::Class(class) | NodeWithScopeId::ClassTypeParams(class) => { - let class = ast::StmtClassDef::lookup(db, file, class); - class.name.as_str() - } - NodeWithScopeId::Function(function) | NodeWithScopeId::FunctionTypeParams(function) => { - let function = ast::StmtFunctionDef::lookup(db, file, function); - function.name.as_str() - } - } - } - - /// The node that creates this scope. - pub(crate) fn node(&self) -> NodeWithScopeId { - self.node - } - pub fn parent(self) -> Option { self.parent } @@ -257,15 +224,15 @@ pub enum ScopeKind { /// Symbol table for a specific [`Scope`]. #[derive(Debug)] -pub struct SymbolTable { +pub struct SymbolTable<'db> { /// The symbols in this scope. - symbols: IndexVec, + symbols: IndexVec>, /// The symbols indexed by name. symbols_by_name: SymbolMap, } -impl SymbolTable { +impl<'db> SymbolTable<'db> { fn new() -> Self { Self { symbols: IndexVec::new(), @@ -277,21 +244,21 @@ impl SymbolTable { self.symbols.shrink_to_fit(); } - pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol { + pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol<'db> { &self.symbols[symbol_id.into()] } - pub(crate) fn symbol_ids(&self) -> impl Iterator { + pub(crate) fn symbol_ids(&self) -> impl Iterator + 'db { self.symbols.indices() } - pub fn symbols(&self) -> impl Iterator { + pub fn symbols(&self) -> impl Iterator> { self.symbols.iter() } /// Returns the symbol named `name`. #[allow(unused)] - pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> { + pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol<'db>> { let id = self.symbol_id_by_name(name)?; Some(self.symbol(id)) } @@ -315,21 +282,21 @@ impl SymbolTable { } } -impl PartialEq for SymbolTable { +impl PartialEq for SymbolTable<'_> { fn eq(&self, other: &Self) -> bool { // We don't need to compare the symbols_by_name because the name is already captured in `Symbol`. self.symbols == other.symbols } } -impl Eq for SymbolTable {} +impl Eq for SymbolTable<'_> {} #[derive(Debug)] -pub(super) struct SymbolTableBuilder { - table: SymbolTable, +pub(super) struct SymbolTableBuilder<'db> { + table: SymbolTable<'db>, } -impl SymbolTableBuilder { +impl<'db> SymbolTableBuilder<'db> { pub(super) fn new() -> Self { Self { table: SymbolTable::new(), @@ -340,7 +307,6 @@ impl SymbolTableBuilder { &mut self, name: Name, flags: SymbolFlags, - definition: Option, ) -> ScopedSymbolId { let hash = SymbolTable::hash_name(&name); let entry = self @@ -354,14 +320,10 @@ impl SymbolTableBuilder { let symbol = &mut self.table.symbols[*entry.key()]; symbol.insert_flags(flags); - if let Some(definition) = definition { - symbol.push_definition(definition); - } - *entry.key() } RawEntryMut::Vacant(entry) => { - let mut symbol = Symbol::new(name, definition); + let mut symbol = Symbol::new(name); symbol.insert_flags(flags); let id = self.table.symbols.push(symbol); @@ -373,8 +335,92 @@ impl SymbolTableBuilder { } } - pub(super) fn finish(mut self) -> SymbolTable { + pub(super) fn add_definition(&mut self, symbol: ScopedSymbolId, definition: Definition<'db>) { + self.table.symbols[symbol].push_definition(definition); + } + + pub(super) fn finish(mut self) -> SymbolTable<'db> { self.table.shrink_to_fit(); self.table } } + +/// Reference to a node that introduces a new scope. +#[derive(Copy, Clone, Debug)] +pub(crate) enum NodeWithScopeRef<'a> { + Module, + Class(&'a ast::StmtClassDef), + Function(&'a ast::StmtFunctionDef), + FunctionTypeParameters(&'a ast::StmtFunctionDef), + ClassTypeParameters(&'a ast::StmtClassDef), +} + +impl NodeWithScopeRef<'_> { + /// Converts the unowned reference to an owned [`NodeWithScopeKind`]. + /// + /// # Safety + /// The node wrapped by `self` must be a child of `module`. + #[allow(unsafe_code)] + pub(super) unsafe fn to_kind(self, module: ParsedModule) -> NodeWithScopeKind { + match self { + NodeWithScopeRef::Module => NodeWithScopeKind::Module, + NodeWithScopeRef::Class(class) => { + NodeWithScopeKind::Class(AstNodeRef::new(module, class)) + } + NodeWithScopeRef::Function(function) => { + NodeWithScopeKind::Function(AstNodeRef::new(module, function)) + } + NodeWithScopeRef::FunctionTypeParameters(function) => { + NodeWithScopeKind::FunctionTypeParameters(AstNodeRef::new(module, function)) + } + NodeWithScopeRef::ClassTypeParameters(class) => { + NodeWithScopeKind::Class(AstNodeRef::new(module, class)) + } + } + } + + pub(super) fn scope_kind(self) -> ScopeKind { + match self { + NodeWithScopeRef::Module => ScopeKind::Module, + NodeWithScopeRef::Class(_) => ScopeKind::Class, + NodeWithScopeRef::Function(_) => ScopeKind::Function, + NodeWithScopeRef::FunctionTypeParameters(_) + | NodeWithScopeRef::ClassTypeParameters(_) => ScopeKind::Annotation, + } + } + + pub(crate) fn node_key(self) -> NodeWithScopeKey { + match self { + NodeWithScopeRef::Module => NodeWithScopeKey::Module, + NodeWithScopeRef::Class(class) => NodeWithScopeKey::Class(NodeKey::from_node(class)), + NodeWithScopeRef::Function(function) => { + NodeWithScopeKey::Function(NodeKey::from_node(function)) + } + NodeWithScopeRef::FunctionTypeParameters(function) => { + NodeWithScopeKey::FunctionTypeParameters(NodeKey::from_node(function)) + } + NodeWithScopeRef::ClassTypeParameters(class) => { + NodeWithScopeKey::ClassTypeParameters(NodeKey::from_node(class)) + } + } + } +} + +/// Node that introduces a new scope. +#[derive(Clone, Debug)] +pub enum NodeWithScopeKind { + Module, + Class(AstNodeRef), + ClassTypeParameters(AstNodeRef), + Function(AstNodeRef), + FunctionTypeParameters(AstNodeRef), +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub(crate) enum NodeWithScopeKey { + Module, + Class(NodeKey), + ClassTypeParameters(NodeKey), + Function(NodeKey), + FunctionTypeParameters(NodeKey), +} diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 834f81fa52..5078a44d64 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -4,9 +4,8 @@ use ruff_python_ast as ast; use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef}; use crate::semantic_index::ast_ids::HasScopedAstId; -use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::PublicSymbolId; -use crate::semantic_index::{public_symbol, semantic_index, NodeWithScopeKey}; +use crate::semantic_index::{public_symbol, semantic_index}; use crate::types::{infer_types, public_symbol_ty, Type, TypingContext}; use crate::Db; @@ -143,12 +142,10 @@ impl HasTy for ast::Expr { impl HasTy for ast::StmtFunctionDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); - let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); - - let scope = definition_scope.to_scope_id(model.db, model.file); + let definition = index.definition(self); + let scope = definition.scope(model.db).to_scope_id(model.db, model.file); let types = infer_types(model.db, scope); - let definition = Definition::FunctionDef(self.scoped_ast_id(model.db, scope)); types.definition_ty(definition) } @@ -157,11 +154,10 @@ impl HasTy for ast::StmtFunctionDef { impl HasTy for StmtClassDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); - let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); - let scope = definition_scope.to_scope_id(model.db, model.file); + let definition = index.definition(self); + let scope = definition.scope(model.db).to_scope_id(model.db, model.file); let types = infer_types(model.db, scope); - let definition = Definition::ClassDef(self.scoped_ast_id(model.db, scope)); types.definition_ty(definition) } diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 825f50e464..e0116a6a7b 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,17 +1,14 @@ -use crate::semantic_index::ast_ids::AstIdNode; -use crate::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId}; -use crate::semantic_index::{ - public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId, -}; -use crate::types::infer::{TypeInference, TypeInferenceBuilder}; -use crate::Db; -use crate::FxIndexSet; use ruff_db::parsed::parsed_module; use ruff_db::vfs::VfsFile; use ruff_index::newtype_index; -use ruff_python_ast as ast; use ruff_python_ast::name::Name; +use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, PublicSymbolId, ScopeId}; +use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table}; +use crate::types::infer::{TypeInference, TypeInferenceBuilder}; +use crate::Db; +use crate::FxIndexSet; + mod display; mod infer; @@ -70,31 +67,22 @@ pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfe // The isolation of the query is by the return inferred types. let index = semantic_index(db, file); - let scope_id = scope.file_scope_id(db); - let node = index.scope(scope_id).node(); + let node = scope.node(db); let mut context = TypeInferenceBuilder::new(db, scope, index); match node { - NodeWithScopeId::Module => { + NodeWithScopeKind::Module => { let parsed = parsed_module(db.upcast(), file); context.infer_module(parsed.syntax()); } - NodeWithScopeId::Class(class_id) => { - let class = ast::StmtClassDef::lookup(db, file, class_id); - context.infer_class_body(class); + NodeWithScopeKind::Function(function) => context.infer_function_body(function.node()), + NodeWithScopeKind::Class(class) => context.infer_class_body(class.node()), + NodeWithScopeKind::ClassTypeParameters(class) => { + context.infer_class_type_params(class.node()); } - NodeWithScopeId::ClassTypeParams(class_id) => { - let class = ast::StmtClassDef::lookup(db, file, class_id); - context.infer_class_type_params(class); - } - NodeWithScopeId::Function(function_id) => { - let function = ast::StmtFunctionDef::lookup(db, file, function_id); - context.infer_function_body(function); - } - NodeWithScopeId::FunctionTypeParams(function_id) => { - let function = ast::StmtFunctionDef::lookup(db, file, function_id); - context.infer_function_type_params(function); + NodeWithScopeKind::FunctionTypeParameters(function) => { + context.infer_function_type_params(function.node()); } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 4ae5c76feb..f66c1b7114 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1,6 +1,5 @@ -use std::sync::Arc; - use rustc_hash::FxHashMap; +use std::sync::Arc; use red_knot_module_resolver::resolve_module; use red_knot_module_resolver::ModuleName; @@ -9,9 +8,11 @@ use ruff_index::IndexVec; use ruff_python_ast as ast; use ruff_python_ast::{ExprContext, TypeParams}; -use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId}; -use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; -use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId, SymbolTable}; +use crate::semantic_index::ast_ids::ScopedExpressionId; +use crate::semantic_index::definition::{Definition, DefinitionNodeRef}; +use crate::semantic_index::symbol::{ + FileScopeId, NodeWithScopeRef, ScopeId, ScopedSymbolId, SymbolTable, +}; use crate::semantic_index::{symbol_table, SemanticIndex}; use crate::types::{ infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, @@ -42,7 +43,7 @@ pub(crate) struct TypeInference<'db> { symbol_tys: IndexVec>, /// The type of a definition. - definition_tys: FxHashMap>, + definition_tys: FxHashMap, Type<'db>>, } impl<'db> TypeInference<'db> { @@ -92,23 +93,27 @@ impl<'db> TypeInference<'db> { } /// Builder to infer all types in a [`ScopeId`]. -pub(super) struct TypeInferenceBuilder<'a> { - db: &'a dyn Db, +pub(super) struct TypeInferenceBuilder<'db> { + db: &'db dyn Db, // Cached lookups - index: &'a SemanticIndex, - scope: ScopeId<'a>, + index: &'db SemanticIndex<'db>, + scope: ScopeId<'db>, file_scope_id: FileScopeId, file_id: VfsFile, - symbol_table: Arc, + symbol_table: Arc>, /// The type inference results - types: TypeInference<'a>, + types: TypeInference<'db>, } impl<'db> TypeInferenceBuilder<'db> { /// Creates a new builder for inferring the types of `scope`. - pub(super) fn new(db: &'db dyn Db, scope: ScopeId<'db>, index: &'db SemanticIndex) -> Self { + pub(super) fn new( + db: &'db dyn Db, + scope: ScopeId<'db>, + index: &'db SemanticIndex<'db>, + ) -> Self { let file_scope_id = scope.file_scope_id(db); let file = scope.file(db); let symbol_table = index.symbol_table(file_scope_id); @@ -188,7 +193,6 @@ impl<'db> TypeInferenceBuilder<'db> { decorator_list, } = function; - let function_id = function.scoped_ast_id(self.db, self.scope); let decorator_tys = decorator_list .iter() .map(|decorator| self.infer_decorator(decorator)) @@ -205,9 +209,8 @@ impl<'db> TypeInferenceBuilder<'db> { decorators: decorator_tys, }); - self.types - .definition_tys - .insert(Definition::FunctionDef(function_id), function_ty); + let definition = self.index.definition(function); + self.types.definition_tys.insert(definition, function_ty); } fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) { @@ -220,8 +223,6 @@ impl<'db> TypeInferenceBuilder<'db> { body: _, } = class; - let class_id = class.scoped_ast_id(self.db, self.scope); - for decorator in decorator_list { self.infer_decorator(decorator); } @@ -231,17 +232,16 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|arguments| self.infer_arguments(arguments)) .unwrap_or(Vec::new()); - let class_body_scope_id = self.index.node_scope(class); + let body_scope = self.index.node_scope(NodeWithScopeRef::Class(class)); let class_ty = self.class_ty(ClassType { name: name.id.clone(), bases, - body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id), + body_scope: body_scope.to_scope_id(self.db, self.file_id), }); - self.types - .definition_tys - .insert(Definition::ClassDef(class_id), class_ty); + let definition = self.index.definition(class); + self.types.definition_tys.insert(definition, class_ty); } fn infer_if_statement(&mut self, if_statement: &ast::StmtIf) { @@ -281,14 +281,12 @@ impl<'db> TypeInferenceBuilder<'db> { for target in targets { self.infer_expression(target); + + self.types.definition_tys.insert( + self.index.definition(DefinitionNodeRef::Target(target)), + value_ty, + ); } - - let assign_id = assignment.scoped_ast_id(self.db, self.scope); - - // TODO: Handle multiple targets. - self.types - .definition_tys - .insert(Definition::Assignment(assign_id), value_ty); } fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { @@ -308,7 +306,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(target); self.types.definition_tys.insert( - Definition::AnnotatedAssignment(assignment.scoped_ast_id(self.db, self.scope)), + self.index.definition(DefinitionNodeRef::Target(target)), annotation_ty, ); } @@ -332,9 +330,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_import_statement(&mut self, import: &ast::StmtImport) { let ast::StmtImport { range: _, names } = import; - let import_id = import.scoped_ast_id(self.db, self.scope); - - for (i, alias) in names.iter().enumerate() { + for alias in names { let ast::Alias { range: _, name, @@ -347,13 +343,9 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|module| self.typing_context().module_ty(module.file())) .unwrap_or(Type::Unknown); - self.types.definition_tys.insert( - Definition::Import(ImportDefinition { - import_id, - alias: u32::try_from(i).unwrap(), - }), - module_ty, - ); + let definition = self.index.definition(alias); + + self.types.definition_tys.insert(definition, module_ty); } } @@ -365,7 +357,6 @@ impl<'db> TypeInferenceBuilder<'db> { level: _, } = import; - let import_id = import.scoped_ast_id(self.db, self.scope); let module_name = ModuleName::new(module.as_deref().expect("Support relative imports")); let module = @@ -374,7 +365,7 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|module| self.typing_context().module_ty(module.file())) .unwrap_or(Type::Unknown); - for (i, alias) in names.iter().enumerate() { + for alias in names { let ast::Alias { range: _, name, @@ -385,13 +376,8 @@ impl<'db> TypeInferenceBuilder<'db> { .member(&self.typing_context(), &name.id) .unwrap_or(Type::Unknown); - self.types.definition_tys.insert( - Definition::ImportFrom(ImportFromDefinition { - import_id, - name: u32::try_from(i).unwrap(), - }), - ty, - ); + let definition = self.index.definition(alias); + self.types.definition_tys.insert(definition, ty); } } @@ -467,10 +453,9 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); self.infer_expression(target); - self.types.definition_tys.insert( - Definition::NamedExpr(named.scoped_ast_id(self.db, self.scope)), - value_ty, - ); + self.types + .definition_tys + .insert(self.index.definition(named), value_ty); value_ty }