ruff/crates/red_knot_python_semantic/src/ast_node_ref.rs
Carl Meyer 595b1aa4a1
[red-knot] per-definition inference, use-def maps (#12269)
Implements definition-level type inference, with basic control flow
(only if statements and if expressions so far) in Salsa.

There are a couple key ideas here:

1) We can do type inference queries at any of three region
granularities: an entire scope, a single definition, or a single
expression. These are represented by the `InferenceRegion` enum, and the
entry points are the salsa queries `infer_scope_types`,
`infer_definition_types`, and `infer_expression_types`. Generally
per-scope will be used for scopes that we are directly checking and
per-definition will be used anytime we are looking up symbol types from
another module/scope. Per-expression should be uncommon: used only for
the RHS of an unpacking or multi-target assignment (to avoid
re-inferring the RHS once per symbol defined in the assignment) and for
test nodes in type narrowing (e.g. the `test` of an `If` node). All
three queries return a `TypeInference` with a map of types for all
definitions and expressions within their region. If you do e.g.
scope-level inference, when it hits a definition, or an
independently-inferable expression, it should use the relevant query
(which may already be cached) to get all types within the smaller
region. This avoids double-inferring smaller regions, even though larger
regions encompass smaller ones.

2) Instead of building a control-flow graph and lazily traversing it to
find definitions which reach a use of a name (which is O(n^2) in the
worst case), instead semantic indexing builds a use-def map, where every
use of a name knows which definitions can reach that use. We also no
longer track all definitions of a symbol in the symbol itself; instead
the use-def map also records which defs remain visible at the end of the
scope, and considers these the publicly-visible definitions of the
symbol (see below).

Major items left as TODOs in this PR, to be done in follow-up PRs:

1) Free/global references aren't supported yet (only lookup based on
definitions in current scope), which means the override-check example
doesn't currently work. This is the first thing I'll fix as follow-up to
this PR.

2) Control flow outside of if statements and expressions.

3) Type narrowing.

There are also some smaller relevant changes here:

1) Eliminate `Option` in the return type of member lookups; instead
always return `Type::Unbound` for a name we can't find. Also use
`Type::Unbound` for modules we can't resolve (not 100% sure about this
one yet.)

2) Eliminate the use of the terms "public" and "root" to refer to
module-global scope or symbols. Instead consistently use the term
"module-global". It's longer, but it's the clearest, and the most
consistent with typical Python terminology. In particular I don't like
"public" for this use because it has other implications around author
intent (is an underscore-prefixed module-global symbol "public"?). And
"root" is just not commonly used for this in Python.

3) Eliminate the `PublicSymbol` Salsa ingredient. Many non-module-global
symbols can also be seen from other scopes (e.g. by a free var in a
nested scope, or by class attribute access), and thus need to have a
"public type" (that is, the type not as seen from a particular use in
the control flow of the same scope, but the type as seen from some other
scope.) So all symbols need to have a "public type" (here I want to keep
the use of the term "public", unless someone has a better term to
suggest -- since it's "public type of a symbol" and not "public symbol"
the confusion with e.g. initial underscores is less of an issue.) At
least initially, I would like to try not having special handling for
module-global symbols vs other symbols.

4) Switch to using "definitions that reach end of scope" rather than
"all definitions" in determining the public type of a symbol. I'm
convinced that in general this is the right way to go. We may want to
refine this further in future for some free-variable cases, but it can
be changed purely by making changes to the building of the use-def map
(the `public_definitions` index in it), without affecting any other
code. One consequence of combining this with no control-flow support
(just last-definition-wins) is that some inference tests now give more
wrong-looking results; I left TODO comments on these tests to fix them
when control flow is added.

And some potential areas for consideration in the future:

1) Should `symbol_ty` be a Salsa query? This would require making all
symbols a Salsa ingredient, and tracking even more dependencies. But it
would save some repeated reconstruction of unions, for symbols with
multiple public definitions. For now I'm not making it a query, but open
to changing this in future with actual perf evidence that it's better.
2024-07-16 11:02:30 -07:00

163 lines
4.7 KiB
Rust

use std::hash::Hash;
use std::ops::Deref;
use ruff_db::parsed::ParsedModule;
/// Ref-counted owned reference to an AST node.
///
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
/// node must still be valid.
///
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
///
/// ## Equality
/// Two `AstNodeRef` are considered equal if their wrapped nodes are equal.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
_parsed: ParsedModule,
/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
}
#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to
/// which the `AstNodeRef` belongs.
///
/// ## Safety
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
/// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
_parsed: parsed,
node: std::ptr::NonNull::from(node),
}
}
/// Returns a reference to the wrapped node.
pub fn node(&self) -> &T {
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still
// alive and not moved.
unsafe { self.node.as_ref() }
}
}
impl<T> Deref for AstNodeRef<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.node()
}
}
impl<T> std::fmt::Debug for AstNodeRef<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
}
}
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.node().eq(other.node())
}
}
impl<T> Eq for AstNodeRef<T> where T: Eq {}
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node().hash(state);
}
}
#[allow(unsafe_code)]
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
#[allow(unsafe_code)]
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
#[cfg(test)]
mod tests {
use crate::ast_node_ref::AstNodeRef;
use ruff_db::parsed::ParsedModule;
use ruff_python_ast::PySourceType;
use ruff_python_parser::parse_unchecked_source;
#[test]
#[allow(unsafe_code)]
fn equality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
assert_eq!(node1, node2);
// Compare from different trees
let cloned = ParsedModule::new(parsed_raw);
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
assert_eq!(node1, cloned_node);
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node1, other_node);
}
#[allow(unsafe_code)]
#[test]
fn inequality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node, other_node);
}
#[test]
#[allow(unsafe_code)]
fn debug() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let debug = format!("{stmt_node:?}");
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
}
}