diff --git a/crates/red_knot_module_resolver/src/lib.rs b/crates/red_knot_module_resolver/src/lib.rs index 9479a5c002..72be73c55d 100644 --- a/crates/red_knot_module_resolver/src/lib.rs +++ b/crates/red_knot_module_resolver/src/lib.rs @@ -4,6 +4,6 @@ mod resolver; mod typeshed; pub use db::{Db, Jar}; -pub use module::{ModuleKind, ModuleName}; +pub use module::{Module, ModuleKind, ModuleName}; pub use resolver::{resolve_module, set_module_resolution_settings, ModuleResolutionSettings}; pub use typeshed::versions::TypeshedVersions; diff --git a/crates/red_knot_module_resolver/src/resolver.rs b/crates/red_knot_module_resolver/src/resolver.rs index dbd8734049..33f7281cf1 100644 --- a/crates/red_knot_module_resolver/src/resolver.rs +++ b/crates/red_knot_module_resolver/src/resolver.rs @@ -1,4 +1,3 @@ -use salsa::DebugWithDb; use std::ops::Deref; use ruff_db::file_system::{FileSystem, FileSystemPath, FileSystemPathBuf}; @@ -42,7 +41,7 @@ pub(crate) fn resolve_module_query<'db>( db: &'db dyn Db, module_name: internal::ModuleNameIngredient<'db>, ) -> Option { - let _ = tracing::trace_span!("resolve_module", module_name = ?module_name.debug(db)).enter(); + let _span = tracing::trace_span!("resolve_module", ?module_name).entered(); let name = module_name.name(db); @@ -76,7 +75,7 @@ pub fn path_to_module(db: &dyn Db, path: &VfsPath) -> Option { #[salsa::tracked] #[allow(unused)] pub(crate) fn file_to_module(db: &dyn Db, file: VfsFile) -> Option { - let _ = tracing::trace_span!("file_to_module", file = ?file.debug(db.upcast())).enter(); + let _span = tracing::trace_span!("file_to_module", ?file).entered(); let path = file.path(db.upcast()); diff --git a/crates/red_knot_python_semantic/src/lib.rs b/crates/red_knot_python_semantic/src/lib.rs index 436fd07f4c..86c195b567 100644 --- a/crates/red_knot_python_semantic/src/lib.rs +++ b/crates/red_knot_python_semantic/src/lib.rs @@ -1,11 +1,15 @@ +use std::hash::BuildHasherDefault; + +use rustc_hash::FxHasher; + +pub use db::{Db, Jar}; +pub use semantic_model::{HasTy, SemanticModel}; + pub mod ast_node_ref; mod db; mod node_key; pub mod semantic_index; +mod semantic_model; pub mod types; type FxIndexSet = indexmap::set::IndexSet>; - -pub use db::{Db, Jar}; -use rustc_hash::FxHasher; -use std::hash::BuildHasherDefault; diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 402abffc6c..abc50aacbc 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -2,7 +2,6 @@ use std::iter::FusedIterator; use std::sync::Arc; use rustc_hash::FxHashMap; -use salsa::DebugWithDb; use ruff_db::parsed::parsed_module; use ruff_db::vfs::VfsFile; @@ -10,10 +9,10 @@ use ruff_index::{IndexSlice, IndexVec}; use ruff_python_ast as ast; use crate::node_key::NodeKey; -use crate::semantic_index::ast_ids::{AstId, AstIds, ScopeClassId, ScopeFunctionId}; +use crate::semantic_index::ast_ids::{AstId, AstIds, ScopedClassId, ScopedFunctionId}; use crate::semantic_index::builder::SemanticIndexBuilder; use crate::semantic_index::symbol::{ - FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable, + FileScopeId, PublicSymbolId, Scope, ScopeId, ScopedSymbolId, SymbolTable, }; use crate::Db; @@ -29,7 +28,7 @@ type SymbolMap = hashbrown::HashMap; /// Prefer using [`symbol_table`] when working with symbols from a single scope. #[salsa::tracked(return_ref, no_eq)] pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex { - let _ = tracing::trace_span!("semantic_index", file = ?file.debug(db.upcast())).enter(); + let _span = tracing::trace_span!("semantic_index", ?file).entered(); let parsed = parsed_module(db.upcast(), file); @@ -43,7 +42,7 @@ pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex { /// is unchanged. #[salsa::tracked] pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc { - let _ = tracing::trace_span!("symbol_table", scope = ?scope.debug(db)).enter(); + let _span = tracing::trace_span!("symbol_table", ?scope).entered(); let index = semantic_index(db, scope.file(db)); index.symbol_table(scope.file_scope_id(db)) @@ -52,7 +51,7 @@ pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc ScopeId<'_> { - let _ = tracing::trace_span!("root_scope", file = ?file.debug(db.upcast())).enter(); + let _span = tracing::trace_span!("root_scope", ?file).entered(); FileScopeId::root().to_scope_id(db, file) } @@ -82,7 +81,7 @@ pub struct SemanticIndex { /// Maps expressions to their corresponding scope. /// We can't use [`ExpressionId`] here, because the challenge is how to get from /// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope). - expression_scopes: FxHashMap, + scopes_by_expression: FxHashMap, /// Lookup table to map between node ids and ast nodes. /// @@ -91,7 +90,10 @@ pub struct SemanticIndex { ast_ids: IndexVec, /// Map from scope to the node that introduces the scope. - scope_nodes: IndexVec, + nodes_by_scope: IndexVec, + + /// Map from nodes that introduce a scope to the scope they define. + scopes_by_node: FxHashMap, } impl SemanticIndex { @@ -108,13 +110,19 @@ impl SemanticIndex { } /// Returns the ID of the `expression`'s enclosing scope. - pub(crate) fn expression_scope_id(&self, expression: &ast::Expr) -> FileScopeId { - self.expression_scopes[&NodeKey::from_node(expression)] + pub(crate) fn expression_scope_id<'expr>( + &self, + expression: impl Into>, + ) -> FileScopeId { + self.scopes_by_expression[&NodeKey::from_node(expression.into())] } /// Returns the [`Scope`] of the `expression`'s enclosing scope. #[allow(unused)] - pub(crate) fn expression_scope(&self, expression: &ast::Expr) -> &Scope { + pub(crate) fn expression_scope<'expr>( + &self, + expression: impl Into>, + ) -> &Scope { &self.scopes[self.expression_scope_id(expression)] } @@ -152,7 +160,14 @@ impl SemanticIndex { } pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId { - self.scope_nodes[scope_id] + self.nodes_by_scope[scope_id] + } + + pub(crate) fn definition_scope( + &self, + node_with_scope: impl Into, + ) -> FileScopeId { + self.scopes_by_node[&node_with_scope.into()] } } @@ -248,29 +263,43 @@ impl<'a> Iterator for ChildrenIter<'a> { } } +impl FusedIterator for ChildrenIter<'_> {} + #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub(crate) enum NodeWithScopeId { Module, - Class(AstId), - ClassTypeParams(AstId), - Function(AstId), - FunctionTypeParams(AstId), + Class(AstId), + ClassTypeParams(AstId), + Function(AstId), + FunctionTypeParams(AstId), } -impl NodeWithScopeId { - fn scope_kind(self) -> ScopeKind { - match self { - NodeWithScopeId::Module => ScopeKind::Module, - NodeWithScopeId::Class(_) => ScopeKind::Class, - NodeWithScopeId::Function(_) => ScopeKind::Function, - NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => { - ScopeKind::Annotation - } - } +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub(crate) struct NodeWithScopeKey(NodeKey); + +impl From<&ast::StmtClassDef> for NodeWithScopeKey { + fn from(node: &ast::StmtClassDef) -> Self { + Self(NodeKey::from_node(node)) } } -impl FusedIterator for ChildrenIter<'_> {} +impl From<&ast::StmtFunctionDef> for NodeWithScopeKey { + fn from(value: &ast::StmtFunctionDef) -> Self { + Self(NodeKey::from_node(value)) + } +} + +impl From<&ast::TypeParams> for NodeWithScopeKey { + fn from(value: &ast::TypeParams) -> Self { + Self(NodeKey::from_node(value)) + } +} + +impl From<&ast::ModModule> for NodeWithScopeKey { + fn from(value: &ast::ModModule) -> Self { + Self(NodeKey::from_node(value)) + } +} #[cfg(test)] mod tests { diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 184916fc2e..dd5081a1bf 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -4,7 +4,7 @@ use ruff_db::parsed::ParsedModule; use ruff_db::vfs::VfsFile; use ruff_index::{newtype_index, IndexVec}; use ruff_python_ast as ast; -use ruff_python_ast::AnyNodeRef; +use ruff_python_ast::{AnyNodeRef, ExpressionRef}; use crate::ast_node_ref::AstNodeRef; use crate::node_key::NodeKey; @@ -29,27 +29,27 @@ use crate::Db; /// ``` pub(crate) struct AstIds { /// Maps expression ids to their expressions. - expressions: IndexVec>, + expressions: IndexVec>, /// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`]. - expressions_map: FxHashMap, + expressions_map: FxHashMap, - statements: IndexVec>, + statements: IndexVec>, - statements_map: FxHashMap, + statements_map: FxHashMap, } impl AstIds { - fn statement_id<'a, N>(&self, node: N) -> ScopeStatementId + fn statement_id<'a, N>(&self, node: N) -> ScopedStatementId where N: Into>, { self.statements_map[&NodeKey::from_node(node.into())] } - fn expression_id<'a, N>(&self, node: N) -> ScopeExpressionId + fn expression_id<'a, N>(&self, node: N) -> ScopedExpressionId where - N: Into>, + N: Into>, { self.expressions_map[&NodeKey::from_node(node.into())] } @@ -69,8 +69,7 @@ fn ast_ids<'db>(db: &'db dyn Db, scope: ScopeId) -> &'db AstIds { semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db)) } -/// Node that can be uniquely identified by an id in a [`FileScopeId`]. -pub trait ScopeAstIdNode { +pub trait HasScopedAstId { /// The type of the ID uniquely identifying the node. type Id: Copy; @@ -78,8 +77,11 @@ pub trait ScopeAstIdNode { /// /// ## Panics /// Panics if the node doesn't belong to `file` or is outside `scope`. - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> Self::Id; + fn scoped_ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> Self::Id; +} +/// Node that can be uniquely identified by an id in a [`FileScopeId`]. +pub trait ScopedAstIdNode: HasScopedAstId { /// Looks up the AST node by its ID. /// /// ## Panics @@ -112,12 +114,12 @@ pub trait AstIdNode { impl AstIdNode for T where - T: ScopeAstIdNode, + T: ScopedAstIdNode, { type ScopeId = T::Id; fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId { - let in_scope_id = self.scope_ast_id(db, file, scope); + let in_scope_id = self.scoped_ast_id(db, file, scope); AstId { scope, in_scope_id } } @@ -152,17 +154,71 @@ impl AstId { /// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`]. #[newtype_index] -pub struct ScopeExpressionId; +pub struct ScopedExpressionId; -impl ScopeAstIdNode for ast::Expr { - type Id = ScopeExpressionId; +macro_rules! impl_has_scoped_expression_id { + ($ty: ty) => { + impl HasScopedAstId for $ty { + type Id = ScopedExpressionId; - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + fn scoped_ast_id( + &self, + db: &dyn Db, + file: VfsFile, + file_scope: FileScopeId, + ) -> Self::Id { + let expression_ref = ExpressionRef::from(self); + expression_ref.scoped_ast_id(db, file, file_scope) + } + } + }; +} + +impl_has_scoped_expression_id!(ast::ExprBoolOp); +impl_has_scoped_expression_id!(ast::ExprName); +impl_has_scoped_expression_id!(ast::ExprBinOp); +impl_has_scoped_expression_id!(ast::ExprUnaryOp); +impl_has_scoped_expression_id!(ast::ExprLambda); +impl_has_scoped_expression_id!(ast::ExprIf); +impl_has_scoped_expression_id!(ast::ExprDict); +impl_has_scoped_expression_id!(ast::ExprSet); +impl_has_scoped_expression_id!(ast::ExprListComp); +impl_has_scoped_expression_id!(ast::ExprSetComp); +impl_has_scoped_expression_id!(ast::ExprDictComp); +impl_has_scoped_expression_id!(ast::ExprGenerator); +impl_has_scoped_expression_id!(ast::ExprAwait); +impl_has_scoped_expression_id!(ast::ExprYield); +impl_has_scoped_expression_id!(ast::ExprYieldFrom); +impl_has_scoped_expression_id!(ast::ExprCompare); +impl_has_scoped_expression_id!(ast::ExprCall); +impl_has_scoped_expression_id!(ast::ExprFString); +impl_has_scoped_expression_id!(ast::ExprStringLiteral); +impl_has_scoped_expression_id!(ast::ExprBytesLiteral); +impl_has_scoped_expression_id!(ast::ExprNumberLiteral); +impl_has_scoped_expression_id!(ast::ExprBooleanLiteral); +impl_has_scoped_expression_id!(ast::ExprNoneLiteral); +impl_has_scoped_expression_id!(ast::ExprEllipsisLiteral); +impl_has_scoped_expression_id!(ast::ExprAttribute); +impl_has_scoped_expression_id!(ast::ExprSubscript); +impl_has_scoped_expression_id!(ast::ExprStarred); +impl_has_scoped_expression_id!(ast::ExprNamed); +impl_has_scoped_expression_id!(ast::ExprList); +impl_has_scoped_expression_id!(ast::ExprTuple); +impl_has_scoped_expression_id!(ast::ExprSlice); +impl_has_scoped_expression_id!(ast::ExprIpyEscapeCommand); +impl_has_scoped_expression_id!(ast::Expr); + +impl HasScopedAstId for ast::ExpressionRef<'_> { + type Id = ScopedExpressionId; + + fn scoped_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { let scope = file_scope.to_scope_id(db, file); let ast_ids = ast_ids(db, scope); - ast_ids.expressions_map[&NodeKey::from_node(self)] + ast_ids.expression_id(*self) } +} +impl ScopedAstIdNode for ast::Expr { fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self { let scope = file_scope.to_scope_id(db, file); let ast_ids = ast_ids(db, scope); @@ -172,17 +228,30 @@ impl ScopeAstIdNode for ast::Expr { /// Uniquely identifies an [`ast::Stmt`] in a [`FileScopeId`]. #[newtype_index] -pub struct ScopeStatementId; +pub struct ScopedStatementId; -impl ScopeAstIdNode for ast::Stmt { - type Id = ScopeStatementId; +macro_rules! impl_has_scoped_statement_id { + ($ty: ty) => { + impl HasScopedAstId for $ty { + type Id = ScopedStatementId; - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { - let scope = file_scope.to_scope_id(db, file); - let ast_ids = ast_ids(db, scope); - ast_ids.statement_id(self) - } + fn scoped_ast_id( + &self, + db: &dyn Db, + file: VfsFile, + file_scope: FileScopeId, + ) -> Self::Id { + let scope = file_scope.to_scope_id(db, file); + let ast_ids = ast_ids(db, scope); + ast_ids.statement_id(self) + } + } + }; +} +impl_has_scoped_statement_id!(ast::Stmt); + +impl ScopedAstIdNode for ast::Stmt { fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self { let scope = file_scope.to_scope_id(db, file); let ast_ids = ast_ids(db, scope); @@ -192,17 +261,19 @@ impl ScopeAstIdNode for ast::Stmt { } #[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeFunctionId(pub(super) ScopeStatementId); +pub struct ScopedFunctionId(pub(super) ScopedStatementId); -impl ScopeAstIdNode for ast::StmtFunctionDef { - type Id = ScopeFunctionId; +impl HasScopedAstId for ast::StmtFunctionDef { + type Id = ScopedFunctionId; - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + fn scoped_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { let scope = file_scope.to_scope_id(db, file); let ast_ids = ast_ids(db, scope); - ScopeFunctionId(ast_ids.statement_id(self)) + ScopedFunctionId(ast_ids.statement_id(self)) } +} +impl ScopedAstIdNode for ast::StmtFunctionDef { fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { ast::Stmt::lookup_in_scope(db, file, scope, id.0) .as_function_def_stmt() @@ -211,122 +282,36 @@ impl ScopeAstIdNode for ast::StmtFunctionDef { } #[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeClassId(pub(super) ScopeStatementId); +pub struct ScopedClassId(pub(super) ScopedStatementId); -impl ScopeAstIdNode for ast::StmtClassDef { - type Id = ScopeClassId; +impl HasScopedAstId for ast::StmtClassDef { + type Id = ScopedClassId; - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { + fn scoped_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { let scope = file_scope.to_scope_id(db, file); let ast_ids = ast_ids(db, scope); - ScopeClassId(ast_ids.statement_id(self)) + ScopedClassId(ast_ids.statement_id(self)) } +} +impl ScopedAstIdNode for ast::StmtClassDef { fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); statement.as_class_def_stmt().unwrap() } } -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeAssignmentId(pub(super) ScopeStatementId); - -impl ScopeAstIdNode for ast::StmtAssign { - type Id = ScopeAssignmentId; - - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { - let scope = file_scope.to_scope_id(db, file); - let ast_ids = ast_ids(db, scope); - ScopeAssignmentId(ast_ids.statement_id(self)) - } - - fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { - let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); - statement.as_assign_stmt().unwrap() - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeAnnotatedAssignmentId(ScopeStatementId); - -impl ScopeAstIdNode for ast::StmtAnnAssign { - type Id = ScopeAnnotatedAssignmentId; - - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { - let scope = file_scope.to_scope_id(db, file); - let ast_ids = ast_ids(db, scope); - ScopeAnnotatedAssignmentId(ast_ids.statement_id(self)) - } - - fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { - let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); - statement.as_ann_assign_stmt().unwrap() - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeImportId(pub(super) ScopeStatementId); - -impl ScopeAstIdNode for ast::StmtImport { - type Id = ScopeImportId; - - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { - let scope = file_scope.to_scope_id(db, file); - let ast_ids = ast_ids(db, scope); - ScopeImportId(ast_ids.statement_id(self)) - } - - fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { - let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); - statement.as_import_stmt().unwrap() - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeImportFromId(pub(super) ScopeStatementId); - -impl ScopeAstIdNode for ast::StmtImportFrom { - type Id = ScopeImportFromId; - - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { - let scope = file_scope.to_scope_id(db, file); - let ast_ids = ast_ids(db, scope); - ScopeImportFromId(ast_ids.statement_id(self)) - } - - fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self { - let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0); - statement.as_import_from_stmt().unwrap() - } -} - -#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)] -pub struct ScopeNamedExprId(pub(super) ScopeExpressionId); - -impl ScopeAstIdNode for ast::ExprNamed { - type Id = ScopeNamedExprId; - - fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id { - let scope = file_scope.to_scope_id(db, file); - let ast_ids = ast_ids(db, scope); - ScopeNamedExprId(ast_ids.expression_id(self)) - } - - fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self - where - Self: Sized, - { - let expression = ast::Expr::lookup_in_scope(db, file, scope, id.0); - expression.as_named_expr().unwrap() - } -} +impl_has_scoped_statement_id!(ast::StmtAssign); +impl_has_scoped_statement_id!(ast::StmtAnnAssign); +impl_has_scoped_statement_id!(ast::StmtImport); +impl_has_scoped_statement_id!(ast::StmtImportFrom); #[derive(Debug)] pub(super) struct AstIdsBuilder { - expressions: IndexVec>, - expressions_map: FxHashMap, - statements: IndexVec>, - statements_map: FxHashMap, + expressions: IndexVec>, + expressions_map: FxHashMap, + statements: IndexVec>, + statements_map: FxHashMap, } impl AstIdsBuilder { @@ -349,7 +334,7 @@ impl AstIdsBuilder { &mut self, stmt: &ast::Stmt, parsed: &ParsedModule, - ) -> ScopeStatementId { + ) -> ScopedStatementId { let statement_id = self.statements.push(AstNodeRef::new(parsed.clone(), stmt)); self.statements_map @@ -368,7 +353,7 @@ impl AstIdsBuilder { &mut self, expr: &ast::Expr, parsed: &ParsedModule, - ) -> ScopeExpressionId { + ) -> ScopedExpressionId { let expression_id = self.expressions.push(AstNodeRef::new(parsed.clone(), expr)); self.expressions_map diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index e491e3408d..03867d9c93 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -9,15 +9,12 @@ use ruff_python_ast::name::Name; use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor}; use crate::node_key::NodeKey; -use crate::semantic_index::ast_ids::{ - AstId, AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId, - ScopeImportId, ScopeNamedExprId, -}; +use crate::semantic_index::ast_ids::{AstId, AstIdsBuilder, ScopedClassId, ScopedFunctionId}; use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; use crate::semantic_index::symbol::{ - FileScopeId, FileSymbolId, Scope, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, + FileScopeId, FileSymbolId, Scope, ScopeKind, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, }; -use crate::semantic_index::{NodeWithScopeId, SemanticIndex}; +use crate::semantic_index::{NodeWithScopeId, NodeWithScopeKey, SemanticIndex}; pub(super) struct SemanticIndexBuilder<'a> { // Builder state @@ -32,6 +29,7 @@ pub(super) struct SemanticIndexBuilder<'a> { ast_ids: IndexVec, expression_scopes: FxHashMap, scope_nodes: IndexVec, + node_scopes: FxHashMap, } impl<'a> SemanticIndexBuilder<'a> { @@ -45,12 +43,16 @@ impl<'a> SemanticIndexBuilder<'a> { symbol_tables: IndexVec::new(), ast_ids: IndexVec::new(), expression_scopes: FxHashMap::default(), + node_scopes: FxHashMap::default(), scope_nodes: IndexVec::new(), }; builder.push_scope_with_parent( - NodeWithScopeId::Module, - &Name::new_static(""), + NodeWithScope::new( + parsed.syntax(), + NodeWithScopeId::Module, + Name::new_static(""), + ), None, None, None, @@ -68,42 +70,44 @@ impl<'a> SemanticIndexBuilder<'a> { fn push_scope( &mut self, - node: NodeWithScopeId, - name: &Name, + node: NodeWithScope, defining_symbol: Option, definition: Option, ) { let parent = self.current_scope(); - self.push_scope_with_parent(node, name, defining_symbol, definition, Some(parent)); + self.push_scope_with_parent(node, defining_symbol, definition, Some(parent)); } fn push_scope_with_parent( &mut self, - node: NodeWithScopeId, - name: &Name, + node: NodeWithScope, defining_symbol: Option, definition: Option, parent: Option, ) { let children_start = self.scopes.next_index() + 1; + let node_key = node.key(); + let node_id = node.id(); + let scope_kind = node.scope_kind(); let scope = Scope { - name: name.clone(), + name: node.name, parent, defining_symbol, definition, - kind: node.scope_kind(), + kind: scope_kind, descendents: children_start..children_start, }; let scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::new()); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); - let scope_node_id = self.scope_nodes.push(node); + let scope_node_id = self.scope_nodes.push(node_id); debug_assert_eq!(ast_id_scope, scope_id); debug_assert_eq!(scope_id, scope_node_id); self.scope_stack.push(scope_id); + self.node_scopes.insert(node_key, scope_id); } fn pop_scope(&mut self) -> FileScopeId { @@ -124,10 +128,18 @@ impl<'a> SemanticIndexBuilder<'a> { &mut self.ast_ids[scope_id] } - fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { - let symbol_table = self.current_symbol_table(); + fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> FileSymbolId { + for scope in self.scope_stack.iter().rev().skip(1) { + let builder = &self.symbol_tables[*scope]; - symbol_table.add_or_update_symbol(name, flags, None) + if let Some(symbol) = builder.symbol_by_name(&name) { + return FileSymbolId::new(*scope, symbol); + } + } + + let scope = self.current_scope(); + let symbol_table = self.current_symbol_table(); + FileSymbolId::new(scope, symbol_table.add_or_update_symbol(name, flags, None)) } fn add_or_update_symbol_with_definition( @@ -142,7 +154,7 @@ impl<'a> SemanticIndexBuilder<'a> { fn with_type_params( &mut self, - name: &Name, + name: Name, with_params: &WithTypeParams, defining_symbol: FileSymbolId, nested: impl FnOnce(&mut Self) -> FileScopeId, @@ -150,14 +162,13 @@ impl<'a> SemanticIndexBuilder<'a> { let type_params = with_params.type_parameters(); if let Some(type_params) = type_params { - let type_node = match with_params { + let type_params_id = match with_params { WithTypeParams::ClassDef { id, .. } => NodeWithScopeId::ClassTypeParams(*id), WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id), }; self.push_scope( - type_node, - name, + NodeWithScope::new(type_params, type_params_id, name), Some(defining_symbol), Some(with_params.definition()), ); @@ -211,9 +222,10 @@ impl<'a> SemanticIndexBuilder<'a> { SemanticIndex { symbol_tables, scopes: self.scopes, - scope_nodes: self.scope_nodes, + nodes_by_scope: self.scope_nodes, + scopes_by_node: self.node_scopes, ast_ids, - expression_scopes: self.expression_scopes, + scopes_by_expression: self.expression_scopes, } } } @@ -233,7 +245,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { self.visit_decorator(decorator); } let name = &function_def.name.id; - let function_id = ScopeFunctionId(statement_id); + let function_id = ScopedFunctionId(statement_id); let definition = Definition::FunctionDef(function_id); let scope = self.current_scope(); let symbol = FileSymbolId::new( @@ -242,7 +254,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { ); self.with_type_params( - name, + name.clone(), &WithTypeParams::FunctionDef { node: function_def, id: AstId::new(scope, function_id), @@ -255,8 +267,11 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } builder.push_scope( - NodeWithScopeId::Function(AstId::new(scope, function_id)), - name, + NodeWithScope::new( + function_def, + NodeWithScopeId::Function(AstId::new(scope, function_id)), + name.clone(), + ), Some(symbol), Some(definition), ); @@ -271,15 +286,15 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } let name = &class.name.id; - let class_id = ScopeClassId(statement_id); - let definition = Definition::from(class_id); + let class_id = ScopedClassId(statement_id); + let definition = Definition::ClassDef(class_id); let scope = self.current_scope(); let id = FileSymbolId::new( self.current_scope(), self.add_or_update_symbol_with_definition(name.clone(), definition), ); self.with_type_params( - name, + name.clone(), &WithTypeParams::ClassDef { node: class, id: AstId::new(scope, class_id), @@ -291,8 +306,11 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } builder.push_scope( - NodeWithScopeId::Class(AstId::new(scope, class_id)), - name, + NodeWithScope::new( + class, + NodeWithScopeId::Class(AstId::new(scope, class_id)), + name.clone(), + ), Some(id), Some(definition), ); @@ -311,7 +329,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { }; let def = Definition::Import(ImportDefinition { - import_id: ScopeImportId(statement_id), + import_id: statement_id, alias: u32::try_from(i).unwrap(), }); self.add_or_update_symbol_with_definition(symbol_name, def); @@ -330,7 +348,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { &alias.name.id }; let def = Definition::ImportFrom(ImportFromDefinition { - import_id: ScopeImportFromId(statement_id), + import_id: statement_id, name: u32::try_from(i).unwrap(), }); self.add_or_update_symbol_with_definition(symbol_name.clone(), def); @@ -339,8 +357,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { ast::Stmt::Assign(node) => { debug_assert!(self.current_definition.is_none()); self.visit_expr(&node.value); - self.current_definition = - Some(Definition::Assignment(ScopeAssignmentId(statement_id))); + self.current_definition = Some(Definition::Assignment(statement_id)); for target in &node.targets { self.visit_expr(target); } @@ -385,8 +402,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { } ast::Expr::Named(node) => { debug_assert!(self.current_definition.is_none()); - self.current_definition = - Some(Definition::NamedExpr(ScopeNamedExprId(expression_id))); + self.current_definition = Some(Definition::NamedExpr(expression_id)); // TODO walrus in comprehensions is implicitly nonlocal self.visit_expr(&node.target); self.current_definition = None; @@ -428,11 +444,11 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { enum WithTypeParams<'a> { ClassDef { node: &'a ast::StmtClassDef, - id: AstId, + id: AstId, }, FunctionDef { node: &'a ast::StmtFunctionDef, - id: AstId, + id: AstId, }, } @@ -451,3 +467,38 @@ impl<'a> WithTypeParams<'a> { } } } + +struct NodeWithScope { + id: NodeWithScopeId, + key: NodeWithScopeKey, + name: Name, +} + +impl NodeWithScope { + fn new(node: impl Into, id: NodeWithScopeId, name: Name) -> Self { + Self { + id, + key: node.into(), + name, + } + } + + fn id(&self) -> NodeWithScopeId { + self.id + } + + fn key(&self) -> NodeWithScopeKey { + self.key + } + + fn scope_kind(&self) -> ScopeKind { + match self.id { + NodeWithScopeId::Module => ScopeKind::Module, + NodeWithScopeId::Class(_) => ScopeKind::Class, + NodeWithScopeId::Function(_) => ScopeKind::Function, + NodeWithScopeId::ClassTypeParams(_) | NodeWithScopeId::FunctionTypeParams(_) => { + ScopeKind::Annotation + } + } + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 3eb8f40c18..f1427ace93 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -1,17 +1,16 @@ use crate::semantic_index::ast_ids::{ - ScopeAnnotatedAssignmentId, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, - ScopeImportFromId, ScopeImportId, ScopeNamedExprId, + ScopedClassId, ScopedExpressionId, ScopedFunctionId, ScopedStatementId, }; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum Definition { Import(ImportDefinition), ImportFrom(ImportFromDefinition), - ClassDef(ScopeClassId), - FunctionDef(ScopeFunctionId), - Assignment(ScopeAssignmentId), - AnnotatedAssignment(ScopeAnnotatedAssignmentId), - NamedExpr(ScopeNamedExprId), + ClassDef(ScopedClassId), + FunctionDef(ScopedFunctionId), + Assignment(ScopedStatementId), + AnnotatedAssignment(ScopedStatementId), + NamedExpr(ScopedExpressionId), /// represents the implicit initial definition of every name as "unbound" Unbound, // TODO with statements, except handlers, function args... @@ -29,39 +28,21 @@ impl From for Definition { } } -impl From for Definition { - fn from(value: ScopeClassId) -> Self { +impl From for Definition { + fn from(value: ScopedClassId) -> Self { Self::ClassDef(value) } } -impl From for Definition { - fn from(value: ScopeFunctionId) -> Self { +impl From for Definition { + fn from(value: ScopedFunctionId) -> Self { Self::FunctionDef(value) } } -impl From for Definition { - fn from(value: ScopeAssignmentId) -> Self { - Self::Assignment(value) - } -} - -impl From for Definition { - fn from(value: ScopeAnnotatedAssignmentId) -> Self { - Self::AnnotatedAssignment(value) - } -} - -impl From for Definition { - fn from(value: ScopeNamedExprId) -> Self { - Self::NamedExpr(value) - } -} - #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct ImportDefinition { - pub(crate) import_id: ScopeImportId, + pub(crate) import_id: ScopedStatementId, /// Index into [`ruff_python_ast::StmtImport::names`]. pub(crate) alias: u32, @@ -69,7 +50,7 @@ pub struct ImportDefinition { #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct ImportFromDefinition { - pub(crate) import_id: ScopeImportFromId, + pub(crate) import_id: ScopedStatementId, /// Index into [`ruff_python_ast::StmtImportFrom::names`]. pub(crate) name: u32, diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index ac447d3eee..4cca2ea263 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -4,7 +4,6 @@ use std::ops::Range; use bitflags::bitflags; use hashbrown::hash_map::RawEntryMut; use rustc_hash::FxHasher; -use salsa::DebugWithDb; use smallvec::SmallVec; use crate::semantic_index::definition::Definition; @@ -128,7 +127,7 @@ impl ScopedSymbolId { /// Returns a mapping from [`FileScopeId`] to globally unique [`ScopeId`]. #[salsa::tracked(return_ref)] pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap<'_> { - let _ = tracing::trace_span!("scopes_map", file = ?file.debug(db.upcast())).enter(); + let _span = tracing::trace_span!("scopes_map", ?file).entered(); let index = semantic_index(db, file); @@ -160,7 +159,7 @@ impl<'db> ScopesMap<'db> { #[salsa::tracked(return_ref)] pub(crate) fn public_symbols_map(db: &dyn Db, file: VfsFile) -> PublicSymbolsMap<'_> { - let _ = tracing::trace_span!("public_symbols_map", file = ?file.debug(db.upcast())).enter(); + let _span = tracing::trace_span!("public_symbols_map", ?file).entered(); let module_scope = root_scope(db, file); let symbols = symbol_table(db, module_scope); @@ -371,6 +370,10 @@ impl SymbolTableBuilder { } } + pub(super) fn symbol_by_name(&self, name: &str) -> Option { + self.table.symbol_id_by_name(name) + } + pub(super) fn finish(mut self) -> SymbolTable { self.table.shrink_to_fit(); self.table diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs new file mode 100644 index 0000000000..3768631d91 --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -0,0 +1,183 @@ +use red_knot_module_resolver::{resolve_module, Module, ModuleName}; +use ruff_db::vfs::VfsFile; +use ruff_python_ast as ast; +use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef}; + +use crate::semantic_index::ast_ids::HasScopedAstId; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::symbol::{PublicSymbolId, ScopeKind}; +use crate::semantic_index::{public_symbol, semantic_index, NodeWithScopeKey}; +use crate::types::{infer_types, public_symbol_ty, Type, TypingContext}; +use crate::Db; + +pub struct SemanticModel<'db> { + db: &'db dyn Db, + file: VfsFile, +} + +impl<'db> SemanticModel<'db> { + pub fn new(db: &'db dyn Db, file: VfsFile) -> Self { + Self { db, file } + } + + pub fn resolve_module(&self, module_name: ModuleName) -> Option { + resolve_module(self.db.upcast(), module_name) + } + + pub fn public_symbol(&self, module: &Module, symbol_name: &str) -> Option> { + public_symbol(self.db, module.file(), symbol_name) + } + + pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type<'db> { + public_symbol_ty(self.db, symbol) + } + + pub fn typing_context(&self) -> TypingContext<'db, '_> { + TypingContext::global(self.db) + } +} + +pub trait HasTy { + fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db>; +} + +impl HasTy for ast::ExpressionRef<'_> { + fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { + let index = semantic_index(model.db, model.file); + let file_scope = index.expression_scope_id(*self); + let expression_id = self.scoped_ast_id(model.db, model.file, file_scope); + + let scope = file_scope.to_scope_id(model.db, model.file); + infer_types(model.db, scope).expression_ty(expression_id) + } +} + +macro_rules! impl_expression_has_ty { + ($ty: ty) => { + impl HasTy for $ty { + #[inline] + fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { + let expression_ref = ExpressionRef::from(self); + expression_ref.ty(model) + } + } + }; +} + +impl_expression_has_ty!(ast::ExprBoolOp); +impl_expression_has_ty!(ast::ExprNamed); +impl_expression_has_ty!(ast::ExprBinOp); +impl_expression_has_ty!(ast::ExprUnaryOp); +impl_expression_has_ty!(ast::ExprLambda); +impl_expression_has_ty!(ast::ExprIf); +impl_expression_has_ty!(ast::ExprDict); +impl_expression_has_ty!(ast::ExprSet); +impl_expression_has_ty!(ast::ExprListComp); +impl_expression_has_ty!(ast::ExprSetComp); +impl_expression_has_ty!(ast::ExprDictComp); +impl_expression_has_ty!(ast::ExprGenerator); +impl_expression_has_ty!(ast::ExprAwait); +impl_expression_has_ty!(ast::ExprYield); +impl_expression_has_ty!(ast::ExprYieldFrom); +impl_expression_has_ty!(ast::ExprCompare); +impl_expression_has_ty!(ast::ExprCall); +impl_expression_has_ty!(ast::ExprFString); +impl_expression_has_ty!(ast::ExprStringLiteral); +impl_expression_has_ty!(ast::ExprBytesLiteral); +impl_expression_has_ty!(ast::ExprNumberLiteral); +impl_expression_has_ty!(ast::ExprBooleanLiteral); +impl_expression_has_ty!(ast::ExprNoneLiteral); +impl_expression_has_ty!(ast::ExprEllipsisLiteral); +impl_expression_has_ty!(ast::ExprAttribute); +impl_expression_has_ty!(ast::ExprSubscript); +impl_expression_has_ty!(ast::ExprStarred); +impl_expression_has_ty!(ast::ExprName); +impl_expression_has_ty!(ast::ExprList); +impl_expression_has_ty!(ast::ExprTuple); +impl_expression_has_ty!(ast::ExprSlice); +impl_expression_has_ty!(ast::ExprIpyEscapeCommand); + +impl HasTy for ast::Expr { + fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { + match self { + Expr::BoolOp(inner) => inner.ty(model), + Expr::Named(inner) => inner.ty(model), + Expr::BinOp(inner) => inner.ty(model), + Expr::UnaryOp(inner) => inner.ty(model), + Expr::Lambda(inner) => inner.ty(model), + Expr::If(inner) => inner.ty(model), + Expr::Dict(inner) => inner.ty(model), + Expr::Set(inner) => inner.ty(model), + Expr::ListComp(inner) => inner.ty(model), + Expr::SetComp(inner) => inner.ty(model), + Expr::DictComp(inner) => inner.ty(model), + Expr::Generator(inner) => inner.ty(model), + Expr::Await(inner) => inner.ty(model), + Expr::Yield(inner) => inner.ty(model), + Expr::YieldFrom(inner) => inner.ty(model), + Expr::Compare(inner) => inner.ty(model), + Expr::Call(inner) => inner.ty(model), + Expr::FString(inner) => inner.ty(model), + Expr::StringLiteral(inner) => inner.ty(model), + Expr::BytesLiteral(inner) => inner.ty(model), + Expr::NumberLiteral(inner) => inner.ty(model), + Expr::BooleanLiteral(inner) => inner.ty(model), + Expr::NoneLiteral(inner) => inner.ty(model), + Expr::EllipsisLiteral(inner) => inner.ty(model), + Expr::Attribute(inner) => inner.ty(model), + Expr::Subscript(inner) => inner.ty(model), + Expr::Starred(inner) => inner.ty(model), + Expr::Name(inner) => inner.ty(model), + Expr::List(inner) => inner.ty(model), + Expr::Tuple(inner) => inner.ty(model), + Expr::Slice(inner) => inner.ty(model), + Expr::IpyEscapeCommand(inner) => inner.ty(model), + } + } +} + +impl HasTy for ast::StmtFunctionDef { + fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { + let index = semantic_index(model.db, model.file); + let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); + + // SAFETY: A function always has either an enclosing module, function or class scope. + let mut parent_scope_id = index.parent_scope_id(definition_scope).unwrap(); + let parent_scope = index.scope(parent_scope_id); + + if parent_scope.kind() == ScopeKind::Annotation { + parent_scope_id = index.parent_scope_id(parent_scope_id).unwrap(); + } + + let scope = parent_scope_id.to_scope_id(model.db, model.file); + + let types = infer_types(model.db, scope); + let definition = + Definition::FunctionDef(self.scoped_ast_id(model.db, model.file, parent_scope_id)); + + types.definition_ty(definition) + } +} + +impl HasTy for StmtClassDef { + fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { + let index = semantic_index(model.db, model.file); + let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); + + // SAFETY: A class always has either an enclosing module, function or class scope. + let mut parent_scope_id = index.parent_scope_id(definition_scope).unwrap(); + let parent_scope = index.scope(parent_scope_id); + + if parent_scope.kind() == ScopeKind::Annotation { + parent_scope_id = index.parent_scope_id(parent_scope_id).unwrap(); + } + + let scope = parent_scope_id.to_scope_id(model.db, model.file); + + let types = infer_types(model.db, scope); + let definition = + Definition::ClassDef(self.scoped_ast_id(model.db, model.file, parent_scope_id)); + + types.definition_ty(definition) + } +} diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index e47870b960..a5fe056c26 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,6 +1,4 @@ -use salsa::DebugWithDb; - -use crate::semantic_index::ast_ids::{AstIdNode, ScopeAstIdNode}; +use crate::semantic_index::ast_ids::AstIdNode; use crate::semantic_index::symbol::{FileScopeId, PublicSymbolId, ScopeId}; use crate::semantic_index::{ public_symbol, root_scope, semantic_index, symbol_table, NodeWithScopeId, @@ -17,28 +15,6 @@ use ruff_python_ast::name::Name; mod display; mod infer; -/// Infers the type of `expr`. -/// -/// Calling this function from a salsa query adds a dependency on [`semantic_index`] -/// which changes with every AST change. That's why you should only call -/// this function for the current file that's being analyzed and not for -/// a dependency (or the query reruns whenever a dependency change). -/// -/// Prefer [`public_symbol_ty`] when resolving the type of symbol from another file. -#[tracing::instrument(level = "debug", skip(db))] -pub(crate) fn expression_ty<'db>( - db: &'db dyn Db, - file: VfsFile, - expression: &ast::Expr, -) -> Type<'db> { - let index = semantic_index(db, file); - let file_scope = index.expression_scope_id(expression); - let expression_id = expression.scope_ast_id(db, file, file_scope); - let scope = file_scope.to_scope_id(db, file); - - infer_types(db, scope).expression_ty(expression_id) -} - /// Infers the type of a public symbol. /// /// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation. @@ -65,7 +41,7 @@ pub(crate) fn expression_ty<'db>( /// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change. #[salsa::tracked] pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db>) -> Type<'db> { - let _ = tracing::trace_span!("public_symbol_ty", symbol = ?symbol.debug(db)).enter(); + let _span = tracing::trace_span!("public_symbol_ty", ?symbol).entered(); let file = symbol.file(db); let scope = root_scope(db, file); @@ -87,7 +63,7 @@ pub fn public_symbol_ty_by_name<'db>( /// Infers all types for `scope`. #[salsa::tracked(return_ref)] pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInference<'db> { - let _ = tracing::trace_span!("infer_types", scope = ?scope.debug(db)).enter(); + let _span = tracing::trace_span!("infer_types", ?scope).entered(); let file = scope.file(db); // Using the index here is fine because the code below depends on the AST anyway. @@ -270,6 +246,18 @@ impl<'a> FunctionType<'a> { } } +impl<'db> TypeId<'db, ScopedFunctionTypeId> { + pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name { + let function_ty = self.lookup(context); + &function_ty.name + } + + pub fn has_decorator(self, context: &TypingContext, decorator: Type<'db>) -> bool { + let function_ty = self.lookup(context); + function_ty.decorators.contains(&decorator) + } +} + #[newtype_index] pub struct ScopedClassTypeId; @@ -282,14 +270,42 @@ impl ScopedTypeId for ScopedClassTypeId { } impl<'db> TypeId<'db, ScopedClassTypeId> { + pub fn name<'a>(self, context: &'a TypingContext<'db, 'a>) -> &'a Name { + let class_ty = self.lookup(context); + &class_ty.name + } + /// Returns the class member of this class named `name`. /// /// The member resolves to a member of the class itself or any of its bases. - fn class_member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option> { + pub fn class_member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option> { if let Some(member) = self.own_class_member(context, name) { return Some(member); } + self.inherited_class_member(context, name) + } + + /// Returns the inferred type of the class member named `name`. + pub fn own_class_member( + self, + context: &TypingContext<'db, '_>, + name: &Name, + ) -> Option> { + let class = self.lookup(context); + + let symbols = symbol_table(context.db, class.body_scope); + let symbol = symbols.symbol_id_by_name(name)?; + let types = context.types(class.body_scope); + + Some(types.symbol_ty(symbol)) + } + + pub fn inherited_class_member( + self, + context: &TypingContext<'db, '_>, + name: &Name, + ) -> Option> { let class = self.lookup(context); for base in &class.bases { if let Some(member) = base.member(context, name) { @@ -299,17 +315,6 @@ impl<'db> TypeId<'db, ScopedClassTypeId> { None } - - /// Returns the inferred type of the class member named `name`. - fn own_class_member(self, context: &TypingContext<'db, '_>, name: &Name) -> Option> { - let class = self.lookup(context); - - let symbols = symbol_table(context.db, class.body_scope); - let symbol = symbols.symbol_id_by_name(name)?; - let types = context.types(class.body_scope); - - Some(types.symbol_ty(symbol)) - } } #[derive(Debug, Eq, PartialEq, Clone)] @@ -505,6 +510,7 @@ impl<'db, 'inference> TypingContext<'db, 'inference> { #[cfg(test)] mod tests { + use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings}; use ruff_db::file_system::FileSystemPathBuf; use ruff_db::parsed::parsed_module; use ruff_db::vfs::system_path_to_file; @@ -513,8 +519,8 @@ mod tests { assert_will_not_run_function_query, assert_will_run_function_query, TestDb, }; use crate::semantic_index::root_scope; - use crate::types::{expression_ty, infer_types, public_symbol_ty_by_name, TypingContext}; - use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings}; + use crate::types::{infer_types, public_symbol_ty_by_name, TypingContext}; + use crate::{HasTy, SemanticModel}; fn setup_db() -> TestDb { let mut db = TestDb::new(); @@ -541,8 +547,9 @@ mod tests { let parsed = parsed_module(&db, a); let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap(); + let model = SemanticModel::new(&db, a); - let literal_ty = expression_ty(&db, a, &statement.value); + let literal_ty = statement.value.ty(&model); assert_eq!( format!("{}", literal_ty.display(&TypingContext::global(&db))), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index bfdba6d606..6830a58d6b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -9,14 +9,14 @@ use ruff_index::IndexVec; use ruff_python_ast as ast; use ruff_python_ast::{ExprContext, TypeParams}; -use crate::semantic_index::ast_ids::{ScopeAstIdNode, ScopeExpressionId}; +use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId}; use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable}; use crate::semantic_index::{symbol_table, ChildrenIter, SemanticIndex}; use crate::types::{ - ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId, - ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, UnionType, - UnionTypeBuilder, + infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, + ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, + UnionType, UnionTypeBuilder, }; use crate::Db; @@ -36,15 +36,18 @@ pub(crate) struct TypeInference<'db> { intersection_types: IndexVec>, /// The types of every expression in this scope. - expression_tys: IndexVec>, + expression_tys: IndexVec>, /// The public types of every symbol in this scope. symbol_tys: IndexVec>, + + /// The type of a definition. + definition_tys: FxHashMap>, } impl<'db> TypeInference<'db> { #[allow(unused)] - pub(super) fn expression_ty(&self, expression: ScopeExpressionId) -> Type<'db> { + pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { self.expression_tys[expression] } @@ -72,6 +75,10 @@ impl<'db> TypeInference<'db> { &self.intersection_types[id] } + pub(crate) fn definition_ty(&self, definition: Definition) -> Type<'db> { + self.definition_tys[&definition] + } + fn shrink_to_fit(&mut self) { self.class_types.shrink_to_fit(); self.function_types.shrink_to_fit(); @@ -80,6 +87,7 @@ impl<'db> TypeInference<'db> { self.expression_tys.shrink_to_fit(); self.symbol_tys.shrink_to_fit(); + self.definition_tys.shrink_to_fit(); } } @@ -96,7 +104,6 @@ pub(super) struct TypeInferenceBuilder<'a> { /// The type inference results types: TypeInference<'a>, - definition_tys: FxHashMap>, children_scopes: ChildrenIter<'a>, } @@ -117,7 +124,6 @@ impl<'db> TypeInferenceBuilder<'db> { db, types: TypeInference::default(), - definition_tys: FxHashMap::default(), children_scopes, } } @@ -185,7 +191,7 @@ impl<'db> TypeInferenceBuilder<'db> { decorator_list, } = function; - let function_id = function.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let function_id = function.scoped_ast_id(self.db, self.file_id, self.file_scope_id); let decorator_tys = decorator_list .iter() .map(|decorator| self.infer_decorator(decorator)) @@ -210,7 +216,8 @@ impl<'db> TypeInferenceBuilder<'db> { ScopeKind::Function | ScopeKind::Annotation )); - self.definition_tys + self.types + .definition_tys .insert(Definition::FunctionDef(function_id), function_ty); } @@ -224,7 +231,7 @@ impl<'db> TypeInferenceBuilder<'db> { body: _, } = class; - let class_id = class.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let class_id = class.scoped_ast_id(self.db, self.file_id, self.file_scope_id); for decorator in decorator_list { self.infer_decorator(decorator); @@ -252,7 +259,8 @@ impl<'db> TypeInferenceBuilder<'db> { body_scope: class_body_scope_id.to_scope_id(self.db, self.file_id), }); - self.definition_tys + self.types + .definition_tys .insert(Definition::ClassDef(class_id), class_ty); } @@ -295,10 +303,11 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_expression(target); } - let assign_id = assignment.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let assign_id = assignment.scoped_ast_id(self.db, self.file_id, self.file_scope_id); // TODO: Handle multiple targets. - self.definition_tys + self.types + .definition_tys .insert(Definition::Assignment(assign_id), value_ty); } @@ -318,8 +327,8 @@ impl<'db> TypeInferenceBuilder<'db> { let annotation_ty = self.infer_expression(annotation); self.infer_expression(target); - self.definition_tys.insert( - Definition::AnnotatedAssignment(assignment.scope_ast_id( + self.types.definition_tys.insert( + Definition::AnnotatedAssignment(assignment.scoped_ast_id( self.db, self.file_id, self.file_scope_id, @@ -347,7 +356,7 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_import_statement(&mut self, import: &ast::StmtImport) { let ast::StmtImport { range: _, names } = import; - let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let import_id = import.scoped_ast_id(self.db, self.file_id, self.file_scope_id); for (i, alias) in names.iter().enumerate() { let ast::Alias { @@ -362,7 +371,7 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|module| self.typing_context().module_ty(module.file())) .unwrap_or(Type::Unknown); - self.definition_tys.insert( + self.types.definition_tys.insert( Definition::Import(ImportDefinition { import_id, alias: u32::try_from(i).unwrap(), @@ -380,7 +389,7 @@ impl<'db> TypeInferenceBuilder<'db> { level: _, } = import; - let import_id = import.scope_ast_id(self.db, self.file_id, self.file_scope_id); + let import_id = import.scoped_ast_id(self.db, self.file_id, self.file_scope_id); let module_name = ModuleName::new(module.as_deref().expect("Support relative imports")); let module = @@ -400,7 +409,7 @@ impl<'db> TypeInferenceBuilder<'db> { .member(&self.typing_context(), &name.id) .unwrap_or(Type::Unknown); - self.definition_tys.insert( + self.types.definition_tys.insert( Definition::ImportFrom(ImportFromDefinition { import_id, name: u32::try_from(i).unwrap(), @@ -482,8 +491,8 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); self.infer_expression(target); - self.definition_tys.insert( - Definition::NamedExpr(named.scope_ast_id(self.db, self.file_id, self.file_scope_id)), + self.types.definition_tys.insert( + Definition::NamedExpr(named.scoped_ast_id(self.db, self.file_id, self.file_scope_id)), value_ty, ); @@ -530,11 +539,12 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO: Skip over class scopes unless the they are a immediately-nested type param scope. // TODO: Support built-ins - let symbol_table = - symbol_table(self.db, ancestor_id.to_scope_id(self.db, self.file_id)); + let ancestor_scope = ancestor_id.to_scope_id(self.db, self.file_id); + let symbol_table = symbol_table(self.db, ancestor_scope); - if let Some(_symbol_id) = symbol_table.symbol_id_by_name(id) { - todo!("Return type for symbol from outer scope"); + if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) { + let types = infer_types(self.db, ancestor_scope); + return types.symbol_ty(symbol_id); } } Type::Unknown @@ -666,7 +676,7 @@ impl<'db> TypeInferenceBuilder<'db> { let mut definitions = symbol .definitions() .iter() - .filter_map(|definition| self.definition_tys.get(definition).copied()); + .filter_map(|definition| self.types.definition_tys.get(definition).copied()); let Some(first) = definitions.next() else { return Type::Unbound; diff --git a/crates/ruff_db/src/parsed.rs b/crates/ruff_db/src/parsed.rs index 5808bca4ae..8eaf5506a7 100644 --- a/crates/ruff_db/src/parsed.rs +++ b/crates/ruff_db/src/parsed.rs @@ -1,4 +1,3 @@ -use salsa::DebugWithDb; use std::fmt::Formatter; use std::ops::Deref; use std::sync::Arc; @@ -23,7 +22,7 @@ use crate::Db; /// for determining if a query result is unchanged. #[salsa::tracked(return_ref, no_eq)] pub fn parsed_module(db: &dyn Db, file: VfsFile) -> ParsedModule { - let _ = tracing::trace_span!("parse_module", file = ?file.debug(db)).enter(); + let _span = tracing::trace_span!("parse_module", file = ?file).entered(); let source = source_text(db, file); let path = file.path(db); diff --git a/crates/ruff_db/src/source.rs b/crates/ruff_db/src/source.rs index 0dcab3987b..ab044721cc 100644 --- a/crates/ruff_db/src/source.rs +++ b/crates/ruff_db/src/source.rs @@ -10,7 +10,7 @@ use crate::Db; /// Reads the content of file. #[salsa::tracked] pub fn source_text(db: &dyn Db, file: VfsFile) -> SourceText { - let _ = tracing::trace_span!("source_text", file = ?file.debug(db)).enter(); + let _span = tracing::trace_span!("source_text", ?file).entered(); let content = file.read(db); @@ -22,7 +22,7 @@ pub fn source_text(db: &dyn Db, file: VfsFile) -> SourceText { /// Computes the [`LineIndex`] for `file`. #[salsa::tracked] pub fn line_index(db: &dyn Db, file: VfsFile) -> LineIndex { - let _ = tracing::trace_span!("line_index", file = ?file.debug(db)).enter(); + let _span = tracing::trace_span!("line_index", file = ?file.debug(db)).entered(); let source = source_text(db, file);