Address review feedback from 11963 (#12145)

This commit is contained in:
Micha Reiser 2024-07-02 09:05:55 +02:00 committed by GitHub
parent 25080acb7a
commit dcb9523b1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 203 additions and 241 deletions

View file

@ -12,7 +12,7 @@ use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::{AstId, AstIds, ScopedClassId, ScopedFunctionId}; use crate::semantic_index::ast_ids::{AstId, AstIds, ScopedClassId, ScopedFunctionId};
use crate::semantic_index::builder::SemanticIndexBuilder; use crate::semantic_index::builder::SemanticIndexBuilder;
use crate::semantic_index::symbol::{ use crate::semantic_index::symbol::{
FileScopeId, PublicSymbolId, Scope, ScopeId, ScopedSymbolId, SymbolTable, FileScopeId, PublicSymbolId, Scope, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable,
}; };
use crate::Db; use crate::Db;
@ -83,17 +83,14 @@ pub struct SemanticIndex {
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope). /// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
scopes_by_expression: FxHashMap<NodeKey, FileScopeId>, scopes_by_expression: FxHashMap<NodeKey, FileScopeId>,
/// Map from the definition that introduce a scope to the scope they define.
scopes_by_definition: FxHashMap<NodeWithScopeKey, FileScopeId>,
/// Lookup table to map between node ids and ast nodes. /// Lookup table to map between node ids and ast nodes.
/// ///
/// Note: We should not depend on this map when analysing other files or /// Note: We should not depend on this map when analysing other files or
/// changing a file invalidates all dependents. /// changing a file invalidates all dependents.
ast_ids: IndexVec<FileScopeId, AstIds>, ast_ids: IndexVec<FileScopeId, AstIds>,
/// Map from scope to the node that introduces the scope.
nodes_by_scope: IndexVec<FileScopeId, NodeWithScopeId>,
/// Map from nodes that introduce a scope to the scope they define.
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
} }
impl SemanticIndex { impl SemanticIndex {
@ -150,6 +147,7 @@ impl SemanticIndex {
} }
/// Returns an iterator over the direct child scopes of `scope`. /// Returns an iterator over the direct child scopes of `scope`.
#[allow(unused)]
pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter { pub(crate) fn child_scopes(&self, scope: FileScopeId) -> ChildrenIter {
ChildrenIter::new(self, scope) ChildrenIter::new(self, scope)
} }
@ -159,15 +157,45 @@ impl SemanticIndex {
AncestorsIter::new(self, scope) AncestorsIter::new(self, scope)
} }
pub(crate) fn scope_node(&self, scope_id: FileScopeId) -> NodeWithScopeId { /// Returns the scope that is created by `node`.
self.nodes_by_scope[scope_id] pub(crate) fn node_scope(&self, node: impl Into<NodeWithScopeKey>) -> FileScopeId {
self.scopes_by_definition[&node.into()]
} }
/// Returns the scope in which `node_with_scope` is defined.
///
/// The returned scope can be used to lookup the symbol of the definition or its type.
///
/// * Annotation: Returns the direct parent scope
/// * Function and classes: Returns the parent scope unless they have type parameters in which case
/// the grandparent scope is returned.
pub(crate) fn definition_scope( pub(crate) fn definition_scope(
&self, &self,
node_with_scope: impl Into<NodeWithScopeKey>, node_with_scope: impl Into<NodeWithScopeKey>,
) -> FileScopeId { ) -> FileScopeId {
self.scopes_by_node[&node_with_scope.into()] fn resolve_scope(index: &SemanticIndex, node_with_scope: NodeWithScopeKey) -> FileScopeId {
let scope_id = index.node_scope(node_with_scope);
let scope = index.scope(scope_id);
match scope.kind() {
ScopeKind::Module => scope_id,
ScopeKind::Annotation => scope.parent.unwrap(),
ScopeKind::Class | ScopeKind::Function => {
let mut ancestors = index.ancestor_scopes(scope_id).skip(1);
let (mut scope_id, mut scope) = ancestors.next().unwrap();
if scope.kind() == ScopeKind::Annotation {
(scope_id, scope) = ancestors.next().unwrap();
}
debug_assert_ne!(scope.kind(), ScopeKind::Annotation);
scope_id
}
}
}
resolve_scope(self, node_with_scope.into())
} }
} }
@ -307,8 +335,9 @@ mod tests {
use ruff_db::vfs::{system_path_to_file, VfsFile}; use ruff_db::vfs::{system_path_to_file, VfsFile};
use crate::db::tests::TestDb; use crate::db::tests::TestDb;
use crate::semantic_index::symbol::{FileScopeId, FileSymbolId, Scope, ScopeKind, SymbolTable}; use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable};
use crate::semantic_index::{root_scope, semantic_index, symbol_table, SemanticIndex}; use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::Db;
struct TestCase { struct TestCase {
db: TestDb, db: TestDb,
@ -440,18 +469,12 @@ y = 2
let index = semantic_index(&db, file); let index = semantic_index(&db, file);
let root = index.symbol_table(FileScopeId::root());
let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect();
assert_eq!(scopes.len(), 1); assert_eq!(scopes.len(), 1);
let (class_scope_id, class_scope) = scopes[0]; let (class_scope_id, class_scope) = scopes[0];
assert_eq!(class_scope.kind(), ScopeKind::Class); assert_eq!(class_scope.kind(), ScopeKind::Class);
assert_eq!( assert_eq!(class_scope.name(&db, file), "C");
class_scope
.defining_symbol()
.map(super::symbol::FileSymbolId::scoped_symbol_id),
root.symbol_id_by_name("C")
);
let class_table = index.symbol_table(class_scope_id); let class_table = index.symbol_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]); assert_eq!(names(&class_table), vec!["x"]);
@ -480,12 +503,7 @@ y = 2
let (function_scope_id, function_scope) = scopes[0]; let (function_scope_id, function_scope) = scopes[0];
assert_eq!(function_scope.kind(), ScopeKind::Function); assert_eq!(function_scope.kind(), ScopeKind::Function);
assert_eq!( assert_eq!(function_scope.name(&db, file), "func");
function_scope
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("func")
);
let function_table = index.symbol_table(function_scope_id); let function_table = index.symbol_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]); assert_eq!(names(&function_table), vec!["x"]);
@ -521,19 +539,9 @@ def func():
assert_eq!(func_scope_1.kind(), ScopeKind::Function); assert_eq!(func_scope_1.kind(), ScopeKind::Function);
assert_eq!( assert_eq!(func_scope_1.name(&db, file), "func");
func_scope_1
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("func")
);
assert_eq!(func_scope_2.kind(), ScopeKind::Function); assert_eq!(func_scope_2.kind(), ScopeKind::Function);
assert_eq!( assert_eq!(func_scope_2.name(&db, file), "func");
func_scope_2
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("func")
);
let func1_table = index.symbol_table(func_scope1_id); let func1_table = index.symbol_table(func_scope1_id);
let func2_table = index.symbol_table(func_scope2_id); let func2_table = index.symbol_table(func_scope2_id);
@ -568,12 +576,7 @@ def func[T]():
let (ann_scope_id, ann_scope) = scopes[0]; let (ann_scope_id, ann_scope) = scopes[0];
assert_eq!(ann_scope.kind(), ScopeKind::Annotation); assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!( assert_eq!(ann_scope.name(&db, file), "func");
ann_scope
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("func")
);
let ann_table = index.symbol_table(ann_scope_id); let ann_table = index.symbol_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]); assert_eq!(names(&ann_table), vec!["T"]);
@ -581,12 +584,7 @@ def func[T]():
assert_eq!(scopes.len(), 1); assert_eq!(scopes.len(), 1);
let (func_scope_id, func_scope) = scopes[0]; let (func_scope_id, func_scope) = scopes[0];
assert_eq!(func_scope.kind(), ScopeKind::Function); assert_eq!(func_scope.kind(), ScopeKind::Function);
assert_eq!( assert_eq!(func_scope.name(&db, file), "func");
func_scope
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("func")
);
let func_table = index.symbol_table(func_scope_id); let func_table = index.symbol_table(func_scope_id);
assert_eq!(names(&func_table), vec!["x"]); assert_eq!(names(&func_table), vec!["x"]);
} }
@ -610,12 +608,7 @@ class C[T]:
assert_eq!(scopes.len(), 1); assert_eq!(scopes.len(), 1);
let (ann_scope_id, ann_scope) = scopes[0]; let (ann_scope_id, ann_scope) = scopes[0];
assert_eq!(ann_scope.kind(), ScopeKind::Annotation); assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!( assert_eq!(ann_scope.name(&db, file), "C");
ann_scope
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("C")
);
let ann_table = index.symbol_table(ann_scope_id); let ann_table = index.symbol_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]); assert_eq!(names(&ann_table), vec!["T"]);
assert!( assert!(
@ -630,12 +623,7 @@ class C[T]:
let (func_scope_id, class_scope) = scopes[0]; let (func_scope_id, class_scope) = scopes[0];
assert_eq!(class_scope.kind(), ScopeKind::Class); assert_eq!(class_scope.kind(), ScopeKind::Class);
assert_eq!( assert_eq!(class_scope.name(&db, file), "C");
class_scope
.defining_symbol()
.map(FileSymbolId::scoped_symbol_id),
root_table.symbol_id_by_name("C")
);
assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]); assert_eq!(names(&index.symbol_table(func_scope_id)), vec!["x"]);
} }
@ -698,23 +686,13 @@ class C[T]:
fn scope_iterators() { fn scope_iterators() {
fn scope_names<'a>( fn scope_names<'a>(
scopes: impl Iterator<Item = (FileScopeId, &'a Scope)>, scopes: impl Iterator<Item = (FileScopeId, &'a Scope)>,
index: &'a SemanticIndex, db: &'a dyn Db,
file: VfsFile,
) -> Vec<&'a str> { ) -> Vec<&'a str> {
let mut names = Vec::new(); scopes
.into_iter()
for (_, scope) in scopes { .map(|(_, scope)| scope.name(db, file))
if let Some(defining_symbol) = scope.defining_symbol { .collect()
let symbol_table = &index.symbol_tables[defining_symbol.scope()];
let symbol = symbol_table.symbol(defining_symbol.scoped_symbol_id());
names.push(symbol.name().as_str());
} else if scope.parent.is_none() {
names.push("<module>");
} else {
panic!("Unsupported");
}
}
names
} }
let TestCase { db, file } = test_case( let TestCase { db, file } = test_case(
@ -734,16 +712,19 @@ def x():
let descendents = index.descendent_scopes(FileScopeId::root()); let descendents = index.descendent_scopes(FileScopeId::root());
assert_eq!( assert_eq!(
scope_names(descendents, index), scope_names(descendents, &db, file),
vec!["Test", "foo", "bar", "baz", "x"] vec!["Test", "foo", "bar", "baz", "x"]
); );
let children = index.child_scopes(FileScopeId::root()); let children = index.child_scopes(FileScopeId::root());
assert_eq!(scope_names(children, index), vec!["Test", "x"]); assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]);
let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0; let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0;
let test_child_scopes = index.child_scopes(test_class); let test_child_scopes = index.child_scopes(test_class);
assert_eq!(scope_names(test_child_scopes, index), vec!["foo", "baz"]); assert_eq!(
scope_names(test_child_scopes, &db, file),
vec!["foo", "baz"]
);
let bar_scope = index let bar_scope = index
.descendent_scopes(FileScopeId::root()) .descendent_scopes(FileScopeId::root())
@ -753,7 +734,7 @@ def x():
let ancestors = index.ancestor_scopes(bar_scope); let ancestors = index.ancestor_scopes(bar_scope);
assert_eq!( assert_eq!(
scope_names(ancestors, index), scope_names(ancestors, &db, file),
vec!["bar", "foo", "Test", "<module>"] vec!["bar", "foo", "Test", "<module>"]
); );
} }

View file

@ -145,10 +145,6 @@ impl<L: Copy> AstId<L> {
pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self { pub(super) fn new(scope: FileScopeId, in_scope_id: L) -> Self {
Self { scope, in_scope_id } Self { scope, in_scope_id }
} }
pub(super) fn in_scope_id(self) -> L {
self.in_scope_id
}
} }
/// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`]. /// Uniquely identifies an [`ast::Expr`] in a [`FileScopeId`].

View file

@ -12,7 +12,7 @@ use crate::node_key::NodeKey;
use crate::semantic_index::ast_ids::{AstId, AstIdsBuilder, ScopedClassId, ScopedFunctionId}; use crate::semantic_index::ast_ids::{AstId, AstIdsBuilder, ScopedClassId, ScopedFunctionId};
use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition};
use crate::semantic_index::symbol::{ use crate::semantic_index::symbol::{
FileScopeId, FileSymbolId, Scope, ScopeKind, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, FileScopeId, Scope, ScopeKind, ScopedSymbolId, SymbolFlags, SymbolTableBuilder,
}; };
use crate::semantic_index::{NodeWithScopeId, NodeWithScopeKey, SemanticIndex}; use crate::semantic_index::{NodeWithScopeId, NodeWithScopeKey, SemanticIndex};
@ -27,9 +27,8 @@ pub(super) struct SemanticIndexBuilder<'a> {
scopes: IndexVec<FileScopeId, Scope>, scopes: IndexVec<FileScopeId, Scope>,
symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>, symbol_tables: IndexVec<FileScopeId, SymbolTableBuilder>,
ast_ids: IndexVec<FileScopeId, AstIdsBuilder>, ast_ids: IndexVec<FileScopeId, AstIdsBuilder>,
expression_scopes: FxHashMap<NodeKey, FileScopeId>, scopes_by_expression: FxHashMap<NodeKey, FileScopeId>,
scope_nodes: IndexVec<FileScopeId, NodeWithScopeId>, scopes_by_definition: FxHashMap<NodeWithScopeKey, FileScopeId>,
node_scopes: FxHashMap<NodeWithScopeKey, FileScopeId>,
} }
impl<'a> SemanticIndexBuilder<'a> { impl<'a> SemanticIndexBuilder<'a> {
@ -42,16 +41,13 @@ impl<'a> SemanticIndexBuilder<'a> {
scopes: IndexVec::new(), scopes: IndexVec::new(),
symbol_tables: IndexVec::new(), symbol_tables: IndexVec::new(),
ast_ids: IndexVec::new(), ast_ids: IndexVec::new(),
expression_scopes: FxHashMap::default(), scopes_by_expression: FxHashMap::default(),
node_scopes: FxHashMap::default(), scopes_by_definition: FxHashMap::default(),
scope_nodes: IndexVec::new(),
}; };
builder.push_scope_with_parent( builder.push_scope_with_parent(
&NodeWithScope::new(parsed.syntax(), NodeWithScopeId::Module), &NodeWithScope::new(parsed.syntax(), NodeWithScopeId::Module),
None, None,
None,
None,
); );
builder builder
@ -64,45 +60,29 @@ impl<'a> SemanticIndexBuilder<'a> {
.expect("Always to have a root scope") .expect("Always to have a root scope")
} }
fn push_scope( fn push_scope(&mut self, node: &NodeWithScope) {
&mut self,
node: &NodeWithScope,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
) {
let parent = self.current_scope(); let parent = self.current_scope();
self.push_scope_with_parent(node, defining_symbol, definition, Some(parent)); self.push_scope_with_parent(node, Some(parent));
} }
fn push_scope_with_parent( fn push_scope_with_parent(&mut self, node: &NodeWithScope, parent: Option<FileScopeId>) {
&mut self,
node: &NodeWithScope,
defining_symbol: Option<FileSymbolId>,
definition: Option<Definition>,
parent: Option<FileScopeId>,
) {
let children_start = self.scopes.next_index() + 1; let children_start = self.scopes.next_index() + 1;
let node_key = node.key();
let node_id = node.id();
let scope_kind = node.scope_kind();
let scope = Scope { let scope = Scope {
node: node.id(),
parent, parent,
defining_symbol, kind: node.scope_kind(),
definition,
kind: scope_kind,
descendents: children_start..children_start, descendents: children_start..children_start,
}; };
let scope_id = self.scopes.push(scope); let scope_id = self.scopes.push(scope);
self.symbol_tables.push(SymbolTableBuilder::new()); self.symbol_tables.push(SymbolTableBuilder::new());
let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new());
let scope_node_id = self.scope_nodes.push(node_id);
debug_assert_eq!(ast_id_scope, scope_id); debug_assert_eq!(ast_id_scope, scope_id);
debug_assert_eq!(scope_id, scope_node_id);
self.scope_stack.push(scope_id); self.scope_stack.push(scope_id);
self.node_scopes.insert(node_key, scope_id); self.scopes_by_definition.insert(node.key(), scope_id);
} }
fn pop_scope(&mut self) -> FileScopeId { fn pop_scope(&mut self) -> FileScopeId {
@ -123,18 +103,9 @@ impl<'a> SemanticIndexBuilder<'a> {
&mut self.ast_ids[scope_id] &mut self.ast_ids[scope_id]
} }
fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> FileSymbolId { fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId {
for scope in self.scope_stack.iter().rev().skip(1) {
let builder = &self.symbol_tables[*scope];
if let Some(symbol) = builder.symbol_by_name(&name) {
return FileSymbolId::new(*scope, symbol);
}
}
let scope = self.current_scope();
let symbol_table = self.current_symbol_table(); let symbol_table = self.current_symbol_table();
FileSymbolId::new(scope, symbol_table.add_or_update_symbol(name, flags, None)) symbol_table.add_or_update_symbol(name, flags, None)
} }
fn add_or_update_symbol_with_definition( fn add_or_update_symbol_with_definition(
@ -150,7 +121,6 @@ impl<'a> SemanticIndexBuilder<'a> {
fn with_type_params( fn with_type_params(
&mut self, &mut self,
with_params: &WithTypeParams, with_params: &WithTypeParams,
defining_symbol: FileSymbolId,
nested: impl FnOnce(&mut Self) -> FileScopeId, nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId { ) -> FileScopeId {
let type_params = with_params.type_parameters(); let type_params = with_params.type_parameters();
@ -161,11 +131,7 @@ impl<'a> SemanticIndexBuilder<'a> {
WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id), WithTypeParams::FunctionDef { id, .. } => NodeWithScopeId::FunctionTypeParams(*id),
}; };
self.push_scope( self.push_scope(&NodeWithScope::new(type_params, type_params_id));
&NodeWithScope::new(type_params, type_params_id),
Some(defining_symbol),
Some(with_params.definition()),
);
for type_param in &type_params.type_params { for type_param in &type_params.type_params {
let name = match type_param { let name = match type_param {
ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name, ast::TypeParam::TypeVar(ast::TypeParamTypeVar { name, .. }) => name,
@ -210,16 +176,14 @@ impl<'a> SemanticIndexBuilder<'a> {
self.scopes.shrink_to_fit(); self.scopes.shrink_to_fit();
ast_ids.shrink_to_fit(); ast_ids.shrink_to_fit();
symbol_tables.shrink_to_fit(); symbol_tables.shrink_to_fit();
self.expression_scopes.shrink_to_fit(); self.scopes_by_expression.shrink_to_fit();
self.scope_nodes.shrink_to_fit();
SemanticIndex { SemanticIndex {
symbol_tables, symbol_tables,
scopes: self.scopes, scopes: self.scopes,
nodes_by_scope: self.scope_nodes, scopes_by_definition: self.scopes_by_definition,
scopes_by_node: self.node_scopes,
ast_ids, ast_ids,
scopes_by_expression: self.expression_scopes, scopes_by_expression: self.scopes_by_expression,
} }
} }
} }
@ -242,31 +206,24 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
let function_id = ScopedFunctionId(statement_id); let function_id = ScopedFunctionId(statement_id);
let definition = Definition::FunctionDef(function_id); let definition = Definition::FunctionDef(function_id);
let scope = self.current_scope(); let scope = self.current_scope();
let symbol = FileSymbolId::new(
scope, self.add_or_update_symbol_with_definition(name.clone(), definition);
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params( self.with_type_params(
&WithTypeParams::FunctionDef { &WithTypeParams::FunctionDef {
node: function_def, node: function_def,
id: AstId::new(scope, function_id), id: AstId::new(scope, function_id),
}, },
symbol,
|builder| { |builder| {
builder.visit_parameters(&function_def.parameters); builder.visit_parameters(&function_def.parameters);
for expr in &function_def.returns { for expr in &function_def.returns {
builder.visit_annotation(expr); builder.visit_annotation(expr);
} }
builder.push_scope( builder.push_scope(&NodeWithScope::new(
&NodeWithScope::new(
function_def, function_def,
NodeWithScopeId::Function(AstId::new(scope, function_id)), NodeWithScopeId::Function(AstId::new(scope, function_id)),
), ));
Some(symbol),
Some(definition),
);
builder.visit_body(&function_def.body); builder.visit_body(&function_def.body);
builder.pop_scope() builder.pop_scope()
}, },
@ -281,29 +238,23 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
let class_id = ScopedClassId(statement_id); let class_id = ScopedClassId(statement_id);
let definition = Definition::ClassDef(class_id); let definition = Definition::ClassDef(class_id);
let scope = self.current_scope(); let scope = self.current_scope();
let id = FileSymbolId::new(
self.current_scope(), self.add_or_update_symbol_with_definition(name.clone(), definition);
self.add_or_update_symbol_with_definition(name.clone(), definition),
);
self.with_type_params( self.with_type_params(
&WithTypeParams::ClassDef { &WithTypeParams::ClassDef {
node: class, node: class,
id: AstId::new(scope, class_id), id: AstId::new(scope, class_id),
}, },
id,
|builder| { |builder| {
if let Some(arguments) = &class.arguments { if let Some(arguments) = &class.arguments {
builder.visit_arguments(arguments); builder.visit_arguments(arguments);
} }
builder.push_scope( builder.push_scope(&NodeWithScope::new(
&NodeWithScope::new(
class, class,
NodeWithScopeId::Class(AstId::new(scope, class_id)), NodeWithScopeId::Class(AstId::new(scope, class_id)),
), ));
Some(id),
Some(definition),
);
builder.visit_body(&class.body); builder.visit_body(&class.body);
builder.pop_scope() builder.pop_scope()
@ -368,7 +319,7 @@ impl Visitor<'_> for SemanticIndexBuilder<'_> {
self.current_ast_ids().record_expression(expr, module) self.current_ast_ids().record_expression(expr, module)
}; };
self.expression_scopes self.scopes_by_expression
.insert(NodeKey::from_node(expr), self.current_scope()); .insert(NodeKey::from_node(expr), self.current_scope());
match expr { match expr {
@ -449,13 +400,6 @@ impl<'a> WithTypeParams<'a> {
WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(), WithTypeParams::FunctionDef { node, .. } => node.type_params.as_deref(),
} }
} }
fn definition(&self) -> Definition {
match self {
WithTypeParams::ClassDef { id, .. } => Definition::ClassDef(id.in_scope_id()),
WithTypeParams::FunctionDef { id, .. } => Definition::FunctionDef(id.in_scope_id()),
}
}
} }
struct NodeWithScope { struct NodeWithScope {

View file

@ -6,13 +6,14 @@ use hashbrown::hash_map::RawEntryMut;
use rustc_hash::FxHasher; use rustc_hash::FxHasher;
use smallvec::SmallVec; use smallvec::SmallVec;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap};
use crate::Db;
use ruff_db::vfs::VfsFile; use ruff_db::vfs::VfsFile;
use ruff_index::{newtype_index, IndexVec}; use ruff_index::{newtype_index, IndexVec};
use ruff_python_ast::name::Name; use ruff_python_ast::name::Name;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::{root_scope, semantic_index, symbol_table, NodeWithScopeId, SymbolMap};
use crate::Db;
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug)]
pub struct Symbol { pub struct Symbol {
name: Name, name: Name,
@ -87,13 +88,6 @@ pub struct FileSymbolId {
} }
impl FileSymbolId { impl FileSymbolId {
pub(super) fn new(scope: FileScopeId, symbol: ScopedSymbolId) -> Self {
Self {
scope,
scoped_symbol_id: symbol,
}
}
pub fn scope(self) -> FileScopeId { pub fn scope(self) -> FileScopeId {
self.scope self.scope
} }
@ -215,19 +209,33 @@ impl FileScopeId {
#[derive(Debug, Eq, PartialEq)] #[derive(Debug, Eq, PartialEq)]
pub struct Scope { pub struct Scope {
pub(super) parent: Option<FileScopeId>, pub(super) parent: Option<FileScopeId>,
pub(super) definition: Option<Definition>, pub(super) node: NodeWithScopeId,
pub(super) defining_symbol: Option<FileSymbolId>,
pub(super) kind: ScopeKind, pub(super) kind: ScopeKind,
pub(super) descendents: Range<FileScopeId>, pub(super) descendents: Range<FileScopeId>,
} }
impl Scope { impl Scope {
pub fn definition(&self) -> Option<Definition> { #[cfg(test)]
self.definition pub(crate) fn name<'db>(&self, db: &'db dyn Db, file: VfsFile) -> &'db str {
use crate::semantic_index::ast_ids::AstIdNode;
use ruff_python_ast as ast;
match self.node {
NodeWithScopeId::Module => "<module>",
NodeWithScopeId::Class(class) | NodeWithScopeId::ClassTypeParams(class) => {
let class = ast::StmtClassDef::lookup(db, file, class);
class.name.as_str()
}
NodeWithScopeId::Function(function) | NodeWithScopeId::FunctionTypeParams(function) => {
let function = ast::StmtFunctionDef::lookup(db, file, function);
function.name.as_str()
}
}
} }
pub fn defining_symbol(&self) -> Option<FileSymbolId> { /// The node that creates this scope.
self.defining_symbol pub(crate) fn node(&self) -> NodeWithScopeId {
self.node
} }
pub fn parent(self) -> Option<FileScopeId> { pub fn parent(self) -> Option<FileScopeId> {
@ -365,10 +373,6 @@ impl SymbolTableBuilder {
} }
} }
pub(super) fn symbol_by_name(&self, name: &str) -> Option<ScopedSymbolId> {
self.table.symbol_id_by_name(name)
}
pub(super) fn finish(mut self) -> SymbolTable { pub(super) fn finish(mut self) -> SymbolTable {
self.table.shrink_to_fit(); self.table.shrink_to_fit();
self.table self.table

View file

@ -5,7 +5,7 @@ use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef};
use crate::semantic_index::ast_ids::HasScopedAstId; use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{PublicSymbolId, ScopeKind}; use crate::semantic_index::symbol::PublicSymbolId;
use crate::semantic_index::{public_symbol, semantic_index, NodeWithScopeKey}; use crate::semantic_index::{public_symbol, semantic_index, NodeWithScopeKey};
use crate::types::{infer_types, public_symbol_ty, Type, TypingContext}; use crate::types::{infer_types, public_symbol_ty, Type, TypingContext};
use crate::Db; use crate::Db;
@ -38,6 +38,10 @@ impl<'db> SemanticModel<'db> {
} }
pub trait HasTy { pub trait HasTy {
/// Returns the inferred type of `self`.
///
/// ## Panics
/// May panic if `self` is from another file than `model`.
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db>; fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db>;
} }
@ -48,7 +52,6 @@ impl HasTy for ast::ExpressionRef<'_> {
let scope = file_scope.to_scope_id(model.db, model.file); let scope = file_scope.to_scope_id(model.db, model.file);
let expression_id = self.scoped_ast_id(model.db, scope); let expression_id = self.scoped_ast_id(model.db, scope);
infer_types(model.db, scope).expression_ty(expression_id) infer_types(model.db, scope).expression_ty(expression_id)
} }
} }
@ -142,15 +145,7 @@ impl HasTy for ast::StmtFunctionDef {
let index = semantic_index(model.db, model.file); let index = semantic_index(model.db, model.file);
let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); let definition_scope = index.definition_scope(NodeWithScopeKey::from(self));
// SAFETY: A function always has either an enclosing module, function or class scope. let scope = definition_scope.to_scope_id(model.db, model.file);
let mut parent_scope_id = index.parent_scope_id(definition_scope).unwrap();
let parent_scope = index.scope(parent_scope_id);
if parent_scope.kind() == ScopeKind::Annotation {
parent_scope_id = index.parent_scope_id(parent_scope_id).unwrap();
}
let scope = parent_scope_id.to_scope_id(model.db, model.file);
let types = infer_types(model.db, scope); let types = infer_types(model.db, scope);
let definition = Definition::FunctionDef(self.scoped_ast_id(model.db, scope)); let definition = Definition::FunctionDef(self.scoped_ast_id(model.db, scope));
@ -163,16 +158,7 @@ impl HasTy for StmtClassDef {
fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> {
let index = semantic_index(model.db, model.file); let index = semantic_index(model.db, model.file);
let definition_scope = index.definition_scope(NodeWithScopeKey::from(self)); let definition_scope = index.definition_scope(NodeWithScopeKey::from(self));
let scope = definition_scope.to_scope_id(model.db, model.file);
// SAFETY: A class always has either an enclosing module, function or class scope.
let mut parent_scope_id = index.parent_scope_id(definition_scope).unwrap();
let parent_scope = index.scope(parent_scope_id);
if parent_scope.kind() == ScopeKind::Annotation {
parent_scope_id = index.parent_scope_id(parent_scope_id).unwrap();
}
let scope = parent_scope_id.to_scope_id(model.db, model.file);
let types = infer_types(model.db, scope); let types = infer_types(model.db, scope);
let definition = Definition::ClassDef(self.scoped_ast_id(model.db, scope)); let definition = Definition::ClassDef(self.scoped_ast_id(model.db, scope));
@ -180,3 +166,68 @@ impl HasTy for StmtClassDef {
types.definition_ty(definition) types.definition_ty(definition)
} }
} }
#[cfg(test)]
mod tests {
use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use ruff_db::file_system::FileSystemPathBuf;
use ruff_db::parsed::parsed_module;
use ruff_db::vfs::system_path_to_file;
use crate::db::tests::TestDb;
use crate::types::Type;
use crate::{HasTy, SemanticModel};
fn setup_db() -> TestDb {
let mut db = TestDb::new();
set_module_resolution_settings(
&mut db,
ModuleResolutionSettings {
extra_paths: vec![],
workspace_root: FileSystemPathBuf::from("/src"),
site_packages: None,
custom_typeshed: None,
},
);
db
}
#[test]
fn function_ty() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("/src/foo.py", "def test(): pass")?;
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
let ast = parsed_module(&db, foo);
let function = ast.suite()[0].as_function_def_stmt().unwrap();
let model = SemanticModel::new(&db, foo);
let ty = function.ty(&model);
assert!(matches!(ty, Type::Function(_)));
Ok(())
}
#[test]
fn class_ty() -> anyhow::Result<()> {
let db = setup_db();
db.memory_file_system()
.write_file("/src/foo.py", "class Test: pass")?;
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
let ast = parsed_module(&db, foo);
let class = ast.suite()[0].as_class_def_stmt().unwrap();
let model = SemanticModel::new(&db, foo);
let ty = class.ty(&model);
assert!(matches!(ty, Type::Class(_)));
Ok(())
}
}

View file

@ -71,7 +71,7 @@ pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfe
let index = semantic_index(db, file); let index = semantic_index(db, file);
let scope_id = scope.file_scope_id(db); let scope_id = scope.file_scope_id(db);
let node = index.scope_node(scope_id); let node = index.scope(scope_id).node();
let mut context = TypeInferenceBuilder::new(db, scope, index); let mut context = TypeInferenceBuilder::new(db, scope, index);

View file

@ -11,8 +11,8 @@ use ruff_python_ast::{ExprContext, TypeParams};
use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId}; use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId};
use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition}; use crate::semantic_index::definition::{Definition, ImportDefinition, ImportFromDefinition};
use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopeKind, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId, SymbolTable};
use crate::semantic_index::{symbol_table, ChildrenIter, SemanticIndex}; use crate::semantic_index::{symbol_table, SemanticIndex};
use crate::types::{ use crate::types::{
infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId, infer_types, ClassType, FunctionType, IntersectionType, ModuleType, ScopedClassTypeId,
ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext, ScopedFunctionTypeId, ScopedIntersectionTypeId, ScopedUnionTypeId, Type, TypeId, TypingContext,
@ -104,7 +104,6 @@ pub(super) struct TypeInferenceBuilder<'a> {
/// The type inference results /// The type inference results
types: TypeInference<'a>, types: TypeInference<'a>,
children_scopes: ChildrenIter<'a>,
} }
impl<'db> TypeInferenceBuilder<'db> { impl<'db> TypeInferenceBuilder<'db> {
@ -112,7 +111,6 @@ impl<'db> TypeInferenceBuilder<'db> {
pub(super) fn new(db: &'db dyn Db, scope: ScopeId<'db>, index: &'db SemanticIndex) -> Self { pub(super) fn new(db: &'db dyn Db, scope: ScopeId<'db>, index: &'db SemanticIndex) -> Self {
let file_scope_id = scope.file_scope_id(db); let file_scope_id = scope.file_scope_id(db);
let file = scope.file(db); let file = scope.file(db);
let children_scopes = index.child_scopes(file_scope_id);
let symbol_table = index.symbol_table(file_scope_id); let symbol_table = index.symbol_table(file_scope_id);
Self { Self {
@ -124,7 +122,6 @@ impl<'db> TypeInferenceBuilder<'db> {
db, db,
types: TypeInference::default(), types: TypeInference::default(),
children_scopes,
} }
} }
@ -208,14 +205,6 @@ impl<'db> TypeInferenceBuilder<'db> {
decorators: decorator_tys, decorators: decorator_tys,
}); });
// Skip over the function or type params child scope.
let (_, scope) = self.children_scopes.next().unwrap();
assert!(matches!(
scope.kind(),
ScopeKind::Function | ScopeKind::Annotation
));
self.types self.types
.definition_tys .definition_tys
.insert(Definition::FunctionDef(function_id), function_ty); .insert(Definition::FunctionDef(function_id), function_ty);
@ -225,7 +214,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let ast::StmtClassDef { let ast::StmtClassDef {
range: _, range: _,
name, name,
type_params, type_params: _,
decorator_list, decorator_list,
arguments, arguments,
body: _, body: _,
@ -242,16 +231,7 @@ impl<'db> TypeInferenceBuilder<'db> {
.map(|arguments| self.infer_arguments(arguments)) .map(|arguments| self.infer_arguments(arguments))
.unwrap_or(Vec::new()); .unwrap_or(Vec::new());
// If the class has type parameters, then the class body scope is the first child scope of the type parameter's scope let class_body_scope_id = self.index.node_scope(class);
// Otherwise the next scope must be the class definition scope.
let (class_body_scope_id, class_body_scope) = if type_params.is_some() {
let (type_params_scope, _) = self.children_scopes.next().unwrap();
self.index.child_scopes(type_params_scope).next().unwrap()
} else {
self.children_scopes.next().unwrap()
};
assert_eq!(class_body_scope.kind(), ScopeKind::Class);
let class_ty = self.class_ty(ClassType { let class_ty = self.class_ty(ClassType {
name: name.id.clone(), name: name.id.clone(),
@ -539,6 +519,12 @@ impl<'db> TypeInferenceBuilder<'db> {
let symbol_table = symbol_table(self.db, ancestor_scope); let symbol_table = symbol_table(self.db, ancestor_scope);
if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) { if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) {
let symbol = symbol_table.symbol(symbol_id);
if !symbol.is_defined() {
continue;
}
let types = infer_types(self.db, ancestor_scope); let types = infer_types(self.db, ancestor_scope);
return types.symbol_ty(symbol_id); return types.symbol_ty(symbol_id);
} }
@ -696,13 +682,13 @@ impl<'db> TypeInferenceBuilder<'db> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use ruff_db::file_system::FileSystemPathBuf; use ruff_db::file_system::FileSystemPathBuf;
use ruff_db::vfs::system_path_to_file; use ruff_db::vfs::system_path_to_file;
use ruff_python_ast::name::Name;
use crate::db::tests::TestDb; use crate::db::tests::TestDb;
use crate::types::{public_symbol_ty_by_name, Type, TypingContext}; use crate::types::{public_symbol_ty_by_name, Type, TypingContext};
use red_knot_module_resolver::{set_module_resolution_settings, ModuleResolutionSettings};
use ruff_python_ast::name::Name;
fn setup_db() -> TestDb { fn setup_db() -> TestDb {
let mut db = TestDb::new(); let mut db = TestDb::new();