Remove AST-node dependency from FunctionType and ClassType (#14087)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Micha Reiser 2024-11-05 09:02:38 +01:00 committed by GitHub
parent 9dddd73c29
commit 4323512a65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 230 additions and 113 deletions

View file

@ -121,9 +121,11 @@ impl<'db> SemanticIndexBuilder<'db> {
fn push_scope_with_parent(&mut self, node: NodeWithScopeRef, parent: Option<FileScopeId>) {
let children_start = self.scopes.next_index() + 1;
#[allow(unsafe_code)]
let scope = Scope {
parent,
kind: node.scope_kind(),
// SAFETY: `node` is guaranteed to be a child of `self.module`
node: unsafe { node.to_kind(self.module.clone()) },
descendents: children_start..children_start,
};
self.try_node_context_stack_manager.enter_nested_scope();
@ -133,15 +135,7 @@ impl<'db> SemanticIndexBuilder<'db> {
self.use_def_maps.push(UseDefMapBuilder::new());
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
#[allow(unsafe_code)]
// SAFETY: `node` is guaranteed to be a child of `self.module`
let scope_id = ScopeId::new(
self.db,
self.file,
file_scope_id,
unsafe { node.to_kind(self.module.clone()) },
countme::Count::default(),
);
let scope_id = ScopeId::new(self.db, self.file, file_scope_id, countme::Count::default());
self.scope_ids_by_scope.push(scope_id);
self.scopes_by_node.insert(node.node_key(), file_scope_id);

View file

@ -9,6 +9,19 @@ use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId};
use crate::unpack::Unpack;
use crate::Db;
/// A definition of a symbol.
///
/// ## Module-local type
/// This type should not be used as part of any cross-module API because
/// it holds a reference to the AST node. Range-offset changes
/// then propagate through all usages, and deserialization requires
/// reparsing the entire module.
///
/// E.g. don't use this type in:
///
/// * a return type of a cross-module query
/// * a field of a type that is a return type of a cross-module query
/// * an argument of a cross-module query
#[salsa::tracked]
pub struct Definition<'db> {
/// The file in which the definition occurs.

View file

@ -8,6 +8,18 @@ use salsa;
/// An independently type-inferable expression.
///
/// Includes constraint expressions (e.g. if tests) and the RHS of an unpacking assignment.
///
/// ## Module-local type
/// This type should not be used as part of any cross-module API because
/// it holds a reference to the AST node. Range-offset changes
/// then propagate through all usages, and deserialization requires
/// reparsing the entire module.
///
/// E.g. don't use this type in:
///
/// * a return type of a cross-module query
/// * a field of a type that is a return type of a cross-module query
/// * an argument of a cross-module query
#[salsa::tracked]
pub(crate) struct Expression<'db> {
/// The file in which the expression occurs.

View file

@ -103,14 +103,10 @@ pub struct ScopedSymbolId;
pub struct ScopeId<'db> {
#[id]
pub file: File,
#[id]
pub file_scope_id: FileScopeId,
/// The node that introduces this scope.
#[no_eq]
#[return_ref]
pub node: NodeWithScopeKind,
#[no_eq]
count: countme::Count<ScopeId<'static>>,
}
@ -131,6 +127,14 @@ impl<'db> ScopeId<'db> {
)
}
pub(crate) fn node(self, db: &dyn Db) -> &NodeWithScopeKind {
self.scope(db).node()
}
pub(crate) fn scope(self, db: &dyn Db) -> &Scope {
semantic_index(db, self.file(db)).scope(self.file_scope_id(db))
}
#[cfg(test)]
pub(crate) fn name(self, db: &'db dyn Db) -> &'db str {
match self.node(db) {
@ -169,10 +173,10 @@ impl FileScopeId {
}
}
#[derive(Debug, Eq, PartialEq)]
#[derive(Debug)]
pub struct Scope {
pub(super) parent: Option<FileScopeId>,
pub(super) kind: ScopeKind,
pub(super) node: NodeWithScopeKind,
pub(super) descendents: Range<FileScopeId>,
}
@ -181,8 +185,12 @@ impl Scope {
self.parent
}
pub fn node(&self) -> &NodeWithScopeKind {
&self.node
}
pub fn kind(&self) -> ScopeKind {
self.kind
self.node().scope_kind()
}
}
@ -376,21 +384,6 @@ impl NodeWithScopeRef<'_> {
}
}
pub(super) fn scope_kind(self) -> ScopeKind {
match self {
NodeWithScopeRef::Module => ScopeKind::Module,
NodeWithScopeRef::Class(_) => ScopeKind::Class,
NodeWithScopeRef::Function(_) => ScopeKind::Function,
NodeWithScopeRef::Lambda(_) => ScopeKind::Function,
NodeWithScopeRef::FunctionTypeParameters(_)
| NodeWithScopeRef::ClassTypeParameters(_) => ScopeKind::Annotation,
NodeWithScopeRef::ListComprehension(_)
| NodeWithScopeRef::SetComprehension(_)
| NodeWithScopeRef::DictComprehension(_)
| NodeWithScopeRef::GeneratorExpression(_) => ScopeKind::Comprehension,
}
}
pub(crate) fn node_key(self) -> NodeWithScopeKey {
match self {
NodeWithScopeRef::Module => NodeWithScopeKey::Module,
@ -438,6 +431,36 @@ pub enum NodeWithScopeKind {
GeneratorExpression(AstNodeRef<ast::ExprGenerator>),
}
impl NodeWithScopeKind {
pub(super) const fn scope_kind(&self) -> ScopeKind {
match self {
Self::Module => ScopeKind::Module,
Self::Class(_) => ScopeKind::Class,
Self::Function(_) => ScopeKind::Function,
Self::Lambda(_) => ScopeKind::Function,
Self::FunctionTypeParameters(_) | Self::ClassTypeParameters(_) => ScopeKind::Annotation,
Self::ListComprehension(_)
| Self::SetComprehension(_)
| Self::DictComprehension(_)
| Self::GeneratorExpression(_) => ScopeKind::Comprehension,
}
}
pub fn expect_class(&self) -> &ast::StmtClassDef {
match self {
Self::Class(class) => class.node(),
_ => panic!("expected class"),
}
}
pub fn expect_function(&self) -> &ast::StmtFunctionDef {
match self {
Self::Function(function) => function.node(),
_ => panic!("expected function"),
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(crate) enum NodeWithScopeKey {
Module,

View file

@ -6,7 +6,7 @@ use itertools::Itertools;
use crate::module_resolver::file_to_module;
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::definition::{Definition, DefinitionKind};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{self as symbol, ScopeId, ScopedSymbolId};
use crate::semantic_index::{
global_scope, semantic_index, symbol_table, use_def_map, BindingWithConstraints,
@ -1109,11 +1109,11 @@ impl<'db> Type<'db> {
Type::FunctionLiteral(function_type) => {
if function_type.is_known(db, KnownFunction::RevealType) {
CallOutcome::revealed(
function_type.return_type(db),
function_type.return_ty(db),
*arg_types.first().unwrap_or(&Type::Unknown),
)
} else {
CallOutcome::callable(function_type.return_type(db))
CallOutcome::callable(function_type.return_ty(db))
}
}
@ -1854,17 +1854,18 @@ pub struct FunctionType<'db> {
/// Is this a function that we special-case somehow? If so, which one?
known: Option<KnownFunction>,
definition: Definition<'db>,
body_scope: ScopeId<'db>,
/// types of all decorators on this function
decorators: Box<[Type<'db>]>,
}
#[salsa::tracked]
impl<'db> FunctionType<'db> {
/// Return true if this is a standard library function with given module name and name.
pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
name == self.name(db)
&& file_to_module(db, self.definition(db).file(db)).is_some_and(|module| {
&& file_to_module(db, self.body_scope(db).file(db)).is_some_and(|module| {
module.search_path().is_standard_library() && module.name() == module_name
})
}
@ -1874,11 +1875,18 @@ impl<'db> FunctionType<'db> {
}
/// inferred return type for this function
pub fn return_type(&self, db: &'db dyn Db) -> Type<'db> {
let definition = self.definition(db);
let DefinitionKind::Function(function_stmt_node) = definition.kind(db) else {
panic!("Function type definition must have `DefinitionKind::Function`")
};
///
/// ## Why is this a salsa query?
///
/// This is a salsa query to short-circuit the invalidation
/// when the function's AST node changes.
///
/// Were this not a salsa query, then the calling query
/// would depend on the function's AST and rerun for every change in that file.
#[salsa::tracked]
pub fn return_ty(self, db: &'db dyn Db) -> Type<'db> {
let scope = self.body_scope(db);
let function_stmt_node = scope.node(db).expect_function();
// TODO if a function `bar` is decorated by `foo`,
// where `foo` is annotated as returning a type `X` that is a subtype of `Callable`,
@ -1897,6 +1905,8 @@ impl<'db> FunctionType<'db> {
// TODO: generic `types.CoroutineType`!
Type::Todo
} else {
let definition =
semantic_index(db, scope.file(db)).definition(function_stmt_node);
definition_expression_ty(db, definition, returns.as_ref())
}
})
@ -1924,8 +1934,6 @@ pub struct ClassType<'db> {
#[return_ref]
pub name: ast::name::Name,
definition: Definition<'db>,
body_scope: ScopeId<'db>,
known: Option<KnownClass>,
@ -1955,23 +1963,55 @@ impl<'db> ClassType<'db> {
/// Note that any class (except for `object`) that has no explicit
/// bases will implicitly inherit from `object` at runtime. Nonetheless,
/// this method does *not* include `object` in the bases it iterates over.
fn explicit_bases(self, db: &'db dyn Db) -> impl Iterator<Item = Type<'db>> {
let definition = self.definition(db);
let class_stmt = self.node(db);
let has_type_params = class_stmt.type_params.is_some();
///
/// ## Why is this a salsa query?
///
/// This is a salsa query to short-circuit the invalidation
/// when the class's AST node changes.
///
/// Were this not a salsa query, then the calling query
/// would depend on the class's AST and rerun for every change in that file.
fn explicit_bases(self, db: &'db dyn Db) -> &[Type<'db>] {
self.explicit_bases_query(db)
}
class_stmt
.bases()
.iter()
.map(move |base_node| infer_class_base_type(db, base_node, definition, has_type_params))
#[salsa::tracked(return_ref)]
fn explicit_bases_query(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
let class_stmt = self.node(db);
if class_stmt.type_params.is_some() {
// when we have a specialized scope, we'll look up the inference
// within that scope
let model = SemanticModel::new(db, self.file(db));
class_stmt
.bases()
.iter()
.map(|base| base.ty(&model))
.collect()
} else {
// Otherwise, we can do the lookup based on the definition scope
let class_definition = semantic_index(db, self.file(db)).definition(class_stmt);
class_stmt
.bases()
.iter()
.map(|base_node| definition_expression_ty(db, class_definition, base_node))
.collect()
}
}
fn file(self, db: &dyn Db) -> File {
self.body_scope(db).file(db)
}
/// Return the original [`ast::StmtClassDef`] node associated with this class
///
/// ## Note
/// Only call this function from queries in the same file or your
/// query depends on the AST of another file (bad!).
fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef {
match self.definition(db).kind(db) {
DefinitionKind::Class(class_stmt_node) => class_stmt_node,
_ => unreachable!("Class type definition should always have DefinitionKind::Class"),
}
self.body_scope(db).node(db).expect_class()
}
/// Attempt to resolve the [method resolution order] ("MRO") for this class.
@ -2049,26 +2089,6 @@ impl<'db> ClassType<'db> {
}
}
/// Infer the type of a node representing an explicit class base.
///
/// For example, infer the type of `Foo` in the statement `class Bar(Foo, Baz): ...`.
fn infer_class_base_type<'db>(
db: &'db dyn Db,
base_node: &'db ast::Expr,
class_definition: Definition<'db>,
class_has_type_params: bool,
) -> Type<'db> {
if class_has_type_params {
// when we have a specialized scope, we'll look up the inference
// within that scope
let model = SemanticModel::new(db, class_definition.file(db));
base_node.ty(&model)
} else {
// Otherwise, we can do the lookup based on the definition scope
definition_expression_ty(db, class_definition, base_node)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct InstanceType<'db> {
class: ClassType<'db>,
@ -2197,7 +2217,10 @@ mod tests {
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
use crate::ProgramSettings;
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast as ast;
use test_case::test_case;
@ -2639,4 +2662,59 @@ mod tests {
let property_symbol_name = ast::name::Name::new_static("property");
assert!(!symbol_names.contains(&property_symbol_name));
}
/// Inferring the result of a call-expression shouldn't need to re-run after
/// a trivial change to the function's file (e.g. by adding a docstring to the function).
#[test]
fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
return 5
"#,
)?;
db.write_dedented(
"src/bar.py",
r#"
from foo import foo
a = foo()
"#,
)?;
let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a");
assert_eq!(a.expect_type(), KnownClass::Int.to_instance(&db));
// Add a docstring to foo to trigger a re-run.
// The bar-call site of foo should not be re-run because of that
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
"Computes a value"
return 5
"#,
)?;
db.clear_salsa_events();
let a = global_symbol(&db, bar, "a");
assert_eq!(a.expect_type(), KnownClass::Int.to_instance(&db));
let events = db.take_salsa_events();
let call = &*parsed_module(&db, bar).syntax().body[1]
.as_assign_stmt()
.unwrap()
.value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events);
Ok(())
}
}

View file

@ -444,6 +444,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
// TODO: Only call this function when diagnostics are enabled.
self.check_class_definitions();
}
@ -854,11 +855,17 @@ impl<'db> TypeInferenceBuilder<'db> {
}
_ => None,
};
let body_scope = self
.index
.node_scope(NodeWithScopeRef::Function(function))
.to_scope_id(self.db, self.file);
let function_ty = Type::FunctionLiteral(FunctionType::new(
self.db,
&*name.id,
function_kind,
definition,
body_scope,
decorator_tys,
));
@ -966,7 +973,6 @@ impl<'db> TypeInferenceBuilder<'db> {
let class_ty = Type::ClassLiteral(ClassType::new(
self.db,
&*name.id,
definition,
body_scope,
maybe_known_class,
));
@ -4684,7 +4690,7 @@ mod tests {
let function = global_symbol(&db, mod_file, "example")
.expect_type()
.expect_function_literal();
let returns = function.return_type(&db);
let returns = function.return_ty(&db);
assert_eq!(returns.display(&db).to_string(), "int");
Ok(())

View file

@ -5,10 +5,7 @@ use indexmap::IndexSet;
use itertools::Either;
use rustc_hash::FxHashSet;
use ruff_python_ast as ast;
use super::{infer_class_base_type, ClassType, KnownClass, Type};
use crate::semantic_index::definition::Definition;
use super::{ClassType, KnownClass, Type};
use crate::Db;
/// The inferred method resolution order of a given class.
@ -43,8 +40,7 @@ impl<'db> Mro<'db> {
}
fn of_class_impl(db: &'db dyn Db, class: ClassType<'db>) -> Result<Self, MroErrorKind<'db>> {
let class_stmt_node = class.node(db);
let class_bases = class_stmt_node.bases();
let class_bases = class.explicit_bases(db);
match class_bases {
// `builtins.object` is the special case:
@ -73,13 +69,8 @@ impl<'db> Mro<'db> {
// This *could* theoretically be handled by the final branch below,
// but it's a common case (i.e., worth optimizing for),
// and the `c3_merge` function requires lots of allocations.
[single_base_node] => {
let single_base = ClassBase::try_from_node(
db,
single_base_node,
class.definition(db),
class_stmt_node.type_params.is_some(),
);
[single_base] => {
let single_base = ClassBase::try_from_ty(*single_base).ok_or(*single_base);
single_base.map_or_else(
|invalid_base_ty| {
let bases_info = Box::from([(0, invalid_base_ty)]);
@ -109,13 +100,11 @@ impl<'db> Mro<'db> {
return Err(MroErrorKind::CyclicClassDefinition);
}
let definition = class.definition(db);
let has_type_params = class_stmt_node.type_params.is_some();
let mut valid_bases = vec![];
let mut invalid_bases = vec![];
for (i, base_node) in multiple_bases.iter().enumerate() {
match ClassBase::try_from_node(db, base_node, definition, has_type_params) {
for (i, base) in multiple_bases.iter().enumerate() {
match ClassBase::try_from_ty(*base).ok_or(*base) {
Ok(valid_base) => valid_bases.push(valid_base),
Err(invalid_base) => invalid_bases.push((i, invalid_base)),
}
@ -289,7 +278,7 @@ pub(super) enum MroErrorKind<'db> {
///
/// This variant records the indices and types of class bases
/// that we deem to be invalid. The indices are the indices of nodes
/// in the bases list of the class's [`ast::StmtClassDef`] node.
/// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node.
/// Each index is the index of a node representing an invalid base.
InvalidBases(Box<[(usize, Type<'db>)]>),
@ -304,7 +293,7 @@ pub(super) enum MroErrorKind<'db> {
///
/// This variant records the indices and [`ClassType`]s
/// of the duplicate bases. The indices are the indices of nodes
/// in the bases list of the class's [`ast::StmtClassDef`] node.
/// in the bases list of the class's [`StmtClassDef`](ruff_python_ast::StmtClassDef) node.
/// Each index is the index of a node representing a duplicate base.
DuplicateBases(Box<[(usize, ClassType<'db>)]>),
@ -365,20 +354,6 @@ impl<'db> ClassBase<'db> {
.map_or(Self::Unknown, Self::Class)
}
/// Attempt to resolve the node `base_node` into a `ClassBase`.
///
/// If the inferred type of `base_node` is not an acceptable class-base type,
/// return an error indicating what the inferred type was.
fn try_from_node(
db: &'db dyn Db,
base_node: &'db ast::Expr,
class_definition: Definition<'db>,
class_has_type_params: bool,
) -> Result<Self, Type<'db>> {
let base_ty = infer_class_base_type(db, base_node, class_definition, class_has_type_params);
Self::try_from_ty(base_ty).ok_or(base_ty)
}
/// Attempt to resolve `ty` into a `ClassBase`.
///
/// Return `None` if `ty` is not an acceptable type for a class base.
@ -496,6 +471,8 @@ fn class_is_cyclically_defined(db: &dyn Db, class: ClassType) -> bool {
}
for explicit_base_class in class
.explicit_bases(db)
.iter()
.copied()
.filter_map(Type::into_class_literal_type)
{
// Each base must be considered in isolation.
@ -513,6 +490,8 @@ fn class_is_cyclically_defined(db: &dyn Db, class: ClassType) -> bool {
class
.explicit_bases(db)
.iter()
.copied()
.filter_map(Type::into_class_literal_type)
.any(|base_class| is_cyclically_defined_recursive(db, base_class, &mut IndexSet::default()))
}

View file

@ -12,6 +12,18 @@ use crate::Db;
/// involved. It allows us to:
/// 1. Avoid doing structural match multiple times for each definition
/// 2. Avoid highlighting the same error multiple times
///
/// ## Module-local type
/// This type should not be used as part of any cross-module API because
/// it holds a reference to the AST node. Range-offset changes
/// then propagate through all usages, and deserialization requires
/// reparsing the entire module.
///
/// E.g. don't use this type in:
///
/// * a return type of a cross-module query
/// * a field of a type that is a return type of a cross-module query
/// * an argument of a cross-module query
#[salsa::tracked]
pub(crate) struct Unpack<'db> {
#[id]