mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 13:51:37 +00:00

Implements definition-level type inference, with basic control flow (only if statements and if expressions so far) in Salsa. There are a couple key ideas here: 1) We can do type inference queries at any of three region granularities: an entire scope, a single definition, or a single expression. These are represented by the `InferenceRegion` enum, and the entry points are the salsa queries `infer_scope_types`, `infer_definition_types`, and `infer_expression_types`. Generally per-scope will be used for scopes that we are directly checking and per-definition will be used anytime we are looking up symbol types from another module/scope. Per-expression should be uncommon: used only for the RHS of an unpacking or multi-target assignment (to avoid re-inferring the RHS once per symbol defined in the assignment) and for test nodes in type narrowing (e.g. the `test` of an `If` node). All three queries return a `TypeInference` with a map of types for all definitions and expressions within their region. If you do e.g. scope-level inference, when it hits a definition, or an independently-inferable expression, it should use the relevant query (which may already be cached) to get all types within the smaller region. This avoids double-inferring smaller regions, even though larger regions encompass smaller ones. 2) Instead of building a control-flow graph and lazily traversing it to find definitions which reach a use of a name (which is O(n^2) in the worst case), instead semantic indexing builds a use-def map, where every use of a name knows which definitions can reach that use. We also no longer track all definitions of a symbol in the symbol itself; instead the use-def map also records which defs remain visible at the end of the scope, and considers these the publicly-visible definitions of the symbol (see below). Major items left as TODOs in this PR, to be done in follow-up PRs: 1) Free/global references aren't supported yet (only lookup based on definitions in current scope), which means the override-check example doesn't currently work. This is the first thing I'll fix as follow-up to this PR. 2) Control flow outside of if statements and expressions. 3) Type narrowing. There are also some smaller relevant changes here: 1) Eliminate `Option` in the return type of member lookups; instead always return `Type::Unbound` for a name we can't find. Also use `Type::Unbound` for modules we can't resolve (not 100% sure about this one yet.) 2) Eliminate the use of the terms "public" and "root" to refer to module-global scope or symbols. Instead consistently use the term "module-global". It's longer, but it's the clearest, and the most consistent with typical Python terminology. In particular I don't like "public" for this use because it has other implications around author intent (is an underscore-prefixed module-global symbol "public"?). And "root" is just not commonly used for this in Python. 3) Eliminate the `PublicSymbol` Salsa ingredient. Many non-module-global symbols can also be seen from other scopes (e.g. by a free var in a nested scope, or by class attribute access), and thus need to have a "public type" (that is, the type not as seen from a particular use in the control flow of the same scope, but the type as seen from some other scope.) So all symbols need to have a "public type" (here I want to keep the use of the term "public", unless someone has a better term to suggest -- since it's "public type of a symbol" and not "public symbol" the confusion with e.g. initial underscores is less of an issue.) At least initially, I would like to try not having special handling for module-global symbols vs other symbols. 4) Switch to using "definitions that reach end of scope" rather than "all definitions" in determining the public type of a symbol. I'm convinced that in general this is the right way to go. We may want to refine this further in future for some free-variable cases, but it can be changed purely by making changes to the building of the use-def map (the `public_definitions` index in it), without affecting any other code. One consequence of combining this with no control-flow support (just last-definition-wins) is that some inference tests now give more wrong-looking results; I left TODO comments on these tests to fix them when control flow is added. And some potential areas for consideration in the future: 1) Should `symbol_ty` be a Salsa query? This would require making all symbols a Salsa ingredient, and tracking even more dependencies. But it would save some repeated reconstruction of unions, for symbols with multiple public definitions. For now I'm not making it a query, but open to changing this in future with actual perf evidence that it's better.
163 lines
4.7 KiB
Rust
163 lines
4.7 KiB
Rust
use std::hash::Hash;
|
|
use std::ops::Deref;
|
|
|
|
use ruff_db::parsed::ParsedModule;
|
|
|
|
/// Ref-counted owned reference to an AST node.
|
|
///
|
|
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
|
|
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
|
|
/// node must still be valid.
|
|
///
|
|
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
|
|
///
|
|
/// ## Equality
|
|
/// Two `AstNodeRef` are considered equal if their wrapped nodes are equal.
|
|
#[derive(Clone)]
|
|
pub struct AstNodeRef<T> {
|
|
/// Owned reference to the node's [`ParsedModule`].
|
|
///
|
|
/// The node's reference is guaranteed to remain valid as long as it's enclosing
|
|
/// [`ParsedModule`] is alive.
|
|
_parsed: ParsedModule,
|
|
|
|
/// Pointer to the referenced node.
|
|
node: std::ptr::NonNull<T>,
|
|
}
|
|
|
|
#[allow(unsafe_code)]
|
|
impl<T> AstNodeRef<T> {
|
|
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to
|
|
/// which the `AstNodeRef` belongs.
|
|
///
|
|
/// ## Safety
|
|
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
|
|
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
|
|
/// the invariant `node belongs to parsed` is upheld.
|
|
|
|
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
|
|
Self {
|
|
_parsed: parsed,
|
|
node: std::ptr::NonNull::from(node),
|
|
}
|
|
}
|
|
|
|
/// Returns a reference to the wrapped node.
|
|
pub fn node(&self) -> &T {
|
|
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still
|
|
// alive and not moved.
|
|
unsafe { self.node.as_ref() }
|
|
}
|
|
}
|
|
|
|
impl<T> Deref for AstNodeRef<T> {
|
|
type Target = T;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
self.node()
|
|
}
|
|
}
|
|
|
|
impl<T> std::fmt::Debug for AstNodeRef<T>
|
|
where
|
|
T: std::fmt::Debug,
|
|
{
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
|
|
}
|
|
}
|
|
|
|
impl<T> PartialEq for AstNodeRef<T>
|
|
where
|
|
T: PartialEq,
|
|
{
|
|
fn eq(&self, other: &Self) -> bool {
|
|
self.node().eq(other.node())
|
|
}
|
|
}
|
|
|
|
impl<T> Eq for AstNodeRef<T> where T: Eq {}
|
|
|
|
impl<T> Hash for AstNodeRef<T>
|
|
where
|
|
T: Hash,
|
|
{
|
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
self.node().hash(state);
|
|
}
|
|
}
|
|
|
|
#[allow(unsafe_code)]
|
|
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
|
|
#[allow(unsafe_code)]
|
|
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::ast_node_ref::AstNodeRef;
|
|
use ruff_db::parsed::ParsedModule;
|
|
use ruff_python_ast::PySourceType;
|
|
use ruff_python_parser::parse_unchecked_source;
|
|
|
|
#[test]
|
|
#[allow(unsafe_code)]
|
|
fn equality() {
|
|
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
|
|
let parsed = ParsedModule::new(parsed_raw.clone());
|
|
|
|
let stmt = &parsed.syntax().body[0];
|
|
|
|
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
|
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
|
|
|
assert_eq!(node1, node2);
|
|
|
|
// Compare from different trees
|
|
let cloned = ParsedModule::new(parsed_raw);
|
|
let stmt_cloned = &cloned.syntax().body[0];
|
|
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
|
|
|
|
assert_eq!(node1, cloned_node);
|
|
|
|
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
|
|
let other = ParsedModule::new(other_raw);
|
|
|
|
let other_stmt = &other.syntax().body[0];
|
|
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
|
|
|
|
assert_ne!(node1, other_node);
|
|
}
|
|
|
|
#[allow(unsafe_code)]
|
|
#[test]
|
|
fn inequality() {
|
|
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
|
|
let parsed = ParsedModule::new(parsed_raw.clone());
|
|
|
|
let stmt = &parsed.syntax().body[0];
|
|
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
|
|
|
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
|
|
let other = ParsedModule::new(other_raw);
|
|
|
|
let other_stmt = &other.syntax().body[0];
|
|
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
|
|
|
|
assert_ne!(node, other_node);
|
|
}
|
|
|
|
#[test]
|
|
#[allow(unsafe_code)]
|
|
fn debug() {
|
|
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
|
|
let parsed = ParsedModule::new(parsed_raw.clone());
|
|
|
|
let stmt = &parsed.syntax().body[0];
|
|
|
|
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
|
|
|
|
let debug = format!("{stmt_node:?}");
|
|
|
|
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
|
|
}
|
|
}
|