[red-knot] per-definition inference, use-def maps (#12269)

Implements definition-level type inference, with basic control flow
(only if statements and if expressions so far) in Salsa.

There are a couple key ideas here:

1) We can do type inference queries at any of three region
granularities: an entire scope, a single definition, or a single
expression. These are represented by the `InferenceRegion` enum, and the
entry points are the salsa queries `infer_scope_types`,
`infer_definition_types`, and `infer_expression_types`. Generally
per-scope will be used for scopes that we are directly checking and
per-definition will be used anytime we are looking up symbol types from
another module/scope. Per-expression should be uncommon: used only for
the RHS of an unpacking or multi-target assignment (to avoid
re-inferring the RHS once per symbol defined in the assignment) and for
test nodes in type narrowing (e.g. the `test` of an `If` node). All
three queries return a `TypeInference` with a map of types for all
definitions and expressions within their region. If you do e.g.
scope-level inference, when it hits a definition, or an
independently-inferable expression, it should use the relevant query
(which may already be cached) to get all types within the smaller
region. This avoids double-inferring smaller regions, even though larger
regions encompass smaller ones.

2) Instead of building a control-flow graph and lazily traversing it to
find definitions which reach a use of a name (which is O(n^2) in the
worst case), instead semantic indexing builds a use-def map, where every
use of a name knows which definitions can reach that use. We also no
longer track all definitions of a symbol in the symbol itself; instead
the use-def map also records which defs remain visible at the end of the
scope, and considers these the publicly-visible definitions of the
symbol (see below).

Major items left as TODOs in this PR, to be done in follow-up PRs:

1) Free/global references aren't supported yet (only lookup based on
definitions in current scope), which means the override-check example
doesn't currently work. This is the first thing I'll fix as follow-up to
this PR.

2) Control flow outside of if statements and expressions.

3) Type narrowing.

There are also some smaller relevant changes here:

1) Eliminate `Option` in the return type of member lookups; instead
always return `Type::Unbound` for a name we can't find. Also use
`Type::Unbound` for modules we can't resolve (not 100% sure about this
one yet.)

2) Eliminate the use of the terms "public" and "root" to refer to
module-global scope or symbols. Instead consistently use the term
"module-global". It's longer, but it's the clearest, and the most
consistent with typical Python terminology. In particular I don't like
"public" for this use because it has other implications around author
intent (is an underscore-prefixed module-global symbol "public"?). And
"root" is just not commonly used for this in Python.

3) Eliminate the `PublicSymbol` Salsa ingredient. Many non-module-global
symbols can also be seen from other scopes (e.g. by a free var in a
nested scope, or by class attribute access), and thus need to have a
"public type" (that is, the type not as seen from a particular use in
the control flow of the same scope, but the type as seen from some other
scope.) So all symbols need to have a "public type" (here I want to keep
the use of the term "public", unless someone has a better term to
suggest -- since it's "public type of a symbol" and not "public symbol"
the confusion with e.g. initial underscores is less of an issue.) At
least initially, I would like to try not having special handling for
module-global symbols vs other symbols.

4) Switch to using "definitions that reach end of scope" rather than
"all definitions" in determining the public type of a symbol. I'm
convinced that in general this is the right way to go. We may want to
refine this further in future for some free-variable cases, but it can
be changed purely by making changes to the building of the use-def map
(the `public_definitions` index in it), without affecting any other
code. One consequence of combining this with no control-flow support
(just last-definition-wins) is that some inference tests now give more
wrong-looking results; I left TODO comments on these tests to fix them
when control flow is added.

And some potential areas for consideration in the future:

1) Should `symbol_ty` be a Salsa query? This would require making all
symbols a Salsa ingredient, and tracking even more dependencies. But it
would save some repeated reconstruction of unions, for symbols with
multiple public definitions. For now I'm not making it a query, but open
to changing this in future with actual perf evidence that it's better.
This commit is contained in:
Carl Meyer 2024-07-16 11:02:30 -07:00 committed by GitHub
parent 30cef67b45
commit 595b1aa4a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1488 additions and 815 deletions

View file

@ -10,17 +10,20 @@ use ruff_index::{IndexSlice, IndexVec};
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIds;
use crate::semantic_index::builder::SemanticIndexBuilder;
use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef};
use crate::semantic_index::definition::{Definition, DefinitionNodeKey};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, PublicSymbolId, Scope, ScopeId,
ScopedSymbolId, SymbolTable,
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::use_def::UseDefMap;
use crate::Db;
pub mod ast_ids;
mod builder;
pub mod definition;
pub mod expression;
pub mod symbol;
pub mod use_def;
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
@ -42,57 +45,63 @@ pub(crate) fn semantic_index(db: &dyn Db, file: File) -> SemanticIndex<'_> {
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
/// is unchanged.
#[salsa::tracked]
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable<'db>> {
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable> {
let _span = tracing::trace_span!("symbol_table", ?scope).entered();
let index = semantic_index(db, scope.file(db));
index.symbol_table(scope.file_scope_id(db))
}
/// Returns the root scope of `file`.
/// Returns the use-def map for a specific `scope`.
///
/// Using [`use_def_map`] over [`semantic_index`] has the advantage that
/// Salsa can avoid invalidating dependent queries if this scope's use-def map
/// is unchanged.
#[salsa::tracked]
pub(crate) fn root_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
let _span = tracing::trace_span!("root_scope", ?file).entered();
pub(crate) fn use_def_map<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<UseDefMap<'db>> {
let _span = tracing::trace_span!("use_def_map", ?scope).entered();
let index = semantic_index(db, scope.file(db));
FileScopeId::root().to_scope_id(db, file)
index.use_def_map(scope.file_scope_id(db))
}
/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub(crate) fn public_symbol<'db>(
db: &'db dyn Db,
file: File,
name: &str,
) -> Option<PublicSymbolId<'db>> {
let root_scope = root_scope(db, file);
let symbol_table = symbol_table(db, root_scope);
let local = symbol_table.symbol_id_by_name(name)?;
Some(local.to_public_symbol(db, file))
/// Returns the module global scope of `file`.
#[salsa::tracked]
pub(crate) fn module_global_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
let _span = tracing::trace_span!("module_global_scope", ?file).entered();
FileScopeId::module_global().to_scope_id(db, file)
}
/// The symbol tables for an entire file.
/// The symbol tables and use-def maps for all scopes in a file.
#[derive(Debug)]
pub(crate) struct SemanticIndex<'db> {
/// List of all symbol tables in this file, indexed by scope.
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable<'db>>>,
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,
/// List of all scopes in this file.
scopes: IndexVec<FileScopeId, Scope>,
/// Maps expressions to their corresponding scope.
/// Map expressions to their corresponding scope.
/// We can't use [`ExpressionId`] here, because the challenge is how to get from
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
scopes_by_expression: FxHashMap<ExpressionNodeKey, FileScopeId>,
/// Maps from a node creating a definition node to its definition.
/// Map from a node creating a definition to its definition.
definitions_by_node: FxHashMap<DefinitionNodeKey, Definition<'db>>,
/// Map from a standalone expression to its [`Expression`] ingredient.
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
/// Map from nodes that create a scope to the scope they create.
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
/// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`].
scope_ids_by_scope: IndexVec<FileScopeId, ScopeId<'db>>,
/// Use-def map for each scope in this file.
use_def_maps: IndexVec<FileScopeId, Arc<UseDefMap<'db>>>,
/// Lookup table to map between node ids and ast nodes.
///
/// Note: We should not depend on this map when analysing other files or
@ -105,10 +114,18 @@ impl<'db> SemanticIndex<'db> {
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable<'db>> {
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}
/// Returns the use-def map for a specific scope.
///
/// Use the Salsa cached [`use_def_map`] query if you only need the
/// use-def map for a single scope.
pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc<UseDefMap> {
self.use_def_maps[scope_id].clone()
}
pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
&self.ast_ids[scope_id]
}
@ -157,16 +174,28 @@ impl<'db> SemanticIndex<'db> {
}
/// Returns an iterator over all ancestors of `scope`, starting with `scope` itself.
#[allow(unused)]
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}
/// Returns the [`Definition`] salsa ingredient for `definition_node`.
pub(crate) fn definition<'def>(
/// Returns the [`Definition`] salsa ingredient for `definition_key`.
pub(crate) fn definition(
&self,
definition_node: impl Into<DefinitionNodeRef<'def>>,
definition_key: impl Into<DefinitionNodeKey>,
) -> Definition<'db> {
self.definitions_by_node[&definition_node.into().key()]
self.definitions_by_node[&definition_key.into()]
}
/// Returns the [`Expression`] ingredient for an expression node.
/// Panics if we have no expression ingredient for that node. We can only call this method for
/// standalone-inferable expressions, which we call `add_standalone_expression` for in
/// [`SemanticIndexBuilder`].
pub(crate) fn expression(
&self,
expression_key: impl Into<ExpressionNodeKey>,
) -> Expression<'db> {
self.expressions_by_node[&expression_key.into()]
}
/// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which
@ -176,8 +205,6 @@ impl<'db> SemanticIndex<'db> {
}
}
/// ID that uniquely identifies an expression inside a [`Scope`].
pub struct AncestorsIter<'a> {
scopes: &'a IndexSlice<FileScopeId, Scope>,
next_id: Option<FileScopeId>,
@ -278,7 +305,7 @@ mod tests {
use crate::db::tests::TestDb;
use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::semantic_index::{module_global_scope, semantic_index, symbol_table, use_def_map};
use crate::Db;
struct TestCase {
@ -305,95 +332,110 @@ mod tests {
#[test]
fn empty() {
let TestCase { db, file } = test_case("");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
let root_names = names(&root_table);
let module_global_names = names(&module_global_table);
assert_eq!(root_names, Vec::<&str>::new());
assert_eq!(module_global_names, Vec::<&str>::new());
}
#[test]
fn simple() {
let TestCase { db, file } = test_case("x");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["x"]);
assert_eq!(names(&module_global_table), vec!["x"]);
}
#[test]
fn annotation_only() {
let TestCase { db, file } = test_case("x: int");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["int", "x"]);
assert_eq!(names(&module_global_table), vec!["int", "x"]);
// TODO record definition
}
#[test]
fn import() {
let TestCase { db, file } = test_case("import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = module_global_scope(&db, file);
let module_global_table = symbol_table(&db, scope);
assert_eq!(names(&root_table), vec!["foo"]);
let foo = root_table.symbol_by_name("foo").unwrap();
assert_eq!(names(&module_global_table), vec!["foo"]);
let foo = module_global_table.symbol_id_by_name("foo").unwrap();
assert_eq!(foo.definitions().len(), 1);
let use_def = use_def_map(&db, scope);
assert_eq!(use_def.public_definitions(foo).len(), 1);
}
#[test]
fn import_sub() {
let TestCase { db, file } = test_case("import foo.bar");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(names(&module_global_table), vec!["foo"]);
}
#[test]
fn import_as() {
let TestCase { db, file } = test_case("import foo.bar as baz");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["baz"]);
assert_eq!(names(&module_global_table), vec!["baz"]);
}
#[test]
fn import_from() {
let TestCase { db, file } = test_case("from bar import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = module_global_scope(&db, file);
let module_global_table = symbol_table(&db, scope);
assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(
root_table
assert_eq!(names(&module_global_table), vec!["foo"]);
assert!(
module_global_table
.symbol_by_name("foo")
.unwrap()
.definitions()
.is_some_and(|symbol| { symbol.is_defined() && !symbol.is_used() }),
"symbols that are defined get the defined flag"
);
let use_def = use_def_map(&db, scope);
assert_eq!(
use_def
.public_definitions(
module_global_table
.symbol_id_by_name("foo")
.expect("symbol exists")
)
.len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }),
"symbols that are defined get the defined flag"
);
}
#[test]
fn assign() {
let TestCase { db, file } = test_case("x = foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = module_global_scope(&db, file);
let module_global_table = symbol_table(&db, scope);
assert_eq!(names(&root_table), vec!["foo", "x"]);
assert_eq!(
root_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
assert_eq!(names(&module_global_table), vec!["foo", "x"]);
assert!(
root_table
module_global_table
.symbol_by_name("foo")
.is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }),
"a symbol used but not defined in a scope should have only the used flag"
);
let use_def = use_def_map(&db, scope);
assert_eq!(
use_def
.public_definitions(
module_global_table
.symbol_id_by_name("x")
.expect("symbol exists")
)
.len(),
1
);
}
#[test]
@ -405,13 +447,13 @@ class C:
y = 2
",
);
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["C", "y"]);
assert_eq!(names(&module_global_table), vec!["C", "y"]);
let index = semantic_index(&db, file);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 1);
let (class_scope_id, class_scope) = scopes[0];
@ -421,8 +463,12 @@ y = 2
let class_table = index.symbol_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]);
let use_def = index.use_def_map(class_scope_id);
assert_eq!(
class_table.symbol_by_name("x").unwrap().definitions().len(),
use_def
.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists"))
.len(),
1
);
}
@ -437,11 +483,13 @@ y = 2
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["func", "y"]);
assert_eq!(names(&module_global_table), vec!["func", "y"]);
let scopes = index.child_scopes(FileScopeId::root()).collect::<Vec<_>>();
let scopes = index
.child_scopes(FileScopeId::module_global())
.collect::<Vec<_>>();
assert_eq!(scopes.len(), 1);
let (function_scope_id, function_scope) = scopes[0];
@ -450,11 +498,15 @@ y = 2
let function_table = index.symbol_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]);
let use_def = index.use_def_map(function_scope_id);
assert_eq!(
function_table
.symbol_by_name("x")
.unwrap()
.definitions()
use_def
.public_definitions(
function_table
.symbol_id_by_name("x")
.expect("symbol exists")
)
.len(),
1
);
@ -471,10 +523,10 @@ def func():
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(names(&module_global_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 2);
let (func_scope1_id, func_scope_1) = scopes[0];
@ -490,13 +542,17 @@ def func():
let func2_table = index.symbol_table(func_scope2_id);
assert_eq!(names(&func1_table), vec!["x"]);
assert_eq!(names(&func2_table), vec!["y"]);
let use_def = index.use_def_map(FileScopeId::module_global());
assert_eq!(
root_table
.symbol_by_name("func")
.unwrap()
.definitions()
use_def
.public_definitions(
module_global_table
.symbol_id_by_name("func")
.expect("symbol exists")
)
.len(),
2
1
);
}
@ -510,11 +566,11 @@ def func[T]():
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["func"]);
assert_eq!(names(&module_global_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
@ -542,11 +598,11 @@ class C[T]:
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["C"]);
assert_eq!(names(&module_global_table), vec!["C"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
@ -578,7 +634,7 @@ class C[T]:
// let index = SemanticIndex::from_ast(ast);
// let table = &index.symbol_table;
// let x_sym = table
// .root_symbol_id_by_name("x")
// .module_global_symbol_id_by_name("x")
// .expect("x symbol should exist");
// let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else {
// panic!("should be an expr")
@ -616,7 +672,7 @@ class C[T]:
let x = &x_stmt.targets[0];
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
assert_eq!(index.expression_scope_id(x), FileScopeId::root());
assert_eq!(index.expression_scope_id(x), FileScopeId::module_global());
let def = ast.body[1].as_function_def_stmt().unwrap();
let y_stmt = def.body[0].as_assign_stmt().unwrap();
@ -653,16 +709,20 @@ def x():
let index = semantic_index(&db, file);
let descendents = index.descendent_scopes(FileScopeId::root());
let descendents = index.descendent_scopes(FileScopeId::module_global());
assert_eq!(
scope_names(descendents, &db, file),
vec!["Test", "foo", "bar", "baz", "x"]
);
let children = index.child_scopes(FileScopeId::root());
let children = index.child_scopes(FileScopeId::module_global());
assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]);
let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0;
let test_class = index
.child_scopes(FileScopeId::module_global())
.next()
.unwrap()
.0;
let test_child_scopes = index.child_scopes(test_class);
assert_eq!(
scope_names(test_child_scopes, &db, file),
@ -670,7 +730,7 @@ def x():
);
let bar_scope = index
.descendent_scopes(FileScopeId::root())
.descendent_scopes(FileScopeId::module_global())
.nth(2)
.unwrap()
.0;