diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 7bcb35fd5e..c4534f688c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -121,9 +121,11 @@ impl<'db> SemanticIndexBuilder<'db> { fn push_scope_with_parent(&mut self, node: NodeWithScopeRef, parent: Option) { let children_start = self.scopes.next_index() + 1; + #[allow(unsafe_code)] let scope = Scope { parent, - kind: node.scope_kind(), + // SAFETY: `node` is guaranteed to be a child of `self.module` + node: unsafe { node.to_kind(self.module.clone()) }, descendents: children_start..children_start, }; self.try_node_context_stack_manager.enter_nested_scope(); @@ -133,15 +135,7 @@ impl<'db> SemanticIndexBuilder<'db> { self.use_def_maps.push(UseDefMapBuilder::new()); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); - #[allow(unsafe_code)] - // SAFETY: `node` is guaranteed to be a child of `self.module` - let scope_id = ScopeId::new( - self.db, - self.file, - file_scope_id, - unsafe { node.to_kind(self.module.clone()) }, - countme::Count::default(), - ); + let scope_id = ScopeId::new(self.db, self.file, file_scope_id, countme::Count::default()); self.scope_ids_by_scope.push(scope_id); self.scopes_by_node.insert(node.node_key(), file_scope_id); diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index c723edff7d..9a48b1790c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -9,6 +9,19 @@ use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId}; use crate::unpack::Unpack; use crate::Db; +/// A definition of a symbol. +/// +/// ## Module-local type +/// This type should not be used as part of any cross-module API because +/// it holds a reference to the AST node. Range-offset changes +/// then propagate through all usages, and deserialization requires +/// reparsing the entire module. +/// +/// E.g. don't use this type in: +/// +/// * a return type of a cross-module query +/// * a field of a type that is a return type of a cross-module query +/// * an argument of a cross-module query #[salsa::tracked] pub struct Definition<'db> { /// The file in which the definition occurs. diff --git a/crates/red_knot_python_semantic/src/semantic_index/expression.rs b/crates/red_knot_python_semantic/src/semantic_index/expression.rs index 4a7582bc32..327d40a4b1 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/expression.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/expression.rs @@ -8,6 +8,18 @@ use salsa; /// An independently type-inferable expression. /// /// Includes constraint expressions (e.g. if tests) and the RHS of an unpacking assignment. +/// +/// ## Module-local type +/// This type should not be used as part of any cross-module API because +/// it holds a reference to the AST node. Range-offset changes +/// then propagate through all usages, and deserialization requires +/// reparsing the entire module. +/// +/// E.g. don't use this type in: +/// +/// * a return type of a cross-module query +/// * a field of a type that is a return type of a cross-module query +/// * an argument of a cross-module query #[salsa::tracked] pub(crate) struct Expression<'db> { /// The file in which the expression occurs. diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index 44bf466d9d..f42099dd7d 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -103,14 +103,10 @@ pub struct ScopedSymbolId; pub struct ScopeId<'db> { #[id] pub file: File, + #[id] pub file_scope_id: FileScopeId, - /// The node that introduces this scope. - #[no_eq] - #[return_ref] - pub node: NodeWithScopeKind, - #[no_eq] count: countme::Count>, } @@ -131,6 +127,14 @@ impl<'db> ScopeId<'db> { ) } + pub(crate) fn node(self, db: &dyn Db) -> &NodeWithScopeKind { + self.scope(db).node() + } + + pub(crate) fn scope(self, db: &dyn Db) -> &Scope { + semantic_index(db, self.file(db)).scope(self.file_scope_id(db)) + } + #[cfg(test)] pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { match self.node(db) { @@ -169,10 +173,10 @@ impl FileScopeId { } } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug)] pub struct Scope { pub(super) parent: Option, - pub(super) kind: ScopeKind, + pub(super) node: NodeWithScopeKind, pub(super) descendents: Range, } @@ -181,8 +185,12 @@ impl Scope { self.parent } + pub fn node(&self) -> &NodeWithScopeKind { + &self.node + } + pub fn kind(&self) -> ScopeKind { - self.kind + self.node().scope_kind() } } @@ -376,21 +384,6 @@ impl NodeWithScopeRef<'_> { } } - pub(super) fn scope_kind(self) -> ScopeKind { - match self { - NodeWithScopeRef::Module => ScopeKind::Module, - NodeWithScopeRef::Class(_) => ScopeKind::Class, - NodeWithScopeRef::Function(_) => ScopeKind::Function, - NodeWithScopeRef::Lambda(_) => ScopeKind::Function, - NodeWithScopeRef::FunctionTypeParameters(_) - | NodeWithScopeRef::ClassTypeParameters(_) => ScopeKind::Annotation, - NodeWithScopeRef::ListComprehension(_) - | NodeWithScopeRef::SetComprehension(_) - | NodeWithScopeRef::DictComprehension(_) - | NodeWithScopeRef::GeneratorExpression(_) => ScopeKind::Comprehension, - } - } - pub(crate) fn node_key(self) -> NodeWithScopeKey { match self { NodeWithScopeRef::Module => NodeWithScopeKey::Module, @@ -438,6 +431,36 @@ pub enum NodeWithScopeKind { GeneratorExpression(AstNodeRef), } +impl NodeWithScopeKind { + pub(super) const fn scope_kind(&self) -> ScopeKind { + match self { + Self::Module => ScopeKind::Module, + Self::Class(_) => ScopeKind::Class, + Self::Function(_) => ScopeKind::Function, + Self::Lambda(_) => ScopeKind::Function, + Self::FunctionTypeParameters(_) | Self::ClassTypeParameters(_) => ScopeKind::Annotation, + Self::ListComprehension(_) + | Self::SetComprehension(_) + | Self::DictComprehension(_) + | Self::GeneratorExpression(_) => ScopeKind::Comprehension, + } + } + + pub fn expect_class(&self) -> &ast::StmtClassDef { + match self { + Self::Class(class) => class.node(), + _ => panic!("expected class"), + } + } + + pub fn expect_function(&self) -> &ast::StmtFunctionDef { + match self { + Self::Function(function) => function.node(), + _ => panic!("expected function"), + } + } +} + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub(crate) enum NodeWithScopeKey { Module, diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 1a4edae35a..bedaeaf841 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use crate::module_resolver::file_to_module; use crate::semantic_index::ast_ids::HasScopedAstId; -use crate::semantic_index::definition::{Definition, DefinitionKind}; +use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::{self as symbol, ScopeId, ScopedSymbolId}; use crate::semantic_index::{ global_scope, semantic_index, symbol_table, use_def_map, BindingWithConstraints, @@ -1109,11 +1109,11 @@ impl<'db> Type<'db> { Type::FunctionLiteral(function_type) => { if function_type.is_known(db, KnownFunction::RevealType) { CallOutcome::revealed( - function_type.return_type(db), + function_type.return_ty(db), *arg_types.first().unwrap_or(&Type::Unknown), ) } else { - CallOutcome::callable(function_type.return_type(db)) + CallOutcome::callable(function_type.return_ty(db)) } } @@ -1854,17 +1854,18 @@ pub struct FunctionType<'db> { /// Is this a function that we special-case somehow? If so, which one? known: Option, - definition: Definition<'db>, + body_scope: ScopeId<'db>, /// types of all decorators on this function decorators: Box<[Type<'db>]>, } +#[salsa::tracked] impl<'db> FunctionType<'db> { /// Return true if this is a standard library function with given module name and name. pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool { name == self.name(db) - && file_to_module(db, self.definition(db).file(db)).is_some_and(|module| { + && file_to_module(db, self.body_scope(db).file(db)).is_some_and(|module| { module.search_path().is_standard_library() && module.name() == module_name }) } @@ -1874,11 +1875,18 @@ impl<'db> FunctionType<'db> { } /// inferred return type for this function - pub fn return_type(&self, db: &'db dyn Db) -> Type<'db> { - let definition = self.definition(db); - let DefinitionKind::Function(function_stmt_node) = definition.kind(db) else { - panic!("Function type definition must have `DefinitionKind::Function`") - }; + /// + /// ## Why is this a salsa query? + /// + /// This is a salsa query to short-circuit the invalidation + /// when the function's AST node changes. + /// + /// Were this not a salsa query, then the calling query + /// would depend on the function's AST and rerun for every change in that file. + #[salsa::tracked] + pub fn return_ty(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.body_scope(db); + let function_stmt_node = scope.node(db).expect_function(); // TODO if a function `bar` is decorated by `foo`, // where `foo` is annotated as returning a type `X` that is a subtype of `Callable`, @@ -1897,6 +1905,8 @@ impl<'db> FunctionType<'db> { // TODO: generic `types.CoroutineType`! Type::Todo } else { + let definition = + semantic_index(db, scope.file(db)).definition(function_stmt_node); definition_expression_ty(db, definition, returns.as_ref()) } }) @@ -1924,8 +1934,6 @@ pub struct ClassType<'db> { #[return_ref] pub name: ast::name::Name, - definition: Definition<'db>, - body_scope: ScopeId<'db>, known: Option, @@ -1955,23 +1963,55 @@ impl<'db> ClassType<'db> { /// Note that any class (except for `object`) that has no explicit /// bases will implicitly inherit from `object` at runtime. Nonetheless, /// this method does *not* include `object` in the bases it iterates over. - fn explicit_bases(self, db: &'db dyn Db) -> impl Iterator> { - let definition = self.definition(db); - let class_stmt = self.node(db); - let has_type_params = class_stmt.type_params.is_some(); + /// + /// ## Why is this a salsa query? + /// + /// This is a salsa query to short-circuit the invalidation + /// when the class's AST node changes. + /// + /// Were this not a salsa query, then the calling query + /// would depend on the class's AST and rerun for every change in that file. + fn explicit_bases(self, db: &'db dyn Db) -> &[Type<'db>] { + self.explicit_bases_query(db) + } - class_stmt - .bases() - .iter() - .map(move |base_node| infer_class_base_type(db, base_node, definition, has_type_params)) + #[salsa::tracked(return_ref)] + fn explicit_bases_query(self, db: &'db dyn Db) -> Box<[Type<'db>]> { + let class_stmt = self.node(db); + + if class_stmt.type_params.is_some() { + // when we have a specialized scope, we'll look up the inference + // within that scope + let model = SemanticModel::new(db, self.file(db)); + + class_stmt + .bases() + .iter() + .map(|base| base.ty(&model)) + .collect() + } else { + // Otherwise, we can do the lookup based on the definition scope + let class_definition = semantic_index(db, self.file(db)).definition(class_stmt); + + class_stmt + .bases() + .iter() + .map(|base_node| definition_expression_ty(db, class_definition, base_node)) + .collect() + } + } + + fn file(self, db: &dyn Db) -> File { + self.body_scope(db).file(db) } /// Return the original [`ast::StmtClassDef`] node associated with this class + /// + /// ## Note + /// Only call this function from queries in the same file or your + /// query depends on the AST of another file (bad!). fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef { - match self.definition(db).kind(db) { - DefinitionKind::Class(class_stmt_node) => class_stmt_node, - _ => unreachable!("Class type definition should always have DefinitionKind::Class"), - } + self.body_scope(db).node(db).expect_class() } /// Attempt to resolve the [method resolution order] ("MRO") for this class. @@ -2049,26 +2089,6 @@ impl<'db> ClassType<'db> { } } -/// Infer the type of a node representing an explicit class base. -/// -/// For example, infer the type of `Foo` in the statement `class Bar(Foo, Baz): ...`. -fn infer_class_base_type<'db>( - db: &'db dyn Db, - base_node: &'db ast::Expr, - class_definition: Definition<'db>, - class_has_type_params: bool, -) -> Type<'db> { - if class_has_type_params { - // when we have a specialized scope, we'll look up the inference - // within that scope - let model = SemanticModel::new(db, class_definition.file(db)); - base_node.ty(&model) - } else { - // Otherwise, we can do the lookup based on the definition scope - definition_expression_ty(db, class_definition, base_node) - } -} - #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct InstanceType<'db> { class: ClassType<'db>, @@ -2197,7 +2217,10 @@ mod tests { use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; use crate::ProgramSettings; + use ruff_db::files::system_path_to_file; + use ruff_db::parsed::parsed_module; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; + use ruff_db::testing::assert_function_query_was_not_run; use ruff_python_ast as ast; use test_case::test_case; @@ -2639,4 +2662,59 @@ mod tests { let property_symbol_name = ast::name::Name::new_static("property"); assert!(!symbol_names.contains(&property_symbol_name)); } + + /// Inferring the result of a call-expression shouldn't need to re-run after + /// a trivial change to the function's file (e.g. by adding a docstring to the function). + #[test] + fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/foo.py", + r#" + def foo() -> int: + return 5 + "#, + )?; + db.write_dedented( + "src/bar.py", + r#" + from foo import foo + + a = foo() + "#, + )?; + + let bar = system_path_to_file(&db, "src/bar.py")?; + let a = global_symbol(&db, bar, "a"); + + assert_eq!(a.expect_type(), KnownClass::Int.to_instance(&db)); + + // Add a docstring to foo to trigger a re-run. + // The bar-call site of foo should not be re-run because of that + db.write_dedented( + "src/foo.py", + r#" + def foo() -> int: + "Computes a value" + return 5 + "#, + )?; + db.clear_salsa_events(); + + let a = global_symbol(&db, bar, "a"); + + assert_eq!(a.expect_type(), KnownClass::Int.to_instance(&db)); + let events = db.take_salsa_events(); + + let call = &*parsed_module(&db, bar).syntax().body[1] + .as_assign_stmt() + .unwrap() + .value; + let foo_call = semantic_index(&db, bar).expression(call); + + assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events); + + Ok(()) + } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index c0eac9563b..1e18e8c51d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -444,6 +444,7 @@ impl<'db> TypeInferenceBuilder<'db> { } } + // TODO: Only call this function when diagnostics are enabled. self.check_class_definitions(); } @@ -854,11 +855,17 @@ impl<'db> TypeInferenceBuilder<'db> { } _ => None, }; + + let body_scope = self + .index + .node_scope(NodeWithScopeRef::Function(function)) + .to_scope_id(self.db, self.file); + let function_ty = Type::FunctionLiteral(FunctionType::new( self.db, &*name.id, function_kind, - definition, + body_scope, decorator_tys, )); @@ -966,7 +973,6 @@ impl<'db> TypeInferenceBuilder<'db> { let class_ty = Type::ClassLiteral(ClassType::new( self.db, &*name.id, - definition, body_scope, maybe_known_class, )); @@ -4684,7 +4690,7 @@ mod tests { let function = global_symbol(&db, mod_file, "example") .expect_type() .expect_function_literal(); - let returns = function.return_type(&db); + let returns = function.return_ty(&db); assert_eq!(returns.display(&db).to_string(), "int"); Ok(()) diff --git a/crates/red_knot_python_semantic/src/types/mro.rs b/crates/red_knot_python_semantic/src/types/mro.rs index ad8d991d66..ea298478f5 100644 --- a/crates/red_knot_python_semantic/src/types/mro.rs +++ b/crates/red_knot_python_semantic/src/types/mro.rs @@ -5,10 +5,7 @@ use indexmap::IndexSet; use itertools::Either; use rustc_hash::FxHashSet; -use ruff_python_ast as ast; - -use super::{infer_class_base_type, ClassType, KnownClass, Type}; -use crate::semantic_index::definition::Definition; +use super::{ClassType, KnownClass, Type}; use crate::Db; /// The inferred method resolution order of a given class. @@ -43,8 +40,7 @@ impl<'db> Mro<'db> { } fn of_class_impl(db: &'db dyn Db, class: ClassType<'db>) -> Result> { - let class_stmt_node = class.node(db); - let class_bases = class_stmt_node.bases(); + let class_bases = class.explicit_bases(db); match class_bases { // `builtins.object` is the special case: @@ -73,13 +69,8 @@ impl<'db> Mro<'db> { // This *could* theoretically be handled by the final branch below, // but it's a common case (i.e., worth optimizing for), // and the `c3_merge` function requires lots of allocations. - [single_base_node] => { - let single_base = ClassBase::try_from_node( - db, - single_base_node, - class.definition(db), - class_stmt_node.type_params.is_some(), - ); + [single_base] => { + let single_base = ClassBase::try_from_ty(*single_base).ok_or(*single_base); single_base.map_or_else( |invalid_base_ty| { let bases_info = Box::from([(0, invalid_base_ty)]); @@ -109,13 +100,11 @@ impl<'db> Mro<'db> { return Err(MroErrorKind::CyclicClassDefinition); } - let definition = class.definition(db); - let has_type_params = class_stmt_node.type_params.is_some(); let mut valid_bases = vec![]; let mut invalid_bases = vec![]; - for (i, base_node) in multiple_bases.iter().enumerate() { - match ClassBase::try_from_node(db, base_node, definition, has_type_params) { + for (i, base) in multiple_bases.iter().enumerate() { + match ClassBase::try_from_ty(*base).ok_or(*base) { Ok(valid_base) => valid_bases.push(valid_base), Err(invalid_base) => invalid_bases.push((i, invalid_base)), } @@ -289,7 +278,7 @@ pub(super) enum MroErrorKind<'db> { /// /// This variant records the indices and types of class bases /// that we deem to be invalid. The indices are the indices of nodes - /// in the bases list of the class's [`ast::StmtClassDef`] node. + /// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node. /// Each index is the index of a node representing an invalid base. InvalidBases(Box<[(usize, Type<'db>)]>), @@ -304,7 +293,7 @@ pub(super) enum MroErrorKind<'db> { /// /// This variant records the indices and [`ClassType`]s /// of the duplicate bases. The indices are the indices of nodes - /// in the bases list of the class's [`ast::StmtClassDef`] node. + /// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node. /// Each index is the index of a node representing a duplicate base. DuplicateBases(Box<[(usize, ClassType<'db>)]>), @@ -365,20 +354,6 @@ impl<'db> ClassBase<'db> { .map_or(Self::Unknown, Self::Class) } - /// Attempt to resolve the node `base_node` into a `ClassBase`. - /// - /// If the inferred type of `base_node` is not an acceptable class-base type, - /// return an error indicating what the inferred type was. - fn try_from_node( - db: &'db dyn Db, - base_node: &'db ast::Expr, - class_definition: Definition<'db>, - class_has_type_params: bool, - ) -> Result> { - let base_ty = infer_class_base_type(db, base_node, class_definition, class_has_type_params); - Self::try_from_ty(base_ty).ok_or(base_ty) - } - /// Attempt to resolve `ty` into a `ClassBase`. /// /// Return `None` if `ty` is not an acceptable type for a class base. @@ -496,6 +471,8 @@ fn class_is_cyclically_defined(db: &dyn Db, class: ClassType) -> bool { } for explicit_base_class in class .explicit_bases(db) + .iter() + .copied() .filter_map(Type::into_class_literal_type) { // Each base must be considered in isolation. @@ -513,6 +490,8 @@ fn class_is_cyclically_defined(db: &dyn Db, class: ClassType) -> bool { class .explicit_bases(db) + .iter() + .copied() .filter_map(Type::into_class_literal_type) .any(|base_class| is_cyclically_defined_recursive(db, base_class, &mut IndexSet::default())) } diff --git a/crates/red_knot_python_semantic/src/unpack.rs b/crates/red_knot_python_semantic/src/unpack.rs index 13d8d164a3..25ad40231c 100644 --- a/crates/red_knot_python_semantic/src/unpack.rs +++ b/crates/red_knot_python_semantic/src/unpack.rs @@ -12,6 +12,18 @@ use crate::Db; /// involved. It allows us to: /// 1. Avoid doing structural match multiple times for each definition /// 2. Avoid highlighting the same error multiple times +/// +/// ## Module-local type +/// This type should not be used as part of any cross-module API because +/// it holds a reference to the AST node. Range-offset changes +/// then propagate through all usages, and deserialization requires +/// reparsing the entire module. +/// +/// E.g. don't use this type in: +/// +/// * a return type of a cross-module query +/// * a field of a type that is a return type of a cross-module query +/// * an argument of a cross-module query #[salsa::tracked] pub(crate) struct Unpack<'db> { #[id]