diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 01e8998acc..b85683889b 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -12,7 +12,7 @@ use crate::node_key::NodeKey; use crate::semantic_index::ast_ids::{AstId, AstIds, ScopedClassId, ScopedFunctionId}; use crate::semantic_index::builder::SemanticIndexBuilder; use crate::semantic_index::symbol::{ - FileScopeId, PublicSymbolId, Scope, ScopeId, ScopedSymbolId, SymbolTable, + FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable, }; use crate::Db; @@ -83,17 +83,14 @@ pub struct SemanticIndex { /// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope). scopes_by_expression: FxHashMap, + /// Map from the definition that introduce a scope to the scope they define. + scopes_by_definition: FxHashMap, + /// Lookup table to map between node ids and ast nodes. /// /// Note: We should not depend on this map when analysing other files or /// changing a file invalidates all dependents. ast_ids: IndexVec, - - /// Map from scope to the node that introduces the scope. - nodes_by_scope: IndexVec, - - /// Map from nodes that introduce a scope to the scope they define. - scopes_by_node: FxHashMap, } impl SemanticIndex { @@ -150,6 +147,7 @@ impl SemanticIndex { } /// Returns an iterator over the direct child scopes of `scope`. + #[allow(unused)] pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter { ChildrenIter::new(self, scope) } @@ -159,15 +157,45 @@ impl SemanticIndex { AncestorsIter::new(self, scope) } - pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId { - self.nodes_by_scope[scope_id] + /// Returns the scope that is created by `node`. + pub(crate) fn node_scope(&self, node: impl Into) -> FileScopeId { + self.scopes_by_definition[&node.into()] } + /// Returns the scope in which `node_with_scope` is defined. + /// + /// The returned scope can be used to lookup the symbol of the definition or its type. + /// + /// * Annotation: Returns the direct parent scope + /// * Function and classes: Returns the parent scope unless they have type parameters in which case + /// the grandparent scope is returned. pub(crate) fn definition_scope( &self, node_with_scope: impl Into, ) -> FileScopeId { - self.scopes_by_node[&node_with_scope.into()] + fn resolve_scope(index: &SemanticIndex, node_with_scope: NodeWithScopeKey) -> FileScopeId { + let scope_id = index.node_scope(node_with_scope); + let scope = index.scope(scope_id); + + match scope.kind() { + ScopeKind::Module => scope_id, + ScopeKind::Annotation => scope.parent.unwrap(), + ScopeKind::Class | ScopeKind::Function => { + let mut ancestors = index.ancestor_scopes(scope_id).skip(1); + + let (mut scope_id, mut scope) = ancestors.next().unwrap(); + if scope.kind() == ScopeKind::Annotation { + (scope_id, scope) = ancestors.next().unwrap(); + } + + debug_assert_ne!(scope.kind(), ScopeKind::Annotation); + + scope_id + } + } + } + + resolve_scope(self, node_with_scope.into()) } } @@ -307,8 +335,9 @@ mod tests { use ruff_db::vfs::{system_path_to_file, VfsFile}; use crate::db::tests::TestDb; - use crate::semantic_index::symbol::{FileScopeId, FileSymbolId, Scope, ScopeKind, SymbolTable}; - use crate::semantic_index::{root_scope, semantic_index, symbol_table, SemanticIndex}; + use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable}; + use crate::semantic_index::{root_scope, semantic_index, symbol_table}; + use crate::Db; struct TestCase { db: TestDb, @@ -440,18 +469,12 @@ y = 2 let index = semantic_index(&db, file); - let root = index.symbol_table(FileScopeId::root()); let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); assert_eq!(scopes.len(), 1); let (class_scope_id, class_scope) = scopes[0]; assert_eq!(class_scope.kind(), ScopeKind::Class); - assert_eq!( - class_scope - .defining_symbol() - .map(super::symbol::FileSymbolId::scoped_symbol_id), - root.symbol_id_by_name("C") - ); + assert_eq!(class_scope.name(&db, file), "C"); let class_table = index.symbol_table(class_scope_id); assert_eq!(names(&class_table), vec!["x"]); @@ -480,12 +503,7 @@ y = 2 let (function_scope_id, function_scope) = scopes[0]; assert_eq!(function_scope.kind(), ScopeKind::Function); - assert_eq!( - function_scope - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("func") - ); + assert_eq!(function_scope.name(&db, file), "func"); let function_table = index.symbol_table(function_scope_id); assert_eq!(names(&function_table), vec!["x"]); @@ -521,19 +539,9 @@ def func(): assert_eq!(func_scope_1.kind(), ScopeKind::Function); - assert_eq!( - func_scope_1 - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("func") - ); + assert_eq!(func_scope_1.name(&db, file), "func"); assert_eq!(func_scope_2.kind(), ScopeKind::Function); - assert_eq!( - func_scope_2 - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("func") - ); + assert_eq!(func_scope_2.name(&db, file), "func"); let func1_table = index.symbol_table(func_scope1_id); let func2_table = index.symbol_table(func_scope2_id); @@ -568,12 +576,7 @@ def func[T](): let (ann_scope_id, ann_scope) = scopes[0]; assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!( - ann_scope - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("func") - ); + assert_eq!(ann_scope.name(&db, file), "func"); let ann_table = index.symbol_table(ann_scope_id); assert_eq!(names(&ann_table), vec!["T"]); @@ -581,12 +584,7 @@ def func[T](): assert_eq!(scopes.len(), 1); let (func_scope_id, func_scope) = scopes[0]; assert_eq!(func_scope.kind(), ScopeKind::Function); - assert_eq!( - func_scope - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("func") - ); + assert_eq!(func_scope.name(&db, file), "func"); let func_table = index.symbol_table(func_scope_id); assert_eq!(names(&func_table), vec!["x"]); } @@ -610,12 +608,7 @@ class C[T]: assert_eq!(scopes.len(), 1); let (ann_scope_id, ann_scope) = scopes[0]; assert_eq!(ann_scope.kind(), ScopeKind::Annotation); - assert_eq!( - ann_scope - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("C") - ); + assert_eq!(ann_scope.name(&db, file), "C"); let ann_table = index.symbol_table(ann_scope_id); assert_eq!(names(&ann_table), vec!["T"]); assert!( @@ -630,12 +623,7 @@ class C[T]: let (func_scope_id, class_scope) = scopes[0]; assert_eq!(class_scope.kind(), ScopeKind::Class); - assert_eq!( - class_scope - .defining_symbol() - .map(FileSymbolId::scoped_symbol_id), - root_table.symbol_id_by_name("C") - ); + assert_eq!(class_scope.name(&db, file), "C"); assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]); } @@ -698,23 +686,13 @@ class C[T]: fn scope_iterators() { fn scope_names<'a>( scopes: impl Iterator, - index: &'a SemanticIndex, + db: &'a dyn Db, + file: VfsFile, ) -> Vec<&'a str> { - let mut names = Vec::new(); - - for (_, scope) in scopes { - if let Some(defining_symbol) = scope.defining_symbol { - let symbol_table = &index.symbol_tables[defining_symbol.scope()]; - let symbol = symbol_table.symbol(defining_symbol.scoped_symbol_id()); - names.push(symbol.name().as_str()); - } else if scope.parent.is_none() { - names.push(""); - } else { - panic!("Unsupported"); - } - } - - names + scopes + .into_iter() + .map(|(_, scope)| scope.name(db, file)) + .collect() } let TestCase { db, file } = test_case( @@ -734,16 +712,19 @@ def x(): let descendents = index.descendent_scopes(FileScopeId::root()); assert_eq!( - scope_names(descendents, index), + scope_names(descendents, &db, file), vec!["Test", "foo", "bar", "baz", "x"] ); let children = index.child_scopes(FileScopeId::root()); - assert_eq!(scope_names(children, index), vec!["Test", "x"]); + assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]); let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0; let test_child_scopes = index.child_scopes(test_class); - assert_eq!(scope_names(test_child_scopes, index), vec!["foo", "baz"]); + assert_eq!( + scope_names(test_child_scopes, &db, file), + vec!["foo", "baz"] + ); let bar_scope = index .descendent_scopes(FileScopeId::root()) @@ -753,7 +734,7 @@ def x(): let ancestors = index.ancestor_scopes(bar_scope); assert_eq!( - scope_names(ancestors, index), + scope_names(ancestors, &db, file), vec!["bar", "foo", "Test", ""] ); } diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 85c197073e..892d92fc40 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -145,10 +145,6 @@ impl AstId { pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self { Self { scope, in_scope_id } } - - pub(super) fn in_scope_id(self) -> L { - self.in_scope_id - } } /// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`]. diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index d6d4042224..750f928229 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -12,7 +12,7 @@ use crate::node_key::NodeKey; use crate::semantic_index::ast_ids::{AstId, AstIdsBuilder, ScopedClassId, ScopedFunctionId}; use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; use crate::semantic_index::symbol::{ - FileScopeId, FileSymbolId, Scope, ScopeKind, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, + FileScopeId, Scope, ScopeKind, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, }; use crate::semantic_index::{NodeWithScopeId, NodeWithScopeKey, SemanticIndex}; @@ -27,9 +27,8 @@ pub(super) struct SemanticIndexBuilder<'a> { scopes: IndexVec, symbol_tables: IndexVec, ast_ids: IndexVec, - expression_scopes: FxHashMap, - scope_nodes: IndexVec, - node_scopes: FxHashMap, + scopes_by_expression: FxHashMap, + scopes_by_definition: FxHashMap, } impl<'a> SemanticIndexBuilder<'a> { @@ -42,16 +41,13 @@ impl<'a> SemanticIndexBuilder<'a> { scopes: IndexVec::new(), symbol_tables: IndexVec::new(), ast_ids: IndexVec::new(), - expression_scopes: FxHashMap::default(), - node_scopes: FxHashMap::default(), - scope_nodes: IndexVec::new(), + scopes_by_expression: FxHashMap::default(), + scopes_by_definition: FxHashMap::default(), }; builder.push_scope_with_parent( &NodeWithScope::new(parsed.syntax(), NodeWithScopeId::Module), None, - None, - None, ); builder @@ -64,45 +60,29 @@ impl<'a> SemanticIndexBuilder<'a> { .expect("Always to have a root scope") } - fn push_scope( - &mut self, - node: &NodeWithScope, - defining_symbol: Option, - definition: Option, - ) { + fn push_scope(&mut self, node: &NodeWithScope) { let parent = self.current_scope(); - self.push_scope_with_parent(node, defining_symbol, definition, Some(parent)); + self.push_scope_with_parent(node, Some(parent)); } - fn push_scope_with_parent( - &mut self, - node: &NodeWithScope, - defining_symbol: Option, - definition: Option, - parent: Option, - ) { + fn push_scope_with_parent(&mut self, node: &NodeWithScope, parent: Option) { let children_start = self.scopes.next_index() + 1; - let node_key = node.key(); - let node_id = node.id(); - let scope_kind = node.scope_kind(); let scope = Scope { + node: node.id(), parent, - defining_symbol, - definition, - kind: scope_kind, + kind: node.scope_kind(), descendents: children_start..children_start, }; let scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::new()); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); - let scope_node_id = self.scope_nodes.push(node_id); debug_assert_eq!(ast_id_scope, scope_id); - debug_assert_eq!(scope_id, scope_node_id); + self.scope_stack.push(scope_id); - self.node_scopes.insert(node_key, scope_id); + self.scopes_by_definition.insert(node.key(), scope_id); } fn pop_scope(&mut self) -> FileScopeId { @@ -123,18 +103,9 @@ impl<'a> SemanticIndexBuilder<'a> { &mut self.ast_ids[scope_id] } - fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> FileSymbolId { - for scope in self.scope_stack.iter().rev().skip(1) { - let builder = &self.symbol_tables[*scope]; - - if let Some(symbol) = builder.symbol_by_name(&name) { - return FileSymbolId::new(*scope, symbol); - } - } - - let scope = self.current_scope(); + fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); - FileSymbolId::new(scope, symbol_table.add_or_update_symbol(name, flags, None)) + symbol_table.add_or_update_symbol(name, flags, None) } fn add_or_update_symbol_with_definition( @@ -150,7 +121,6 @@ impl<'a> SemanticIndexBuilder<'a> { fn with_type_params( &mut self, with_params: &WithTypeParams, - defining_symbol: FileSymbolId, nested: impl FnOnce(&mut Self) -> FileScopeId, ) -> FileScopeId { let type_params = with_params.type_parameters(); @@ -161,11 +131,7 @@ impl<'a> SemanticIndexBuilder<'a> { WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id), }; - self.push_scope( - &NodeWithScope::new(type_params, type_params_id), - Some(defining_symbol), - Some(with_params.definition()), - ); + self.push_scope(&NodeWithScope::new(type_params, type_params_id)); for type_param in &type_params.type_params { let name = match type_param { ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, @@ -210,16 +176,14 @@ impl<'a> SemanticIndexBuilder<'a> { self.scopes.shrink_to_fit(); ast_ids.shrink_to_fit(); symbol_tables.shrink_to_fit(); - self.expression_scopes.shrink_to_fit(); - self.scope_nodes.shrink_to_fit(); + self.scopes_by_expression.shrink_to_fit(); SemanticIndex { symbol_tables, scopes: self.scopes, - nodes_by_scope: self.scope_nodes, - scopes_by_node: self.node_scopes, + scopes_by_definition: self.scopes_by_definition, ast_ids, - scopes_by_expression: self.expression_scopes, + scopes_by_expression: self.scopes_by_expression, } } } @@ -242,31 +206,24 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { let function_id = ScopedFunctionId(statement_id); let definition = Definition::FunctionDef(function_id); let scope = self.current_scope(); - let symbol = FileSymbolId::new( - scope, - self.add_or_update_symbol_with_definition(name.clone(), definition), - ); + + self.add_or_update_symbol_with_definition(name.clone(), definition); self.with_type_params( &WithTypeParams::FunctionDef { node: function_def, id: AstId::new(scope, function_id), }, - symbol, |builder| { builder.visit_parameters(&function_def.parameters); for expr in &function_def.returns { builder.visit_annotation(expr); } - builder.push_scope( - &NodeWithScope::new( - function_def, - NodeWithScopeId::Function(AstId::new(scope, function_id)), - ), - Some(symbol), - Some(definition), - ); + builder.push_scope(&NodeWithScope::new( + function_def, + NodeWithScopeId::Function(AstId::new(scope, function_id)), + )); builder.visit_body(&function_def.body); builder.pop_scope() }, @@ -281,29 +238,23 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { let class_id = ScopedClassId(statement_id); let definition = Definition::ClassDef(class_id); let scope = self.current_scope(); - let id = FileSymbolId::new( - self.current_scope(), - self.add_or_update_symbol_with_definition(name.clone(), definition), - ); + + self.add_or_update_symbol_with_definition(name.clone(), definition); + self.with_type_params( &WithTypeParams::ClassDef { node: class, id: AstId::new(scope, class_id), }, - id, |builder| { if let Some(arguments) = &class.arguments { builder.visit_arguments(arguments); } - builder.push_scope( - &NodeWithScope::new( - class, - NodeWithScopeId::Class(AstId::new(scope, class_id)), - ), - Some(id), - Some(definition), - ); + builder.push_scope(&NodeWithScope::new( + class, + NodeWithScopeId::Class(AstId::new(scope, class_id)), + )); builder.visit_body(&class.body); builder.pop_scope() @@ -368,7 +319,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> { self.current_ast_ids().record_expression(expr, module) }; - self.expression_scopes + self.scopes_by_expression .insert(NodeKey::from_node(expr), self.current_scope()); match expr { @@ -449,13 +400,6 @@ impl<'a> WithTypeParams<'a> { WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(), } } - - fn definition(&self) -> Definition { - match self { - WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()), - WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()), - } - } } struct NodeWithScope { diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index 8d752b665b..8c5ebb8c23 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -6,13 +6,14 @@ use hashbrown::hash_map::RawEntryMut; use rustc_hash::FxHasher; use smallvec::SmallVec; -use crate::semantic_index::definition::Definition; -use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap}; -use crate::Db; use ruff_db::vfs::VfsFile; use ruff_index::{newtype_index, IndexVec}; use ruff_python_ast::name::Name; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::{root_scope, semantic_index, symbol_table, NodeWithScopeId, SymbolMap}; +use crate::Db; + #[derive(Eq, PartialEq, Debug)] pub struct Symbol { name: Name, @@ -87,13 +88,6 @@ pub struct FileSymbolId { } impl FileSymbolId { - pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self { - Self { - scope, - scoped_symbol_id: symbol, - } - } - pub fn scope(self) -> FileScopeId { self.scope } @@ -215,19 +209,33 @@ impl FileScopeId { #[derive(Debug, Eq, PartialEq)] pub struct Scope { pub(super) parent: Option, - pub(super) definition: Option, - pub(super) defining_symbol: Option, + pub(super) node: NodeWithScopeId, pub(super) kind: ScopeKind, pub(super) descendents: Range, } impl Scope { - pub fn definition(&self) -> Option { - self.definition + #[cfg(test)] + pub(crate) fn name<'db>(&self, db: &'db dyn Db, file: VfsFile) -> &'db str { + use crate::semantic_index::ast_ids::AstIdNode; + use ruff_python_ast as ast; + + match self.node { + NodeWithScopeId::Module => "", + NodeWithScopeId::Class(class) | NodeWithScopeId::ClassTypeParams(class) => { + let class = ast::StmtClassDef::lookup(db, file, class); + class.name.as_str() + } + NodeWithScopeId::Function(function) | NodeWithScopeId::FunctionTypeParams(function) => { + let function = ast::StmtFunctionDef::lookup(db, file, function); + function.name.as_str() + } + } } - pub fn defining_symbol(&self) -> Option { - self.defining_symbol + /// The node that creates this scope. + pub(crate) fn node(&self) -> NodeWithScopeId { + self.node } pub fn parent(self) -> Option { @@ -365,10 +373,6 @@ impl SymbolTableBuilder { } } - pub(super) fn symbol_by_name(&self, name: &str) -> Option { - self.table.symbol_id_by_name(name) - } - pub(super) fn finish(mut self) -> SymbolTable { self.table.shrink_to_fit(); self.table diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 6e5800bc5c..834f81fa52 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -5,7 +5,7 @@ use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef}; use crate::semantic_index::ast_ids::HasScopedAstId; use crate::semantic_index::definition::Definition; -use crate::semantic_index::symbol::{PublicSymbolId, ScopeKind}; +use crate::semantic_index::symbol::PublicSymbolId; use crate::semantic_index::{public_symbol, semantic_index, NodeWithScopeKey}; use crate::types::{infer_types, public_symbol_ty, Type, TypingContext}; use crate::Db; @@ -38,6 +38,10 @@ impl<'db> SemanticModel<'db> { } pub trait HasTy { + /// Returns the inferred type of `self`. + /// + /// ## Panics + /// May panic if `self` is from another file than `model`. fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db>; } @@ -48,7 +52,6 @@ impl HasTy for ast::ExpressionRef<'_> { let scope = file_scope.to_scope_id(model.db, model.file); let expression_id = self.scoped_ast_id(model.db, scope); - infer_types(model.db, scope).expression_ty(expression_id) } } @@ -142,15 +145,7 @@ impl HasTy for ast::StmtFunctionDef { let index = semantic_index(model.db, model.file); let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); - // SAFETY: A function always has either an enclosing module, function or class scope. - let mut parent_scope_id = index.parent_scope_id(definition_scope).unwrap(); - let parent_scope = index.scope(parent_scope_id); - - if parent_scope.kind() == ScopeKind::Annotation { - parent_scope_id = index.parent_scope_id(parent_scope_id).unwrap(); - } - - let scope = parent_scope_id.to_scope_id(model.db, model.file); + let scope = definition_scope.to_scope_id(model.db, model.file); let types = infer_types(model.db, scope); let definition = Definition::FunctionDef(self.scoped_ast_id(model.db, scope)); @@ -163,16 +158,7 @@ impl HasTy for StmtClassDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); - - // SAFETY: A class always has either an enclosing module, function or class scope. - let mut parent_scope_id = index.parent_scope_id(definition_scope).unwrap(); - let parent_scope = index.scope(parent_scope_id); - - if parent_scope.kind() == ScopeKind::Annotation { - parent_scope_id = index.parent_scope_id(parent_scope_id).unwrap(); - } - - let scope = parent_scope_id.to_scope_id(model.db, model.file); + let scope = definition_scope.to_scope_id(model.db, model.file); let types = infer_types(model.db, scope); let definition = Definition::ClassDef(self.scoped_ast_id(model.db, scope)); @@ -180,3 +166,68 @@ impl HasTy for StmtClassDef { types.definition_ty(definition) } } + +#[cfg(test)] +mod tests { + use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings}; + use ruff_db::file_system::FileSystemPathBuf; + use ruff_db::parsed::parsed_module; + use ruff_db::vfs::system_path_to_file; + + use crate::db::tests::TestDb; + use crate::types::Type; + use crate::{HasTy, SemanticModel}; + + fn setup_db() -> TestDb { + let mut db = TestDb::new(); + set_module_resolution_settings( + &mut db, + ModuleResolutionSettings { + extra_paths: vec![], + workspace_root: FileSystemPathBuf::from("/src"), + site_packages: None, + custom_typeshed: None, + }, + ); + + db + } + + #[test] + fn function_ty() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("/src/foo.py", "def test(): pass")?; + let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); + + let ast = parsed_module(&db, foo); + + let function = ast.suite()[0].as_function_def_stmt().unwrap(); + let model = SemanticModel::new(&db, foo); + let ty = function.ty(&model); + + assert!(matches!(ty, Type::Function(_))); + + Ok(()) + } + + #[test] + fn class_ty() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("/src/foo.py", "class Test: pass")?; + let foo = system_path_to_file(&db, "/src/foo.py").unwrap(); + + let ast = parsed_module(&db, foo); + + let class = ast.suite()[0].as_class_def_stmt().unwrap(); + let model = SemanticModel::new(&db, foo); + let ty = class.ty(&model); + + assert!(matches!(ty, Type::Class(_))); + + Ok(()) + } +} diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index a5fe056c26..825f50e464 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -71,7 +71,7 @@ pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfe let index = semantic_index(db, file); let scope_id = scope.file_scope_id(db); - let node = index.scope_node(scope_id); + let node = index.scope(scope_id).node(); let mut context = TypeInferenceBuilder::new(db, scope, index); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 423fce7955..4ae5c76feb 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -11,8 +11,8 @@ use ruff_python_ast::{ExprContext, TypeParams}; use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId}; use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; -use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable}; -use crate::semantic_index::{symbol_table, ChildrenIter, SemanticIndex}; +use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId, SymbolTable}; +use crate::semantic_index::{symbol_table, SemanticIndex}; use crate::types::{ infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, @@ -104,7 +104,6 @@ pub(super) struct TypeInferenceBuilder<'a> { /// The type inference results types: TypeInference<'a>, - children_scopes: ChildrenIter<'a>, } impl<'db> TypeInferenceBuilder<'db> { @@ -112,7 +111,6 @@ impl<'db> TypeInferenceBuilder<'db> { pub(super) fn new(db: &'db dyn Db, scope: ScopeId<'db>, index: &'db SemanticIndex) -> Self { let file_scope_id = scope.file_scope_id(db); let file = scope.file(db); - let children_scopes = index.child_scopes(file_scope_id); let symbol_table = index.symbol_table(file_scope_id); Self { @@ -124,7 +122,6 @@ impl<'db> TypeInferenceBuilder<'db> { db, types: TypeInference::default(), - children_scopes, } } @@ -208,14 +205,6 @@ impl<'db> TypeInferenceBuilder<'db> { decorators: decorator_tys, }); - // Skip over the function or type params child scope. - let (_, scope) = self.children_scopes.next().unwrap(); - - assert!(matches!( - scope.kind(), - ScopeKind::Function | ScopeKind::Annotation - )); - self.types .definition_tys .insert(Definition::FunctionDef(function_id), function_ty); @@ -225,7 +214,7 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::StmtClassDef { range: _, name, - type_params, + type_params: _, decorator_list, arguments, body: _, @@ -242,16 +231,7 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|arguments| self.infer_arguments(arguments)) .unwrap_or(Vec::new()); - // If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope - // Otherwise the next scope must be the class definition scope. - let (class_body_scope_id, class_body_scope) = if type_params.is_some() { - let (type_params_scope, _) = self.children_scopes.next().unwrap(); - self.index.child_scopes(type_params_scope).next().unwrap() - } else { - self.children_scopes.next().unwrap() - }; - - assert_eq!(class_body_scope.kind(), ScopeKind::Class); + let class_body_scope_id = self.index.node_scope(class); let class_ty = self.class_ty(ClassType { name: name.id.clone(), @@ -539,6 +519,12 @@ impl<'db> TypeInferenceBuilder<'db> { let symbol_table = symbol_table(self.db, ancestor_scope); if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) { + let symbol = symbol_table.symbol(symbol_id); + + if !symbol.is_defined() { + continue; + } + let types = infer_types(self.db, ancestor_scope); return types.symbol_ty(symbol_id); } @@ -696,13 +682,13 @@ impl<'db> TypeInferenceBuilder<'db> { #[cfg(test)] mod tests { + use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings}; use ruff_db::file_system::FileSystemPathBuf; use ruff_db::vfs::system_path_to_file; + use ruff_python_ast::name::Name; use crate::db::tests::TestDb; use crate::types::{public_symbol_ty_by_name, Type, TypingContext}; - use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings}; - use ruff_python_ast::name::Name; fn setup_db() -> TestDb { let mut db = TestDb::new();