Add branch detection to the semantic model (#6694)

## Summary

We have a few rules that rely on detecting whether two statements are in
different branches -- for example, different arms of an `if`-`else`.
Historically, the way this was implemented is that, given two statement
IDs, we'd find the common parent (by traversing upwards via our
`Statements` abstraction); then identify branches "manually" by matching
the parents against `try`, `if`, and `match`, and returning iterators
over the arms; then check if there's an arm for which one of the
statements is a child, and the other is not.

This has a few drawbacks:

1. First, the code is generally a bit hard to follow (Konsti mentioned
this too when working on the `ElifElseClause` refactor).

2. Second, this is the only place in the codebase where we need to go
from `&Stmt` to `StatementID` -- _everywhere_ else, we only need to go
in the _other_ direction. Supporting these lookups means we need to
maintain a mapping from `&Stmt` to `StatementID` that includes every
`&Stmt` in the program. (We _also_ end up maintaining a `depth` level
for every statement.) I'd like to get rid of these requirements to
improve efficiency, reduce complexity, and enable us to treat AST modes
more generically in the future. (When I looked at adding the `&Expr` to
our existing statement-tracking infrastructure, maintaining a hash map
with all the statements noticeably hurt performance.)

The solution implemented here instead makes branches a first-class
concept in the semantic model. Like with `Statements`, we now have a
`Branches` abstraction, where each branch points to its optional parent.
When we store statements, we store the `BranchID` alongside each
statement. When we need to detect whether two statements are in the same
branch, we just realize each statement's branch path and compare the
two. (Assuming that the two statements are in the same scope, then
they're on the same branch IFF one branch path is a subset of the other,
starting from the top.) We then add some calls to the visitor to push
and pop branches in the appropriate places, for `if`, `try`, and `match`
statements.

Note that a branch is not 1:1 with a statement; instead, each branch is
closer to a suite, but not _every_ suite is a branch. For example, each
arm in an `if`-`elif`-`else` is a branch, but the `else` in a `for` loop
is not considered a branch.

In addition to being much simpler, this should also be more efficient,
since we've shed the entire `&Stmt` hash map, plus the `depth` that we
track on `StatementWithParent` in favor of a single `Option<BranchID>`
on `StatementWithParent` plus a single vector for all branches. The
lookups should be faster too, since instead of doing a bunch of jumps
around with the hash map + repeated recursive calls to find the common
parents, we instead just do a few simple lookups in the `Branches`
vector to realize and compare the branch paths.

## Test Plan

`cargo test` -- we have a lot of coverage for this, which we inherited
from PyFlakes
This commit is contained in:
Charlie Marsh 2023-08-19 17:28:17 -04:00 committed by GitHub
parent 648333b8b2
commit 17af12e57c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 283 additions and 298 deletions

View file

@ -15,6 +15,7 @@ use crate::binding::{
Binding, BindingFlags, BindingId, BindingKind, Bindings, Exceptions, FromImport, Import,
SubmoduleImport,
};
use crate::branches::{BranchId, Branches};
use crate::context::ExecutionContext;
use crate::definition::{Definition, DefinitionId, Definitions, Member, Module};
use crate::expressions::{ExpressionId, Expressions};
@ -32,18 +33,24 @@ pub struct SemanticModel<'a> {
typing_modules: &'a [String],
module_path: Option<&'a [String]>,
/// Stack of all visited statements.
/// Stack of statements in the program.
statements: Statements<'a>,
/// The identifier of the current statement.
/// The ID of the current statement.
statement_id: Option<StatementId>,
/// Stack of all visited expressions.
/// Stack of expressions in the program.
expressions: Expressions<'a>,
/// The identifier of the current expression.
/// The ID of the current expression.
expression_id: Option<ExpressionId>,
/// Stack of all branches in the program.
branches: Branches,
/// The ID of the current branch.
branch_id: Option<BranchId>,
/// Stack of all scopes, along with the identifier of the current scope.
pub scopes: Scopes<'a>,
pub scope_id: ScopeId,
@ -138,6 +145,8 @@ impl<'a> SemanticModel<'a> {
statement_id: None,
expressions: Expressions::default(),
expression_id: None,
branch_id: None,
branches: Branches::default(),
scopes: Scopes::default(),
scope_id: ScopeId::global(),
definitions: Definitions::for_module(module),
@ -781,7 +790,10 @@ impl<'a> SemanticModel<'a> {
/// Push a [`Stmt`] onto the stack.
pub fn push_statement(&mut self, stmt: &'a Stmt) {
self.statement_id = Some(self.statements.insert(stmt, self.statement_id));
self.statement_id = Some(
self.statements
.insert(stmt, self.statement_id, self.branch_id),
);
}
/// Pop the current [`Stmt`] off the stack.
@ -831,54 +843,78 @@ impl<'a> SemanticModel<'a> {
self.definition_id = member.parent;
}
/// Return the current `Stmt`.
pub fn current_statement(&self) -> &'a Stmt {
let node_id = self.statement_id.expect("No current statement");
self.statements[node_id]
/// Push a new branch onto the stack, returning its [`BranchId`].
pub fn push_branch(&mut self) -> Option<BranchId> {
self.branch_id = Some(self.branches.insert(self.branch_id));
self.branch_id
}
/// Pop the current [`BranchId`] off the stack.
pub fn pop_branch(&mut self) {
let node_id = self.branch_id.expect("Attempted to pop without branch");
self.branch_id = self.branches.parent_id(node_id);
}
/// Set the current [`BranchId`].
pub fn set_branch(&mut self, branch_id: Option<BranchId>) {
self.branch_id = branch_id;
}
/// Returns an [`Iterator`] over the current statement hierarchy represented as [`StatementId`],
/// from the current [`StatementId`] through to any parents.
pub fn current_statement_ids(&self) -> impl Iterator<Item = StatementId> + '_ {
self.statement_id
.iter()
.flat_map(|id| self.statements.ancestor_ids(*id))
}
/// Returns an [`Iterator`] over the current statement hierarchy, from the current [`Stmt`]
/// through to any parents.
pub fn current_statements(&self) -> impl Iterator<Item = &'a Stmt> + '_ {
self.statement_id
.iter()
.flat_map(|id| {
self.statements
.ancestor_ids(*id)
.map(|id| &self.statements[id])
})
.copied()
self.current_statement_ids().map(|id| self.statements[id])
}
/// Return the parent `Stmt` of the current `Stmt`, if any.
/// Return the [`StatementId`] of the current [`Stmt`].
pub fn current_statement_id(&self) -> StatementId {
self.statement_id.expect("No current statement")
}
/// Return the [`StatementId`] of the current [`Stmt`] parent, if any.
pub fn current_statement_parent_id(&self) -> Option<StatementId> {
self.current_statement_ids().nth(1)
}
/// Return the current [`Stmt`].
pub fn current_statement(&self) -> &'a Stmt {
let node_id = self.statement_id.expect("No current statement");
self.statements[node_id]
}
/// Return the parent [`Stmt`] of the current [`Stmt`], if any.
pub fn current_statement_parent(&self) -> Option<&'a Stmt> {
self.current_statements().nth(1)
}
/// Return the grandparent `Stmt` of the current `Stmt`, if any.
pub fn current_statement_grandparent(&self) -> Option<&'a Stmt> {
self.current_statements().nth(2)
/// Returns an [`Iterator`] over the current expression hierarchy represented as
/// [`ExpressionId`], from the current [`Expr`] through to any parents.
pub fn current_expression_ids(&self) -> impl Iterator<Item = ExpressionId> + '_ {
self.expression_id
.iter()
.flat_map(|id| self.expressions.ancestor_ids(*id))
}
/// Return the current `Expr`.
/// Returns an [`Iterator`] over the current expression hierarchy, from the current [`Expr`]
/// through to any parents.
pub fn current_expressions(&self) -> impl Iterator<Item = &'a Expr> + '_ {
self.current_expression_ids().map(|id| self.expressions[id])
}
/// Return the current [`Expr`].
pub fn current_expression(&self) -> Option<&'a Expr> {
let node_id = self.expression_id?;
Some(self.expressions[node_id])
}
/// Returns an [`Iterator`] over the current statement hierarchy, from the current [`Expr`]
/// through to any parents.
pub fn current_expressions(&self) -> impl Iterator<Item = &'a Expr> + '_ {
self.expression_id
.iter()
.flat_map(|id| {
self.expressions
.ancestor_ids(*id)
.map(|id| &self.expressions[id])
})
.copied()
}
/// Return the parent [`Expr`] of the current [`Expr`], if any.
pub fn current_expression_parent(&self) -> Option<&'a Expr> {
self.current_expressions().nth(1)
@ -937,17 +973,6 @@ impl<'a> SemanticModel<'a> {
None
}
/// Return the [`Statements`] vector of all statements.
pub const fn statements(&self) -> &Statements<'a> {
&self.statements
}
/// Return the [`StatementId`] corresponding to the given [`Stmt`].
#[inline]
pub fn statement_id(&self, statement: &Stmt) -> Option<StatementId> {
self.statements.statement_id(statement)
}
/// Return the [`Stmt]` corresponding to the given [`StatementId`].
#[inline]
pub fn statement(&self, statement_id: StatementId) -> &'a Stmt {
@ -962,6 +987,12 @@ impl<'a> SemanticModel<'a> {
.map(|id| self.statements[id])
}
/// Given a [`StatementId`], return the ID of its parent statement, if any.
#[inline]
pub fn parent_statement_id(&self, statement_id: StatementId) -> Option<StatementId> {
self.statements.parent_id(statement_id)
}
/// Set the [`Globals`] for the current [`Scope`].
pub fn set_globals(&mut self, globals: Globals<'a>) {
// If any global bindings don't already exist in the global scope, add them.
@ -1066,6 +1097,33 @@ impl<'a> SemanticModel<'a> {
false
}
/// Returns `true` if `left` and `right` are on different branches of an `if`, `match`, or
/// `try` statement.
///
/// This implementation assumes that the statements are in the same scope.
pub fn different_branches(&self, left: StatementId, right: StatementId) -> bool {
// Collect the branch path for the left statement.
let left = self
.statements
.branch_id(left)
.iter()
.flat_map(|branch_id| self.branches.ancestor_ids(*branch_id))
.collect::<Vec<_>>();
// Collect the branch path for the right statement.
let right = self
.statements
.branch_id(right)
.iter()
.flat_map(|branch_id| self.branches.ancestor_ids(*branch_id))
.collect::<Vec<_>>();
!left
.iter()
.zip(right.iter())
.all(|(left, right)| left == right)
}
/// Returns `true` if the given [`BindingId`] is used.
pub fn is_used(&self, binding_id: BindingId) -> bool {
self.bindings[binding_id].is_used()