Use a single node hierarchy to track statements and expressions (#6709)

## Summary

This PR is a follow-up to the suggestion in
https://github.com/astral-sh/ruff/pull/6345#discussion_r1285470953 to
use a single stack to store all statements and expressions, rather than
using separate vectors for each, which gives us something closer to a
full-fidelity chain. (We can then generalize this concept to include all
other AST nodes too.)

This is in part made possible by the removal of the hash map from
`&Stmt` to `StatementId` (#6694), which makes it much cheaper to store
these using a single interface (since doing so no longer introduces the
requirement that we hash all expressions).

I'll follow-up with some profiling, but a few notes on how the data
requirements have changed:

- We now store a `BranchId` for every expression, not just every
statement, so that's an extra `u32`.
- We now store a single `NodeId` on every snapshot, rather than separate
`StatementId` and `ExpressionId` IDs, so that's one fewer `u32` for each
snapshot.
- We're probably doing a few more lookups in general, since any calls to
`current_statement()` etc. now have to iterate up the node hierarchy
until they identify the first statement.

## Test Plan

`cargo test`
This commit is contained in:
Charlie Marsh 2023-08-21 21:32:57 -04:00 committed by GitHub
parent abc5065fc7
commit 424b8d4ad2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 268 additions and 299 deletions

View file

@ -18,14 +18,13 @@ use crate::binding::{
use crate::branches::{BranchId, Branches};
use crate::context::ExecutionContext;
use crate::definition::{Definition, DefinitionId, Definitions, Member, Module};
use crate::expressions::{ExpressionId, Expressions};
use crate::globals::{Globals, GlobalsArena};
use crate::nodes::{NodeId, NodeRef, Nodes};
use crate::reference::{
ResolvedReference, ResolvedReferenceId, ResolvedReferences, UnresolvedReference,
UnresolvedReferenceFlags, UnresolvedReferences,
};
use crate::scope::{Scope, ScopeId, ScopeKind, Scopes};
use crate::statements::{StatementId, Statements};
use crate::Imported;
/// A semantic model for a Python module, to enable querying the module's semantic information.
@ -33,17 +32,11 @@ pub struct SemanticModel<'a> {
typing_modules: &'a [String],
module_path: Option<&'a [String]>,
/// Stack of statements in the program.
statements: Statements<'a>,
/// Stack of all AST nodes in the program.
nodes: Nodes<'a>,
/// The ID of the current statement.
statement_id: Option<StatementId>,
/// Stack of expressions in the program.
expressions: Expressions<'a>,
/// The ID of the current expression.
expression_id: Option<ExpressionId>,
/// The ID of the current AST node.
node_id: Option<NodeId>,
/// Stack of all branches in the program.
branches: Branches,
@ -141,12 +134,10 @@ impl<'a> SemanticModel<'a> {
Self {
typing_modules,
module_path: module.path(),
statements: Statements::default(),
statement_id: None,
expressions: Expressions::default(),
expression_id: None,
branch_id: None,
nodes: Nodes::default(),
node_id: None,
branches: Branches::default(),
branch_id: None,
scopes: Scopes::default(),
scope_id: ScopeId::global(),
definitions: Definitions::for_module(module),
@ -236,7 +227,7 @@ impl<'a> SemanticModel<'a> {
flags,
references: Vec::new(),
scope: self.scope_id,
source: self.statement_id,
source: self.node_id,
context: self.execution_context(),
exceptions: self.exceptions(),
})
@ -728,7 +719,7 @@ impl<'a> SemanticModel<'a> {
{
return Some(ImportedName {
name: format!("{name}.{member}"),
range: self.statements[source].range(),
range: self.nodes[source].range(),
context: binding.context,
});
}
@ -752,7 +743,7 @@ impl<'a> SemanticModel<'a> {
{
return Some(ImportedName {
name: (*name).to_string(),
range: self.statements[source].range(),
range: self.nodes[source].range(),
context: binding.context,
});
}
@ -773,7 +764,7 @@ impl<'a> SemanticModel<'a> {
{
return Some(ImportedName {
name: format!("{name}.{member}"),
range: self.statements[source].range(),
range: self.nodes[source].range(),
context: binding.context,
});
}
@ -788,33 +779,15 @@ impl<'a> SemanticModel<'a> {
})
}
/// Push a [`Stmt`] onto the stack.
pub fn push_statement(&mut self, stmt: &'a Stmt) {
self.statement_id = Some(
self.statements
.insert(stmt, self.statement_id, self.branch_id),
);
/// Push an AST node [`NodeRef`] onto the stack.
pub fn push_node<T: Into<NodeRef<'a>>>(&mut self, node: T) {
self.node_id = Some(self.nodes.insert(node.into(), self.node_id, self.branch_id));
}
/// Pop the current [`Stmt`] off the stack.
pub fn pop_statement(&mut self) {
let node_id = self
.statement_id
.expect("Attempted to pop without statement");
self.statement_id = self.statements.parent_id(node_id);
}
/// Push a [`Expr`] onto the stack.
pub fn push_expression(&mut self, expr: &'a Expr) {
self.expression_id = Some(self.expressions.insert(expr, self.expression_id));
}
/// Pop the current [`Expr`] off the stack.
pub fn pop_expression(&mut self) {
let node_id = self
.expression_id
.expect("Attempted to pop without expression");
self.expression_id = self.expressions.parent_id(node_id);
/// Pop the current AST node [`NodeRef`] off the stack.
pub fn pop_node(&mut self) {
let node_id = self.node_id.expect("Attempted to pop without node");
self.node_id = self.nodes.parent_id(node_id);
}
/// Push a [`Scope`] with the given [`ScopeKind`] onto the stack.
@ -860,34 +833,20 @@ impl<'a> SemanticModel<'a> {
self.branch_id = branch_id;
}
/// Returns an [`Iterator`] over the current statement hierarchy represented as [`StatementId`],
/// from the current [`StatementId`] through to any parents.
pub fn current_statement_ids(&self) -> impl Iterator<Item = StatementId> + '_ {
self.statement_id
.iter()
.flat_map(|id| self.statements.ancestor_ids(*id))
}
/// Returns an [`Iterator`] over the current statement hierarchy, from the current [`Stmt`]
/// through to any parents.
pub fn current_statements(&self) -> impl Iterator<Item = &'a Stmt> + '_ {
self.current_statement_ids().map(|id| self.statements[id])
}
/// Return the [`StatementId`] of the current [`Stmt`].
pub fn current_statement_id(&self) -> StatementId {
self.statement_id.expect("No current statement")
}
/// Return the [`StatementId`] of the current [`Stmt`] parent, if any.
pub fn current_statement_parent_id(&self) -> Option<StatementId> {
self.current_statement_ids().nth(1)
let id = self.node_id.expect("No current node");
self.nodes
.ancestor_ids(id)
.filter_map(move |id| self.nodes[id].as_statement())
}
/// Return the current [`Stmt`].
pub fn current_statement(&self) -> &'a Stmt {
let node_id = self.statement_id.expect("No current statement");
self.statements[node_id]
self.current_statements()
.next()
.expect("No current statement")
}
/// Return the parent [`Stmt`] of the current [`Stmt`], if any.
@ -895,24 +854,18 @@ impl<'a> SemanticModel<'a> {
self.current_statements().nth(1)
}
/// Returns an [`Iterator`] over the current expression hierarchy represented as
/// [`ExpressionId`], from the current [`Expr`] through to any parents.
pub fn current_expression_ids(&self) -> impl Iterator<Item = ExpressionId> + '_ {
self.expression_id
.iter()
.flat_map(|id| self.expressions.ancestor_ids(*id))
}
/// Returns an [`Iterator`] over the current expression hierarchy, from the current [`Expr`]
/// through to any parents.
pub fn current_expressions(&self) -> impl Iterator<Item = &'a Expr> + '_ {
self.current_expression_ids().map(|id| self.expressions[id])
let id = self.node_id.expect("No current node");
self.nodes
.ancestor_ids(id)
.filter_map(move |id| self.nodes[id].as_expression())
}
/// Return the current [`Expr`].
pub fn current_expression(&self) -> Option<&'a Expr> {
let node_id = self.expression_id?;
Some(self.expressions[node_id])
self.current_expressions().next()
}
/// Return the parent [`Expr`] of the current [`Expr`], if any.
@ -925,6 +878,27 @@ impl<'a> SemanticModel<'a> {
self.current_expressions().nth(2)
}
/// Returns an [`Iterator`] over the current statement hierarchy represented as [`NodeId`],
/// from the current [`NodeId`] through to any parents.
pub fn current_statement_ids(&self) -> impl Iterator<Item = NodeId> + '_ {
self.node_id
.iter()
.flat_map(|id| self.nodes.ancestor_ids(*id))
.filter(|id| self.nodes[*id].is_statement())
}
/// Return the [`NodeId`] of the current [`Stmt`].
pub fn current_statement_id(&self) -> NodeId {
self.current_statement_ids()
.next()
.expect("No current statement")
}
/// Return the [`NodeId`] of the current [`Stmt`] parent, if any.
pub fn current_statement_parent_id(&self) -> Option<NodeId> {
self.current_statement_ids().nth(1)
}
/// Returns a reference to the global [`Scope`].
pub fn global_scope(&self) -> &Scope<'a> {
self.scopes.global()
@ -973,24 +947,36 @@ impl<'a> SemanticModel<'a> {
None
}
/// Return the [`Stmt]` corresponding to the given [`StatementId`].
/// Return the [`Stmt`] corresponding to the given [`NodeId`].
#[inline]
pub fn statement(&self, statement_id: StatementId) -> &'a Stmt {
self.statements[statement_id]
pub fn node(&self, node_id: NodeId) -> &NodeRef<'a> {
&self.nodes[node_id]
}
/// Return the [`Stmt`] corresponding to the given [`NodeId`].
#[inline]
pub fn statement(&self, node_id: NodeId) -> &'a Stmt {
self.nodes
.ancestor_ids(node_id)
.find_map(|id| self.nodes[id].as_statement())
.expect("No statement found")
}
/// Given a [`Stmt`], return its parent, if any.
#[inline]
pub fn parent_statement(&self, statement_id: StatementId) -> Option<&'a Stmt> {
self.statements
.parent_id(statement_id)
.map(|id| self.statements[id])
pub fn parent_statement(&self, node_id: NodeId) -> Option<&'a Stmt> {
self.nodes
.ancestor_ids(node_id)
.filter_map(|id| self.nodes[id].as_statement())
.nth(1)
}
/// Given a [`StatementId`], return the ID of its parent statement, if any.
#[inline]
pub fn parent_statement_id(&self, statement_id: StatementId) -> Option<StatementId> {
self.statements.parent_id(statement_id)
/// Given a [`NodeId`], return the [`NodeId`] of the parent statement, if any.
pub fn parent_statement_id(&self, node_id: NodeId) -> Option<NodeId> {
self.nodes
.ancestor_ids(node_id)
.filter(|id| self.nodes[*id].is_statement())
.nth(1)
}
/// Set the [`Globals`] for the current [`Scope`].
@ -1007,7 +993,7 @@ impl<'a> SemanticModel<'a> {
range: *range,
references: Vec::new(),
scope: self.scope_id,
source: self.statement_id,
source: self.node_id,
context: self.execution_context(),
exceptions: self.exceptions(),
flags: BindingFlags::empty(),
@ -1053,10 +1039,7 @@ impl<'a> SemanticModel<'a> {
/// Return `true` if the model is at the top level of the module (i.e., in the module scope,
/// and not nested within any statements).
pub fn at_top_level(&self) -> bool {
self.scope_id.is_global()
&& self
.statement_id
.map_or(true, |stmt_id| self.statements.parent_id(stmt_id).is_none())
self.scope_id.is_global() && self.current_statement_parent_id().is_none()
}
/// Return `true` if the model is in an async context.
@ -1101,10 +1084,10 @@ impl<'a> SemanticModel<'a> {
/// `try` statement.
///
/// This implementation assumes that the statements are in the same scope.
pub fn different_branches(&self, left: StatementId, right: StatementId) -> bool {
pub fn different_branches(&self, left: NodeId, right: NodeId) -> bool {
// Collect the branch path for the left statement.
let left = self
.statements
.nodes
.branch_id(left)
.iter()
.flat_map(|branch_id| self.branches.ancestor_ids(*branch_id))
@ -1112,7 +1095,7 @@ impl<'a> SemanticModel<'a> {
// Collect the branch path for the right statement.
let right = self
.statements
.nodes
.branch_id(right)
.iter()
.flat_map(|branch_id| self.branches.ancestor_ids(*branch_id))
@ -1191,8 +1174,7 @@ impl<'a> SemanticModel<'a> {
pub fn snapshot(&self) -> Snapshot {
Snapshot {
scope_id: self.scope_id,
stmt_id: self.statement_id,
expr_id: self.expression_id,
node_id: self.node_id,
branch_id: self.branch_id,
definition_id: self.definition_id,
flags: self.flags,
@ -1203,15 +1185,13 @@ impl<'a> SemanticModel<'a> {
pub fn restore(&mut self, snapshot: Snapshot) {
let Snapshot {
scope_id,
stmt_id,
expr_id,
node_id,
branch_id,
definition_id,
flags,
} = snapshot;
self.scope_id = scope_id;
self.statement_id = stmt_id;
self.expression_id = expr_id;
self.node_id = node_id;
self.branch_id = branch_id;
self.definition_id = definition_id;
self.flags = flags;
@ -1625,8 +1605,7 @@ impl SemanticModelFlags {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Snapshot {
scope_id: ScopeId,
stmt_id: Option<StatementId>,
expr_id: Option<ExpressionId>,
node_id: Option<NodeId>,
branch_id: Option<BranchId>,
definition_id: DefinitionId,
flags: SemanticModelFlags,