red-knot: Symbol table (#11860)

This commit is contained in:
Micha Reiser 2024-06-18 14:10:45 +01:00 committed by GitHub
parent 26ac805e6d
commit f666d79cd7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2153 additions and 10 deletions

View file

@ -21,9 +21,11 @@ ruff_text_size = { workspace = true }
bitflags = { workspace = true }
is-macro = { workspace = true }
salsa = { workspace = true, optional = true }
smol_str = { workspace = true, optional = true }
smallvec = { workspace = true, optional = true }
smol_str = { workspace = true }
tracing = { workspace = true, optional = true }
rustc-hash = { workspace = true }
hashbrown = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -34,4 +36,4 @@ tempfile = { workspace = true }
workspace = true
[features]
red_knot = ["dep:salsa", "dep:smol_str", "dep:tracing"]
red_knot = ["dep:salsa", "dep:tracing", "dep:hashbrown", "dep:smallvec"]

View file

@ -1,16 +1,25 @@
use salsa::DbWithJar;
use ruff_db::{Db as SourceDb, Upcast};
use crate::module::resolver::{
file_to_module, internal::ModuleNameIngredient, internal::ModuleResolverSearchPaths,
resolve_module_query,
};
use ruff_db::{Db as SourceDb, Upcast};
use salsa::DbWithJar;
use crate::red_knot::semantic_index::symbol::ScopeId;
use crate::red_knot::semantic_index::{scopes_map, semantic_index, symbol_table};
#[salsa::jar(db=Db)]
pub struct Jar(
ModuleNameIngredient,
ModuleResolverSearchPaths,
ScopeId,
symbol_table,
resolve_module_query,
file_to_module,
scopes_map,
semantic_index,
);
/// Database giving access to semantic information about a Python program.
@ -18,12 +27,15 @@ pub trait Db: SourceDb + DbWithJar<Jar> + Upcast<dyn SourceDb> {}
#[cfg(test)]
pub(crate) mod tests {
use super::{Db, Jar};
use std::sync::Arc;
use salsa::DebugWithDb;
use ruff_db::file_system::{FileSystem, MemoryFileSystem, OsFileSystem};
use ruff_db::vfs::Vfs;
use ruff_db::{Db as SourceDb, Jar as SourceJar, Upcast};
use salsa::DebugWithDb;
use std::sync::Arc;
use super::{Db, Jar};
#[salsa::db(Jar, SourceJar)]
pub(crate) struct TestDb {

View file

@ -9,7 +9,10 @@ mod globals;
mod model;
#[cfg(feature = "red_knot")]
pub mod module;
pub mod name;
mod nodes;
#[cfg(feature = "red_knot")]
pub mod red_knot;
mod reference;
mod scope;
mod star_import;

View file

@ -0,0 +1,56 @@
use std::ops::Deref;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Name(smol_str::SmolStr);
impl Name {
#[inline]
pub fn new(name: &str) -> Self {
Self(smol_str::SmolStr::new(name))
}
#[inline]
pub fn new_static(name: &'static str) -> Self {
Self(smol_str::SmolStr::new_static(name))
}
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Deref for Name {
type Target = str;
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl<T> From<T> for Name
where
T: Into<smol_str::SmolStr>,
{
fn from(value: T) -> Self {
Self(value.into())
}
}
impl std::fmt::Display for Name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl PartialEq<str> for Name {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl PartialEq<Name> for str {
fn eq(&self, other: &Name) -> bool {
other == self
}
}

View file

@ -0,0 +1,162 @@
use std::hash::Hash;
use std::ops::Deref;
use ruff_db::parsed::ParsedModule;
/// Ref-counted owned reference to an AST node.
///
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
/// node must still be valid.
///
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
///
/// ## Equality
/// Two `AstNodeRef` are considered equal if their wrapped nodes are equal.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
_parsed: ParsedModule,
/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
}
#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which
/// the `AstNodeRef` belongs.
///
/// ## Safety
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to
/// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
_parsed: parsed,
node: std::ptr::NonNull::from(node),
}
}
/// Returns a reference to the wrapped node.
pub fn node(&self) -> &T {
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive
// and not moved.
unsafe { self.node.as_ref() }
}
}
impl<T> Deref for AstNodeRef<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.node()
}
}
impl<T> std::fmt::Debug for AstNodeRef<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
}
}
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node().eq(other.node())
}
}
impl<T> Eq for AstNodeRef<T> where T: Eq {}
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node().hash(state);
}
}
#[allow(unsafe_code)]
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
#[allow(unsafe_code)]
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
#[cfg(test)]
mod tests {
use crate::red_knot::ast_node_ref::AstNodeRef;
use ruff_db::parsed::ParsedModule;
use ruff_python_ast::PySourceType;
use ruff_python_parser::parse_unchecked_source;
#[test]
#[allow(unsafe_code)]
fn equality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
assert_eq!(node1, node2);
// Compare from different trees
let cloned = ParsedModule::new(parsed_raw);
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
assert_eq!(node1, cloned_node);
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node1, other_node);
}
#[allow(unsafe_code)]
#[test]
fn inequality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node, other_node);
}
#[test]
#[allow(unsafe_code)]
fn debug() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let debug = format!("{stmt_node:?}");
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
}
}

View file

@ -0,0 +1,3 @@
pub mod ast_node_ref;
mod node_key;
pub mod semantic_index;

View file

@ -0,0 +1,24 @@
use ruff_python_ast::{AnyNodeRef, NodeKind};
use ruff_text_size::{Ranged, TextRange};
/// Compact key for a node for use in a hash map.
///
/// Compares two nodes by their kind and text range.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub(super) struct NodeKey {
kind: NodeKind,
range: TextRange,
}
impl NodeKey {
pub(super) fn from_node<'a, N>(node: N) -> Self
where
N: Into<AnyNodeRef<'a>>,
{
let node = node.into();
NodeKey {
kind: node.kind(),
range: node.range(),
}
}
}

View file

@ -0,0 +1,655 @@
use std::iter::FusedIterator;
use std::sync::Arc;
use rustc_hash::FxHashMap;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::VfsFile;
use ruff_index::{IndexSlice, IndexVec};
use ruff_python_ast as ast;
use crate::red_knot::node_key::NodeKey;
use crate::red_knot::semantic_index::ast_ids::AstIds;
use crate::red_knot::semantic_index::builder::SemanticIndexBuilder;
use crate::red_knot::semantic_index::symbol::{
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeSymbolId, ScopesMap, SymbolTable,
};
use crate::Db;
pub mod ast_ids;
mod builder;
pub mod definition;
pub mod symbol;
type SymbolMap = hashbrown::HashMap<ScopeSymbolId, (), ()>;
/// Returns the semantic index for `file`.
///
/// Prefer using [`symbol_table`] when working with symbols from a single scope.
#[salsa::tracked(return_ref, no_eq)]
pub(crate) fn semantic_index(db: &dyn Db, file: VfsFile) -> SemanticIndex {
let parsed = parsed_module(db.upcast(), file);
SemanticIndexBuilder::new(parsed).build()
}
/// Returns the symbol table for a specific `scope`.
///
/// Using [`symbol_table`] over [`semantic_index`] has the advantage that
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
/// is unchanged.
#[salsa::tracked]
pub(crate) fn symbol_table(db: &dyn Db, scope: ScopeId) -> Arc<SymbolTable> {
let index = semantic_index(db, scope.file(db));
index.symbol_table(scope.scope_id(db))
}
/// Returns a mapping from file specific [`FileScopeId`] to a program-wide unique [`ScopeId`].
#[salsa::tracked(return_ref)]
pub(crate) fn scopes_map(db: &dyn Db, file: VfsFile) -> ScopesMap {
let index = semantic_index(db, file);
let scopes: IndexVec<_, _> = index
.scopes
.indices()
.map(|id| ScopeId::new(db, file, id))
.collect();
ScopesMap::new(scopes)
}
/// Returns the root scope of `file`.
pub fn root_scope(db: &dyn Db, file: VfsFile) -> ScopeId {
FileScopeId::root().to_scope_id(db, file)
}
/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub fn global_symbol(db: &dyn Db, file: VfsFile, name: &str) -> Option<PublicSymbolId> {
let root_scope = root_scope(db, file);
root_scope.symbol(db, name)
}
/// The symbol tables for an entire file.
#[derive(Debug)]
pub struct SemanticIndex {
/// List of all symbol tables in this file, indexed by scope.
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,
/// List of all scopes in this file.
scopes: IndexVec<FileScopeId, Scope>,
/// Maps expressions to their corresponding scope.
/// We can't use [`ExpressionId`] here, because the challenge is how to get from
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
/// Lookup table to map between node ids and ast nodes.
///
/// Note: We should not depend on this map when analysing other files or
/// changing a file invalidates all dependents.
ast_ids: IndexVec<FileScopeId, AstIds>,
}
impl SemanticIndex {
/// Returns the symbol table for a specific scope.
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}
pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
&self.ast_ids[scope_id]
}
/// Returns the ID of the `expression`'s enclosing scope.
#[allow(unused)]
pub(crate) fn expression_scope_id(&self, expression: &ast::Expr) -> FileScopeId {
self.expression_scopes[&NodeKey::from_node(expression)]
}
/// Returns the [`Scope`] of the `expression`'s enclosing scope.
#[allow(unused)]
pub(crate) fn expression_scope(&self, expression: &ast::Expr) -> &Scope {
&self.scopes[self.expression_scope_id(expression)]
}
/// Returns the [`Scope`] with the given id.
#[allow(unused)]
pub(crate) fn scope(&self, id: FileScopeId) -> &Scope {
&self.scopes[id]
}
/// Returns the id of the parent scope.
pub(crate) fn parent_scope_id(&self, scope_id: FileScopeId) -> Option<FileScopeId> {
let scope = self.scope(scope_id);
scope.parent
}
/// Returns the parent scope of `scope_id`.
#[allow(unused)]
pub(crate) fn parent_scope(&self, scope_id: FileScopeId) -> Option<&Scope> {
Some(&self.scopes[self.parent_scope_id(scope_id)?])
}
/// Returns an iterator over the descendent scopes of `scope`.
#[allow(unused)]
pub(crate) fn descendent_scopes(&self, scope: FileScopeId) -> DescendentsIter {
DescendentsIter::new(self, scope)
}
/// Returns an iterator over the direct child scopes of `scope`.
#[allow(unused)]
pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter {
ChildrenIter::new(self, scope)
}
/// Returns an iterator over all ancestors of `scope`, starting with `scope` itself.
#[allow(unused)]
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}
}
/// ID that uniquely identifies an expression inside a [`Scope`].
pub struct AncestorsIter<'a> {
scopes: &'a IndexSlice<FileScopeId, Scope>,
next_id: Option<FileScopeId>,
}
impl<'a> AncestorsIter<'a> {
fn new(module_symbol_table: &'a SemanticIndex, start: FileScopeId) -> Self {
Self {
scopes: &module_symbol_table.scopes,
next_id: Some(start),
}
}
}
impl<'a> Iterator for AncestorsIter<'a> {
type Item = (FileScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
let current_id = self.next_id?;
let current = &self.scopes[current_id];
self.next_id = current.parent;
Some((current_id, current))
}
}
impl FusedIterator for AncestorsIter<'_> {}
pub struct DescendentsIter<'a> {
next_id: FileScopeId,
descendents: std::slice::Iter<'a, Scope>,
}
impl<'a> DescendentsIter<'a> {
fn new(symbol_table: &'a SemanticIndex, scope_id: FileScopeId) -> Self {
let scope = &symbol_table.scopes[scope_id];
let scopes = &symbol_table.scopes[scope.descendents.clone()];
Self {
next_id: scope_id + 1,
descendents: scopes.iter(),
}
}
}
impl<'a> Iterator for DescendentsIter<'a> {
type Item = (FileScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
let descendent = self.descendents.next()?;
let id = self.next_id;
self.next_id = self.next_id + 1;
Some((id, descendent))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.descendents.size_hint()
}
}
impl FusedIterator for DescendentsIter<'_> {}
impl ExactSizeIterator for DescendentsIter<'_> {}
pub struct ChildrenIter<'a> {
parent: FileScopeId,
descendents: DescendentsIter<'a>,
}
impl<'a> ChildrenIter<'a> {
fn new(module_symbol_table: &'a SemanticIndex, parent: FileScopeId) -> Self {
let descendents = DescendentsIter::new(module_symbol_table, parent);
Self {
parent,
descendents,
}
}
}
impl<'a> Iterator for ChildrenIter<'a> {
type Item = (FileScopeId, &'a Scope);
fn next(&mut self) -> Option<Self::Item> {
self.descendents
.find(|(_, scope)| scope.parent == Some(self.parent))
}
}
impl FusedIterator for ChildrenIter<'_> {}
#[cfg(test)]
mod tests {
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::{system_path_to_file, VfsFile};
use crate::db::tests::TestDb;
use crate::red_knot::semantic_index::symbol::{FileScopeId, ScopeKind, SymbolTable};
use crate::red_knot::semantic_index::{root_scope, semantic_index, symbol_table};
struct TestCase {
db: TestDb,
file: VfsFile,
}
fn test_case(content: impl ToString) -> TestCase {
let db = TestDb::new();
db.memory_file_system()
.write_file("test.py", content)
.unwrap();
let file = system_path_to_file(&db, "test.py").unwrap();
TestCase { db, file }
}
fn names(table: &SymbolTable) -> Vec<&str> {
table
.symbols()
.map(|symbol| symbol.name().as_str())
.collect()
}
#[test]
fn empty() {
let TestCase { db, file } = test_case("");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), Vec::<&str>::new());
}
#[test]
fn simple() {
let TestCase { db, file } = test_case("x");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["x"]);
}
#[test]
fn annotation_only() {
let TestCase { db, file } = test_case("x: int");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["int", "x"]);
// TODO record definition
}
#[test]
fn import() {
let TestCase { db, file } = test_case("import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
let foo = root_table.symbol_by_name("foo").unwrap();
assert_eq!(foo.definitions().len(), 1);
}
#[test]
fn import_sub() {
let TestCase { db, file } = test_case("import foo.bar");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
}
#[test]
fn import_as() {
let TestCase { db, file } = test_case("import foo.bar as baz");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["baz"]);
}
#[test]
fn import_from() {
let TestCase { db, file } = test_case("from bar import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(
root_table
.symbol_by_name("foo")
.unwrap()
.definitions()
.len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }),
"symbols that are defined get the defined flag"
);
}
#[test]
fn assign() {
let TestCase { db, file } = test_case("x = foo");
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["foo", "x"]);
assert_eq!(
root_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }),
"a symbol used but not defined in a scope should have only the used flag"
);
}
#[test]
fn class_scope() {
let TestCase { db, file } = test_case(
"
class C:
x = 1
y = 2
",
);
let root_table = symbol_table(&db, root_scope(&db, file));
assert_eq!(names(&root_table), vec!["C", "y"]);
let index = semantic_index(&db, file);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1);
let (class_scope_id, class_scope) = scopes[0];
assert_eq!(class_scope.kind(), ScopeKind::Class);
assert_eq!(class_scope.name(), "C");
let class_table = index.symbol_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]);
assert_eq!(
class_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
}
#[test]
fn function_scope() {
let TestCase { db, file } = test_case(
"
def func():
x = 1
y = 2
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["func", "y"]);
let scopes = index.child_scopes(FileScopeId::root()).collect::<Vec<_>>();
assert_eq!(scopes.len(), 1);
let (function_scope_id, function_scope) = scopes[0];
assert_eq!(function_scope.kind(), ScopeKind::Function);
assert_eq!(function_scope.name(), "func");
let function_table = index.symbol_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]);
assert_eq!(
function_table
.symbol_by_name("x")
.unwrap()
.definitions()
.len(),
1
);
}
#[test]
fn dupes() {
let TestCase { db, file } = test_case(
"
def func():
x = 1
def func():
y = 2
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 2);
let (func_scope1_id, func_scope_1) = scopes[0];
let (func_scope2_id, func_scope_2) = scopes[1];
assert_eq!(func_scope_1.kind(), ScopeKind::Function);
assert_eq!(func_scope_1.name(), "func");
assert_eq!(func_scope_2.kind(), ScopeKind::Function);
assert_eq!(func_scope_2.name(), "func");
let func1_table = index.symbol_table(func_scope1_id);
let func2_table = index.symbol_table(func_scope2_id);
assert_eq!(names(&func1_table), vec!["x"]);
assert_eq!(names(&func2_table), vec!["y"]);
assert_eq!(
root_table
.symbol_by_name("func")
.unwrap()
.definitions()
.len(),
2
);
}
#[test]
fn generic_function() {
let TestCase { db, file } = test_case(
"
def func[T]():
x = 1
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["func"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope.name(), "func");
let ann_table = index.symbol_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]);
let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect();
assert_eq!(scopes.len(), 1);
let (func_scope_id, func_scope) = scopes[0];
assert_eq!(func_scope.kind(), ScopeKind::Function);
assert_eq!(func_scope.name(), "func");
let func_table = index.symbol_table(func_scope_id);
assert_eq!(names(&func_table), vec!["x"]);
}
#[test]
fn generic_class() {
let TestCase { db, file } = test_case(
"
class C[T]:
x = 1
",
);
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
assert_eq!(names(&root_table), vec!["C"]);
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0];
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope.name(), "C");
let ann_table = index.symbol_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]);
assert!(
ann_table
.symbol_by_name("T")
.is_some_and(|s| s.is_defined() && !s.is_used()),
"type parameters are defined by the scope that introduces them"
);
let scopes: Vec<_> = index.child_scopes(ann_scope_id).collect();
assert_eq!(scopes.len(), 1);
let (func_scope_id, func_scope) = scopes[0];
assert_eq!(func_scope.kind(), ScopeKind::Class);
assert_eq!(func_scope.name(), "C");
assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]);
}
// TODO: After porting the control flow graph.
// #[test]
// fn reachability_trivial() {
// let parsed = parse("x = 1; x");
// let ast = parsed.syntax();
// let index = SemanticIndex::from_ast(ast);
// let table = &index.symbol_table;
// let x_sym = table
// .root_symbol_id_by_name("x")
// .expect("x symbol should exist");
// let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else {
// panic!("should be an expr")
// };
// let x_defs: Vec<_> = index
// .reachable_definitions(x_sym, x_use)
// .map(|constrained_definition| constrained_definition.definition)
// .collect();
// assert_eq!(x_defs.len(), 1);
// let Definition::Assignment(node_key) = &x_defs[0] else {
// panic!("def should be an assignment")
// };
// let Some(def_node) = node_key.resolve(ast.into()) else {
// panic!("node key should resolve")
// };
// let ast::Expr::NumberLiteral(ast::ExprNumberLiteral {
// value: ast::Number::Int(num),
// ..
// }) = &*def_node.value
// else {
// panic!("should be a number literal")
// };
// assert_eq!(*num, 1);
// }
#[test]
fn expression_scope() {
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");
let index = semantic_index(&db, file);
let root_table = index.symbol_table(FileScopeId::root());
let parsed = parsed_module(&db, file);
let ast = parsed.syntax();
let x_sym = root_table
.symbol_by_name("x")
.expect("x symbol should exist");
let x_stmt = ast.body[0].as_assign_stmt().unwrap();
let x = &x_stmt.targets[0];
assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module);
assert_eq!(index.expression_scope_id(x), x_sym.scope());
let def = ast.body[1].as_function_def_stmt().unwrap();
let y_stmt = def.body[0].as_assign_stmt().unwrap();
let y = &y_stmt.targets[0];
assert_eq!(index.expression_scope(y).kind(), ScopeKind::Function);
}
#[test]
fn scope_iterators() {
let TestCase { db, file } = test_case(
r#"
class Test:
def foo():
def bar():
...
def baz():
pass
def x():
pass"#,
);
let index = semantic_index(&db, file);
let descendents: Vec<_> = index
.descendent_scopes(FileScopeId::root())
.map(|(_, scope)| scope.name().as_str())
.collect();
assert_eq!(descendents, vec!["Test", "foo", "bar", "baz", "x"]);
let children: Vec<_> = index
.child_scopes(FileScopeId::root())
.map(|(_, scope)| scope.name.as_str())
.collect();
assert_eq!(children, vec!["Test", "x"]);
let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0;
let test_child_scopes: Vec<_> = index
.child_scopes(test_class)
.map(|(_, scope)| scope.name.as_str())
.collect();
assert_eq!(test_child_scopes, vec!["foo", "baz"]);
let bar_scope = index
.descendent_scopes(FileScopeId::root())
.nth(2)
.unwrap()
.0;
let ancestors: Vec<_> = index
.ancestor_scopes(bar_scope)
.map(|(_, scope)| scope.name())
.collect();
assert_eq!(ancestors, vec!["bar", "foo", "Test", "<module>"]);
}
}

View file

@ -0,0 +1,384 @@
use rustc_hash::FxHashMap;
use ruff_db::parsed::ParsedModule;
use ruff_db::vfs::VfsFile;
use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast as ast;
use ruff_python_ast::AnyNodeRef;
use crate::red_knot::ast_node_ref::AstNodeRef;
use crate::red_knot::node_key::NodeKey;
use crate::red_knot::semantic_index::semantic_index;
use crate::red_knot::semantic_index::symbol::{FileScopeId, ScopeId};
use crate::Db;
/// AST ids for a single scope.
///
/// The motivation for building the AST ids per scope isn't about reducing invalidation because
/// the struct changes whenever the parsed AST changes. Instead, it's mainly that we can
/// build the AST ids struct when building the symbol table and also keep the property that
/// IDs of outer scopes are unaffected by changes in inner scopes.
///
/// For example, we don't want that adding new statements to `foo` changes the statement id of `x = foo()` in:
///
/// ```python
/// def foo():
/// return 5
///
/// x = foo()
/// ```
pub(crate) struct AstIds {
/// Maps expression ids to their expressions.
expressions: IndexVec<ScopeExpressionId, AstNodeRef<ast::Expr>>,
/// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`].
expressions_map: FxHashMap<NodeKey, ScopeExpressionId>,
statements: IndexVec<ScopeStatementId, AstNodeRef<ast::Stmt>>,
statements_map: FxHashMap<NodeKey, ScopeStatementId>,
}
impl AstIds {
fn statement_id<'a, N>(&self, node: N) -> ScopeStatementId
where
N: Into<AnyNodeRef<'a>>,
{
self.statements_map[&NodeKey::from_node(node.into())]
}
fn expression_id<'a, N>(&self, node: N) -> ScopeExpressionId
where
N: Into<AnyNodeRef<'a>>,
{
self.expressions_map[&NodeKey::from_node(node.into())]
}
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for AstIds {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AstIds")
.field("expressions", &self.expressions)
.field("statements", &self.statements)
.finish()
}
}
fn ast_ids(db: &dyn Db, scope: ScopeId) -> &AstIds {
semantic_index(db, scope.file(db)).ast_ids(scope.scope_id(db))
}
/// Node that can be uniquely identified by an id in a [`FileScopeId`].
pub trait ScopeAstIdNode {
/// The type of the ID uniquely identifying the node.
type Id;
/// Returns the ID that uniquely identifies the node in `scope`.
///
/// ## Panics
/// Panics if the node doesn't belong to `file` or is outside `scope`.
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> Self::Id;
/// Looks up the AST node by its ID.
///
/// ## Panics
/// May panic if the `id` does not belong to the AST of `file`, or is outside `scope`.
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self
where
Self: Sized;
}
/// Extension trait for AST nodes that can be resolved by an `AstId`.
pub trait AstIdNode {
type ScopeId;
/// Resolves the AST id of the node.
///
/// ## Panics
/// May panic if the node does not belongs to `file`'s AST or is outside of `scope`. It may also
/// return an incorrect node if that's the case.
fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId<Self::ScopeId>;
/// Resolves the AST node for `id`.
///
/// ## Panics
/// May panic if the `id` does not belong to the AST of `file` or it returns an incorrect node.
fn lookup(db: &dyn Db, file: VfsFile, id: AstId<Self::ScopeId>) -> &Self
where
Self: Sized;
}
impl<T> AstIdNode for T
where
T: ScopeAstIdNode,
{
type ScopeId = T::Id;
fn ast_id(&self, db: &dyn Db, file: VfsFile, scope: FileScopeId) -> AstId<Self::ScopeId> {
let in_scope_id = self.scope_ast_id(db, file, scope);
AstId { scope, in_scope_id }
}
fn lookup(db: &dyn Db, file: VfsFile, id: AstId<Self::ScopeId>) -> &Self
where
Self: Sized,
{
let scope = id.scope;
Self::lookup_in_scope(db, file, scope, id.in_scope_id)
}
}
/// Uniquely identifies an AST node in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct AstId<L> {
/// The node's scope.
scope: FileScopeId,
/// The ID of the node inside [`Self::scope`].
in_scope_id: L,
}
/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`].
#[newtype_index]
pub struct ScopeExpressionId;
impl ScopeAstIdNode for ast::Expr {
type Id = ScopeExpressionId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.expressions_map[&NodeKey::from_node(self)]
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.expressions[id].node()
}
}
/// Uniquely identifies an [`ast::Stmt`] in a [`FileScopeId`].
#[newtype_index]
pub struct ScopeStatementId;
impl ScopeAstIdNode for ast::Stmt {
type Id = ScopeStatementId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.statement_id(self)
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, file_scope: FileScopeId, id: Self::Id) -> &Self {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ast_ids.statements[id].node()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeFunctionId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtFunctionDef {
type Id = ScopeFunctionId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeFunctionId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
ast::Stmt::lookup_in_scope(db, file, scope, id.0)
.as_function_def_stmt()
.unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeClassId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtClassDef {
type Id = ScopeClassId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeClassId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_class_def_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeAssignmentId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtAssign {
type Id = ScopeAssignmentId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeAssignmentId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_assign_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeAnnotatedAssignmentId(ScopeStatementId);
impl ScopeAstIdNode for ast::StmtAnnAssign {
type Id = ScopeAnnotatedAssignmentId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeAnnotatedAssignmentId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_ann_assign_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeImportId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtImport {
type Id = ScopeImportId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeImportId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_import_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeImportFromId(pub(super) ScopeStatementId);
impl ScopeAstIdNode for ast::StmtImportFrom {
type Id = ScopeImportFromId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeImportFromId(ast_ids.statement_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self {
let statement = ast::Stmt::lookup_in_scope(db, file, scope, id.0);
statement.as_import_from_stmt().unwrap()
}
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub struct ScopeNamedExprId(pub(super) ScopeExpressionId);
impl ScopeAstIdNode for ast::ExprNamed {
type Id = ScopeNamedExprId;
fn scope_ast_id(&self, db: &dyn Db, file: VfsFile, file_scope: FileScopeId) -> Self::Id {
let scope = file_scope.to_scope_id(db, file);
let ast_ids = ast_ids(db, scope);
ScopeNamedExprId(ast_ids.expression_id(self))
}
fn lookup_in_scope(db: &dyn Db, file: VfsFile, scope: FileScopeId, id: Self::Id) -> &Self
where
Self: Sized,
{
let expression = ast::Expr::lookup_in_scope(db, file, scope, id.0);
expression.as_named_expr().unwrap()
}
}
#[derive(Debug)]
pub(super) struct AstIdsBuilder {
expressions: IndexVec<ScopeExpressionId, AstNodeRef<ast::Expr>>,
expressions_map: FxHashMap<NodeKey, ScopeExpressionId>,
statements: IndexVec<ScopeStatementId, AstNodeRef<ast::Stmt>>,
statements_map: FxHashMap<NodeKey, ScopeStatementId>,
}
impl AstIdsBuilder {
pub(super) fn new() -> Self {
Self {
expressions: IndexVec::default(),
expressions_map: FxHashMap::default(),
statements: IndexVec::default(),
statements_map: FxHashMap::default(),
}
}
/// Adds `stmt` to the AST ids map and returns its id.
///
/// ## Safety
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
/// that `stmt` is a child of `parsed`.
#[allow(unsafe_code)]
pub(super) unsafe fn record_statement(
&mut self,
stmt: &ast::Stmt,
parsed: &ParsedModule,
) -> ScopeStatementId {
let statement_id = self.statements.push(AstNodeRef::new(parsed.clone(), stmt));
self.statements_map
.insert(NodeKey::from_node(stmt), statement_id);
statement_id
}
/// Adds `expr` to the AST ids map and returns its id.
///
/// ## Safety
/// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires
/// that `expr` is a child of `parsed`.
#[allow(unsafe_code)]
pub(super) unsafe fn record_expression(
&mut self,
expr: &ast::Expr,
parsed: &ParsedModule,
) -> ScopeExpressionId {
let expression_id = self.expressions.push(AstNodeRef::new(parsed.clone(), expr));
self.expressions_map
.insert(NodeKey::from_node(expr), expression_id);
expression_id
}
pub(super) fn finish(mut self) -> AstIds {
self.expressions.shrink_to_fit();
self.expressions_map.shrink_to_fit();
self.statements.shrink_to_fit();
self.statements_map.shrink_to_fit();
AstIds {
expressions: self.expressions,
expressions_map: self.expressions_map,
statements: self.statements,
statements_map: self.statements_map,
}
}
}

View file

@ -0,0 +1,398 @@
use std::sync::Arc;
use rustc_hash::FxHashMap;
use ruff_db::parsed::ParsedModule;
use ruff_index::IndexVec;
use ruff_python_ast as ast;
use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor};
use crate::name::Name;
use crate::red_knot::node_key::NodeKey;
use crate::red_knot::semantic_index::ast_ids::{
AstIdsBuilder, ScopeAssignmentId, ScopeClassId, ScopeFunctionId, ScopeImportFromId,
ScopeImportId, ScopeNamedExprId,
};
use crate::red_knot::semantic_index::definition::{
Definition, ImportDefinition, ImportFromDefinition,
};
use crate::red_knot::semantic_index::symbol::{
FileScopeId, FileSymbolId, Scope, ScopeKind, ScopeSymbolId, SymbolFlags, SymbolTableBuilder,
};
use crate::red_knot::semantic_index::SemanticIndex;
pub(super) struct SemanticIndexBuilder<'a> {
// Builder state
module: &'a ParsedModule,
scope_stack: Vec<FileScopeId>,
/// the definition whose target(s) we are currently walking
current_definition: Option<Definition>,
// Semantic Index fields
scopes: IndexVec<FileScopeId, Scope>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
expression_scopes: FxHashMap<NodeKey, FileScopeId>,
}
impl<'a> SemanticIndexBuilder<'a> {
pub(super) fn new(parsed: &'a ParsedModule) -> Self {
let mut builder = Self {
module: parsed,
scope_stack: Vec::new(),
current_definition: None,
scopes: IndexVec::new(),
symbol_tables: IndexVec::new(),
ast_ids: IndexVec::new(),
expression_scopes: FxHashMap::default(),
};
builder.push_scope_with_parent(
ScopeKind::Module,
&Name::new_static("<module>"),
None,
None,
None,
);
builder
}
fn current_scope(&self) -> FileScopeId {
*self
.scope_stack
.last()
.expect("Always to have a root scope")
}
fn push_scope(
&mut self,
scope_kind: ScopeKind,
name: &Name,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
) {
let parent = self.current_scope();
self.push_scope_with_parent(scope_kind, name, defining_symbol, definition, Some(parent));
}
fn push_scope_with_parent(
&mut self,
scope_kind: ScopeKind,
name: &Name,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
parent: Option<FileScopeId>,
) {
let children_start = self.scopes.next_index() + 1;
let scope = Scope {
name: name.clone(),
parent,
defining_symbol,
definition,
kind: scope_kind,
descendents: children_start..children_start,
};
let scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::new());
self.ast_ids.push(AstIdsBuilder::new());
self.scope_stack.push(scope_id);
}
fn pop_scope(&mut self) -> FileScopeId {
let id = self.scope_stack.pop().expect("Root scope to be present");
let children_end = self.scopes.next_index();
let scope = &mut self.scopes[id];
scope.descendents = scope.descendents.start..children_end;
id
}
fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder {
let scope_id = self.current_scope();
&mut self.symbol_tables[scope_id]
}
fn current_ast_ids(&mut self) -> &mut AstIdsBuilder {
let scope_id = self.current_scope();
&mut self.ast_ids[scope_id]
}
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopeSymbolId {
let scope = self.current_scope();
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, scope, flags, None)
}
fn add_or_update_symbol_with_definition(
&mut self,
name: Name,
definition: Definition,
) -> ScopeSymbolId {
let scope = self.current_scope();
let symbol_table = self.current_symbol_table();
symbol_table.add_or_update_symbol(name, scope, SymbolFlags::IS_DEFINED, Some(definition))
}
fn with_type_params(
&mut self,
name: &Name,
params: &Option<Box<ast::TypeParams>>,
definition: Option<Definition>,
defining_symbol: FileSymbolId,
nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId {
if let Some(type_params) = params {
self.push_scope(
ScopeKind::Annotation,
name,
Some(defining_symbol),
definition,
);
for type_param in &type_params.type_params {
let name = match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
ast::TypeParam::ParamSpec(ast::TypeParamParamSpec { name, .. }) => name,
ast::TypeParam::TypeVarTuple(ast::TypeParamTypeVarTuple { name, .. }) => name,
};
self.add_or_update_symbol(Name::new(name), SymbolFlags::IS_DEFINED);
}
}
let nested_scope = nested(self);
if params.is_some() {
self.pop_scope();
}
nested_scope
}
pub(super) fn build(mut self) -> SemanticIndex {
let module = self.module;
self.visit_body(module.suite());
// Pop the root scope
self.pop_scope();
assert!(self.scope_stack.is_empty());
assert!(self.current_definition.is_none());
let mut symbol_tables: IndexVec<_, _> = self
.symbol_tables
.into_iter()
.map(|builder| Arc::new(builder.finish()))
.collect();
let mut ast_ids: IndexVec<_, _> = self
.ast_ids
.into_iter()
.map(super::ast_ids::AstIdsBuilder::finish)
.collect();
self.scopes.shrink_to_fit();
ast_ids.shrink_to_fit();
symbol_tables.shrink_to_fit();
self.expression_scopes.shrink_to_fit();
SemanticIndex {
symbol_tables,
scopes: self.scopes,
ast_ids,
expression_scopes: self.expression_scopes,
}
}
}
impl Visitor<'_> for SemanticIndexBuilder<'_> {
fn visit_stmt(&mut self, stmt: &ast::Stmt) {
let module = self.module;
#[allow(unsafe_code)]
let statement_id = unsafe {
// SAFETY: The builder only visits nodes that are part of `module`. This guarantees that
// the current statement must be a child of `module`.
self.current_ast_ids().record_statement(stmt, module)
};
match stmt {
ast::Stmt::FunctionDef(function_def) => {
for decorator in &function_def.decorator_list {
self.visit_decorator(decorator);
}
let name = Name::new(&function_def.name.id);
let definition = Definition::FunctionDef(ScopeFunctionId(statement_id));
let scope = self.current_scope();
let symbol = FileSymbolId::new(
scope,
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params(
&name,
&function_def.type_params,
Some(definition),
symbol,
|builder| {
builder.visit_parameters(&function_def.parameters);
for expr in &function_def.returns {
builder.visit_annotation(expr);
}
builder.push_scope(
ScopeKind::Function,
&name,
Some(symbol),
Some(definition),
);
builder.visit_body(&function_def.body);
builder.pop_scope()
},
);
}
ast::Stmt::ClassDef(class) => {
for decorator in &class.decorator_list {
self.visit_decorator(decorator);
}
let name = Name::new(&class.name.id);
let definition = Definition::from(ScopeClassId(statement_id));
let id = FileSymbolId::new(
self.current_scope(),
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params(&name, &class.type_params, Some(definition), id, |builder| {
if let Some(arguments) = &class.arguments {
builder.visit_arguments(arguments);
}
builder.push_scope(ScopeKind::Class, &name, Some(id), Some(definition));
builder.visit_body(&class.body);
builder.pop_scope()
});
}
ast::Stmt::Import(ast::StmtImport { names, .. }) => {
for (i, alias) in names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.as_str()
} else {
alias.name.id.split('.').next().unwrap()
};
let def = Definition::Import(ImportDefinition {
import_id: ScopeImportId(statement_id),
alias: u32::try_from(i).unwrap(),
});
self.add_or_update_symbol_with_definition(Name::new(symbol_name), def);
}
}
ast::Stmt::ImportFrom(ast::StmtImportFrom {
module: _,
names,
level: _,
..
}) => {
for (i, alias) in names.iter().enumerate() {
let symbol_name = if let Some(asname) = &alias.asname {
asname.id.as_str()
} else {
alias.name.id.as_str()
};
let def = Definition::ImportFrom(ImportFromDefinition {
import_id: ScopeImportFromId(statement_id),
name: u32::try_from(i).unwrap(),
});
self.add_or_update_symbol_with_definition(Name::new(symbol_name), def);
}
}
ast::Stmt::Assign(node) => {
debug_assert!(self.current_definition.is_none());
self.visit_expr(&node.value);
self.current_definition =
Some(Definition::Assignment(ScopeAssignmentId(statement_id)));
for target in &node.targets {
self.visit_expr(target);
}
self.current_definition = None;
}
_ => {
walk_stmt(self, stmt);
}
}
}
fn visit_expr(&mut self, expr: &'_ ast::Expr) {
let module = self.module;
#[allow(unsafe_code)]
let expression_id = unsafe {
// SAFETY: The builder only visits nodes that are part of `module`. This guarantees that
// the current expression must be a child of `module`.
self.current_ast_ids().record_expression(expr, module)
};
self.expression_scopes
.insert(NodeKey::from_node(expr), self.current_scope());
match expr {
ast::Expr::Name(ast::ExprName { id, ctx, .. }) => {
let flags = match ctx {
ast::ExprContext::Load => SymbolFlags::IS_USED,
ast::ExprContext::Store => SymbolFlags::IS_DEFINED,
ast::ExprContext::Del => SymbolFlags::IS_DEFINED,
ast::ExprContext::Invalid => SymbolFlags::empty(),
};
match self.current_definition {
Some(definition) if flags.contains(SymbolFlags::IS_DEFINED) => {
self.add_or_update_symbol_with_definition(Name::new(id), definition);
}
_ => {
self.add_or_update_symbol(Name::new(id), flags);
}
}
walk_expr(self, expr);
}
ast::Expr::Named(node) => {
debug_assert!(self.current_definition.is_none());
self.current_definition =
Some(Definition::NamedExpr(ScopeNamedExprId(expression_id)));
// TODO walrus in comprehensions is implicitly nonlocal
self.visit_expr(&node.target);
self.current_definition = None;
self.visit_expr(&node.value);
}
ast::Expr::If(ast::ExprIf {
body, test, orelse, ..
}) => {
// TODO detect statically known truthy or falsy test (via type inference, not naive
// AST inspection, so we can't simplify here, need to record test expression in CFG
// for later checking)
self.visit_expr(test);
// let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node());
// self.set_current_flow_node(if_branch);
// self.insert_constraint(test);
self.visit_expr(body);
// let post_body = self.current_flow_node();
// self.set_current_flow_node(if_branch);
self.visit_expr(orelse);
// let post_else = self
// .flow_graph_builder
// .add_phi(self.current_flow_node(), post_body);
// self.set_current_flow_node(post_else);
}
_ => {
walk_expr(self, expr);
}
}
}
}

View file

@ -0,0 +1,76 @@
use crate::red_knot::semantic_index::ast_ids::{
ScopeAnnotatedAssignmentId, ScopeAssignmentId, ScopeClassId, ScopeFunctionId,
ScopeImportFromId, ScopeImportId, ScopeNamedExprId,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Definition {
Import(ImportDefinition),
ImportFrom(ImportFromDefinition),
ClassDef(ScopeClassId),
FunctionDef(ScopeFunctionId),
Assignment(ScopeAssignmentId),
AnnotatedAssignment(ScopeAnnotatedAssignmentId),
NamedExpr(ScopeNamedExprId),
/// represents the implicit initial definition of every name as "unbound"
Unbound,
// TODO with statements, except handlers, function args...
}
impl From<ImportDefinition> for Definition {
fn from(value: ImportDefinition) -> Self {
Self::Import(value)
}
}
impl From<ImportFromDefinition> for Definition {
fn from(value: ImportFromDefinition) -> Self {
Self::ImportFrom(value)
}
}
impl From<ScopeClassId> for Definition {
fn from(value: ScopeClassId) -> Self {
Self::ClassDef(value)
}
}
impl From<ScopeFunctionId> for Definition {
fn from(value: ScopeFunctionId) -> Self {
Self::FunctionDef(value)
}
}
impl From<ScopeAssignmentId> for Definition {
fn from(value: ScopeAssignmentId) -> Self {
Self::Assignment(value)
}
}
impl From<ScopeAnnotatedAssignmentId> for Definition {
fn from(value: ScopeAnnotatedAssignmentId) -> Self {
Self::AnnotatedAssignment(value)
}
}
impl From<ScopeNamedExprId> for Definition {
fn from(value: ScopeNamedExprId) -> Self {
Self::NamedExpr(value)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct ImportDefinition {
pub(super) import_id: ScopeImportId,
/// Index into [`ruff_python_ast::StmtImport::names`].
pub(super) alias: u32,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct ImportFromDefinition {
pub(super) import_id: ScopeImportFromId,
/// Index into [`ruff_python_ast::StmtImportFrom::names`].
pub(super) name: u32,
}

View file

@ -0,0 +1,362 @@
// Allow unused underscore violations generated by the salsa macro
// TODO(micha): Contribute fix upstream
#![allow(clippy::used_underscore_binding)]
use std::hash::{Hash, Hasher};
use std::ops::Range;
use bitflags::bitflags;
use hashbrown::hash_map::RawEntryMut;
use rustc_hash::FxHasher;
use smallvec::SmallVec;
use ruff_db::vfs::VfsFile;
use ruff_index::{newtype_index, IndexVec};
use crate::name::Name;
use crate::red_knot::semantic_index::definition::Definition;
use crate::red_knot::semantic_index::{scopes_map, symbol_table, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)]
pub struct Symbol {
name: Name,
flags: SymbolFlags,
scope: FileScopeId,
/// The nodes that define this symbol, in source order.
definitions: SmallVec<[Definition; 4]>,
}
impl Symbol {
fn new(name: Name, scope: FileScopeId, definition: Option<Definition>) -> Self {
Self {
name,
scope,
flags: SymbolFlags::empty(),
definitions: definition.into_iter().collect(),
}
}
fn push_definition(&mut self, definition: Definition) {
self.definitions.push(definition);
}
fn insert_flags(&mut self, flags: SymbolFlags) {
self.flags.insert(flags);
}
/// The symbol's name.
pub fn name(&self) -> &Name {
&self.name
}
/// The scope in which this symbol is defined.
pub fn scope(&self) -> FileScopeId {
self.scope
}
/// Is the symbol used in its containing scope?
pub fn is_used(&self) -> bool {
self.flags.contains(SymbolFlags::IS_USED)
}
/// Is the symbol defined in its containing scope?
pub fn is_defined(&self) -> bool {
self.flags.contains(SymbolFlags::IS_DEFINED)
}
pub fn definitions(&self) -> &[Definition] {
&self.definitions
}
}
bitflags! {
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(super) struct SymbolFlags: u8 {
const IS_USED = 1 << 0;
const IS_DEFINED = 1 << 1;
/// TODO: This flag is not yet set by anything
const MARKED_GLOBAL = 1 << 2;
/// TODO: This flag is not yet set by anything
const MARKED_NONLOCAL = 1 << 3;
}
}
/// ID that uniquely identifies a public symbol defined in a module's root scope.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct PublicSymbolId {
scope: ScopeId,
symbol: ScopeSymbolId,
}
impl PublicSymbolId {
pub(crate) fn new(scope: ScopeId, symbol: ScopeSymbolId) -> Self {
Self { scope, symbol }
}
pub fn scope(self) -> ScopeId {
self.scope
}
pub(crate) fn scope_symbol(self) -> ScopeSymbolId {
self.symbol
}
}
impl From<PublicSymbolId> for ScopeSymbolId {
fn from(val: PublicSymbolId) -> Self {
val.scope_symbol()
}
}
/// ID that uniquely identifies a symbol in a file.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct FileSymbolId {
scope: FileScopeId,
symbol: ScopeSymbolId,
}
impl FileSymbolId {
pub(super) fn new(scope: FileScopeId, symbol: ScopeSymbolId) -> Self {
Self { scope, symbol }
}
pub fn scope(self) -> FileScopeId {
self.scope
}
pub(crate) fn symbol(self) -> ScopeSymbolId {
self.symbol
}
}
impl From<FileSymbolId> for ScopeSymbolId {
fn from(val: FileSymbolId) -> Self {
val.symbol()
}
}
/// Symbol ID that uniquely identifies a symbol inside a [`Scope`].
#[newtype_index]
pub(crate) struct ScopeSymbolId;
/// Maps from the file specific [`FileScopeId`] to the global [`ScopeId`] that can be used as a Salsa query parameter.
///
/// The [`SemanticIndex`] uses [`FileScopeId`] on a per-file level to identify scopes
/// because they allow for more efficient storage of associated data
/// (use of an [`IndexVec`] keyed by [`FileScopeId`] over an [`FxHashMap`] keyed by [`ScopeId`]).
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct ScopesMap {
scopes: IndexVec<FileScopeId, ScopeId>,
}
impl ScopesMap {
pub(super) fn new(scopes: IndexVec<FileScopeId, ScopeId>) -> Self {
Self { scopes }
}
/// Gets the program-wide unique scope id for the given file specific `scope_id`.
fn get(&self, scope_id: FileScopeId) -> ScopeId {
self.scopes[scope_id]
}
}
/// A cross-module identifier of a scope that can be used as a salsa query parameter.
#[salsa::tracked]
pub struct ScopeId {
#[allow(clippy::used_underscore_binding)]
pub file: VfsFile,
pub scope_id: FileScopeId,
}
impl ScopeId {
/// Resolves the symbol named `name` in this scope.
pub fn symbol(self, db: &dyn Db, name: &str) -> Option<PublicSymbolId> {
let symbol_table = symbol_table(db, self);
let in_scope_id = symbol_table.symbol_id_by_name(name)?;
Some(PublicSymbolId::new(self, in_scope_id))
}
}
/// ID that uniquely identifies a scope inside of a module.
#[newtype_index]
pub struct FileScopeId;
impl FileScopeId {
/// Returns the scope id of the Root scope.
pub fn root() -> Self {
FileScopeId::from_u32(0)
}
pub fn to_scope_id(self, db: &dyn Db, file: VfsFile) -> ScopeId {
scopes_map(db, file).get(self)
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct Scope {
pub(super) name: Name,
pub(super) parent: Option<FileScopeId>,
pub(super) definition: Option<Definition>,
pub(super) defining_symbol: Option<FileSymbolId>,
pub(super) kind: ScopeKind,
pub(super) descendents: Range<FileScopeId>,
}
impl Scope {
pub fn name(&self) -> &Name {
&self.name
}
pub fn definition(&self) -> Option<Definition> {
self.definition
}
pub fn defining_symbol(&self) -> Option<FileSymbolId> {
self.defining_symbol
}
pub fn parent(self) -> Option<FileScopeId> {
self.parent
}
pub fn kind(&self) -> ScopeKind {
self.kind
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ScopeKind {
Module,
Annotation,
Class,
Function,
}
/// Symbol table for a specific [`Scope`].
#[derive(Debug)]
pub struct SymbolTable {
/// The symbols in this scope.
symbols: IndexVec<ScopeSymbolId, Symbol>,
/// The symbols indexed by name.
symbols_by_name: SymbolMap,
}
impl SymbolTable {
fn new() -> Self {
Self {
symbols: IndexVec::new(),
symbols_by_name: SymbolMap::default(),
}
}
fn shrink_to_fit(&mut self) {
self.symbols.shrink_to_fit();
}
pub(crate) fn symbol(&self, symbol_id: impl Into<ScopeSymbolId>) -> &Symbol {
&self.symbols[symbol_id.into()]
}
#[allow(unused)]
pub(crate) fn symbol_ids(&self) -> impl Iterator<Item = ScopeSymbolId> {
self.symbols.indices()
}
pub fn symbols(&self) -> impl Iterator<Item = &Symbol> {
self.symbols.iter()
}
/// Returns the symbol named `name`.
#[allow(unused)]
pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> {
let id = self.symbol_id_by_name(name)?;
Some(self.symbol(id))
}
/// Returns the [`ScopeSymbolId`] of the symbol named `name`.
pub(crate) fn symbol_id_by_name(&self, name: &str) -> Option<ScopeSymbolId> {
let (id, ()) = self
.symbols_by_name
.raw_entry()
.from_hash(Self::hash_name(name), |id| {
self.symbol(*id).name().as_str() == name
})?;
Some(*id)
}
fn hash_name(name: &str) -> u64 {
let mut hasher = FxHasher::default();
name.hash(&mut hasher);
hasher.finish()
}
}
impl PartialEq for SymbolTable {
fn eq(&self, other: &Self) -> bool {
// We don't need to compare the symbols_by_name because the name is already captured in `Symbol`.
self.symbols == other.symbols
}
}
impl Eq for SymbolTable {}
#[derive(Debug)]
pub(super) struct SymbolTableBuilder {
table: SymbolTable,
}
impl SymbolTableBuilder {
pub(super) fn new() -> Self {
Self {
table: SymbolTable::new(),
}
}
pub(super) fn add_or_update_symbol(
&mut self,
name: Name,
scope: FileScopeId,
flags: SymbolFlags,
definition: Option<Definition>,
) -> ScopeSymbolId {
let hash = SymbolTable::hash_name(&name);
let entry = self
.table
.symbols_by_name
.raw_entry_mut()
.from_hash(hash, |id| self.table.symbols[*id].name() == &name);
match entry {
RawEntryMut::Occupied(entry) => {
let symbol = &mut self.table.symbols[*entry.key()];
symbol.insert_flags(flags);
if let Some(definition) = definition {
symbol.push_definition(definition);
}
*entry.key()
}
RawEntryMut::Vacant(entry) => {
let mut symbol = Symbol::new(name, scope, definition);
symbol.insert_flags(flags);
let id = self.table.symbols.push(symbol);
entry.insert_with_hasher(hash, id, (), |id| {
SymbolTable::hash_name(self.table.symbols[*id].name().as_str())
});
id
}
}
}
pub(super) fn finish(mut self) -> SymbolTable {
self.table.shrink_to_fit();
self.table
}
}