[red-knot] per-definition inference, use-def maps (#12269)

Implements definition-level type inference, with basic control flow
(only if statements and if expressions so far) in Salsa.

There are a couple key ideas here:

1) We can do type inference queries at any of three region
granularities: an entire scope, a single definition, or a single
expression. These are represented by the `InferenceRegion` enum, and the
entry points are the salsa queries `infer_scope_types`,
`infer_definition_types`, and `infer_expression_types`. Generally
per-scope will be used for scopes that we are directly checking and
per-definition will be used anytime we are looking up symbol types from
another module/scope. Per-expression should be uncommon: used only for
the RHS of an unpacking or multi-target assignment (to avoid
re-inferring the RHS once per symbol defined in the assignment) and for
test nodes in type narrowing (e.g. the `test` of an `If` node). All
three queries return a `TypeInference` with a map of types for all
definitions and expressions within their region. If you do e.g.
scope-level inference, when it hits a definition, or an
independently-inferable expression, it should use the relevant query
(which may already be cached) to get all types within the smaller
region. This avoids double-inferring smaller regions, even though larger
regions encompass smaller ones.

2) Instead of building a control-flow graph and lazily traversing it to
find definitions which reach a use of a name (which is O(n^2) in the
worst case), instead semantic indexing builds a use-def map, where every
use of a name knows which definitions can reach that use. We also no
longer track all definitions of a symbol in the symbol itself; instead
the use-def map also records which defs remain visible at the end of the
scope, and considers these the publicly-visible definitions of the
symbol (see below).

Major items left as TODOs in this PR, to be done in follow-up PRs:

1) Free/global references aren't supported yet (only lookup based on
definitions in current scope), which means the override-check example
doesn't currently work. This is the first thing I'll fix as follow-up to
this PR.

2) Control flow outside of if statements and expressions.

3) Type narrowing.

There are also some smaller relevant changes here:

1) Eliminate `Option` in the return type of member lookups; instead
always return `Type::Unbound` for a name we can't find. Also use
`Type::Unbound` for modules we can't resolve (not 100% sure about this
one yet.)

2) Eliminate the use of the terms "public" and "root" to refer to
module-global scope or symbols. Instead consistently use the term
"module-global". It's longer, but it's the clearest, and the most
consistent with typical Python terminology. In particular I don't like
"public" for this use because it has other implications around author
intent (is an underscore-prefixed module-global symbol "public"?). And
"root" is just not commonly used for this in Python.

3) Eliminate the `PublicSymbol` Salsa ingredient. Many non-module-global
symbols can also be seen from other scopes (e.g. by a free var in a
nested scope, or by class attribute access), and thus need to have a
"public type" (that is, the type not as seen from a particular use in
the control flow of the same scope, but the type as seen from some other
scope.) So all symbols need to have a "public type" (here I want to keep
the use of the term "public", unless someone has a better term to
suggest -- since it's "public type of a symbol" and not "public symbol"
the confusion with e.g. initial underscores is less of an issue.) At
least initially, I would like to try not having special handling for
module-global symbols vs other symbols.

4) Switch to using "definitions that reach end of scope" rather than
"all definitions" in determining the public type of a symbol. I'm
convinced that in general this is the right way to go. We may want to
refine this further in future for some free-variable cases, but it can
be changed purely by making changes to the building of the use-def map
(the `public_definitions` index in it), without affecting any other
code. One consequence of combining this with no control-flow support
(just last-definition-wins) is that some inference tests now give more
wrong-looking results; I left TODO comments on these tests to fix them
when control flow is added.

And some potential areas for consideration in the future:

1) Should `symbol_ty` be a Salsa query? This would require making all
symbols a Salsa ingredient, and tracking even more dependencies. But it
would save some repeated reconstruction of unions, for symbols with
multiple public definitions. For now I'm not making it a query, but open
to changing this in future with actual perf evidence that it's better.
This commit is contained in:
Carl Meyer 2024-07-16 11:02:30 -07:00 committed by GitHub
parent 30cef67b45
commit 595b1aa4a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 1488 additions and 815 deletions

View file

@ -15,6 +15,7 @@ red_knot_module_resolver = { workspace = true }
ruff_db = { workspace = true }
ruff_index = { workspace = true }
ruff_python_ast = { workspace = true }
ruff_python_trivia = { workspace = true }
ruff_text_size = { workspace = true }
bitflags = { workspace = true }

View file

@ -27,12 +27,13 @@ pub struct AstNodeRef<T> {
#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which
/// the `AstNodeRef` belongs.
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to
/// which the `AstNodeRef` belongs.
///
/// ## Safety
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to
/// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld.
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
/// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
@ -43,8 +44,8 @@ impl<T> AstNodeRef<T> {
/// Returns a reference to the wrapped node.
pub fn node(&self) -> &T {
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive
// and not moved.
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still
// alive and not moved.
unsafe { self.node.as_ref() }
}
}

View file

@ -4,27 +4,30 @@ use red_knot_module_resolver::Db as ResolverDb;
use ruff_db::{Db as SourceDb, Upcast};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::ScopeId;
use crate::semantic_index::{module_global_scope, semantic_index, symbol_table, use_def_map};
use crate::types::{
infer_types, public_symbol_ty, ClassType, FunctionType, IntersectionType, UnionType,
infer_definition_types, infer_expression_types, infer_scope_types, ClassType, FunctionType,
IntersectionType, UnionType,
};
#[salsa::jar(db=Db)]
pub struct Jar(
ScopeId<'_>,
PublicSymbolId<'_>,
Definition<'_>,
Expression<'_>,
FunctionType<'_>,
ClassType<'_>,
UnionType<'_>,
IntersectionType<'_>,
symbol_table,
root_scope,
use_def_map,
module_global_scope,
semantic_index,
infer_types,
public_symbol_ty,
public_symbols_map,
infer_definition_types,
infer_expression_types,
infer_scope_types,
);
/// Database giving access to semantic information about a Python program.
@ -44,6 +47,7 @@ pub(crate) mod tests {
use ruff_db::system::{DbWithTestSystem, System, TestSystem};
use ruff_db::vendored::VendoredFileSystem;
use ruff_db::{Db as SourceDb, Jar as SourceJar, Upcast};
use ruff_python_trivia::textwrap;
use super::{Db, Jar};
@ -85,6 +89,12 @@ pub(crate) mod tests {
pub(crate) fn clear_salsa_events(&mut self) {
self.take_salsa_events();
}
/// Write auto-dedented text to a file.
pub(crate) fn write_dedented(&mut self, path: &str, content: &str) -> anyhow::Result<()> {
self.write_file(path, textwrap::dedent(content))?;
Ok(())
}
}
impl DbWithTestSystem for TestDb {

View file

@ -10,17 +10,20 @@ use ruff_index::{IndexSlice, IndexVec};
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIds;
use crate::semantic_index::builder::SemanticIndexBuilder;
use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef};
use crate::semantic_index::definition::{Definition, DefinitionNodeKey};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, PublicSymbolId, Scope, ScopeId,
ScopedSymbolId, SymbolTable,
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::use_def::UseDefMap;
use crate::Db;
pub mod ast_ids;
mod builder;
pub mod definition;
pub mod expression;
pub mod symbol;
pub mod use_def;
type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;
@ -42,57 +45,63 @@ pub(crate) fn semantic_index(db: &dyn Db, file: File) -> SemanticIndex<'_> {
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
/// is unchanged.
#[salsa::tracked]
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable<'db>> {
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable> {
let _span = tracing::trace_span!("symbol_table", ?scope).entered();
let index = semantic_index(db, scope.file(db));
index.symbol_table(scope.file_scope_id(db))
}
/// Returns the root scope of `file`.
/// Returns the use-def map for a specific `scope`.
///
/// Using [`use_def_map`] over [`semantic_index`] has the advantage that
/// Salsa can avoid invalidating dependent queries if this scope's use-def map
/// is unchanged.
#[salsa::tracked]
pub(crate) fn root_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
let _span = tracing::trace_span!("root_scope", ?file).entered();
pub(crate) fn use_def_map<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<UseDefMap<'db>> {
let _span = tracing::trace_span!("use_def_map", ?scope).entered();
let index = semantic_index(db, scope.file(db));
FileScopeId::root().to_scope_id(db, file)
index.use_def_map(scope.file_scope_id(db))
}
/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub(crate) fn public_symbol<'db>(
db: &'db dyn Db,
file: File,
name: &str,
) -> Option<PublicSymbolId<'db>> {
let root_scope = root_scope(db, file);
let symbol_table = symbol_table(db, root_scope);
let local = symbol_table.symbol_id_by_name(name)?;
Some(local.to_public_symbol(db, file))
/// Returns the module global scope of `file`.
#[salsa::tracked]
pub(crate) fn module_global_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
let _span = tracing::trace_span!("module_global_scope", ?file).entered();
FileScopeId::module_global().to_scope_id(db, file)
}
/// The symbol tables for an entire file.
/// The symbol tables and use-def maps for all scopes in a file.
#[derive(Debug)]
pub(crate) struct SemanticIndex<'db> {
/// List of all symbol tables in this file, indexed by scope.
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable<'db>>>,
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,
/// List of all scopes in this file.
scopes: IndexVec<FileScopeId, Scope>,
/// Maps expressions to their corresponding scope.
/// Map expressions to their corresponding scope.
/// We can't use [`ExpressionId`] here, because the challenge is how to get from
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
scopes_by_expression: FxHashMap<ExpressionNodeKey, FileScopeId>,
/// Maps from a node creating a definition node to its definition.
/// Map from a node creating a definition to its definition.
definitions_by_node: FxHashMap<DefinitionNodeKey, Definition<'db>>,
/// Map from a standalone expression to its [`Expression`] ingredient.
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
/// Map from nodes that create a scope to the scope they create.
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
/// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`].
scope_ids_by_scope: IndexVec<FileScopeId, ScopeId<'db>>,
/// Use-def map for each scope in this file.
use_def_maps: IndexVec<FileScopeId, Arc<UseDefMap<'db>>>,
/// Lookup table to map between node ids and ast nodes.
///
/// Note: We should not depend on this map when analysing other files or
@ -105,10 +114,18 @@ impl<'db> SemanticIndex<'db> {
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable<'db>> {
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}
/// Returns the use-def map for a specific scope.
///
/// Use the Salsa cached [`use_def_map`] query if you only need the
/// use-def map for a single scope.
pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc<UseDefMap> {
self.use_def_maps[scope_id].clone()
}
pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
&self.ast_ids[scope_id]
}
@ -157,16 +174,28 @@ impl<'db> SemanticIndex<'db> {
}
/// Returns an iterator over all ancestors of `scope`, starting with `scope` itself.
#[allow(unused)]
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}
/// Returns the [`Definition`] salsa ingredient for `definition_node`.
pub(crate) fn definition<'def>(
/// Returns the [`Definition`] salsa ingredient for `definition_key`.
pub(crate) fn definition(
&self,
definition_node: impl Into<DefinitionNodeRef<'def>>,
definition_key: impl Into<DefinitionNodeKey>,
) -> Definition<'db> {
self.definitions_by_node[&definition_node.into().key()]
self.definitions_by_node[&definition_key.into()]
}
/// Returns the [`Expression`] ingredient for an expression node.
/// Panics if we have no expression ingredient for that node. We can only call this method for
/// standalone-inferable expressions, which we call `add_standalone_expression` for in
/// [`SemanticIndexBuilder`].
pub(crate) fn expression(
&self,
expression_key: impl Into<ExpressionNodeKey>,
) -> Expression<'db> {
self.expressions_by_node[&expression_key.into()]
}
/// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which
@ -176,8 +205,6 @@ impl<'db> SemanticIndex<'db> {
}
}
/// ID that uniquely identifies an expression inside a [`Scope`].
pub struct AncestorsIter<'a> {
scopes: &'a IndexSlice<FileScopeId, Scope>,
next_id: Option<FileScopeId>,
@ -278,7 +305,7 @@ mod tests {
use crate::db::tests::TestDb;
use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::semantic_index::{module_global_scope, semantic_index, symbol_table, use_def_map};
use crate::Db;
struct TestCase {
@ -305,95 +332,110 @@ mod tests {
#[test]
fn empty() {
let TestCase { db, file } = test_case("");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
let root_names = names(&root_table);
let module_global_names = names(&module_global_table);
assert_eq!(root_names, Vec::<&str>::new());
assert_eq!(module_global_names, Vec::<&str>::new());
}
#[test]
fn simple() {
let TestCase { db, file } = test_case("x");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["x"]);
assert_eq!(names(&module_global_table), vec!["x"]);
}
#[test]
fn annotation_only() {
let TestCase { db, file } = test_case("x: int");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["int", "x"]);
assert_eq!(names(&module_global_table), vec!["int", "x"]);
// TODO record definition
}
#[test]
fn import() {
let TestCase { db, file } = test_case("import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = module_global_scope(&db, file);
let module_global_table = symbol_table(&db, scope);
assert_eq!(names(&root_table), vec!["foo"]);
let foo = root_table.symbol_by_name("foo").unwrap();
assert_eq!(names(&module_global_table), vec!["foo"]);
let foo = module_global_table.symbol_id_by_name("foo").unwrap();
assert_eq!(foo.definitions().len(), 1);
let use_def = use_def_map(&db, scope);
assert_eq!(use_def.public_definitions(foo).len(), 1);
}
#[test]
fn import_sub() {
let TestCase { db, file } = test_case("import foo.bar");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(names(&module_global_table), vec!["foo"]);
}
#[test]
fn import_as() {
let TestCase { db, file } = test_case("import foo.bar as baz");
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["baz"]);
assert_eq!(names(&module_global_table), vec!["baz"]);
}
#[test]
fn import_from() {
let TestCase { db, file } = test_case("from bar import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = module_global_scope(&db, file);
let module_global_table = symbol_table(&db, scope);
assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(
root_table
assert_eq!(names(&module_global_table), vec!["foo"]);
assert!(
module_global_table
.symbol_by_name("foo")
.unwrap()
.definitions()
.is_some_and(|symbol| { symbol.is_defined() && !symbol.is_used() }),
"symbols that are defined get the defined flag"
);
let use_def = use_def_map(&db, scope);
assert_eq!(
use_def
.public_definitions(
module_global_table
.symbol_id_by_name("foo")
.expect("symbol exists")
)
.len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }),
"symbols that are defined get the defined flag"
);
}
#[test]
fn assign() {
let TestCase { db, file } = test_case("x = foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = module_global_scope(&db, file);
let module_global_table = symbol_table(&db, scope);
assert_eq!(names(&root_table), vec!["foo", "x"]);
assert_eq!(
root_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
assert_eq!(names(&module_global_table), vec!["foo", "x"]);
assert!(
root_table
module_global_table
.symbol_by_name("foo")
.is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }),
"a symbol used but not defined in a scope should have only the used flag"
);
let use_def = use_def_map(&db, scope);
assert_eq!(
use_def
.public_definitions(
module_global_table
.symbol_id_by_name("x")
.expect("symbol exists")
)
.len(),
1
);
}
#[test]
@ -405,13 +447,13 @@ class C:
y = 2
",
);
let root_table = symbol_table(&db, root_scope(&db, file));
let module_global_table = symbol_table(&db, module_global_scope(&db, file));
assert_eq!(names(&root_table), vec!["C", "y"]);
assert_eq!(names(&module_global_table), vec!["C", "y"]);
let index = semantic_index(&db, file);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 1);
let (class_scope_id, class_scope) = scopes[0];
@ -421,8 +463,12 @@ y = 2
let class_table = index.symbol_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]);
let use_def = index.use_def_map(class_scope_id);
assert_eq!(
class_table.symbol_by_name("x").unwrap().definitions().len(),
use_def
.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists"))
.len(),
1
);
}
@ -437,11 +483,13 @@ y = 2
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["func", "y"]);
assert_eq!(names(&module_global_table), vec!["func", "y"]);
let scopes = index.child_scopes(FileScopeId::root()).collect::<Vec<_>>();
let scopes = index
.child_scopes(FileScopeId::module_global())
.collect::<Vec<_>>();
assert_eq!(scopes.len(), 1);
let (function_scope_id, function_scope) = scopes[0];
@ -450,11 +498,15 @@ y = 2
let function_table = index.symbol_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]);
let use_def = index.use_def_map(function_scope_id);
assert_eq!(
function_table
.symbol_by_name("x")
.unwrap()
.definitions()
use_def
.public_definitions(
function_table
.symbol_id_by_name("x")
.expect("symbol exists")
)
.len(),
1
);
@ -471,10 +523,10 @@ def func():
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(names(&module_global_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 2);
let (func_scope1_id, func_scope_1) = scopes[0];
@ -490,13 +542,17 @@ def func():
let func2_table = index.symbol_table(func_scope2_id);
assert_eq!(names(&func1_table), vec!["x"]);
assert_eq!(names(&func2_table), vec!["y"]);
let use_def = index.use_def_map(FileScopeId::module_global());
assert_eq!(
root_table
.symbol_by_name("func")
.unwrap()
.definitions()
use_def
.public_definitions(
module_global_table
.symbol_id_by_name("func")
.expect("symbol exists")
)
.len(),
2
1
);
}
@ -510,11 +566,11 @@ def func[T]():
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["func"]);
assert_eq!(names(&module_global_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
@ -542,11 +598,11 @@ class C[T]:
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let module_global_table = index.symbol_table(FileScopeId::module_global());
assert_eq!(names(&root_table), vec!["C"]);
assert_eq!(names(&module_global_table), vec!["C"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
@ -578,7 +634,7 @@ class C[T]:
// let index = SemanticIndex::from_ast(ast);
// let table = &index.symbol_table;
// let x_sym = table
// .root_symbol_id_by_name("x")
// .module_global_symbol_id_by_name("x")
// .expect("x symbol should exist");
// let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else {
// panic!("should be an expr")
@ -616,7 +672,7 @@ class C[T]:
let x = &x_stmt.targets[0];
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
assert_eq!(index.expression_scope_id(x), FileScopeId::root());
assert_eq!(index.expression_scope_id(x), FileScopeId::module_global());
let def = ast.body[1].as_function_def_stmt().unwrap();
let y_stmt = def.body[0].as_assign_stmt().unwrap();
@ -653,16 +709,20 @@ def x():
let index = semantic_index(&db, file);
let descendents = index.descendent_scopes(FileScopeId::root());
let descendents = index.descendent_scopes(FileScopeId::module_global());
assert_eq!(
scope_names(descendents, &db, file),
vec!["Test", "foo", "bar", "baz", "x"]
);
let children = index.child_scopes(FileScopeId::root());
let children = index.child_scopes(FileScopeId::module_global());
assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]);
let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0;
let test_class = index
.child_scopes(FileScopeId::module_global())
.next()
.unwrap()
.0;
let test_child_scopes = index.child_scopes(test_class);
assert_eq!(
scope_names(test_child_scopes, &db, file),
@ -670,7 +730,7 @@ def x():
);
let bar_scope = index
.descendent_scopes(FileScopeId::root())
.descendent_scopes(FileScopeId::module_global())
.nth(2)
.unwrap()
.0;

View file

@ -1,6 +1,6 @@
use rustc_hash::FxHashMap;
use ruff_index::{newtype_index, Idx};
use ruff_index::newtype_index;
use ruff_python_ast as ast;
use ruff_python_ast::ExpressionRef;
@ -28,18 +28,54 @@ use crate::Db;
pub(crate) struct AstIds {
/// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`].
expressions_map: FxHashMap<ExpressionNodeKey, ScopedExpressionId>,
/// Maps expressions which "use" a symbol (that is, [`ExprName`]) to a use id.
uses_map: FxHashMap<ExpressionNodeKey, ScopedUseId>,
}
impl AstIds {
fn expression_id(&self, key: impl Into<ExpressionNodeKey>) -> ScopedExpressionId {
self.expressions_map[&key.into()]
}
fn use_id(&self, key: impl Into<ExpressionNodeKey>) -> ScopedUseId {
self.uses_map[&key.into()]
}
}
fn ast_ids<'db>(db: &'db dyn Db, scope: ScopeId) -> &'db AstIds {
semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db))
}
pub trait HasScopedUseId {
/// The type of the ID uniquely identifying the use.
type Id: Copy;
/// Returns the ID that uniquely identifies the use in `scope`.
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id;
}
/// Uniquely identifies a use of a name in a [`crate::semantic_index::symbol::FileScopeId`].
#[newtype_index]
pub struct ScopedUseId;
impl HasScopedUseId for ast::ExprName {
type Id = ScopedUseId;
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
let expression_ref = ExpressionRef::from(self);
expression_ref.scoped_use_id(db, scope)
}
}
impl HasScopedUseId for ast::ExpressionRef<'_> {
type Id = ScopedUseId;
fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id {
let ast_ids = ast_ids(db, scope);
ast_ids.use_id(*self)
}
}
pub trait HasScopedAstId {
/// The type of the ID uniquely identifying the node.
type Id: Copy;
@ -110,38 +146,43 @@ impl HasScopedAstId for ast::ExpressionRef<'_> {
#[derive(Debug)]
pub(super) struct AstIdsBuilder {
next_id: ScopedExpressionId,
expressions_map: FxHashMap<ExpressionNodeKey, ScopedExpressionId>,
uses_map: FxHashMap<ExpressionNodeKey, ScopedUseId>,
}
impl AstIdsBuilder {
pub(super) fn new() -> Self {
Self {
next_id: ScopedExpressionId::new(0),
expressions_map: FxHashMap::default(),
uses_map: FxHashMap::default(),
}
}
/// Adds `expr` to the AST ids map and returns its id.
///
/// ## Safety
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
/// that `expr` is a child of `parsed`.
#[allow(unsafe_code)]
/// Adds `expr` to the expression ids map and returns its id.
pub(super) fn record_expression(&mut self, expr: &ast::Expr) -> ScopedExpressionId {
let expression_id = self.next_id;
self.next_id = expression_id + 1;
let expression_id = self.expressions_map.len().into();
self.expressions_map.insert(expr.into(), expression_id);
expression_id
}
/// Adds `expr` to the use ids map and returns its id.
pub(super) fn record_use(&mut self, expr: &ast::Expr) -> ScopedUseId {
let use_id = self.uses_map.len().into();
self.uses_map.insert(expr.into(), use_id);
use_id
}
pub(super) fn finish(mut self) -> AstIds {
self.expressions_map.shrink_to_fit();
self.uses_map.shrink_to_fit();
AstIds {
expressions_map: self.expressions_map,
uses_map: self.uses_map,
}
}
}

View file

@ -9,55 +9,62 @@ use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
use crate::ast_node_ref::AstNodeRef;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef};
use crate::semantic_index::definition::{
AssignmentDefinitionNodeRef, Definition, DefinitionNodeKey, DefinitionNodeRef,
ImportFromDefinitionNodeRef,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolFlags,
SymbolTableBuilder,
};
use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex;
use crate::Db;
pub(super) struct SemanticIndexBuilder<'db, 'ast> {
pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
db: &'db dyn Db,
file: File,
module: &'db ParsedModule,
scope_stack: Vec<FileScopeId>,
/// the target we're currently inferring
current_target: Option<CurrentTarget<'ast>>,
/// the assignment we're currently visiting
current_assignment: Option<CurrentAssignment<'db>>,
// Semantic Index fields
scopes: IndexVec<FileScopeId, Scope>,
scope_ids_by_scope: IndexVec<FileScopeId, ScopeId<'db>>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder<'db>>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
use_def_maps: IndexVec<FileScopeId, UseDefMapBuilder<'db>>,
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
scopes_by_expression: FxHashMap<ExpressionNodeKey, FileScopeId>,
definitions_by_node: FxHashMap<DefinitionNodeKey, Definition<'db>>,
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
}
impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast>
where
'db: 'ast,
{
impl<'db> SemanticIndexBuilder<'db> {
pub(super) fn new(db: &'db dyn Db, file: File, parsed: &'db ParsedModule) -> Self {
let mut builder = Self {
db,
file,
module: parsed,
scope_stack: Vec::new(),
current_target: None,
current_assignment: None,
scopes: IndexVec::new(),
symbol_tables: IndexVec::new(),
ast_ids: IndexVec::new(),
scope_ids_by_scope: IndexVec::new(),
use_def_maps: IndexVec::new(),
scopes_by_expression: FxHashMap::default(),
scopes_by_node: FxHashMap::default(),
definitions_by_node: FxHashMap::default(),
expressions_by_node: FxHashMap::default(),
};
builder.push_scope_with_parent(NodeWithScopeRef::Module, None);
@ -72,16 +79,12 @@ where
.expect("Always to have a root scope")
}
fn push_scope(&mut self, node: NodeWithScopeRef<'ast>) {
fn push_scope(&mut self, node: NodeWithScopeRef) {
let parent = self.current_scope();
self.push_scope_with_parent(node, Some(parent));
}
fn push_scope_with_parent(
&mut self,
node: NodeWithScopeRef<'ast>,
parent: Option<FileScopeId>,
) {
fn push_scope_with_parent(&mut self, node: NodeWithScopeRef, parent: Option<FileScopeId>) {
let children_start = self.scopes.next_index() + 1;
let scope = Scope {
@ -92,6 +95,7 @@ where
let file_scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::new());
self.use_def_maps.push(UseDefMapBuilder::new());
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
#[allow(unsafe_code)]
@ -116,32 +120,54 @@ where
id
}
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder<'db> {
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder {
let scope_id = self.current_scope();
&mut self.symbol_tables[scope_id]
}
fn current_use_def_map(&mut self) -> &mut UseDefMapBuilder<'db> {
let scope_id = self.current_scope();
&mut self.use_def_maps[scope_id]
}
fn current_ast_ids(&mut self) -> &mut AstIdsBuilder {
let scope_id = self.current_scope();
&mut self.ast_ids[scope_id]
}
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, flags)
fn flow_snapshot(&mut self) -> FlowSnapshot {
self.current_use_def_map().snapshot()
}
fn add_definition(
fn flow_set(&mut self, state: &FlowSnapshot) {
self.current_use_def_map().set(state);
}
fn flow_merge(&mut self, state: &FlowSnapshot) {
self.current_use_def_map().merge(state);
}
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
let symbol_table = self.current_symbol_table();
let (symbol_id, added) = symbol_table.add_or_update_symbol(name, flags);
if added {
let use_def_map = self.current_use_def_map();
use_def_map.add_symbol(symbol_id);
}
symbol_id
}
fn add_definition<'a>(
&mut self,
definition_node: impl Into<DefinitionNodeRef<'ast>>,
symbol_id: ScopedSymbolId,
symbol: ScopedSymbolId,
definition_node: impl Into<DefinitionNodeRef<'a>>,
) -> Definition<'db> {
let definition_node = definition_node.into();
let definition = Definition::new(
self.db,
self.file,
self.current_scope(),
symbol_id,
symbol,
#[allow(unsafe_code)]
unsafe {
definition_node.into_owned(self.module.clone())
@ -150,26 +176,31 @@ where
self.definitions_by_node
.insert(definition_node.key(), definition);
self.current_use_def_map()
.record_definition(symbol, definition);
definition
}
fn add_or_update_symbol_with_definition(
&mut self,
name: Name,
definition: impl Into<DefinitionNodeRef<'ast>>,
) -> (ScopedSymbolId, Definition<'db>) {
let symbol_table = self.current_symbol_table();
let id = symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED);
let definition = self.add_definition(definition, id);
self.current_symbol_table().add_definition(id, definition);
(id, definition)
/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
/// standalone (type narrowing tests, RHS of an assignment.)
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) {
let expression = Expression::new(
self.db,
self.file,
self.current_scope(),
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), expression_node)
},
);
self.expressions_by_node
.insert(expression_node.into(), expression);
}
fn with_type_params(
&mut self,
with_params: &WithTypeParams<'ast>,
with_params: &WithTypeParams,
nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId {
let type_params = with_params.type_parameters();
@ -213,7 +244,7 @@ where
self.pop_scope();
assert!(self.scope_stack.is_empty());
assert!(self.current_target.is_none());
assert!(self.current_assignment.is_none());
let mut symbol_tables: IndexVec<_, _> = self
.symbol_tables
@ -221,6 +252,12 @@ where
.map(|builder| Arc::new(builder.finish()))
.collect();
let mut use_def_maps: IndexVec<_, _> = self
.use_def_maps
.into_iter()
.map(|builder| Arc::new(builder.finish()))
.collect();
let mut ast_ids: IndexVec<_, _> = self
.ast_ids
.into_iter()
@ -228,8 +265,9 @@ where
.collect();
self.scopes.shrink_to_fit();
ast_ids.shrink_to_fit();
symbol_tables.shrink_to_fit();
use_def_maps.shrink_to_fit();
ast_ids.shrink_to_fit();
self.scopes_by_expression.shrink_to_fit();
self.definitions_by_node.shrink_to_fit();
@ -240,17 +278,19 @@ where
symbol_tables,
scopes: self.scopes,
definitions_by_node: self.definitions_by_node,
expressions_by_node: self.expressions_by_node,
scope_ids_by_scope: self.scope_ids_by_scope,
ast_ids,
scopes_by_expression: self.scopes_by_expression,
scopes_by_node: self.scopes_by_node,
use_def_maps,
}
}
}
impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db, 'ast>
impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db>
where
'db: 'ast,
'ast: 'db,
{
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
match stmt {
@ -259,10 +299,9 @@ where
self.visit_decorator(decorator);
}
self.add_or_update_symbol_with_definition(
function_def.name.id.clone(),
function_def,
);
let symbol = self
.add_or_update_symbol(function_def.name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(symbol, function_def);
self.with_type_params(
&WithTypeParams::FunctionDef { node: function_def },
@ -283,7 +322,9 @@ where
self.visit_decorator(decorator);
}
self.add_or_update_symbol_with_definition(class.name.id.clone(), class);
let symbol =
self.add_or_update_symbol(class.name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(symbol, class);
self.with_type_params(&WithTypeParams::ClassDef { node: class }, |builder| {
if let Some(arguments) = &class.arguments {
@ -296,41 +337,84 @@ where
builder.pop_scope()
});
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
for alias in names {
ast::Stmt::Import(node) => {
for alias in &node.names {
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.clone()
} else {
Name::new(alias.name.id.split('.').next().unwrap())
};
self.add_or_update_symbol_with_definition(symbol_name, alias);
let symbol = self.add_or_update_symbol(symbol_name, SymbolFlags::IS_DEFINED);
self.add_definition(symbol, alias);
}
}
ast::Stmt::ImportFrom(ast::StmtImportFrom {
module: _,
names,
level: _,
..
}) => {
for alias in names {
ast::Stmt::ImportFrom(node) => {
for (alias_index, alias) in node.names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
&asname.id
} else {
&alias.name.id
};
self.add_or_update_symbol_with_definition(symbol_name.clone(), alias);
let symbol =
self.add_or_update_symbol(symbol_name.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(symbol, ImportFromDefinitionNodeRef { node, alias_index });
}
}
ast::Stmt::Assign(node) => {
debug_assert!(self.current_target.is_none());
debug_assert!(self.current_assignment.is_none());
self.visit_expr(&node.value);
self.add_standalone_expression(&node.value);
self.current_assignment = Some(node.into());
for target in &node.targets {
self.current_target = Some(CurrentTarget::Expr(target));
self.visit_expr(target);
}
self.current_target = None;
self.current_assignment = None;
}
ast::Stmt::AnnAssign(node) => {
debug_assert!(self.current_assignment.is_none());
// TODO deferred annotation visiting
self.visit_expr(&node.annotation);
match &node.value {
Some(value) => {
self.visit_expr(value);
self.current_assignment = Some(node.into());
self.visit_expr(&node.target);
self.current_assignment = None;
}
None => {
// TODO annotation-only assignments
self.visit_expr(&node.target);
}
}
}
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
self.visit_body(&node.body);
let mut last_clause_is_else = false;
let mut post_clauses: Vec<FlowSnapshot> = vec![self.flow_snapshot()];
for clause in &node.elif_else_clauses {
// we can only take an elif/else clause if none of the previous ones were taken
self.flow_set(&pre_if);
self.visit_elif_else_clause(clause);
post_clauses.push(self.flow_snapshot());
if clause.test.is_none() {
last_clause_is_else = true;
}
}
let mut post_clause_iter = post_clauses.iter();
if last_clause_is_else {
// if the last clause was an else, the pre_if state can't directly reach the
// post-state; we have to enter one of the clauses.
self.flow_set(post_clause_iter.next().unwrap());
} else {
self.flow_set(&pre_if);
}
for post_clause_state in post_clause_iter {
self.flow_merge(post_clause_state);
}
}
_ => {
walk_stmt(self, stmt);
@ -344,57 +428,64 @@ where
self.current_ast_ids().record_expression(expr);
match expr {
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
ast::Expr::Name(name_node) => {
let ast::ExprName { id, ctx, .. } = name_node;
let flags = match ctx {
ast::ExprContext::Load => SymbolFlags::IS_USED,
ast::ExprContext::Store => SymbolFlags::IS_DEFINED,
ast::ExprContext::Del => SymbolFlags::IS_DEFINED,
ast::ExprContext::Invalid => SymbolFlags::empty(),
};
match self.current_target {
Some(target) if flags.contains(SymbolFlags::IS_DEFINED) => {
self.add_or_update_symbol_with_definition(id.clone(), target);
}
_ => {
self.add_or_update_symbol(id.clone(), flags);
let symbol = self.add_or_update_symbol(id.clone(), flags);
if flags.contains(SymbolFlags::IS_DEFINED) {
match self.current_assignment {
Some(CurrentAssignment::Assign(assignment)) => {
self.add_definition(
symbol,
AssignmentDefinitionNodeRef {
assignment,
target: name_node,
},
);
}
Some(CurrentAssignment::AnnAssign(ann_assign)) => {
self.add_definition(symbol, ann_assign);
}
Some(CurrentAssignment::Named(named)) => {
self.add_definition(symbol, named);
}
None => {}
}
}
if flags.contains(SymbolFlags::IS_USED) {
let use_id = self.current_ast_ids().record_use(expr);
self.current_use_def_map().record_use(symbol, use_id);
}
walk_expr(self, expr);
}
ast::Expr::Named(node) => {
debug_assert!(self.current_target.is_none());
self.current_target = Some(CurrentTarget::ExprNamed(node));
debug_assert!(self.current_assignment.is_none());
self.current_assignment = Some(node.into());
// TODO walrus in comprehensions is implicitly nonlocal
self.visit_expr(&node.target);
self.current_target = None;
self.current_assignment = None;
self.visit_expr(&node.value);
}
ast::Expr::If(ast::ExprIf {
body, test, orelse, ..
}) => {
// TODO detect statically known truthy or falsy test (via type inference, not naive
// AST inspection, so we can't simplify here, need to record test expression in CFG
// for later checking)
// AST inspection, so we can't simplify here, need to record test expression for
// later checking)
self.visit_expr(test);
// let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node());
// self.set_current_flow_node(if_branch);
// self.insert_constraint(test);
let pre_if = self.flow_snapshot();
self.visit_expr(body);
// let post_body = self.current_flow_node();
// self.set_current_flow_node(if_branch);
let post_body = self.flow_snapshot();
self.flow_set(&pre_if);
self.visit_expr(orelse);
// let post_else = self
// .flow_graph_builder
// .add_phi(self.current_flow_node(), post_body);
// self.set_current_flow_node(post_else);
self.flow_merge(&post_body);
}
_ => {
walk_expr(self, expr);
@ -418,16 +509,26 @@ impl<'node> WithTypeParams<'node> {
}
#[derive(Copy, Clone, Debug)]
enum CurrentTarget<'a> {
Expr(&'a ast::Expr),
ExprNamed(&'a ast::ExprNamed),
enum CurrentAssignment<'a> {
Assign(&'a ast::StmtAssign),
AnnAssign(&'a ast::StmtAnnAssign),
Named(&'a ast::ExprNamed),
}
impl<'a> From<CurrentTarget<'a>> for DefinitionNodeRef<'a> {
fn from(val: CurrentTarget<'a>) -> Self {
match val {
CurrentTarget::Expr(expression) => DefinitionNodeRef::Target(expression),
CurrentTarget::ExprNamed(named) => DefinitionNodeRef::NamedExpression(named),
}
impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAssign) -> Self {
Self::Assign(value)
}
}
impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAnnAssign) -> Self {
Self::AnnAssign(value)
}
}
impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
fn from(value: &'a ast::ExprNamed) -> Self {
Self::Named(value)
}
}

View file

@ -4,63 +4,111 @@ use ruff_python_ast as ast;
use crate::ast_node_ref::AstNodeRef;
use crate::node_key::NodeKey;
use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId};
use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId};
use crate::Db;
#[salsa::tracked]
pub struct Definition<'db> {
/// The file in which the definition is defined.
/// The file in which the definition occurs.
#[id]
pub(super) file: File,
pub(crate) file: File,
/// The scope in which the definition is defined.
/// The scope in which the definition occurs.
#[id]
pub(crate) scope: FileScopeId,
pub(crate) file_scope: FileScopeId,
/// The id of the corresponding symbol. Mainly used as ID.
/// The symbol defined.
#[id]
symbol_id: ScopedSymbolId,
pub(crate) symbol: ScopedSymbolId,
#[no_eq]
#[return_ref]
pub(crate) node: DefinitionKind,
}
impl<'db> Definition<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) enum DefinitionNodeRef<'a> {
Alias(&'a ast::Alias),
Import(&'a ast::Alias),
ImportFrom(ImportFromDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
Class(&'a ast::StmtClassDef),
NamedExpression(&'a ast::ExprNamed),
Target(&'a ast::Expr),
Assignment(AssignmentDefinitionNodeRef<'a>),
AnnotatedAssignment(&'a ast::StmtAnnAssign),
}
impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::Alias) -> Self {
Self::Alias(node)
}
}
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtFunctionDef) -> Self {
Self::Function(node)
}
}
impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtClassDef) -> Self {
Self::Class(node)
}
}
impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ExprNamed) -> Self {
Self::NamedExpression(node)
}
}
impl<'a> From<&'a ast::StmtAnnAssign> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtAnnAssign) -> Self {
Self::AnnotatedAssignment(node)
}
}
impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> {
fn from(node_ref: &'a ast::Alias) -> Self {
Self::Import(node_ref)
}
}
impl<'a> From<ImportFromDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: ImportFromDefinitionNodeRef<'a>) -> Self {
Self::ImportFrom(node_ref)
}
}
impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self {
Self::Assignment(node_ref)
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
pub(crate) alias_index: usize,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) assignment: &'a ast::StmtAssign,
pub(crate) target: &'a ast::ExprName,
}
impl DefinitionNodeRef<'_> {
#[allow(unsafe_code)]
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind {
match self {
DefinitionNodeRef::Alias(alias) => {
DefinitionKind::Alias(AstNodeRef::new(parsed, alias))
DefinitionNodeRef::Import(alias) => {
DefinitionKind::Import(AstNodeRef::new(parsed, alias))
}
DefinitionNodeRef::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => {
DefinitionKind::ImportFrom(ImportFromDefinitionKind {
node: AstNodeRef::new(parsed, node),
alias_index,
})
}
DefinitionNodeRef::Function(function) => {
DefinitionKind::Function(AstNodeRef::new(parsed, function))
@ -71,33 +119,111 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
DefinitionNodeRef::Target(target) => {
DefinitionKind::Target(AstNodeRef::new(parsed, target))
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => {
DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::AnnotatedAssignment(assign) => {
DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign))
}
}
}
}
impl DefinitionNodeRef<'_> {
pub(super) fn key(self) -> DefinitionNodeKey {
match self {
Self::Alias(node) => DefinitionNodeKey(NodeKey::from_node(node)),
Self::Function(node) => DefinitionNodeKey(NodeKey::from_node(node)),
Self::Class(node) => DefinitionNodeKey(NodeKey::from_node(node)),
Self::NamedExpression(node) => DefinitionNodeKey(NodeKey::from_node(node)),
Self::Target(node) => DefinitionNodeKey(NodeKey::from_node(node)),
Self::Import(node) => node.into(),
Self::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => {
(&node.names[alias_index]).into()
}
Self::Function(node) => node.into(),
Self::Class(node) => node.into(),
Self::NamedExpression(node) => node.into(),
Self::Assignment(AssignmentDefinitionNodeRef {
assignment: _,
target,
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
}
}
}
#[derive(Clone, Debug)]
pub enum DefinitionKind {
Alias(AstNodeRef<ast::Alias>),
Import(AstNodeRef<ast::Alias>),
ImportFrom(ImportFromDefinitionKind),
Function(AstNodeRef<ast::StmtFunctionDef>),
Class(AstNodeRef<ast::StmtClassDef>),
NamedExpression(AstNodeRef<ast::ExprNamed>),
Target(AstNodeRef<ast::Expr>),
Assignment(AssignmentDefinitionKind),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
}
#[derive(Clone, Debug)]
pub struct ImportFromDefinitionKind {
node: AstNodeRef<ast::StmtImportFrom>,
alias_index: usize,
}
impl ImportFromDefinitionKind {
pub(crate) fn import(&self) -> &ast::StmtImportFrom {
self.node.node()
}
pub(crate) fn alias(&self) -> &ast::Alias {
&self.node.node().names[self.alias_index]
}
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub struct AssignmentDefinitionKind {
assignment: AstNodeRef<ast::StmtAssign>,
target: AstNodeRef<ast::ExprName>,
}
impl AssignmentDefinitionKind {
pub(crate) fn assignment(&self) -> &ast::StmtAssign {
self.assignment.node()
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(super) struct DefinitionNodeKey(NodeKey);
pub(crate) struct DefinitionNodeKey(NodeKey);
impl From<&ast::Alias> for DefinitionNodeKey {
fn from(node: &ast::Alias) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::StmtFunctionDef> for DefinitionNodeKey {
fn from(node: &ast::StmtFunctionDef) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::StmtClassDef> for DefinitionNodeKey {
fn from(node: &ast::StmtClassDef) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::ExprName> for DefinitionNodeKey {
fn from(node: &ast::ExprName) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::ExprNamed> for DefinitionNodeKey {
fn from(node: &ast::ExprNamed) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::StmtAnnAssign> for DefinitionNodeKey {
fn from(node: &ast::StmtAnnAssign) -> Self {
Self(NodeKey::from_node(node))
}
}

View file

@ -0,0 +1,31 @@
use crate::ast_node_ref::AstNodeRef;
use crate::db::Db;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
use ruff_db::files::File;
use ruff_python_ast as ast;
use salsa;
/// An independently type-inferable expression.
///
/// Includes constraint expressions (e.g. if tests) and the RHS of an unpacking assignment.
#[salsa::tracked]
pub(crate) struct Expression<'db> {
/// The file in which the expression occurs.
#[id]
pub(crate) file: File,
/// The scope in which the expression occurs.
#[id]
pub(crate) file_scope: FileScopeId,
/// The expression node.
#[no_eq]
#[return_ref]
pub(crate) node: AstNodeRef<ast::Expr>,
}
impl<'db> Expression<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
}

View file

@ -12,33 +12,23 @@ use rustc_hash::FxHasher;
use crate::ast_node_ref::AstNodeRef;
use crate::node_key::NodeKey;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap};
use crate::semantic_index::{semantic_index, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)]
pub struct Symbol<'db> {
pub struct Symbol {
name: Name,
flags: SymbolFlags,
/// The nodes that define this symbol, in source order.
///
/// TODO: Use smallvec here, but it creates the same lifetime issues as in [QualifiedName](https://github.com/astral-sh/ruff/blob/5109b50bb3847738eeb209352cf26bda392adf62/crates/ruff_python_ast/src/name.rs#L562-L569)
definitions: Vec<Definition<'db>>,
}
impl<'db> Symbol<'db> {
impl Symbol {
fn new(name: Name) -> Self {
Self {
name,
flags: SymbolFlags::empty(),
definitions: Vec::new(),
}
}
fn push_definition(&mut self, definition: Definition<'db>) {
self.definitions.push(definition);
}
fn insert_flags(&mut self, flags: SymbolFlags) {
self.flags.insert(flags);
}
@ -57,10 +47,6 @@ impl<'db> Symbol<'db> {
pub fn is_defined(&self) -> bool {
self.flags.contains(SymbolFlags::IS_DEFINED)
}
pub fn definitions(&self) -> &[Definition] {
&self.definitions
}
}
bitflags! {
@ -75,15 +61,6 @@ bitflags! {
}
}
/// ID that uniquely identifies a public symbol defined in a module's root scope.
#[salsa::tracked]
pub struct PublicSymbolId<'db> {
#[id]
pub(crate) file: File,
#[id]
pub(crate) scoped_symbol_id: ScopedSymbolId,
}
/// ID that uniquely identifies a symbol in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct FileSymbolId {
@ -111,47 +88,6 @@ impl From<FileSymbolId> for ScopedSymbolId {
#[newtype_index]
pub struct ScopedSymbolId;
impl ScopedSymbolId {
/// Converts the symbol to a public symbol.
///
/// # Panics
/// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope.
pub(crate) fn to_public_symbol(self, db: &dyn Db, file: File) -> PublicSymbolId {
let symbols = public_symbols_map(db, file);
symbols.public(self)
}
}
#[salsa::tracked(return_ref)]
pub(crate) fn public_symbols_map(db: &dyn Db, file: File) -> PublicSymbolsMap<'_> {
let _span = tracing::trace_span!("public_symbols_map", ?file).entered();
let module_scope = root_scope(db, file);
let symbols = symbol_table(db, module_scope);
let public_symbols: IndexVec<_, _> = symbols
.symbol_ids()
.map(|id| PublicSymbolId::new(db, file, id))
.collect();
PublicSymbolsMap {
symbols: public_symbols,
}
}
/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients).
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct PublicSymbolsMap<'db> {
symbols: IndexVec<ScopedSymbolId, PublicSymbolId<'db>>,
}
impl<'db> PublicSymbolsMap<'db> {
/// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`.
fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId<'db> {
self.symbols[symbol_id]
}
}
/// A cross-module identifier of a scope that can be used as a salsa query parameter.
#[salsa::tracked]
pub struct ScopeId<'db> {
@ -185,8 +121,8 @@ impl<'db> ScopeId<'db> {
pub struct FileScopeId;
impl FileScopeId {
/// Returns the scope id of the Root scope.
pub fn root() -> Self {
/// Returns the scope id of the module-global scope.
pub fn module_global() -> Self {
FileScopeId::from_u32(0)
}
@ -223,15 +159,15 @@ pub enum ScopeKind {
/// Symbol table for a specific [`Scope`].
#[derive(Debug)]
pub struct SymbolTable<'db> {
pub struct SymbolTable {
/// The symbols in this scope.
symbols: IndexVec<ScopedSymbolId, Symbol<'db>>,
symbols: IndexVec<ScopedSymbolId, Symbol>,
/// The symbols indexed by name.
symbols_by_name: SymbolMap,
}
impl<'db> SymbolTable<'db> {
impl SymbolTable {
fn new() -> Self {
Self {
symbols: IndexVec::new(),
@ -243,21 +179,22 @@ impl<'db> SymbolTable<'db> {
self.symbols.shrink_to_fit();
}
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol<'db> {
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopedSymbolId>) -> &Symbol {
&self.symbols[symbol_id.into()]
}
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> + 'db {
#[allow(unused)]
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopedSymbolId> {
self.symbols.indices()
}
pub fn symbols(&self) -> impl Iterator<Item = &Symbol<'db>> {
pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
self.symbols.iter()
}
/// Returns the symbol named `name`.
#[allow(unused)]
pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol<'db>> {
pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> {
let id = self.symbol_id_by_name(name)?;
Some(self.symbol(id))
}
@ -281,21 +218,21 @@ impl<'db> SymbolTable<'db> {
}
}
impl PartialEq for SymbolTable<'_> {
impl PartialEq for SymbolTable {
fn eq(&self, other: &Self) -> bool {
// We don't need to compare the symbols_by_name because the name is already captured in `Symbol`.
self.symbols == other.symbols
}
}
impl Eq for SymbolTable<'_> {}
impl Eq for SymbolTable {}
#[derive(Debug)]
pub(super) struct SymbolTableBuilder<'db> {
table: SymbolTable<'db>,
pub(super) struct SymbolTableBuilder {
table: SymbolTable,
}
impl<'db> SymbolTableBuilder<'db> {
impl SymbolTableBuilder {
pub(super) fn new() -> Self {
Self {
table: SymbolTable::new(),
@ -306,7 +243,7 @@ impl<'db> SymbolTableBuilder<'db> {
&mut self,
name: Name,
flags: SymbolFlags,
) -> ScopedSymbolId {
) -> (ScopedSymbolId, bool) {
let hash = SymbolTable::hash_name(&name);
let entry = self
.table
@ -319,7 +256,7 @@ impl<'db> SymbolTableBuilder<'db> {
let symbol = &mut self.table.symbols[*entry.key()];
symbol.insert_flags(flags);
*entry.key()
(*entry.key(), false)
}
RawEntryMut::Vacant(entry) => {
let mut symbol = Symbol::new(name);
@ -329,16 +266,12 @@ impl<'db> SymbolTableBuilder<'db> {
entry.insert_with_hasher(hash, id, (), |id| {
SymbolTable::hash_name(self.table.symbols[*id].name().as_str())
});
id
(id, true)
}
}
}
pub(super) fn add_definition(&mut self, symbol: ScopedSymbolId, definition: Definition<'db>) {
self.table.symbols[symbol].push_definition(definition);
}
pub(super) fn finish(mut self) -> SymbolTable<'db> {
pub(super) fn finish(mut self) -> SymbolTable {
self.table.shrink_to_fit();
self.table
}

View file

@ -0,0 +1,164 @@
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::ScopedSymbolId;
use ruff_index::IndexVec;
use std::ops::Range;
/// All definitions that can reach a given use of a name.
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct UseDefMap<'db> {
// TODO store constraints with definitions for type narrowing
all_definitions: Vec<Definition<'db>>,
/// Definitions that can reach a [`ScopedUseId`].
definitions_by_use: IndexVec<ScopedUseId, Definitions>,
/// Definitions of a symbol visible to other scopes.
public_definitions: IndexVec<ScopedSymbolId, Definitions>,
}
impl<'db> UseDefMap<'db> {
pub(crate) fn use_definitions(&self, use_id: ScopedUseId) -> &[Definition<'db>] {
&self.all_definitions[self.definitions_by_use[use_id].definitions.clone()]
}
pub(crate) fn use_may_be_unbound(&self, use_id: ScopedUseId) -> bool {
self.definitions_by_use[use_id].may_be_unbound
}
pub(crate) fn public_definitions(&self, symbol: ScopedSymbolId) -> &[Definition<'db>] {
&self.all_definitions[self.public_definitions[symbol].definitions.clone()]
}
pub(crate) fn public_may_be_unbound(&self, symbol: ScopedSymbolId) -> bool {
self.public_definitions[symbol].may_be_unbound
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct Definitions {
definitions: Range<usize>,
may_be_unbound: bool,
}
impl Default for Definitions {
fn default() -> Self {
Self {
definitions: Range::default(),
may_be_unbound: true,
}
}
}
#[derive(Debug)]
pub(super) struct FlowSnapshot {
definitions_by_symbol: IndexVec<ScopedSymbolId, Definitions>,
}
pub(super) struct UseDefMapBuilder<'db> {
all_definitions: Vec<Definition<'db>>,
definitions_by_use: IndexVec<ScopedUseId, Definitions>,
/// builder state: currently visible definitions for each symbol
definitions_by_symbol: IndexVec<ScopedSymbolId, Definitions>,
}
impl<'db> UseDefMapBuilder<'db> {
pub(super) fn new() -> Self {
Self {
all_definitions: Vec::new(),
definitions_by_use: IndexVec::new(),
definitions_by_symbol: IndexVec::new(),
}
}
pub(super) fn add_symbol(&mut self, symbol: ScopedSymbolId) {
let new_symbol = self.definitions_by_symbol.push(Definitions::default());
debug_assert_eq!(symbol, new_symbol);
}
pub(super) fn record_definition(
&mut self,
symbol: ScopedSymbolId,
definition: Definition<'db>,
) {
let def_idx = self.all_definitions.len();
self.all_definitions.push(definition);
self.definitions_by_symbol[symbol] = Definitions {
#[allow(clippy::range_plus_one)]
definitions: def_idx..(def_idx + 1),
may_be_unbound: false,
};
}
pub(super) fn record_use(&mut self, symbol: ScopedSymbolId, use_id: ScopedUseId) {
let new_use = self
.definitions_by_use
.push(self.definitions_by_symbol[symbol].clone());
debug_assert_eq!(use_id, new_use);
}
pub(super) fn snapshot(&self) -> FlowSnapshot {
FlowSnapshot {
definitions_by_symbol: self.definitions_by_symbol.clone(),
}
}
pub(super) fn set(&mut self, state: &FlowSnapshot) {
let num_symbols = self.definitions_by_symbol.len();
self.definitions_by_symbol = state.definitions_by_symbol.clone();
self.definitions_by_symbol
.resize(num_symbols, Definitions::default());
}
pub(super) fn merge(&mut self, state: &FlowSnapshot) {
for (symbol_id, to_merge) in state.definitions_by_symbol.iter_enumerated() {
let current = &mut self.definitions_by_symbol[symbol_id];
// if the symbol can be unbound in either predecessor, it can be unbound
current.may_be_unbound |= to_merge.may_be_unbound;
// merge the definition ranges
if current.definitions == to_merge.definitions {
// ranges already identical, nothing to do!
} else if current.definitions.end == to_merge.definitions.start {
// ranges adjacent (current first), just merge them
current.definitions = (current.definitions.start)..(to_merge.definitions.end);
} else if current.definitions.start == to_merge.definitions.end {
// ranges adjacent (to_merge first), just merge them
current.definitions = (to_merge.definitions.start)..(current.definitions.end);
} else if current.definitions.end == self.all_definitions.len() {
// ranges not adjacent but current is at end, copy only to_merge
self.all_definitions
.extend_from_within(to_merge.definitions.clone());
current.definitions.end = self.all_definitions.len();
} else if to_merge.definitions.end == self.all_definitions.len() {
// ranges not adjacent but to_merge is at end, copy only current
self.all_definitions
.extend_from_within(current.definitions.clone());
current.definitions.start = to_merge.definitions.start;
current.definitions.end = self.all_definitions.len();
} else {
// ranges not adjacent and neither at end, must copy both
let start = self.all_definitions.len();
self.all_definitions
.extend_from_within(current.definitions.clone());
self.all_definitions
.extend_from_within(to_merge.definitions.clone());
current.definitions.start = start;
current.definitions.end = self.all_definitions.len();
}
}
}
pub(super) fn finish(mut self) -> UseDefMap<'db> {
self.all_definitions.shrink_to_fit();
self.definitions_by_symbol.shrink_to_fit();
self.definitions_by_use.shrink_to_fit();
UseDefMap {
all_definitions: self.all_definitions,
definitions_by_use: self.definitions_by_use,
public_definitions: self.definitions_by_symbol,
}
}
}

View file

@ -4,9 +4,8 @@ use ruff_python_ast as ast;
use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef};
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::symbol::PublicSymbolId;
use crate::semantic_index::{public_symbol, semantic_index};
use crate::types::{infer_types, public_symbol_ty, Type};
use crate::semantic_index::semantic_index;
use crate::types::{definition_ty, infer_scope_types, module_global_symbol_ty_by_name, Type};
use crate::Db;
pub struct SemanticModel<'db> {
@ -29,12 +28,8 @@ impl<'db> SemanticModel<'db> {
resolve_module(self.db.upcast(), module_name)
}
pub fn public_symbol(&self, module: &Module, symbol_name: &str) -> Option<PublicSymbolId<'db>> {
public_symbol(self.db, module.file(), symbol_name)
}
pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type {
public_symbol_ty(self.db, symbol)
pub fn module_global_symbol_ty(&self, module: &Module, symbol_name: &str) -> Type<'db> {
module_global_symbol_ty_by_name(self.db, module.file(), symbol_name)
}
}
@ -53,7 +48,7 @@ impl HasTy for ast::ExpressionRef<'_> {
let scope = file_scope.to_scope_id(model.db, model.file);
let expression_id = self.scoped_ast_id(model.db, scope);
infer_types(model.db, scope).expression_ty(expression_id)
infer_scope_types(model.db, scope).expression_ty(expression_id)
}
}
@ -145,11 +140,7 @@ impl HasTy for ast::StmtFunctionDef {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
let scope = definition.scope(model.db).to_scope_id(model.db, model.file);
let types = infer_types(model.db, scope);
types.definition_ty(definition)
definition_ty(model.db, definition)
}
}
@ -157,11 +148,7 @@ impl HasTy for StmtClassDef {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
let scope = definition.scope(model.db).to_scope_id(model.db, model.file);
let types = infer_types(model.db, scope);
types.definition_ty(definition)
definition_ty(model.db, definition)
}
}
@ -169,11 +156,7 @@ impl HasTy for ast::Alias {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file);
let definition = index.definition(self);
let scope = definition.scope(model.db).to_scope_id(model.db, model.file);
let types = infer_types(model.db, scope);
types.definition_ty(definition)
definition_ty(model.db, definition)
}
}

View file

@ -1,91 +1,92 @@
use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_python_ast::name::Name;
use crate::semantic_index::symbol::{NodeWithScopeKind, PublicSymbolId, ScopeId};
use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table};
use crate::types::infer::{TypeInference, TypeInferenceBuilder};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId};
use crate::semantic_index::{module_global_scope, symbol_table, use_def_map};
use crate::{Db, FxOrderSet};
mod display;
mod infer;
/// Infers the type of a public symbol.
///
/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation.
/// Without this being a query, changing any public type of a module would invalidate the type inference
/// for the module scope of its dependents and the transitive dependents because.
///
/// For example if we have
/// ```python
/// # a.py
/// import x from b
///
/// # b.py
///
/// x = 20
/// ```
///
/// And x is now changed from `x = 20` to `x = 30`. The following happens:
///
/// * The module level types of `b.py` change because `x` now is a `Literal[30]`.
/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type
/// * And so on for all transitive dependencies.
///
/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change.
#[salsa::tracked]
pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db>) -> Type<'db> {
let _span = tracing::trace_span!("public_symbol_ty", ?symbol).entered();
pub(crate) use self::infer::{infer_definition_types, infer_expression_types, infer_scope_types};
let file = symbol.file(db);
let scope = root_scope(db, file);
/// Infer the public type of a symbol (its type as seen from outside its scope).
pub(crate) fn symbol_ty<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
symbol: ScopedSymbolId,
) -> Type<'db> {
let _span = tracing::trace_span!("symbol_ty", ?symbol).entered();
// TODO switch to inferring just the definition(s), not the whole scope
let inference = infer_types(db, scope);
inference.symbol_ty(symbol.scoped_symbol_id(db))
let use_def = use_def_map(db, scope);
definitions_ty(
db,
use_def.public_definitions(symbol),
use_def.public_may_be_unbound(symbol),
)
}
/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`].
pub(crate) fn public_symbol_ty_by_name<'db>(
/// Shorthand for `symbol_ty` that takes a symbol name instead of an ID.
pub(crate) fn symbol_ty_by_name<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
name: &str,
) -> Type<'db> {
let table = symbol_table(db, scope);
table
.symbol_id_by_name(name)
.map(|symbol| symbol_ty(db, scope, symbol))
.unwrap_or(Type::Unbound)
}
/// Shorthand for `symbol_ty` that looks up a module-global symbol in a file.
pub(crate) fn module_global_symbol_ty_by_name<'db>(
db: &'db dyn Db,
file: File,
name: &str,
) -> Option<Type<'db>> {
let symbol = public_symbol(db, file, name)?;
Some(public_symbol_ty(db, symbol))
) -> Type<'db> {
symbol_ty_by_name(db, module_global_scope(db, file), name)
}
/// Infers all types for `scope`.
#[salsa::tracked(return_ref)]
pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInference<'db> {
let _span = tracing::trace_span!("infer_types", ?scope).entered();
/// Infer the type of a [`Definition`].
pub(crate) fn definition_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> {
let inference = infer_definition_types(db, definition);
inference.definition_ty(definition)
}
let file = scope.file(db);
// Using the index here is fine because the code below depends on the AST anyway.
// The isolation of the query is by the return inferred types.
let index = semantic_index(db, file);
/// Infer the combined type of an array of [`Definition`].
/// Will return a union if there are more than definition, or at least one plus the possibility of
/// Unbound.
pub(crate) fn definitions_ty<'db>(
db: &'db dyn Db,
definitions: &[Definition<'db>],
may_be_unbound: bool,
) -> Type<'db> {
let unbound_iter = if may_be_unbound {
[Type::Unbound].iter()
} else {
[].iter()
};
let def_types = definitions.iter().map(|def| definition_ty(db, *def));
let mut all_types = unbound_iter.copied().chain(def_types);
let node = scope.node(db);
let Some(first) = all_types.next() else {
return Type::Unbound;
};
let mut context = TypeInferenceBuilder::new(db, scope, index);
if let Some(second) = all_types.next() {
let mut builder = UnionTypeBuilder::new(db);
builder = builder.add(first).add(second);
match node {
NodeWithScopeKind::Module => {
let parsed = parsed_module(db.upcast(), file);
context.infer_module(parsed.syntax());
}
NodeWithScopeKind::Function(function) => context.infer_function_body(function.node()),
NodeWithScopeKind::Class(class) => context.infer_class_body(class.node()),
NodeWithScopeKind::ClassTypeParameters(class) => {
context.infer_class_type_params(class.node());
}
NodeWithScopeKind::FunctionTypeParameters(function) => {
context.infer_function_type_params(function.node());
for variant in all_types {
builder = builder.add(variant);
}
Type::Union(builder.build())
} else {
first
}
context.finish()
}
/// unique ID for a type
@ -96,9 +97,10 @@ pub enum Type<'db> {
/// the empty set of values
Never,
/// unknown type (no annotation)
/// equivalent to Any, or to object in strict mode
/// equivalent to Any, or possibly to object in strict mode
Unknown,
/// name is not bound to any value
/// name does not exist or is not bound to any value (this represents an error, but with some
/// leniency options it could be silently resolved to Unknown in some cases)
Unbound,
/// the None object (TODO remove this in favor of Instance(types.NoneType)
None,
@ -125,15 +127,16 @@ impl<'db> Type<'db> {
matches!(self, Type::Unknown)
}
pub fn member(&self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
#[must_use]
pub fn member(&self, db: &'db dyn Db, name: &Name) -> Type<'db> {
match self {
Type::Any => Some(Type::Any),
Type::Any => Type::Any,
Type::Never => todo!("attribute lookup on Never type"),
Type::Unknown => Some(Type::Unknown),
Type::Unbound => todo!("attribute lookup on Unbound type"),
Type::Unknown => Type::Unknown,
Type::Unbound => Type::Unbound,
Type::None => todo!("attribute lookup on None type"),
Type::Function(_) => todo!("attribute lookup on Function type"),
Type::Module(file) => public_symbol_ty_by_name(db, *file, name),
Type::Module(file) => module_global_symbol_ty_by_name(db, *file, name),
Type::Class(class) => class.class_member(db, name),
Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member
@ -152,7 +155,7 @@ impl<'db> Type<'db> {
}
Type::IntLiteral(_) => {
// TODO raise error
Some(Type::Unknown)
Type::Unknown
}
}
}
@ -188,32 +191,30 @@ impl<'db> ClassType<'db> {
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
if let Some(member) = self.own_class_member(db, name) {
return Some(member);
pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> {
let member = self.own_class_member(db, name);
if !member.is_unbound() {
return member;
}
self.inherited_class_member(db, name)
}
/// Returns the inferred type of the class member named `name`.
pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> {
let scope = self.body_scope(db);
let symbols = symbol_table(db, scope);
let symbol = symbols.symbol_id_by_name(name)?;
let types = infer_types(db, scope);
Some(types.symbol_ty(symbol))
symbol_ty_by_name(db, scope, name)
}
pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Option<Type<'db>> {
pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> {
for base in self.bases(db) {
if let Some(member) = base.member(db, name) {
return Some(member);
let member = base.member(db, name);
if !member.is_unbound() {
return member;
}
}
None
Type::Unbound
}
}
@ -268,165 +269,3 @@ pub struct IntersectionType<'db> {
// the intersection type does not include any value in any of these types
negative: FxOrderSet<Type<'db>>,
}
#[cfg(test)]
mod tests {
use red_knot_module_resolver::{
set_module_resolution_settings, RawModuleResolutionSettings, TargetVersion,
};
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run};
use crate::db::tests::TestDb;
use crate::semantic_index::root_scope;
use crate::types::{infer_types, public_symbol_ty_by_name};
use crate::{HasTy, SemanticModel};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
set_module_resolution_settings(
&mut db,
RawModuleResolutionSettings {
target_version: TargetVersion::Py38,
extra_paths: vec![],
workspace_root: SystemPathBuf::from("/src"),
site_packages: None,
custom_typeshed: None,
},
);
db
}
#[test]
fn local_inference() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_file("/src/a.py", "x = 10")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let parsed = parsed_module(&db, a);
let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap();
let model = SemanticModel::new(&db, a);
let literal_ty = statement.value.ty(&model);
assert_eq!(format!("{}", literal_ty.display(&db)), "Literal[10]");
Ok(())
}
#[test]
fn dependency_public_symbol_type_change() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ndef foo(): ..."),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
// Change `x` to a different value
db.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]");
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_function_query_was_run::<infer_types, _, _>(
&db,
|ty| &ty.function,
&a_root_scope,
&events,
);
Ok(())
}
#[test]
fn dependency_non_public_symbol_change() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ndef foo(): y = 1"),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
db.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_function_query_was_not_run::<infer_types, _, _>(
&db,
|ty| &ty.function,
&a_root_scope,
&events,
);
Ok(())
}
#[test]
fn dependency_unrelated_public_symbol() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_files([
("/src/a.py", "from foo import x"),
("/src/foo.py", "x = 10\ny = 20"),
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
db.write_file("/src/foo.py", "x = 10\ny = 30")?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
db.clear_salsa_events();
let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap();
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
let events = db.take_salsa_events();
let a_root_scope = root_scope(&db, a);
assert_function_query_was_not_run::<infer_types, _, _>(
&db,
|ty| &ty.function,
&a_root_scope,
&events,
);
Ok(())
}
}

File diff suppressed because it is too large Load diff