[ty] Add infrastructure for AST garbage collection (#18445)

## Summary

https://github.com/astral-sh/ty/issues/214 will require a couple
invasive changes that I would like to get merged even before garbage
collection is fully implemented (to avoid rebasing):
- `ParsedModule` can no longer be dereferenced directly. Instead you
need to load a `ParsedModuleRef` to access the AST, which requires a
reference to the salsa database (as it may require re-parsing the AST if
it was collected).
- `AstNodeRef` can only be dereferenced with the `node` method, which
takes a reference to the `ParsedModuleRef`. This allows us to encode the
fact that ASTs do not live as long as the database and may be collected
as soon a given instance of a `ParsedModuleRef` is dropped. There are a
number of places where we currently merge the `'db` and `'ast`
lifetimes, so this requires giving some types/functions two separate
lifetime parameters.
This commit is contained in:
Ibraheem Ahmed 2025-06-05 11:43:18 -04:00 committed by GitHub
parent 55100209c7
commit 8531f4b3ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 886 additions and 689 deletions

View file

@ -1,15 +1,14 @@
use std::hash::Hash;
use std::ops::Deref;
use std::sync::Arc;
use ruff_db::parsed::ParsedModule;
use ruff_db::parsed::ParsedModuleRef;
/// Ref-counted owned reference to an AST node.
///
/// The type holds an owned reference to the node's ref-counted [`ParsedModule`].
/// Holding on to the node's [`ParsedModule`] guarantees that the reference to the
/// The type holds an owned reference to the node's ref-counted [`ParsedModuleRef`].
/// Holding on to the node's [`ParsedModuleRef`] guarantees that the reference to the
/// node must still be valid.
///
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModule`] from being released.
/// Holding on to any [`AstNodeRef`] prevents the [`ParsedModuleRef`] from being released.
///
/// ## Equality
/// Two `AstNodeRef` are considered equal if their pointer addresses are equal.
@ -33,11 +32,11 @@ use ruff_db::parsed::ParsedModule;
/// run on every AST change. All other queries only run when the expression's identity changes.
#[derive(Clone)]
pub struct AstNodeRef<T> {
/// Owned reference to the node's [`ParsedModule`].
/// Owned reference to the node's [`ParsedModuleRef`].
///
/// The node's reference is guaranteed to remain valid as long as it's enclosing
/// [`ParsedModule`] is alive.
parsed: ParsedModule,
/// [`ParsedModuleRef`] is alive.
parsed: ParsedModuleRef,
/// Pointer to the referenced node.
node: std::ptr::NonNull<T>,
@ -45,15 +44,15 @@ pub struct AstNodeRef<T> {
#[expect(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that references `node`. The `parsed` is the [`ParsedModule`] to
/// Creates a new `AstNodeRef` that references `node`. The `parsed` is the [`ParsedModuleRef`] to
/// which the `AstNodeRef` belongs.
///
/// ## Safety
///
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
/// [`ParsedModuleRef`] to which `node` belongs. It's the caller's responsibility to ensure that
/// the invariant `node belongs to parsed` is upheld.
pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
pub(super) unsafe fn new(parsed: ParsedModuleRef, node: &T) -> Self {
Self {
parsed,
node: std::ptr::NonNull::from(node),
@ -61,54 +60,26 @@ impl<T> AstNodeRef<T> {
}
/// Returns a reference to the wrapped node.
pub const fn node(&self) -> &T {
///
/// Note that this method will panic if the provided module is from a different file or Salsa revision
/// than the module this node was created with.
pub fn node<'ast>(&self, parsed: &'ast ParsedModuleRef) -> &'ast T {
debug_assert!(Arc::ptr_eq(self.parsed.as_arc(), parsed.as_arc()));
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still
// alive and not moved.
unsafe { self.node.as_ref() }
}
}
impl<T> Deref for AstNodeRef<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.node()
}
}
impl<T> std::fmt::Debug for AstNodeRef<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AstNodeRef").field(&self.node()).finish()
}
}
impl<T> PartialEq for AstNodeRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
if self.parsed == other.parsed {
// Comparing the pointer addresses is sufficient to determine equality
// if the parsed are the same.
self.node.eq(&other.node)
} else {
// Otherwise perform a deep comparison.
self.node().eq(other.node())
}
}
}
impl<T> Eq for AstNodeRef<T> where T: Eq {}
impl<T> Hash for AstNodeRef<T>
where
T: Hash,
{
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node().hash(state);
f.debug_tuple("AstNodeRef")
.field(self.node(&self.parsed))
.finish()
}
}
@ -117,7 +88,9 @@ unsafe impl<T> salsa::Update for AstNodeRef<T> {
unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool {
let old_ref = unsafe { &mut (*old_pointer) };
if old_ref.parsed == new_value.parsed && old_ref.node.eq(&new_value.node) {
if Arc::ptr_eq(old_ref.parsed.as_arc(), new_value.parsed.as_arc())
&& old_ref.node.eq(&new_value.node)
{
false
} else {
*old_ref = new_value;
@ -130,73 +103,3 @@ unsafe impl<T> salsa::Update for AstNodeRef<T> {
unsafe impl<T> Send for AstNodeRef<T> where T: Send {}
#[expect(unsafe_code)]
unsafe impl<T> Sync for AstNodeRef<T> where T: Sync {}
#[cfg(test)]
mod tests {
use crate::ast_node_ref::AstNodeRef;
use ruff_db::parsed::ParsedModule;
use ruff_python_ast::PySourceType;
use ruff_python_parser::parse_unchecked_source;
#[test]
#[expect(unsafe_code)]
fn equality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw.clone());
let stmt = &parsed.syntax().body[0];
let node1 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let node2 = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
assert_eq!(node1, node2);
// Compare from different trees
let cloned = ParsedModule::new(parsed_raw);
let stmt_cloned = &cloned.syntax().body[0];
let cloned_node = unsafe { AstNodeRef::new(cloned.clone(), stmt_cloned) };
assert_eq!(node1, cloned_node);
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node1, other_node);
}
#[expect(unsafe_code)]
#[test]
fn inequality() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw);
let stmt = &parsed.syntax().body[0];
let node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let other_raw = parse_unchecked_source("2 + 2", PySourceType::Python);
let other = ParsedModule::new(other_raw);
let other_stmt = &other.syntax().body[0];
let other_node = unsafe { AstNodeRef::new(other.clone(), other_stmt) };
assert_ne!(node, other_node);
}
#[test]
#[expect(unsafe_code)]
fn debug() {
let parsed_raw = parse_unchecked_source("1 + 2", PySourceType::Python);
let parsed = ParsedModule::new(parsed_raw);
let stmt = &parsed.syntax().body[0];
let stmt_node = unsafe { AstNodeRef::new(parsed.clone(), stmt) };
let debug = format!("{stmt_node:?}");
assert_eq!(debug, format!("AstNodeRef({stmt:?})"));
}
}

View file

@ -32,7 +32,7 @@ fn dunder_all_names_cycle_initial(_db: &dyn Db, _file: File) -> Option<FxHashSet
pub(crate) fn dunder_all_names(db: &dyn Db, file: File) -> Option<FxHashSet<Name>> {
let _span = tracing::trace_span!("dunder_all_names", file=?file.path(db)).entered();
let module = parsed_module(db.upcast(), file);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let index = semantic_index(db, file);
let mut collector = DunderAllNamesCollector::new(db, file, index);
collector.visit_body(module.suite());

View file

@ -50,9 +50,9 @@ type PlaceSet = hashbrown::HashMap<ScopedPlaceId, (), FxBuildHasher>;
pub(crate) fn semantic_index(db: &dyn Db, file: File) -> SemanticIndex<'_> {
let _span = tracing::trace_span!("semantic_index", ?file).entered();
let parsed = parsed_module(db.upcast(), file);
let module = parsed_module(db.upcast(), file).load(db.upcast());
SemanticIndexBuilder::new(db, file, parsed).build()
SemanticIndexBuilder::new(db, file, &module).build()
}
/// Returns the place table for a specific `scope`.
@ -129,10 +129,11 @@ pub(crate) fn attribute_scopes<'db, 's>(
class_body_scope: ScopeId<'db>,
) -> impl Iterator<Item = FileScopeId> + use<'s, 'db> {
let file = class_body_scope.file(db);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let index = semantic_index(db, file);
let class_scope_id = class_body_scope.file_scope_id(db);
ChildrenIter::new(index, class_scope_id).filter_map(|(child_scope_id, scope)| {
ChildrenIter::new(index, class_scope_id).filter_map(move |(child_scope_id, scope)| {
let (function_scope_id, function_scope) =
if scope.node().scope_kind() == ScopeKind::Annotation {
// This could be a generic method with a type-params scope.
@ -144,7 +145,7 @@ pub(crate) fn attribute_scopes<'db, 's>(
(child_scope_id, scope)
};
function_scope.node().as_function()?;
function_scope.node().as_function(&module)?;
Some(function_scope_id)
})
}
@ -559,7 +560,7 @@ impl FusedIterator for ChildrenIter<'_> {}
#[cfg(test)]
mod tests {
use ruff_db::files::{File, system_path_to_file};
use ruff_db::parsed::parsed_module;
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast::{self as ast};
use ruff_text_size::{Ranged, TextRange};
@ -742,6 +743,7 @@ y = 2
assert_eq!(names(global_table), vec!["C", "y"]);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let [(class_scope_id, class_scope)] = index
@ -751,7 +753,10 @@ y = 2
panic!("expected one child scope")
};
assert_eq!(class_scope.kind(), ScopeKind::Class);
assert_eq!(class_scope_id.to_scope_id(&db, file).name(&db), "C");
assert_eq!(
class_scope_id.to_scope_id(&db, file).name(&db, &module),
"C"
);
let class_table = index.place_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]);
@ -772,6 +777,7 @@ def func():
y = 2
",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let global_table = index.place_table(FileScopeId::global());
@ -784,7 +790,10 @@ y = 2
panic!("expected one child scope")
};
assert_eq!(function_scope.kind(), ScopeKind::Function);
assert_eq!(function_scope_id.to_scope_id(&db, file).name(&db), "func");
assert_eq!(
function_scope_id.to_scope_id(&db, file).name(&db, &module),
"func"
);
let function_table = index.place_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]);
@ -921,6 +930,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let global_table = index.place_table(FileScopeId::global());
@ -935,7 +945,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
assert_eq!(comprehension_scope.kind(), ScopeKind::Comprehension);
assert_eq!(
comprehension_scope_id.to_scope_id(&db, file).name(&db),
comprehension_scope_id
.to_scope_id(&db, file)
.name(&db, &module),
"<listcomp>"
);
@ -979,8 +991,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let use_def = index.use_def_map(comprehension_scope_id);
let module = parsed_module(&db, file).syntax();
let element = module.body[0]
let module = parsed_module(&db, file).load(&db);
let syntax = module.syntax();
let element = syntax.body[0]
.as_expr_stmt()
.unwrap()
.value
@ -996,7 +1009,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
let DefinitionKind::Comprehension(comprehension) = binding.kind(&db) else {
panic!("expected generator definition")
};
let target = comprehension.target();
let target = comprehension.target(&module);
let name = target.as_name_expr().unwrap().id().as_str();
assert_eq!(name, "x");
@ -1014,6 +1027,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let global_table = index.place_table(FileScopeId::global());
@ -1028,7 +1042,9 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
assert_eq!(comprehension_scope.kind(), ScopeKind::Comprehension);
assert_eq!(
comprehension_scope_id.to_scope_id(&db, file).name(&db),
comprehension_scope_id
.to_scope_id(&db, file)
.name(&db, &module),
"<listcomp>"
);
@ -1047,7 +1063,7 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
assert_eq!(
inner_comprehension_scope_id
.to_scope_id(&db, file)
.name(&db),
.name(&db, &module),
"<setcomp>"
);
@ -1112,6 +1128,7 @@ def func():
y = 2
",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let global_table = index.place_table(FileScopeId::global());
@ -1128,9 +1145,15 @@ def func():
assert_eq!(func_scope_1.kind(), ScopeKind::Function);
assert_eq!(func_scope1_id.to_scope_id(&db, file).name(&db), "func");
assert_eq!(
func_scope1_id.to_scope_id(&db, file).name(&db, &module),
"func"
);
assert_eq!(func_scope_2.kind(), ScopeKind::Function);
assert_eq!(func_scope2_id.to_scope_id(&db, file).name(&db), "func");
assert_eq!(
func_scope2_id.to_scope_id(&db, file).name(&db, &module),
"func"
);
let func1_table = index.place_table(func_scope1_id);
let func2_table = index.place_table(func_scope2_id);
@ -1157,6 +1180,7 @@ def func[T]():
",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let global_table = index.place_table(FileScopeId::global());
@ -1170,7 +1194,10 @@ def func[T]():
};
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db), "func");
assert_eq!(
ann_scope_id.to_scope_id(&db, file).name(&db, &module),
"func"
);
let ann_table = index.place_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]);
@ -1180,7 +1207,10 @@ def func[T]():
panic!("expected one child scope");
};
assert_eq!(func_scope.kind(), ScopeKind::Function);
assert_eq!(func_scope_id.to_scope_id(&db, file).name(&db), "func");
assert_eq!(
func_scope_id.to_scope_id(&db, file).name(&db, &module),
"func"
);
let func_table = index.place_table(func_scope_id);
assert_eq!(names(&func_table), vec!["x"]);
}
@ -1194,6 +1224,7 @@ class C[T]:
",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let global_table = index.place_table(FileScopeId::global());
@ -1207,7 +1238,7 @@ class C[T]:
};
assert_eq!(ann_scope.kind(), ScopeKind::Annotation);
assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db), "C");
assert_eq!(ann_scope_id.to_scope_id(&db, file).name(&db, &module), "C");
let ann_table = index.place_table(ann_scope_id);
assert_eq!(names(&ann_table), vec!["T"]);
assert!(
@ -1224,16 +1255,19 @@ class C[T]:
};
assert_eq!(class_scope.kind(), ScopeKind::Class);
assert_eq!(class_scope_id.to_scope_id(&db, file).name(&db), "C");
assert_eq!(
class_scope_id.to_scope_id(&db, file).name(&db, &module),
"C"
);
assert_eq!(names(&index.place_table(class_scope_id)), vec!["x"]);
}
#[test]
fn reachability_trivial() {
let TestCase { db, file } = test_case("x = 1; x");
let parsed = parsed_module(&db, file);
let module = parsed_module(&db, file).load(&db);
let scope = global_scope(&db, file);
let ast = parsed.syntax();
let ast = module.syntax();
let ast::Stmt::Expr(ast::StmtExpr {
value: x_use_expr, ..
}) = &ast.body[1]
@ -1252,7 +1286,7 @@ class C[T]:
let ast::Expr::NumberLiteral(ast::ExprNumberLiteral {
value: ast::Number::Int(num),
..
}) = assignment.value()
}) = assignment.value(&module)
else {
panic!("should be a number literal")
};
@ -1264,8 +1298,8 @@ class C[T]:
let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4");
let index = semantic_index(&db, file);
let parsed = parsed_module(&db, file);
let ast = parsed.syntax();
let module = parsed_module(&db, file).load(&db);
let ast = module.syntax();
let x_stmt = ast.body[0].as_assign_stmt().unwrap();
let x = &x_stmt.targets[0];
@ -1282,14 +1316,15 @@ class C[T]:
#[test]
fn scope_iterators() {
fn scope_names<'a>(
scopes: impl Iterator<Item = (FileScopeId, &'a Scope)>,
db: &'a dyn Db,
fn scope_names<'a, 'db>(
scopes: impl Iterator<Item = (FileScopeId, &'db Scope)>,
db: &'db dyn Db,
file: File,
module: &'a ParsedModuleRef,
) -> Vec<&'a str> {
scopes
.into_iter()
.map(|(scope_id, _)| scope_id.to_scope_id(db, file).name(db))
.map(|(scope_id, _)| scope_id.to_scope_id(db, file).name(db, module))
.collect()
}
@ -1306,21 +1341,22 @@ def x():
pass",
);
let module = parsed_module(&db, file).load(&db);
let index = semantic_index(&db, file);
let descendants = index.descendent_scopes(FileScopeId::global());
assert_eq!(
scope_names(descendants, &db, file),
scope_names(descendants, &db, file, &module),
vec!["Test", "foo", "bar", "baz", "x"]
);
let children = index.child_scopes(FileScopeId::global());
assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]);
assert_eq!(scope_names(children, &db, file, &module), vec!["Test", "x"]);
let test_class = index.child_scopes(FileScopeId::global()).next().unwrap().0;
let test_child_scopes = index.child_scopes(test_class);
assert_eq!(
scope_names(test_child_scopes, &db, file),
scope_names(test_child_scopes, &db, file, &module),
vec!["foo", "baz"]
);
@ -1332,7 +1368,7 @@ def x():
let ancestors = index.ancestor_scopes(bar_scope);
assert_eq!(
scope_names(ancestors, &db, file),
scope_names(ancestors, &db, file, &module),
vec!["bar", "foo", "Test", "<module>"]
);
}

View file

@ -5,7 +5,7 @@ use except_handlers::TryNodeContextStackManager;
use rustc_hash::{FxHashMap, FxHashSet};
use ruff_db::files::File;
use ruff_db::parsed::ParsedModule;
use ruff_db::parsed::ParsedModuleRef;
use ruff_db::source::{SourceText, source_text};
use ruff_index::IndexVec;
use ruff_python_ast::name::Name;
@ -69,20 +69,20 @@ struct ScopeInfo {
current_loop: Option<Loop>,
}
pub(super) struct SemanticIndexBuilder<'db> {
pub(super) struct SemanticIndexBuilder<'db, 'ast> {
// Builder state
db: &'db dyn Db,
file: File,
source_type: PySourceType,
module: &'db ParsedModule,
module: &'ast ParsedModuleRef,
scope_stack: Vec<ScopeInfo>,
/// The assignments we're currently visiting, with
/// the most recent visit at the end of the Vec
current_assignments: Vec<CurrentAssignment<'db>>,
current_assignments: Vec<CurrentAssignment<'ast, 'db>>,
/// The match case we're currently visiting.
current_match_case: Option<CurrentMatchCase<'db>>,
current_match_case: Option<CurrentMatchCase<'ast>>,
/// The name of the first function parameter of the innermost function that we're currently visiting.
current_first_parameter_name: Option<&'db str>,
current_first_parameter_name: Option<&'ast str>,
/// Per-scope contexts regarding nested `try`/`except` statements
try_node_context_stack_manager: TryNodeContextStackManager,
@ -116,13 +116,13 @@ pub(super) struct SemanticIndexBuilder<'db> {
semantic_syntax_errors: RefCell<Vec<SemanticSyntaxError>>,
}
impl<'db> SemanticIndexBuilder<'db> {
pub(super) fn new(db: &'db dyn Db, file: File, parsed: &'db ParsedModule) -> Self {
impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
pub(super) fn new(db: &'db dyn Db, file: File, module_ref: &'ast ParsedModuleRef) -> Self {
let mut builder = Self {
db,
file,
source_type: file.source_type(db.upcast()),
module: parsed,
module: module_ref,
scope_stack: Vec::new(),
current_assignments: vec![],
current_match_case: None,
@ -423,7 +423,7 @@ impl<'db> SemanticIndexBuilder<'db> {
fn add_definition(
&mut self,
place: ScopedPlaceId,
definition_node: impl Into<DefinitionNodeRef<'db>> + std::fmt::Debug + Copy,
definition_node: impl Into<DefinitionNodeRef<'ast, 'db>> + std::fmt::Debug + Copy,
) -> Definition<'db> {
let (definition, num_definitions) = self.push_additional_definition(place, definition_node);
debug_assert_eq!(
@ -463,16 +463,18 @@ impl<'db> SemanticIndexBuilder<'db> {
fn push_additional_definition(
&mut self,
place: ScopedPlaceId,
definition_node: impl Into<DefinitionNodeRef<'db>>,
definition_node: impl Into<DefinitionNodeRef<'ast, 'db>>,
) -> (Definition<'db>, usize) {
let definition_node: DefinitionNodeRef<'_> = definition_node.into();
let definition_node: DefinitionNodeRef<'ast, 'db> = definition_node.into();
#[expect(unsafe_code)]
// SAFETY: `definition_node` is guaranteed to be a child of `self.module`
let kind = unsafe { definition_node.into_owned(self.module.clone()) };
let category = kind.category(self.source_type.is_stub());
let category = kind.category(self.source_type.is_stub(), self.module);
let is_reexported = kind.is_reexported();
let definition = Definition::new(
let definition: Definition<'db> = Definition::new(
self.db,
self.file,
self.current_scope(),
@ -658,7 +660,7 @@ impl<'db> SemanticIndexBuilder<'db> {
.record_reachability_constraint(negated_constraint);
}
fn push_assignment(&mut self, assignment: CurrentAssignment<'db>) {
fn push_assignment(&mut self, assignment: CurrentAssignment<'ast, 'db>) {
self.current_assignments.push(assignment);
}
@ -667,11 +669,11 @@ impl<'db> SemanticIndexBuilder<'db> {
debug_assert!(popped_assignment.is_some());
}
fn current_assignment(&self) -> Option<CurrentAssignment<'db>> {
fn current_assignment(&self) -> Option<CurrentAssignment<'ast, 'db>> {
self.current_assignments.last().copied()
}
fn current_assignment_mut(&mut self) -> Option<&mut CurrentAssignment<'db>> {
fn current_assignment_mut(&mut self) -> Option<&mut CurrentAssignment<'ast, 'db>> {
self.current_assignments.last_mut()
}
@ -792,7 +794,7 @@ impl<'db> SemanticIndexBuilder<'db> {
fn with_type_params(
&mut self,
with_scope: NodeWithScopeRef,
type_params: Option<&'db ast::TypeParams>,
type_params: Option<&'ast ast::TypeParams>,
nested: impl FnOnce(&mut Self) -> FileScopeId,
) -> FileScopeId {
if let Some(type_params) = type_params {
@ -858,7 +860,7 @@ impl<'db> SemanticIndexBuilder<'db> {
fn with_generators_scope(
&mut self,
scope: NodeWithScopeRef,
generators: &'db [ast::Comprehension],
generators: &'ast [ast::Comprehension],
visit_outer_elt: impl FnOnce(&mut Self),
) {
let mut generators_iter = generators.iter();
@ -908,7 +910,7 @@ impl<'db> SemanticIndexBuilder<'db> {
self.pop_scope();
}
fn declare_parameters(&mut self, parameters: &'db ast::Parameters) {
fn declare_parameters(&mut self, parameters: &'ast ast::Parameters) {
for parameter in parameters.iter_non_variadic_params() {
self.declare_parameter(parameter);
}
@ -925,7 +927,7 @@ impl<'db> SemanticIndexBuilder<'db> {
}
}
fn declare_parameter(&mut self, parameter: &'db ast::ParameterWithDefault) {
fn declare_parameter(&mut self, parameter: &'ast ast::ParameterWithDefault) {
let symbol = self.add_symbol(parameter.name().id().clone());
let definition = self.add_definition(symbol, parameter);
@ -946,8 +948,8 @@ impl<'db> SemanticIndexBuilder<'db> {
/// for statements, etc.
fn add_unpackable_assignment(
&mut self,
unpackable: &Unpackable<'db>,
target: &'db ast::Expr,
unpackable: &Unpackable<'ast>,
target: &'ast ast::Expr,
value: Expression<'db>,
) {
// We only handle assignments to names and unpackings here, other targets like
@ -1010,8 +1012,7 @@ impl<'db> SemanticIndexBuilder<'db> {
}
pub(super) fn build(mut self) -> SemanticIndex<'db> {
let module = self.module;
self.visit_body(module.suite());
self.visit_body(self.module.suite());
// Pop the root scope
self.pop_scope();
@ -1081,10 +1082,7 @@ impl<'db> SemanticIndexBuilder<'db> {
}
}
impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db>
where
'ast: 'db,
{
impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context));
@ -2299,7 +2297,7 @@ where
}
}
impl SemanticSyntaxContext for SemanticIndexBuilder<'_> {
impl SemanticSyntaxContext for SemanticIndexBuilder<'_, '_> {
fn future_annotations_or_stub(&self) -> bool {
self.has_future_annotations
}
@ -2324,7 +2322,7 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_> {
match scope.kind() {
ScopeKind::Class | ScopeKind::Lambda => return false,
ScopeKind::Function => {
return scope.node().expect_function().is_async;
return scope.node().expect_function(self.module).is_async;
}
ScopeKind::Comprehension
| ScopeKind::Module
@ -2366,9 +2364,9 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_> {
for scope_info in self.scope_stack.iter().rev() {
let scope = &self.scopes[scope_info.file_scope_id];
let generators = match scope.node() {
NodeWithScopeKind::ListComprehension(node) => &node.generators,
NodeWithScopeKind::SetComprehension(node) => &node.generators,
NodeWithScopeKind::DictComprehension(node) => &node.generators,
NodeWithScopeKind::ListComprehension(node) => &node.node(self.module).generators,
NodeWithScopeKind::SetComprehension(node) => &node.node(self.module).generators,
NodeWithScopeKind::DictComprehension(node) => &node.node(self.module).generators,
_ => continue,
};
if generators
@ -2409,31 +2407,31 @@ impl SemanticSyntaxContext for SemanticIndexBuilder<'_> {
}
#[derive(Copy, Clone, Debug, PartialEq)]
enum CurrentAssignment<'a> {
enum CurrentAssignment<'ast, 'db> {
Assign {
node: &'a ast::StmtAssign,
unpack: Option<(UnpackPosition, Unpack<'a>)>,
node: &'ast ast::StmtAssign,
unpack: Option<(UnpackPosition, Unpack<'db>)>,
},
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
AnnAssign(&'ast ast::StmtAnnAssign),
AugAssign(&'ast ast::StmtAugAssign),
For {
node: &'a ast::StmtFor,
unpack: Option<(UnpackPosition, Unpack<'a>)>,
node: &'ast ast::StmtFor,
unpack: Option<(UnpackPosition, Unpack<'db>)>,
},
Named(&'a ast::ExprNamed),
Named(&'ast ast::ExprNamed),
Comprehension {
node: &'a ast::Comprehension,
node: &'ast ast::Comprehension,
first: bool,
unpack: Option<(UnpackPosition, Unpack<'a>)>,
unpack: Option<(UnpackPosition, Unpack<'db>)>,
},
WithItem {
item: &'a ast::WithItem,
item: &'ast ast::WithItem,
is_async: bool,
unpack: Option<(UnpackPosition, Unpack<'a>)>,
unpack: Option<(UnpackPosition, Unpack<'db>)>,
},
}
impl CurrentAssignment<'_> {
impl CurrentAssignment<'_, '_> {
fn unpack_position_mut(&mut self) -> Option<&mut UnpackPosition> {
match self {
Self::Assign { unpack, .. }
@ -2445,28 +2443,28 @@ impl CurrentAssignment<'_> {
}
}
impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAnnAssign) -> Self {
impl<'ast> From<&'ast ast::StmtAnnAssign> for CurrentAssignment<'ast, '_> {
fn from(value: &'ast ast::StmtAnnAssign) -> Self {
Self::AnnAssign(value)
}
}
impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAugAssign) -> Self {
impl<'ast> From<&'ast ast::StmtAugAssign> for CurrentAssignment<'ast, '_> {
fn from(value: &'ast ast::StmtAugAssign) -> Self {
Self::AugAssign(value)
}
}
impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
fn from(value: &'a ast::ExprNamed) -> Self {
impl<'ast> From<&'ast ast::ExprNamed> for CurrentAssignment<'ast, '_> {
fn from(value: &'ast ast::ExprNamed) -> Self {
Self::Named(value)
}
}
#[derive(Debug, PartialEq)]
struct CurrentMatchCase<'a> {
struct CurrentMatchCase<'ast> {
/// The pattern that's part of the current match case.
pattern: &'a ast::Pattern,
pattern: &'ast ast::Pattern,
/// The index of the sub-pattern that's being currently visited within the pattern.
///
@ -2488,20 +2486,20 @@ impl<'a> CurrentMatchCase<'a> {
}
}
enum Unpackable<'a> {
Assign(&'a ast::StmtAssign),
For(&'a ast::StmtFor),
enum Unpackable<'ast> {
Assign(&'ast ast::StmtAssign),
For(&'ast ast::StmtFor),
WithItem {
item: &'a ast::WithItem,
item: &'ast ast::WithItem,
is_async: bool,
},
Comprehension {
first: bool,
node: &'a ast::Comprehension,
node: &'ast ast::Comprehension,
},
}
impl<'a> Unpackable<'a> {
impl<'ast> Unpackable<'ast> {
const fn kind(&self) -> UnpackKind {
match self {
Unpackable::Assign(_) => UnpackKind::Assign,
@ -2510,7 +2508,10 @@ impl<'a> Unpackable<'a> {
}
}
fn as_current_assignment(&self, unpack: Option<Unpack<'a>>) -> CurrentAssignment<'a> {
fn as_current_assignment<'db>(
&self,
unpack: Option<Unpack<'db>>,
) -> CurrentAssignment<'ast, 'db> {
let unpack = unpack.map(|unpack| (UnpackPosition::First, unpack));
match self {
Unpackable::Assign(stmt) => CurrentAssignment::Assign { node: stmt, unpack },

View file

@ -1,7 +1,7 @@
use std::ops::Deref;
use ruff_db::files::{File, FileRange};
use ruff_db::parsed::ParsedModule;
use ruff_db::parsed::ParsedModuleRef;
use ruff_python_ast as ast;
use ruff_text_size::{Ranged, TextRange};
@ -49,12 +49,12 @@ impl<'db> Definition<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
pub fn full_range(self, db: &'db dyn Db) -> FileRange {
FileRange::new(self.file(db), self.kind(db).full_range())
pub fn full_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> FileRange {
FileRange::new(self.file(db), self.kind(db).full_range(module))
}
pub fn focus_range(self, db: &'db dyn Db) -> FileRange {
FileRange::new(self.file(db), self.kind(db).target_range())
pub fn focus_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> FileRange {
FileRange::new(self.file(db), self.kind(db).target_range(module))
}
}
@ -123,218 +123,218 @@ impl<'db> DefinitionState<'db> {
}
#[derive(Copy, Clone, Debug)]
pub(crate) enum DefinitionNodeRef<'a> {
Import(ImportDefinitionNodeRef<'a>),
ImportFrom(ImportFromDefinitionNodeRef<'a>),
ImportStar(StarImportDefinitionNodeRef<'a>),
For(ForStmtDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
Class(&'a ast::StmtClassDef),
TypeAlias(&'a ast::StmtTypeAlias),
NamedExpression(&'a ast::ExprNamed),
Assignment(AssignmentDefinitionNodeRef<'a>),
AnnotatedAssignment(AnnotatedAssignmentDefinitionNodeRef<'a>),
AugmentedAssignment(&'a ast::StmtAugAssign),
Comprehension(ComprehensionDefinitionNodeRef<'a>),
VariadicPositionalParameter(&'a ast::Parameter),
VariadicKeywordParameter(&'a ast::Parameter),
Parameter(&'a ast::ParameterWithDefault),
WithItem(WithItemDefinitionNodeRef<'a>),
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>),
TypeVar(&'a ast::TypeParamTypeVar),
ParamSpec(&'a ast::TypeParamParamSpec),
TypeVarTuple(&'a ast::TypeParamTypeVarTuple),
pub(crate) enum DefinitionNodeRef<'ast, 'db> {
Import(ImportDefinitionNodeRef<'ast>),
ImportFrom(ImportFromDefinitionNodeRef<'ast>),
ImportStar(StarImportDefinitionNodeRef<'ast>),
For(ForStmtDefinitionNodeRef<'ast, 'db>),
Function(&'ast ast::StmtFunctionDef),
Class(&'ast ast::StmtClassDef),
TypeAlias(&'ast ast::StmtTypeAlias),
NamedExpression(&'ast ast::ExprNamed),
Assignment(AssignmentDefinitionNodeRef<'ast, 'db>),
AnnotatedAssignment(AnnotatedAssignmentDefinitionNodeRef<'ast>),
AugmentedAssignment(&'ast ast::StmtAugAssign),
Comprehension(ComprehensionDefinitionNodeRef<'ast, 'db>),
VariadicPositionalParameter(&'ast ast::Parameter),
VariadicKeywordParameter(&'ast ast::Parameter),
Parameter(&'ast ast::ParameterWithDefault),
WithItem(WithItemDefinitionNodeRef<'ast, 'db>),
MatchPattern(MatchPatternDefinitionNodeRef<'ast>),
ExceptHandler(ExceptHandlerDefinitionNodeRef<'ast>),
TypeVar(&'ast ast::TypeParamTypeVar),
ParamSpec(&'ast ast::TypeParamParamSpec),
TypeVarTuple(&'ast ast::TypeParamTypeVarTuple),
}
impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtFunctionDef) -> Self {
impl<'ast> From<&'ast ast::StmtFunctionDef> for DefinitionNodeRef<'ast, '_> {
fn from(node: &'ast ast::StmtFunctionDef) -> Self {
Self::Function(node)
}
}
impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtClassDef) -> Self {
impl<'ast> From<&'ast ast::StmtClassDef> for DefinitionNodeRef<'ast, '_> {
fn from(node: &'ast ast::StmtClassDef) -> Self {
Self::Class(node)
}
}
impl<'a> From<&'a ast::StmtTypeAlias> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtTypeAlias) -> Self {
impl<'ast> From<&'ast ast::StmtTypeAlias> for DefinitionNodeRef<'ast, '_> {
fn from(node: &'ast ast::StmtTypeAlias) -> Self {
Self::TypeAlias(node)
}
}
impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ExprNamed) -> Self {
impl<'ast> From<&'ast ast::ExprNamed> for DefinitionNodeRef<'ast, '_> {
fn from(node: &'ast ast::ExprNamed) -> Self {
Self::NamedExpression(node)
}
}
impl<'a> From<&'a ast::StmtAugAssign> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtAugAssign) -> Self {
impl<'ast> From<&'ast ast::StmtAugAssign> for DefinitionNodeRef<'ast, '_> {
fn from(node: &'ast ast::StmtAugAssign) -> Self {
Self::AugmentedAssignment(node)
}
}
impl<'a> From<&'a ast::TypeParamTypeVar> for DefinitionNodeRef<'a> {
fn from(value: &'a ast::TypeParamTypeVar) -> Self {
impl<'ast> From<&'ast ast::TypeParamTypeVar> for DefinitionNodeRef<'ast, '_> {
fn from(value: &'ast ast::TypeParamTypeVar) -> Self {
Self::TypeVar(value)
}
}
impl<'a> From<&'a ast::TypeParamParamSpec> for DefinitionNodeRef<'a> {
fn from(value: &'a ast::TypeParamParamSpec) -> Self {
impl<'ast> From<&'ast ast::TypeParamParamSpec> for DefinitionNodeRef<'ast, '_> {
fn from(value: &'ast ast::TypeParamParamSpec) -> Self {
Self::ParamSpec(value)
}
}
impl<'a> From<&'a ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'a> {
fn from(value: &'a ast::TypeParamTypeVarTuple) -> Self {
impl<'ast> From<&'ast ast::TypeParamTypeVarTuple> for DefinitionNodeRef<'ast, '_> {
fn from(value: &'ast ast::TypeParamTypeVarTuple) -> Self {
Self::TypeVarTuple(value)
}
}
impl<'a> From<ImportDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: ImportDefinitionNodeRef<'a>) -> Self {
impl<'ast> From<ImportDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(node_ref: ImportDefinitionNodeRef<'ast>) -> Self {
Self::Import(node_ref)
}
}
impl<'a> From<ImportFromDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: ImportFromDefinitionNodeRef<'a>) -> Self {
impl<'ast> From<ImportFromDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(node_ref: ImportFromDefinitionNodeRef<'ast>) -> Self {
Self::ImportFrom(node_ref)
}
}
impl<'a> From<ForStmtDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(value: ForStmtDefinitionNodeRef<'a>) -> Self {
impl<'ast, 'db> From<ForStmtDefinitionNodeRef<'ast, 'db>> for DefinitionNodeRef<'ast, 'db> {
fn from(value: ForStmtDefinitionNodeRef<'ast, 'db>) -> Self {
Self::For(value)
}
}
impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self {
impl<'ast, 'db> From<AssignmentDefinitionNodeRef<'ast, 'db>> for DefinitionNodeRef<'ast, 'db> {
fn from(node_ref: AssignmentDefinitionNodeRef<'ast, 'db>) -> Self {
Self::Assignment(node_ref)
}
}
impl<'a> From<AnnotatedAssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: AnnotatedAssignmentDefinitionNodeRef<'a>) -> Self {
impl<'ast> From<AnnotatedAssignmentDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(node_ref: AnnotatedAssignmentDefinitionNodeRef<'ast>) -> Self {
Self::AnnotatedAssignment(node_ref)
}
}
impl<'a> From<WithItemDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: WithItemDefinitionNodeRef<'a>) -> Self {
impl<'ast, 'db> From<WithItemDefinitionNodeRef<'ast, 'db>> for DefinitionNodeRef<'ast, 'db> {
fn from(node_ref: WithItemDefinitionNodeRef<'ast, 'db>) -> Self {
Self::WithItem(node_ref)
}
}
impl<'a> From<ComprehensionDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: ComprehensionDefinitionNodeRef<'a>) -> Self {
impl<'ast, 'db> From<ComprehensionDefinitionNodeRef<'ast, 'db>> for DefinitionNodeRef<'ast, 'db> {
fn from(node: ComprehensionDefinitionNodeRef<'ast, 'db>) -> Self {
Self::Comprehension(node)
}
}
impl<'a> From<&'a ast::ParameterWithDefault> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ParameterWithDefault) -> Self {
impl<'ast> From<&'ast ast::ParameterWithDefault> for DefinitionNodeRef<'ast, '_> {
fn from(node: &'ast ast::ParameterWithDefault) -> Self {
Self::Parameter(node)
}
}
impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: MatchPatternDefinitionNodeRef<'a>) -> Self {
impl<'ast> From<MatchPatternDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(node: MatchPatternDefinitionNodeRef<'ast>) -> Self {
Self::MatchPattern(node)
}
}
impl<'a> From<StarImportDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node: StarImportDefinitionNodeRef<'a>) -> Self {
impl<'ast> From<StarImportDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
fn from(node: StarImportDefinitionNodeRef<'ast>) -> Self {
Self::ImportStar(node)
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImport,
pub(crate) struct ImportDefinitionNodeRef<'ast> {
pub(crate) node: &'ast ast::StmtImport,
pub(crate) alias_index: usize,
pub(crate) is_reexported: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct StarImportDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
pub(crate) struct StarImportDefinitionNodeRef<'ast> {
pub(crate) node: &'ast ast::StmtImportFrom,
pub(crate) place_id: ScopedPlaceId,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
pub(crate) struct ImportFromDefinitionNodeRef<'ast> {
pub(crate) node: &'ast ast::StmtImportFrom,
pub(crate) alias_index: usize,
pub(crate) is_reexported: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>,
pub(crate) value: &'a ast::Expr,
pub(crate) target: &'a ast::Expr,
pub(crate) struct AssignmentDefinitionNodeRef<'ast, 'db> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>,
pub(crate) value: &'ast ast::Expr,
pub(crate) target: &'ast ast::Expr,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct AnnotatedAssignmentDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtAnnAssign,
pub(crate) annotation: &'a ast::Expr,
pub(crate) value: Option<&'a ast::Expr>,
pub(crate) target: &'a ast::Expr,
pub(crate) struct AnnotatedAssignmentDefinitionNodeRef<'ast> {
pub(crate) node: &'ast ast::StmtAnnAssign,
pub(crate) annotation: &'ast ast::Expr,
pub(crate) value: Option<&'ast ast::Expr>,
pub(crate) target: &'ast ast::Expr,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>,
pub(crate) context_expr: &'a ast::Expr,
pub(crate) target: &'a ast::Expr,
pub(crate) struct WithItemDefinitionNodeRef<'ast, 'db> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>,
pub(crate) context_expr: &'ast ast::Expr,
pub(crate) target: &'ast ast::Expr,
pub(crate) is_async: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::Expr,
pub(crate) struct ForStmtDefinitionNodeRef<'ast, 'db> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>,
pub(crate) iterable: &'ast ast::Expr,
pub(crate) target: &'ast ast::Expr,
pub(crate) is_async: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> {
pub(crate) handler: &'a ast::ExceptHandlerExceptHandler,
pub(crate) struct ExceptHandlerDefinitionNodeRef<'ast> {
pub(crate) handler: &'ast ast::ExceptHandlerExceptHandler,
pub(crate) is_star: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'a>)>,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::Expr,
pub(crate) struct ComprehensionDefinitionNodeRef<'ast, 'db> {
pub(crate) unpack: Option<(UnpackPosition, Unpack<'db>)>,
pub(crate) iterable: &'ast ast::Expr,
pub(crate) target: &'ast ast::Expr,
pub(crate) first: bool,
pub(crate) is_async: bool,
}
#[derive(Copy, Clone, Debug)]
pub(crate) struct MatchPatternDefinitionNodeRef<'a> {
pub(crate) struct MatchPatternDefinitionNodeRef<'ast> {
/// The outermost pattern node in which the identifier being defined occurs.
pub(crate) pattern: &'a ast::Pattern,
pub(crate) pattern: &'ast ast::Pattern,
/// The identifier being defined.
pub(crate) identifier: &'a ast::Identifier,
pub(crate) identifier: &'ast ast::Identifier,
/// The index of the identifier in the pattern when visiting the `pattern` node in evaluation
/// order.
pub(crate) index: u32,
}
impl<'db> DefinitionNodeRef<'db> {
impl<'db> DefinitionNodeRef<'_, 'db> {
#[expect(unsafe_code)]
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind<'db> {
pub(super) unsafe fn into_owned(self, parsed: ParsedModuleRef) -> DefinitionKind<'db> {
match self {
DefinitionNodeRef::Import(ImportDefinitionNodeRef {
node,
@ -626,60 +626,74 @@ impl DefinitionKind<'_> {
///
/// A definition target would mainly be the node representing the place being defined i.e.,
/// [`ast::ExprName`], [`ast::Identifier`], [`ast::ExprAttribute`] or [`ast::ExprSubscript`] but could also be other nodes.
pub(crate) fn target_range(&self) -> TextRange {
pub(crate) fn target_range(&self, module: &ParsedModuleRef) -> TextRange {
match self {
DefinitionKind::Import(import) => import.alias().range(),
DefinitionKind::ImportFrom(import) => import.alias().range(),
DefinitionKind::StarImport(import) => import.alias().range(),
DefinitionKind::Function(function) => function.name.range(),
DefinitionKind::Class(class) => class.name.range(),
DefinitionKind::TypeAlias(type_alias) => type_alias.name.range(),
DefinitionKind::NamedExpression(named) => named.target.range(),
DefinitionKind::Assignment(assignment) => assignment.target.range(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(),
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(),
DefinitionKind::For(for_stmt) => for_stmt.target.range(),
DefinitionKind::Comprehension(comp) => comp.target().range(),
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(),
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(),
DefinitionKind::Parameter(parameter) => parameter.parameter.name.range(),
DefinitionKind::WithItem(with_item) => with_item.target.range(),
DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.range(),
DefinitionKind::ExceptHandler(handler) => handler.node().range(),
DefinitionKind::TypeVar(type_var) => type_var.name.range(),
DefinitionKind::ParamSpec(param_spec) => param_spec.name.range(),
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.name.range(),
DefinitionKind::Import(import) => import.alias(module).range(),
DefinitionKind::ImportFrom(import) => import.alias(module).range(),
DefinitionKind::StarImport(import) => import.alias(module).range(),
DefinitionKind::Function(function) => function.node(module).name.range(),
DefinitionKind::Class(class) => class.node(module).name.range(),
DefinitionKind::TypeAlias(type_alias) => type_alias.node(module).name.range(),
DefinitionKind::NamedExpression(named) => named.node(module).target.range(),
DefinitionKind::Assignment(assignment) => assignment.target.node(module).range(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.node(module).range(),
DefinitionKind::AugmentedAssignment(aug_assign) => {
aug_assign.node(module).target.range()
}
DefinitionKind::For(for_stmt) => for_stmt.target.node(module).range(),
DefinitionKind::Comprehension(comp) => comp.target(module).range(),
DefinitionKind::VariadicPositionalParameter(parameter) => {
parameter.node(module).name.range()
}
DefinitionKind::VariadicKeywordParameter(parameter) => {
parameter.node(module).name.range()
}
DefinitionKind::Parameter(parameter) => parameter.node(module).parameter.name.range(),
DefinitionKind::WithItem(with_item) => with_item.target.node(module).range(),
DefinitionKind::MatchPattern(match_pattern) => {
match_pattern.identifier.node(module).range()
}
DefinitionKind::ExceptHandler(handler) => handler.node(module).range(),
DefinitionKind::TypeVar(type_var) => type_var.node(module).name.range(),
DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).name.range(),
DefinitionKind::TypeVarTuple(type_var_tuple) => {
type_var_tuple.node(module).name.range()
}
}
}
/// Returns the [`TextRange`] of the entire definition.
pub(crate) fn full_range(&self) -> TextRange {
pub(crate) fn full_range(&self, module: &ParsedModuleRef) -> TextRange {
match self {
DefinitionKind::Import(import) => import.alias().range(),
DefinitionKind::ImportFrom(import) => import.alias().range(),
DefinitionKind::StarImport(import) => import.import().range(),
DefinitionKind::Function(function) => function.range(),
DefinitionKind::Class(class) => class.range(),
DefinitionKind::TypeAlias(type_alias) => type_alias.range(),
DefinitionKind::NamedExpression(named) => named.range(),
DefinitionKind::Assignment(assignment) => assignment.target.range(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(),
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.range(),
DefinitionKind::For(for_stmt) => for_stmt.target.range(),
DefinitionKind::Comprehension(comp) => comp.target().range(),
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.range(),
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.range(),
DefinitionKind::Parameter(parameter) => parameter.parameter.range(),
DefinitionKind::WithItem(with_item) => with_item.target.range(),
DefinitionKind::MatchPattern(match_pattern) => match_pattern.identifier.range(),
DefinitionKind::ExceptHandler(handler) => handler.node().range(),
DefinitionKind::TypeVar(type_var) => type_var.range(),
DefinitionKind::ParamSpec(param_spec) => param_spec.range(),
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.range(),
DefinitionKind::Import(import) => import.alias(module).range(),
DefinitionKind::ImportFrom(import) => import.alias(module).range(),
DefinitionKind::StarImport(import) => import.import(module).range(),
DefinitionKind::Function(function) => function.node(module).range(),
DefinitionKind::Class(class) => class.node(module).range(),
DefinitionKind::TypeAlias(type_alias) => type_alias.node(module).range(),
DefinitionKind::NamedExpression(named) => named.node(module).range(),
DefinitionKind::Assignment(assignment) => assignment.target.node(module).range(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.node(module).range(),
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.node(module).range(),
DefinitionKind::For(for_stmt) => for_stmt.target.node(module).range(),
DefinitionKind::Comprehension(comp) => comp.target(module).range(),
DefinitionKind::VariadicPositionalParameter(parameter) => {
parameter.node(module).range()
}
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.node(module).range(),
DefinitionKind::Parameter(parameter) => parameter.node(module).parameter.range(),
DefinitionKind::WithItem(with_item) => with_item.target.node(module).range(),
DefinitionKind::MatchPattern(match_pattern) => {
match_pattern.identifier.node(module).range()
}
DefinitionKind::ExceptHandler(handler) => handler.node(module).range(),
DefinitionKind::TypeVar(type_var) => type_var.node(module).range(),
DefinitionKind::ParamSpec(param_spec) => param_spec.node(module).range(),
DefinitionKind::TypeVarTuple(type_var_tuple) => type_var_tuple.node(module).range(),
}
}
pub(crate) fn category(&self, in_stub: bool) -> DefinitionCategory {
pub(crate) fn category(&self, in_stub: bool, module: &ParsedModuleRef) -> DefinitionCategory {
match self {
// functions, classes, and imports always bind, and we consider them declarations
DefinitionKind::Function(_)
@ -694,7 +708,7 @@ impl DefinitionKind<'_> {
// a parameter always binds a value, but is only a declaration if annotated
DefinitionKind::VariadicPositionalParameter(parameter)
| DefinitionKind::VariadicKeywordParameter(parameter) => {
if parameter.annotation.is_some() {
if parameter.node(module).annotation.is_some() {
DefinitionCategory::DeclarationAndBinding
} else {
DefinitionCategory::Binding
@ -702,7 +716,12 @@ impl DefinitionKind<'_> {
}
// presence of a default is irrelevant, same logic as for a no-default parameter
DefinitionKind::Parameter(parameter_with_default) => {
if parameter_with_default.parameter.annotation.is_some() {
if parameter_with_default
.node(module)
.parameter
.annotation
.is_some()
{
DefinitionCategory::DeclarationAndBinding
} else {
DefinitionCategory::Binding
@ -753,15 +772,15 @@ pub struct StarImportDefinitionKind {
}
impl StarImportDefinitionKind {
pub(crate) fn import(&self) -> &ast::StmtImportFrom {
self.node.node()
pub(crate) fn import<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtImportFrom {
self.node.node(module)
}
pub(crate) fn alias(&self) -> &ast::Alias {
pub(crate) fn alias<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Alias {
// INVARIANT: for an invalid-syntax statement such as `from foo import *, bar, *`,
// we only create a `StarImportDefinitionKind` for the *first* `*` alias in the names list.
self.node
.node()
.node(module)
.names
.iter()
.find(|alias| &alias.name == "*")
@ -784,8 +803,8 @@ pub struct MatchPatternDefinitionKind {
}
impl MatchPatternDefinitionKind {
pub(crate) fn pattern(&self) -> &ast::Pattern {
self.pattern.node()
pub(crate) fn pattern<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Pattern {
self.pattern.node(module)
}
pub(crate) fn index(&self) -> u32 {
@ -808,16 +827,16 @@ pub struct ComprehensionDefinitionKind<'db> {
}
impl<'db> ComprehensionDefinitionKind<'db> {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
pub(crate) fn iterable<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.iterable.node(module)
}
pub(crate) fn target_kind(&self) -> TargetKind<'db> {
self.target_kind
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.target.node(module)
}
pub(crate) fn is_first(&self) -> bool {
@ -837,12 +856,12 @@ pub struct ImportDefinitionKind {
}
impl ImportDefinitionKind {
pub(crate) fn import(&self) -> &ast::StmtImport {
self.node.node()
pub(crate) fn import<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtImport {
self.node.node(module)
}
pub(crate) fn alias(&self) -> &ast::Alias {
&self.node.node().names[self.alias_index]
pub(crate) fn alias<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Alias {
&self.node.node(module).names[self.alias_index]
}
pub(crate) fn is_reexported(&self) -> bool {
@ -858,12 +877,12 @@ pub struct ImportFromDefinitionKind {
}
impl ImportFromDefinitionKind {
pub(crate) fn import(&self) -> &ast::StmtImportFrom {
self.node.node()
pub(crate) fn import<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtImportFrom {
self.node.node(module)
}
pub(crate) fn alias(&self) -> &ast::Alias {
&self.node.node().names[self.alias_index]
pub(crate) fn alias<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Alias {
&self.node.node(module).names[self.alias_index]
}
pub(crate) fn is_reexported(&self) -> bool {
@ -883,12 +902,12 @@ impl<'db> AssignmentDefinitionKind<'db> {
self.target_kind
}
pub(crate) fn value(&self) -> &ast::Expr {
self.value.node()
pub(crate) fn value<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.value.node(module)
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.target.node(module)
}
}
@ -900,16 +919,16 @@ pub struct AnnotatedAssignmentDefinitionKind {
}
impl AnnotatedAssignmentDefinitionKind {
pub(crate) fn value(&self) -> Option<&ast::Expr> {
self.value.as_deref()
pub(crate) fn value<'ast>(&self, module: &'ast ParsedModuleRef) -> Option<&'ast ast::Expr> {
self.value.as_ref().map(|value| value.node(module))
}
pub(crate) fn annotation(&self) -> &ast::Expr {
self.annotation.node()
pub(crate) fn annotation<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.annotation.node(module)
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.target.node(module)
}
}
@ -922,16 +941,16 @@ pub struct WithItemDefinitionKind<'db> {
}
impl<'db> WithItemDefinitionKind<'db> {
pub(crate) fn context_expr(&self) -> &ast::Expr {
self.context_expr.node()
pub(crate) fn context_expr<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.context_expr.node(module)
}
pub(crate) fn target_kind(&self) -> TargetKind<'db> {
self.target_kind
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.target.node(module)
}
pub(crate) const fn is_async(&self) -> bool {
@ -948,16 +967,16 @@ pub struct ForStmtDefinitionKind<'db> {
}
impl<'db> ForStmtDefinitionKind<'db> {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
pub(crate) fn iterable<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.iterable.node(module)
}
pub(crate) fn target_kind(&self) -> TargetKind<'db> {
self.target_kind
}
pub(crate) fn target(&self) -> &ast::Expr {
self.target.node()
pub(crate) fn target<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::Expr {
self.target.node(module)
}
pub(crate) const fn is_async(&self) -> bool {
@ -972,12 +991,18 @@ pub struct ExceptHandlerDefinitionKind {
}
impl ExceptHandlerDefinitionKind {
pub(crate) fn node(&self) -> &ast::ExceptHandlerExceptHandler {
self.handler.node()
pub(crate) fn node<'ast>(
&self,
module: &'ast ParsedModuleRef,
) -> &'ast ast::ExceptHandlerExceptHandler {
self.handler.node(module)
}
pub(crate) fn handled_exceptions(&self) -> Option<&ast::Expr> {
self.node().type_.as_deref()
pub(crate) fn handled_exceptions<'ast>(
&self,
module: &'ast ParsedModuleRef,
) -> Option<&'ast ast::Expr> {
self.node(module).type_.as_deref()
}
pub(crate) fn is_star(&self) -> bool {

View file

@ -2,6 +2,7 @@ use crate::ast_node_ref::AstNodeRef;
use crate::db::Db;
use crate::semantic_index::place::{FileScopeId, ScopeId};
use ruff_db::files::File;
use ruff_db::parsed::ParsedModuleRef;
use ruff_python_ast as ast;
use salsa;
@ -41,8 +42,8 @@ pub(crate) struct Expression<'db> {
/// The expression node.
#[no_eq]
#[tracked]
#[returns(deref)]
pub(crate) node_ref: AstNodeRef<ast::Expr>,
#[returns(ref)]
pub(crate) _node_ref: AstNodeRef<ast::Expr>,
/// An assignment statement, if this expression is immediately used as the rhs of that
/// assignment.
@ -62,6 +63,14 @@ pub(crate) struct Expression<'db> {
}
impl<'db> Expression<'db> {
pub(crate) fn node_ref<'ast>(
self,
db: &'db dyn Db,
parsed: &'ast ParsedModuleRef,
) -> &'ast ast::Expr {
self._node_ref(db).node(parsed)
}
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}

View file

@ -5,7 +5,7 @@ use std::ops::Range;
use bitflags::bitflags;
use hashbrown::hash_map::RawEntryMut;
use ruff_db::files::File;
use ruff_db::parsed::ParsedModule;
use ruff_db::parsed::ParsedModuleRef;
use ruff_index::{IndexVec, newtype_index};
use ruff_python_ast as ast;
use ruff_python_ast::name::Name;
@ -381,16 +381,19 @@ impl<'db> ScopeId<'db> {
}
#[cfg(test)]
pub(crate) fn name(self, db: &'db dyn Db) -> &'db str {
pub(crate) fn name<'ast>(self, db: &'db dyn Db, module: &'ast ParsedModuleRef) -> &'ast str {
match self.node(db) {
NodeWithScopeKind::Module => "<module>",
NodeWithScopeKind::Class(class) | NodeWithScopeKind::ClassTypeParameters(class) => {
class.name.as_str()
class.node(module).name.as_str()
}
NodeWithScopeKind::Function(function)
| NodeWithScopeKind::FunctionTypeParameters(function) => function.name.as_str(),
| NodeWithScopeKind::FunctionTypeParameters(function) => {
function.node(module).name.as_str()
}
NodeWithScopeKind::TypeAlias(type_alias)
| NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => type_alias
.node(module)
.name
.as_name_expr()
.map(|name| name.id.as_str())
@ -778,7 +781,7 @@ impl NodeWithScopeRef<'_> {
/// # Safety
/// The node wrapped by `self` must be a child of `module`.
#[expect(unsafe_code)]
pub(super) unsafe fn to_kind(self, module: ParsedModule) -> NodeWithScopeKind {
pub(super) unsafe fn to_kind(self, module: ParsedModuleRef) -> NodeWithScopeKind {
unsafe {
match self {
NodeWithScopeRef::Module => NodeWithScopeKind::Module,
@ -892,34 +895,46 @@ impl NodeWithScopeKind {
}
}
pub fn expect_class(&self) -> &ast::StmtClassDef {
pub fn expect_class<'ast>(&self, module: &'ast ParsedModuleRef) -> &'ast ast::StmtClassDef {
match self {
Self::Class(class) => class.node(),
Self::Class(class) => class.node(module),
_ => panic!("expected class"),
}
}
pub(crate) const fn as_class(&self) -> Option<&ast::StmtClassDef> {
pub(crate) fn as_class<'ast>(
&self,
module: &'ast ParsedModuleRef,
) -> Option<&'ast ast::StmtClassDef> {
match self {
Self::Class(class) => Some(class.node()),
Self::Class(class) => Some(class.node(module)),
_ => None,
}
}
pub fn expect_function(&self) -> &ast::StmtFunctionDef {
self.as_function().expect("expected function")
pub fn expect_function<'ast>(
&self,
module: &'ast ParsedModuleRef,
) -> &'ast ast::StmtFunctionDef {
self.as_function(module).expect("expected function")
}
pub fn expect_type_alias(&self) -> &ast::StmtTypeAlias {
pub fn expect_type_alias<'ast>(
&self,
module: &'ast ParsedModuleRef,
) -> &'ast ast::StmtTypeAlias {
match self {
Self::TypeAlias(type_alias) => type_alias.node(),
Self::TypeAlias(type_alias) => type_alias.node(module),
_ => panic!("expected type alias"),
}
}
pub const fn as_function(&self) -> Option<&ast::StmtFunctionDef> {
pub fn as_function<'ast>(
&self,
module: &'ast ParsedModuleRef,
) -> Option<&'ast ast::StmtFunctionDef> {
match self {
Self::Function(function) => Some(function.node()),
Self::Function(function) => Some(function.node(module)),
_ => None,
}
}

View file

@ -45,7 +45,7 @@ fn exports_cycle_initial(_db: &dyn Db, _file: File) -> Box<[Name]> {
#[salsa::tracked(returns(deref), cycle_fn=exports_cycle_recover, cycle_initial=exports_cycle_initial)]
pub(super) fn exported_names(db: &dyn Db, file: File) -> Box<[Name]> {
let module = parsed_module(db.upcast(), file);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let mut finder = ExportFinder::new(db, file);
finder.visit_body(module.suite());
finder.resolve_exports()

View file

@ -232,7 +232,7 @@ mod tests {
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
let ast = parsed_module(&db, foo);
let ast = parsed_module(&db, foo).load(&db);
let function = ast.suite()[0].as_function_def_stmt().unwrap();
let model = SemanticModel::new(&db, foo);
@ -251,7 +251,7 @@ mod tests {
let foo = system_path_to_file(&db, "/src/foo.py").unwrap();
let ast = parsed_module(&db, foo);
let ast = parsed_module(&db, foo).load(&db);
let class = ast.suite()[0].as_class_def_stmt().unwrap();
let model = SemanticModel::new(&db, foo);
@ -271,7 +271,7 @@ mod tests {
let bar = system_path_to_file(&db, "/src/bar.py").unwrap();
let ast = parsed_module(&db, bar);
let ast = parsed_module(&db, bar).load(&db);
let import = ast.suite()[0].as_import_from_stmt().unwrap();
let alias = &import.names[0];

View file

@ -88,7 +88,7 @@ declare_lint! {
#[salsa::tracked(returns(ref))]
pub(crate) fn suppressions(db: &dyn Db, file: File) -> Suppressions {
let parsed = parsed_module(db.upcast(), file);
let parsed = parsed_module(db.upcast(), file).load(db.upcast());
let source = source_text(db.upcast(), file);
let mut builder = SuppressionsBuilder::new(&source, db.lint_registry());

View file

@ -1,5 +1,6 @@
use infer::nearest_enclosing_class;
use itertools::Either;
use ruff_db::parsed::parsed_module;
use std::slice::Iter;
@ -5065,8 +5066,9 @@ impl<'db> Type<'db> {
SpecialFormType::Callable => Ok(CallableType::unknown(db)),
SpecialFormType::TypingSelf => {
let module = parsed_module(db.upcast(), scope_id.file(db)).load(db.upcast());
let index = semantic_index(db, scope_id.file(db));
let Some(class) = nearest_enclosing_class(db, index, scope_id) else {
let Some(class) = nearest_enclosing_class(db, index, scope_id, &module) else {
return Err(InvalidTypeExpressionError {
fallback_type: Type::unknown(),
invalid_expressions: smallvec::smallvec![
@ -6302,7 +6304,7 @@ impl<'db> ContextManagerError<'db> {
fn report_diagnostic(
&self,
context: &InferContext<'db>,
context: &InferContext<'db, '_>,
context_expression_type: Type<'db>,
context_expression_node: ast::AnyNodeRef,
) {
@ -6475,7 +6477,7 @@ impl<'db> IterationError<'db> {
/// Reports the diagnostic for this error.
fn report_diagnostic(
&self,
context: &InferContext<'db>,
context: &InferContext<'db, '_>,
iterable_type: Type<'db>,
iterable_node: ast::AnyNodeRef,
) {
@ -6951,7 +6953,7 @@ impl<'db> ConstructorCallError<'db> {
fn report_diagnostic(
&self,
context: &InferContext<'db>,
context: &InferContext<'db, '_>,
context_expression_type: Type<'db>,
context_expression_node: ast::AnyNodeRef,
) {
@ -7578,7 +7580,8 @@ pub struct PEP695TypeAliasType<'db> {
impl<'db> PEP695TypeAliasType<'db> {
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let scope = self.rhs_scope(db);
let type_alias_stmt_node = scope.node(db).expect_type_alias();
let module = parsed_module(db.upcast(), scope.file(db)).load(db.upcast());
let type_alias_stmt_node = scope.node(db).expect_type_alias(&module);
semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node)
}
@ -7586,7 +7589,8 @@ impl<'db> PEP695TypeAliasType<'db> {
#[salsa::tracked]
pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> {
let scope = self.rhs_scope(db);
let type_alias_stmt_node = scope.node(db).expect_type_alias();
let module = parsed_module(db.upcast(), scope.file(db)).load(db.upcast());
let type_alias_stmt_node = scope.node(db).expect_type_alias(&module);
let definition = self.definition(db);
definition_expression_type(db, definition, &type_alias_stmt_node.value)
}
@ -8654,10 +8658,8 @@ pub(crate) mod tests {
);
let events = db.take_salsa_events();
let call = &*parsed_module(&db, bar).syntax().body[1]
.as_assign_stmt()
.unwrap()
.value;
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events);

View file

@ -4,6 +4,7 @@
//! union of types, each of which might contain multiple overloads.
use itertools::Itertools;
use ruff_db::parsed::parsed_module;
use smallvec::{SmallVec, smallvec};
use super::{
@ -198,7 +199,11 @@ impl<'db> Bindings<'db> {
/// report a single diagnostic if we couldn't match any union element or overload.
/// TODO: Update this to add subdiagnostics about how we failed to match each union element and
/// overload.
pub(crate) fn report_diagnostics(&self, context: &InferContext<'db>, node: ast::AnyNodeRef) {
pub(crate) fn report_diagnostics(
&self,
context: &InferContext<'db, '_>,
node: ast::AnyNodeRef,
) {
// If all union elements are not callable, report that the union as a whole is not
// callable.
if self.into_iter().all(|b| !b.is_callable()) {
@ -1367,7 +1372,7 @@ impl<'db> CallableBinding<'db> {
fn report_diagnostics(
&self,
context: &InferContext<'db>,
context: &InferContext<'db, '_>,
node: ast::AnyNodeRef,
union_diag: Option<&UnionDiagnostic<'_, '_>>,
) {
@ -1840,7 +1845,7 @@ impl<'db> Binding<'db> {
fn report_diagnostics(
&self,
context: &InferContext<'db>,
context: &InferContext<'db, '_>,
node: ast::AnyNodeRef,
callable_ty: Type<'db>,
callable_description: Option<&CallableDescription>,
@ -2128,7 +2133,7 @@ pub(crate) enum BindingError<'db> {
impl<'db> BindingError<'db> {
fn report_diagnostic(
&self,
context: &InferContext<'db>,
context: &InferContext<'db, '_>,
node: ast::AnyNodeRef,
callable_ty: Type<'db>,
callable_description: Option<&CallableDescription>,
@ -2285,7 +2290,10 @@ impl<'db> BindingError<'db> {
));
if let Some(typevar_definition) = typevar.definition(context.db()) {
let typevar_range = typevar_definition.full_range(context.db());
let module =
parsed_module(context.db().upcast(), typevar_definition.file(context.db()))
.load(context.db().upcast());
let typevar_range = typevar_definition.full_range(context.db(), &module);
let mut sub = SubDiagnostic::new(Severity::Info, "Type variable defined here");
sub.annotate(Annotation::primary(typevar_range.into()));
diag.sub(sub);

View file

@ -37,6 +37,7 @@ use indexmap::IndexSet;
use itertools::Itertools as _;
use ruff_db::diagnostic::Span;
use ruff_db::files::File;
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast::name::Name;
use ruff_python_ast::{self as ast, PythonVersion};
use ruff_text_size::{Ranged, TextRange};
@ -715,7 +716,8 @@ impl<'db> ClassLiteral<'db> {
#[salsa::tracked(cycle_fn=pep695_generic_context_cycle_recover, cycle_initial=pep695_generic_context_cycle_initial)]
pub(crate) fn pep695_generic_context(self, db: &'db dyn Db) -> Option<GenericContext<'db>> {
let scope = self.body_scope(db);
let class_def_node = scope.node(db).expect_class();
let parsed = parsed_module(db.upcast(), scope.file(db)).load(db.upcast());
let class_def_node = scope.node(db).expect_class(&parsed);
class_def_node.type_params.as_ref().map(|type_params| {
let index = semantic_index(db, scope.file(db));
GenericContext::from_type_params(db, index, type_params)
@ -754,14 +756,16 @@ impl<'db> ClassLiteral<'db> {
/// ## Note
/// Only call this function from queries in the same file or your
/// query depends on the AST of another file (bad!).
fn node(self, db: &'db dyn Db) -> &'db ast::StmtClassDef {
self.body_scope(db).node(db).expect_class()
fn node<'ast>(self, db: &'db dyn Db, module: &'ast ParsedModuleRef) -> &'ast ast::StmtClassDef {
let scope = self.body_scope(db);
scope.node(db).expect_class(module)
}
pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let body_scope = self.body_scope(db);
let module = parsed_module(db.upcast(), body_scope.file(db)).load(db.upcast());
let index = semantic_index(db, body_scope.file(db));
index.expect_single_definition(body_scope.node(db).expect_class())
index.expect_single_definition(body_scope.node(db).expect_class(&module))
}
pub(crate) fn apply_optional_specialization(
@ -835,7 +839,8 @@ impl<'db> ClassLiteral<'db> {
pub(super) fn explicit_bases(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
tracing::trace!("ClassLiteral::explicit_bases_query: {}", self.name(db));
let class_stmt = self.node(db);
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let class_stmt = self.node(db, &module);
let class_definition =
semantic_index(db, self.file(db)).expect_single_definition(class_stmt);
@ -897,7 +902,9 @@ impl<'db> ClassLiteral<'db> {
fn decorators(self, db: &'db dyn Db) -> Box<[Type<'db>]> {
tracing::trace!("ClassLiteral::decorators: {}", self.name(db));
let class_stmt = self.node(db);
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let class_stmt = self.node(db, &module);
if class_stmt.decorator_list.is_empty() {
return Box::new([]);
}
@ -983,8 +990,8 @@ impl<'db> ClassLiteral<'db> {
/// ## Note
/// Only call this function from queries in the same file or your
/// query depends on the AST of another file (bad!).
fn explicit_metaclass(self, db: &'db dyn Db) -> Option<Type<'db>> {
let class_stmt = self.node(db);
fn explicit_metaclass(self, db: &'db dyn Db, module: &ParsedModuleRef) -> Option<Type<'db>> {
let class_stmt = self.node(db, module);
let metaclass_node = &class_stmt
.arguments
.as_ref()?
@ -1039,7 +1046,9 @@ impl<'db> ClassLiteral<'db> {
return Ok((SubclassOfType::subclass_of_unknown(), None));
}
let explicit_metaclass = self.explicit_metaclass(db);
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let explicit_metaclass = self.explicit_metaclass(db, &module);
let (metaclass, class_metaclass_was_from) = if let Some(metaclass) = explicit_metaclass {
(metaclass, self)
} else if let Some(base_class) = base_classes.next() {
@ -1608,6 +1617,7 @@ impl<'db> ClassLiteral<'db> {
let mut is_attribute_bound = Truthiness::AlwaysFalse;
let file = class_body_scope.file(db);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let index = semantic_index(db, file);
let class_map = use_def_map(db, class_body_scope);
let class_table = place_table(db, class_body_scope);
@ -1619,19 +1629,20 @@ impl<'db> ClassLiteral<'db> {
let method_map = use_def_map(db, method_scope);
// The attribute assignment inherits the visibility of the method which contains it
let is_method_visible = if let Some(method_def) = method_scope.node(db).as_function() {
let method = index.expect_single_definition(method_def);
let method_place = class_table.place_id_by_name(&method_def.name).unwrap();
class_map
.public_bindings(method_place)
.find_map(|bind| {
(bind.binding.is_defined_and(|def| def == method))
.then(|| class_map.is_binding_visible(db, &bind))
})
.unwrap_or(Truthiness::AlwaysFalse)
} else {
Truthiness::AlwaysFalse
};
let is_method_visible =
if let Some(method_def) = method_scope.node(db).as_function(&module) {
let method = index.expect_single_definition(method_def);
let method_place = class_table.place_id_by_name(&method_def.name).unwrap();
class_map
.public_bindings(method_place)
.find_map(|bind| {
(bind.binding.is_defined_and(|def| def == method))
.then(|| class_map.is_binding_visible(db, &bind))
})
.unwrap_or(Truthiness::AlwaysFalse)
} else {
Truthiness::AlwaysFalse
};
if is_method_visible.is_always_false() {
continue;
}
@ -1688,8 +1699,10 @@ impl<'db> ClassLiteral<'db> {
// self.name: <annotation>
// self.name: <annotation> = …
let annotation_ty =
infer_expression_type(db, index.expression(ann_assign.annotation()));
let annotation_ty = infer_expression_type(
db,
index.expression(ann_assign.annotation(&module)),
);
// TODO: check if there are conflicting declarations
match is_attribute_bound {
@ -1714,8 +1727,9 @@ impl<'db> ClassLiteral<'db> {
// [.., self.name, ..] = <value>
let unpacked = infer_unpack_types(db, unpack);
let target_ast_id =
assign.target().scoped_expression_id(db, method_scope);
let target_ast_id = assign
.target(&module)
.scoped_expression_id(db, method_scope);
let inferred_ty = unpacked.expression_type(target_ast_id);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
@ -1725,8 +1739,10 @@ impl<'db> ClassLiteral<'db> {
//
// self.name = <value>
let inferred_ty =
infer_expression_type(db, index.expression(assign.value()));
let inferred_ty = infer_expression_type(
db,
index.expression(assign.value(&module)),
);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
}
@ -1740,8 +1756,9 @@ impl<'db> ClassLiteral<'db> {
// for .., self.name, .. in <iterable>:
let unpacked = infer_unpack_types(db, unpack);
let target_ast_id =
for_stmt.target().scoped_expression_id(db, method_scope);
let target_ast_id = for_stmt
.target(&module)
.scoped_expression_id(db, method_scope);
let inferred_ty = unpacked.expression_type(target_ast_id);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
@ -1753,7 +1770,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type(
db,
index.expression(for_stmt.iterable()),
index.expression(for_stmt.iterable(&module)),
);
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = iterable_ty.iterate(db);
@ -1770,8 +1787,9 @@ impl<'db> ClassLiteral<'db> {
// with <context_manager> as .., self.name, ..:
let unpacked = infer_unpack_types(db, unpack);
let target_ast_id =
with_item.target().scoped_expression_id(db, method_scope);
let target_ast_id = with_item
.target(&module)
.scoped_expression_id(db, method_scope);
let inferred_ty = unpacked.expression_type(target_ast_id);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
@ -1783,7 +1801,7 @@ impl<'db> ClassLiteral<'db> {
let context_ty = infer_expression_type(
db,
index.expression(with_item.context_expr()),
index.expression(with_item.context_expr(&module)),
);
let inferred_ty = context_ty.enter(db);
@ -1800,7 +1818,7 @@ impl<'db> ClassLiteral<'db> {
let unpacked = infer_unpack_types(db, unpack);
let target_ast_id = comprehension
.target()
.target(&module)
.scoped_expression_id(db, unpack.target_scope(db));
let inferred_ty = unpacked.expression_type(target_ast_id);
@ -1813,7 +1831,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type(
db,
index.expression(comprehension.iterable()),
index.expression(comprehension.iterable(&module)),
);
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = iterable_ty.iterate(db);
@ -2003,8 +2021,8 @@ impl<'db> ClassLiteral<'db> {
/// Returns a [`Span`] with the range of the class's header.
///
/// See [`Self::header_range`] for more details.
pub(super) fn header_span(self, db: &'db dyn Db) -> Span {
Span::from(self.file(db)).with_range(self.header_range(db))
pub(super) fn header_span(self, db: &'db dyn Db, module: &ParsedModuleRef) -> Span {
Span::from(self.file(db)).with_range(self.header_range(db, module))
}
/// Returns the range of the class's "header": the class name
@ -2014,9 +2032,9 @@ impl<'db> ClassLiteral<'db> {
/// class Foo(Bar, metaclass=Baz): ...
/// ^^^^^^^^^^^^^^^^^^^^^^^
/// ```
pub(super) fn header_range(self, db: &'db dyn Db) -> TextRange {
pub(super) fn header_range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> TextRange {
let class_scope = self.body_scope(db);
let class_node = class_scope.node(db).expect_class();
let class_node = class_scope.node(db).expect_class(module);
let class_name = &class_node.name;
TextRange::new(
class_name.start(),

View file

@ -2,6 +2,7 @@ use std::fmt;
use drop_bomb::DebugDropBomb;
use ruff_db::diagnostic::{DiagnosticTag, SubDiagnostic};
use ruff_db::parsed::ParsedModuleRef;
use ruff_db::{
diagnostic::{Annotation, Diagnostic, DiagnosticId, IntoDiagnosticMessage, Severity, Span},
files::File,
@ -32,20 +33,22 @@ use crate::{
/// It's important that the context is explicitly consumed before dropping by calling
/// [`InferContext::finish`] and the returned diagnostics must be stored
/// on the current [`TypeInference`](super::infer::TypeInference) result.
pub(crate) struct InferContext<'db> {
pub(crate) struct InferContext<'db, 'ast> {
db: &'db dyn Db,
scope: ScopeId<'db>,
file: File,
module: &'ast ParsedModuleRef,
diagnostics: std::cell::RefCell<TypeCheckDiagnostics>,
no_type_check: InNoTypeCheck,
bomb: DebugDropBomb,
}
impl<'db> InferContext<'db> {
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>) -> Self {
impl<'db, 'ast> InferContext<'db, 'ast> {
pub(crate) fn new(db: &'db dyn Db, scope: ScopeId<'db>, module: &'ast ParsedModuleRef) -> Self {
Self {
db,
scope,
module,
file: scope.file(db),
diagnostics: std::cell::RefCell::new(TypeCheckDiagnostics::default()),
no_type_check: InNoTypeCheck::default(),
@ -60,6 +63,11 @@ impl<'db> InferContext<'db> {
self.file
}
/// The module for which the types are inferred.
pub(crate) fn module(&self) -> &'ast ParsedModuleRef {
self.module
}
/// Create a span with the range of the given expression
/// in the file being currently type checked.
///
@ -160,7 +168,7 @@ impl<'db> InferContext<'db> {
// Inspect all ancestor function scopes by walking bottom up and infer the function's type.
let mut function_scope_tys = index
.ancestor_scopes(scope_id)
.filter_map(|(_, scope)| scope.node().as_function())
.filter_map(|(_, scope)| scope.node().as_function(self.module()))
.map(|node| binding_type(self.db, index.expect_single_definition(node)))
.filter_map(Type::into_function_literal);
@ -187,7 +195,7 @@ impl<'db> InferContext<'db> {
}
}
impl fmt::Debug for InferContext<'_> {
impl fmt::Debug for InferContext<'_, '_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("TyContext")
.field("file", &self.file)
@ -221,7 +229,7 @@ pub(crate) enum InNoTypeCheck {
/// will attach a message to the primary span on the diagnostic.
pub(super) struct LintDiagnosticGuard<'db, 'ctx> {
/// The typing context.
ctx: &'ctx InferContext<'db>,
ctx: &'ctx InferContext<'db, 'ctx>,
/// The diagnostic that we want to report.
///
/// This is always `Some` until the `Drop` impl.
@ -363,7 +371,7 @@ impl Drop for LintDiagnosticGuard<'_, '_> {
/// it is known that the diagnostic should not be reported. This can happen
/// when the diagnostic is disabled or suppressed (among other reasons).
pub(super) struct LintDiagnosticGuardBuilder<'db, 'ctx> {
ctx: &'ctx InferContext<'db>,
ctx: &'ctx InferContext<'db, 'ctx>,
id: DiagnosticId,
severity: Severity,
source: LintSource,
@ -372,7 +380,7 @@ pub(super) struct LintDiagnosticGuardBuilder<'db, 'ctx> {
impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
fn new(
ctx: &'ctx InferContext<'db>,
ctx: &'ctx InferContext<'db, 'ctx>,
lint: &'static LintMetadata,
range: TextRange,
) -> Option<LintDiagnosticGuardBuilder<'db, 'ctx>> {
@ -462,7 +470,7 @@ impl<'db, 'ctx> LintDiagnosticGuardBuilder<'db, 'ctx> {
/// if either is violated, then the `Drop` impl on `DiagnosticGuard` will
/// panic.
pub(super) struct DiagnosticGuard<'db, 'ctx> {
ctx: &'ctx InferContext<'db>,
ctx: &'ctx InferContext<'db, 'ctx>,
/// The diagnostic that we want to report.
///
/// This is always `Some` until the `Drop` impl.
@ -550,14 +558,14 @@ impl Drop for DiagnosticGuard<'_, '_> {
/// minimal amount of information with which to construct a diagnostic) before
/// one can mutate the diagnostic.
pub(super) struct DiagnosticGuardBuilder<'db, 'ctx> {
ctx: &'ctx InferContext<'db>,
ctx: &'ctx InferContext<'db, 'ctx>,
id: DiagnosticId,
severity: Severity,
}
impl<'db, 'ctx> DiagnosticGuardBuilder<'db, 'ctx> {
fn new(
ctx: &'ctx InferContext<'db>,
ctx: &'ctx InferContext<'db, 'ctx>,
id: DiagnosticId,
severity: Severity,
) -> Option<DiagnosticGuardBuilder<'db, 'ctx>> {

View file

@ -1,6 +1,7 @@
use crate::semantic_index::definition::Definition;
use crate::{Db, Module};
use ruff_db::files::FileRange;
use ruff_db::parsed::parsed_module;
use ruff_db::source::source_text;
use ruff_text_size::{TextLen, TextRange};
@ -20,7 +21,10 @@ impl TypeDefinition<'_> {
Self::Class(definition)
| Self::Function(definition)
| Self::TypeVar(definition)
| Self::TypeAlias(definition) => Some(definition.focus_range(db)),
| Self::TypeAlias(definition) => {
let module = parsed_module(db.upcast(), definition.file(db)).load(db.upcast());
Some(definition.focus_range(db, &module))
}
}
}
@ -34,7 +38,10 @@ impl TypeDefinition<'_> {
Self::Class(definition)
| Self::Function(definition)
| Self::TypeVar(definition)
| Self::TypeAlias(definition) => Some(definition.full_range(db)),
| Self::TypeAlias(definition) => {
let module = parsed_module(db.upcast(), definition.file(db)).load(db.upcast());
Some(definition.full_range(db, &module))
}
}
}
}

View file

@ -1730,7 +1730,7 @@ pub(super) fn report_implicit_return_type(
or `typing_extensions.Protocol` are considered protocol classes",
);
sub_diagnostic.annotate(
Annotation::primary(class.header_span(db)).message(format_args!(
Annotation::primary(class.header_span(db, context.module())).message(format_args!(
"`Protocol` not present in `{class}`'s immediate bases",
class = class.name(db)
)),
@ -1850,7 +1850,7 @@ pub(crate) fn report_bad_argument_to_get_protocol_members(
class.name(db)
),
);
class_def_diagnostic.annotate(Annotation::primary(class.header_span(db)));
class_def_diagnostic.annotate(Annotation::primary(class.header_span(db, context.module())));
diagnostic.sub(class_def_diagnostic);
diagnostic.info(
@ -1910,7 +1910,7 @@ pub(crate) fn report_runtime_check_against_non_runtime_checkable_protocol(
),
);
class_def_diagnostic.annotate(
Annotation::primary(protocol.header_span(db))
Annotation::primary(protocol.header_span(db, context.module()))
.message(format_args!("`{class_name}` declared here")),
);
diagnostic.sub(class_def_diagnostic);
@ -1941,7 +1941,7 @@ pub(crate) fn report_attempted_protocol_instantiation(
format_args!("Protocol classes cannot be instantiated"),
);
class_def_diagnostic.annotate(
Annotation::primary(protocol.header_span(db))
Annotation::primary(protocol.header_span(db, context.module()))
.message(format_args!("`{class_name}` declared as a protocol here")),
);
diagnostic.sub(class_def_diagnostic);
@ -1955,7 +1955,9 @@ pub(crate) fn report_duplicate_bases(
) {
let db = context.db();
let Some(builder) = context.report_lint(&DUPLICATE_BASE, class.header_range(db)) else {
let Some(builder) =
context.report_lint(&DUPLICATE_BASE, class.header_range(db, context.module()))
else {
return;
};
@ -2104,7 +2106,7 @@ fn report_unsupported_base(
}
fn report_invalid_base<'ctx, 'db>(
context: &'ctx InferContext<'db>,
context: &'ctx InferContext<'db, '_>,
base_node: &ast::Expr,
base_type: Type<'db>,
class: ClassLiteral<'db>,

View file

@ -54,6 +54,7 @@ use std::str::FromStr;
use bitflags::bitflags;
use ruff_db::diagnostic::Span;
use ruff_db::files::{File, FileRange};
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast as ast;
use ruff_text_size::Ranged;
@ -187,7 +188,12 @@ impl<'db> OverloadLiteral<'db> {
self.has_known_decorator(db, FunctionDecorators::OVERLOAD)
}
fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef {
fn node<'ast>(
self,
db: &dyn Db,
file: File,
module: &'ast ParsedModuleRef,
) -> &'ast ast::StmtFunctionDef {
debug_assert_eq!(
file,
self.file(db),
@ -195,14 +201,18 @@ impl<'db> OverloadLiteral<'db> {
the function is defined."
);
self.body_scope(db).node(db).expect_function()
self.body_scope(db).node(db).expect_function(module)
}
/// Returns the [`FileRange`] of the function's name.
pub(crate) fn focus_range(self, db: &dyn Db) -> FileRange {
pub(crate) fn focus_range(self, db: &dyn Db, module: &ParsedModuleRef) -> FileRange {
FileRange::new(
self.file(db),
self.body_scope(db).node(db).expect_function().name.range,
self.body_scope(db)
.node(db)
.expect_function(module)
.name
.range,
)
}
@ -216,8 +226,9 @@ impl<'db> OverloadLiteral<'db> {
/// over-invalidation.
fn definition(self, db: &'db dyn Db) -> Definition<'db> {
let body_scope = self.body_scope(db);
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let index = semantic_index(db, body_scope.file(db));
index.expect_single_definition(body_scope.node(db).expect_function())
index.expect_single_definition(body_scope.node(db).expect_function(&module))
}
/// Returns the overload immediately before this one in the AST. Returns `None` if there is no
@ -226,11 +237,12 @@ impl<'db> OverloadLiteral<'db> {
// The semantic model records a use for each function on the name node. This is used
// here to get the previous function definition with the same name.
let scope = self.definition(db).scope(db);
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let use_def = semantic_index(db, scope.file(db)).use_def_map(scope.file_scope_id(db));
let use_id = self
.body_scope(db)
.node(db)
.expect_function()
.expect_function(&module)
.name
.scoped_use_id(db, scope);
@ -266,7 +278,8 @@ impl<'db> OverloadLiteral<'db> {
inherited_generic_context: Option<GenericContext<'db>>,
) -> Signature<'db> {
let scope = self.body_scope(db);
let function_stmt_node = scope.node(db).expect_function();
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let function_stmt_node = scope.node(db).expect_function(&module);
let definition = self.definition(db);
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
let index = semantic_index(db, scope.file(db));
@ -289,7 +302,8 @@ impl<'db> OverloadLiteral<'db> {
let function_scope = self.body_scope(db);
let span = Span::from(function_scope.file(db));
let node = function_scope.node(db);
let func_def = node.as_function()?;
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let func_def = node.as_function(&module)?;
let range = parameter_index
.and_then(|parameter_index| {
func_def
@ -308,7 +322,8 @@ impl<'db> OverloadLiteral<'db> {
let function_scope = self.body_scope(db);
let span = Span::from(function_scope.file(db));
let node = function_scope.node(db);
let func_def = node.as_function()?;
let module = parsed_module(db.upcast(), self.file(db)).load(db.upcast());
let func_def = node.as_function(&module)?;
let return_type_range = func_def.returns.as_ref().map(|returns| returns.range());
let mut signature = func_def.name.range.cover(func_def.parameters.range);
if let Some(return_type_range) = return_type_range {
@ -553,8 +568,13 @@ impl<'db> FunctionType<'db> {
}
/// Returns the AST node for this function.
pub(crate) fn node(self, db: &'db dyn Db, file: File) -> &'db ast::StmtFunctionDef {
self.literal(db).last_definition(db).node(db, file)
pub(crate) fn node<'ast>(
self,
db: &dyn Db,
file: File,
module: &'ast ParsedModuleRef,
) -> &'ast ast::StmtFunctionDef {
self.literal(db).last_definition(db).node(db, file, module)
}
pub(crate) fn name(self, db: &'db dyn Db) -> &'db ast::name::Name {

View file

@ -36,7 +36,7 @@
use itertools::{Either, Itertools};
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
use ruff_db::files::File;
use ruff_db::parsed::parsed_module;
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_ast::visitor::{Visitor, walk_expr};
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, PythonVersion};
use ruff_python_stdlib::builtins::version_builtin_was_added;
@ -136,11 +136,13 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Ty
let file = scope.file(db);
let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered();
let module = parsed_module(db.upcast(), file).load(db.upcast());
// Using the index here is fine because the code below depends on the AST anyway.
// The isolation of the query is by the return inferred types.
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index).finish()
TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish()
}
fn scope_cycle_recover<'db>(
@ -164,16 +166,17 @@ pub(crate) fn infer_definition_types<'db>(
definition: Definition<'db>,
) -> TypeInference<'db> {
let file = definition.file(db);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let _span = tracing::trace_span!(
"infer_definition_types",
range = ?definition.kind(db).target_range(),
range = ?definition.kind(db).target_range(&module),
?file
)
.entered();
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index).finish()
TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index, &module).finish()
}
fn definition_cycle_recover<'db>(
@ -202,17 +205,18 @@ pub(crate) fn infer_deferred_types<'db>(
definition: Definition<'db>,
) -> TypeInference<'db> {
let file = definition.file(db);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let _span = tracing::trace_span!(
"infer_deferred_types",
definition = ?definition.as_id(),
range = ?definition.kind(db).target_range(),
range = ?definition.kind(db).target_range(&module),
?file
)
.entered();
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Deferred(definition), index).finish()
TypeInferenceBuilder::new(db, InferenceRegion::Deferred(definition), index, &module).finish()
}
fn deferred_cycle_recover<'db>(
@ -238,17 +242,18 @@ pub(crate) fn infer_expression_types<'db>(
expression: Expression<'db>,
) -> TypeInference<'db> {
let file = expression.file(db);
let module = parsed_module(db.upcast(), file).load(db.upcast());
let _span = tracing::trace_span!(
"infer_expression_types",
expression = ?expression.as_id(),
range = ?expression.node_ref(db).range(),
range = ?expression.node_ref(db, &module).range(),
?file
)
.entered();
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish()
TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index, &module).finish()
}
fn expression_cycle_recover<'db>(
@ -275,10 +280,15 @@ fn expression_cycle_initial<'db>(
pub(super) fn infer_same_file_expression_type<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
parsed: &ParsedModuleRef,
) -> Type<'db> {
let inference = infer_expression_types(db, expression);
let scope = expression.scope(db);
inference.expression_type(expression.node_ref(db).scoped_expression_id(db, scope))
inference.expression_type(
expression
.node_ref(db, parsed)
.scoped_expression_id(db, scope),
)
}
/// Infers the type of an expression where the expression might come from another file.
@ -293,8 +303,11 @@ pub(crate) fn infer_expression_type<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Type<'db> {
let file = expression.file(db);
let module = parsed_module(db.upcast(), file).load(db.upcast());
// It's okay to call the "same file" version here because we're inside a salsa query.
infer_same_file_expression_type(db, expression)
infer_same_file_expression_type(db, expression, &module)
}
fn single_expression_cycle_recover<'db>(
@ -322,11 +335,12 @@ fn single_expression_cycle_initial<'db>(
#[salsa::tracked(returns(ref), cycle_fn=unpack_cycle_recover, cycle_initial=unpack_cycle_initial)]
pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> {
let file = unpack.file(db);
let _span =
tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), ?file).entered();
let module = parsed_module(db.upcast(), file).load(db.upcast());
let _span = tracing::trace_span!("infer_unpack_types", range=?unpack.range(db, &module), ?file)
.entered();
let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db));
unpacker.unpack(unpack.target(db), unpack.value(db));
let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db), &module);
unpacker.unpack(unpack.target(db, &module), unpack.value(db));
unpacker.finish()
}
@ -356,11 +370,12 @@ pub(crate) fn nearest_enclosing_class<'db>(
db: &'db dyn Db,
semantic: &SemanticIndex<'db>,
scope: ScopeId,
parsed: &ParsedModuleRef,
) -> Option<ClassLiteral<'db>> {
semantic
.ancestor_scopes(scope.file_scope_id(db))
.find_map(|(_, ancestor_scope)| {
let class = ancestor_scope.node().as_class()?;
let class = ancestor_scope.node().as_class(parsed)?;
let definition = semantic.expect_single_definition(class);
infer_definition_types(db, definition)
.declaration_type(definition)
@ -569,8 +584,8 @@ enum DeclaredAndInferredType<'db> {
/// Similarly, when we encounter a standalone-inferable expression (right-hand side of an
/// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we
/// don't infer its types more than once.
pub(super) struct TypeInferenceBuilder<'db> {
context: InferContext<'db>,
pub(super) struct TypeInferenceBuilder<'db, 'ast> {
context: InferContext<'db, 'ast>,
index: &'db SemanticIndex<'db>,
region: InferenceRegion<'db>,
@ -617,7 +632,7 @@ pub(super) struct TypeInferenceBuilder<'db> {
deferred_state: DeferredExpressionState,
}
impl<'db> TypeInferenceBuilder<'db> {
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// How big a string do we build before bailing?
///
/// This is a fairly arbitrary number. It should be *far* more than enough
@ -629,11 +644,12 @@ impl<'db> TypeInferenceBuilder<'db> {
db: &'db dyn Db,
region: InferenceRegion<'db>,
index: &'db SemanticIndex<'db>,
module: &'ast ParsedModuleRef,
) -> Self {
let scope = region.scope(db);
Self {
context: InferContext::new(db, scope),
context: InferContext::new(db, scope, module),
index,
region,
return_types_and_ranges: vec![],
@ -659,6 +675,10 @@ impl<'db> TypeInferenceBuilder<'db> {
self.context.file()
}
fn module(&self) -> &'ast ParsedModuleRef {
self.context.module()
}
fn db(&self) -> &'db dyn Db {
self.context.db()
}
@ -756,35 +776,36 @@ impl<'db> TypeInferenceBuilder<'db> {
let node = scope.node(self.db());
match node {
NodeWithScopeKind::Module => {
let parsed = parsed_module(self.db().upcast(), self.file());
self.infer_module(parsed.syntax());
self.infer_module(self.module().syntax());
}
NodeWithScopeKind::Function(function) => self.infer_function_body(function.node()),
NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node()),
NodeWithScopeKind::Class(class) => self.infer_class_body(class.node()),
NodeWithScopeKind::Function(function) => {
self.infer_function_body(function.node(self.module()));
}
NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node(self.module())),
NodeWithScopeKind::Class(class) => self.infer_class_body(class.node(self.module())),
NodeWithScopeKind::ClassTypeParameters(class) => {
self.infer_class_type_params(class.node());
self.infer_class_type_params(class.node(self.module()));
}
NodeWithScopeKind::FunctionTypeParameters(function) => {
self.infer_function_type_params(function.node());
self.infer_function_type_params(function.node(self.module()));
}
NodeWithScopeKind::TypeAliasTypeParameters(type_alias) => {
self.infer_type_alias_type_params(type_alias.node());
self.infer_type_alias_type_params(type_alias.node(self.module()));
}
NodeWithScopeKind::TypeAlias(type_alias) => {
self.infer_type_alias(type_alias.node());
self.infer_type_alias(type_alias.node(self.module()));
}
NodeWithScopeKind::ListComprehension(comprehension) => {
self.infer_list_comprehension_expression_scope(comprehension.node());
self.infer_list_comprehension_expression_scope(comprehension.node(self.module()));
}
NodeWithScopeKind::SetComprehension(comprehension) => {
self.infer_set_comprehension_expression_scope(comprehension.node());
self.infer_set_comprehension_expression_scope(comprehension.node(self.module()));
}
NodeWithScopeKind::DictComprehension(comprehension) => {
self.infer_dict_comprehension_expression_scope(comprehension.node());
self.infer_dict_comprehension_expression_scope(comprehension.node(self.module()));
}
NodeWithScopeKind::GeneratorExpression(generator) => {
self.infer_generator_expression_scope(generator.node());
self.infer_generator_expression_scope(generator.node(self.module()));
}
}
@ -823,7 +844,7 @@ impl<'db> TypeInferenceBuilder<'db> {
if let DefinitionKind::Class(class) = definition.kind(self.db()) {
ty.inner_type()
.into_class_literal()
.map(|class_literal| (class_literal, class.node()))
.map(|class_literal| (class_literal, class.node(self.module())))
} else {
None
}
@ -1143,7 +1164,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// Check that the overloaded function has at least two overloads
if let [single_overload] = overloads.as_ref() {
let function_node = function.node(self.db(), self.file());
let function_node = function.node(self.db(), self.file(), self.module());
if let Some(builder) = self
.context
.report_lint(&INVALID_OVERLOAD, &function_node.name)
@ -1154,7 +1175,7 @@ impl<'db> TypeInferenceBuilder<'db> {
));
diagnostic.annotate(
self.context
.secondary(single_overload.focus_range(self.db()))
.secondary(single_overload.focus_range(self.db(), self.module()))
.message(format_args!("Only one overload defined here")),
);
}
@ -1169,7 +1190,8 @@ impl<'db> TypeInferenceBuilder<'db> {
if let NodeWithScopeKind::Class(class_node_ref) = scope {
let class = binding_type(
self.db(),
self.index.expect_single_definition(class_node_ref.node()),
self.index
.expect_single_definition(class_node_ref.node(self.module())),
)
.expect_class_literal();
@ -1187,7 +1209,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
if implementation_required {
let function_node = function.node(self.db(), self.file());
let function_node = function.node(self.db(), self.file(), self.module());
if let Some(builder) = self
.context
.report_lint(&INVALID_OVERLOAD, &function_node.name)
@ -1222,7 +1244,7 @@ impl<'db> TypeInferenceBuilder<'db> {
continue;
}
let function_node = function.node(self.db(), self.file());
let function_node = function.node(self.db(), self.file(), self.module());
if let Some(builder) = self
.context
.report_lint(&INVALID_OVERLOAD, &function_node.name)
@ -1235,7 +1257,7 @@ impl<'db> TypeInferenceBuilder<'db> {
for function in decorator_missing {
diagnostic.annotate(
self.context
.secondary(function.focus_range(self.db()))
.secondary(function.focus_range(self.db(), self.module()))
.message(format_args!("Missing here")),
);
}
@ -1251,7 +1273,7 @@ impl<'db> TypeInferenceBuilder<'db> {
if !overload.has_known_decorator(self.db(), decorator) {
continue;
}
let function_node = function.node(self.db(), self.file());
let function_node = function.node(self.db(), self.file(), self.module());
let Some(builder) = self
.context
.report_lint(&INVALID_OVERLOAD, &function_node.name)
@ -1264,7 +1286,7 @@ impl<'db> TypeInferenceBuilder<'db> {
));
diagnostic.annotate(
self.context
.secondary(implementation.focus_range(self.db()))
.secondary(implementation.focus_range(self.db(), self.module()))
.message(format_args!("Implementation defined here")),
);
}
@ -1277,7 +1299,7 @@ impl<'db> TypeInferenceBuilder<'db> {
if !overload.has_known_decorator(self.db(), decorator) {
continue;
}
let function_node = function.node(self.db(), self.file());
let function_node = function.node(self.db(), self.file(), self.module());
let Some(builder) = self
.context
.report_lint(&INVALID_OVERLOAD, &function_node.name)
@ -1290,7 +1312,7 @@ impl<'db> TypeInferenceBuilder<'db> {
));
diagnostic.annotate(
self.context
.secondary(first_overload.focus_range(self.db()))
.secondary(first_overload.focus_range(self.db(), self.module()))
.message(format_args!("First overload defined here")),
);
}
@ -1302,24 +1324,34 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_region_definition(&mut self, definition: Definition<'db>) {
match definition.kind(self.db()) {
DefinitionKind::Function(function) => {
self.infer_function_definition(function.node(), definition);
self.infer_function_definition(function.node(self.module()), definition);
}
DefinitionKind::Class(class) => {
self.infer_class_definition(class.node(self.module()), definition);
}
DefinitionKind::Class(class) => self.infer_class_definition(class.node(), definition),
DefinitionKind::TypeAlias(type_alias) => {
self.infer_type_alias_definition(type_alias.node(), definition);
self.infer_type_alias_definition(type_alias.node(self.module()), definition);
}
DefinitionKind::Import(import) => {
self.infer_import_definition(import.import(), import.alias(), definition);
self.infer_import_definition(
import.import(self.module()),
import.alias(self.module()),
definition,
);
}
DefinitionKind::ImportFrom(import_from) => {
self.infer_import_from_definition(
import_from.import(),
import_from.alias(),
import_from.import(self.module()),
import_from.alias(self.module()),
definition,
);
}
DefinitionKind::StarImport(import) => {
self.infer_import_from_definition(import.import(), import.alias(), definition);
self.infer_import_from_definition(
import.import(self.module()),
import.alias(self.module()),
definition,
);
}
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_definition(assignment, definition);
@ -1328,32 +1360,47 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_annotated_assignment_definition(annotated_assignment, definition);
}
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
self.infer_augment_assignment_definition(augmented_assignment.node(), definition);
self.infer_augment_assignment_definition(
augmented_assignment.node(self.module()),
definition,
);
}
DefinitionKind::For(for_statement_definition) => {
self.infer_for_statement_definition(for_statement_definition, definition);
}
DefinitionKind::NamedExpression(named_expression) => {
self.infer_named_expression_definition(named_expression.node(), definition);
self.infer_named_expression_definition(
named_expression.node(self.module()),
definition,
);
}
DefinitionKind::Comprehension(comprehension) => {
self.infer_comprehension_definition(comprehension, definition);
}
DefinitionKind::VariadicPositionalParameter(parameter) => {
self.infer_variadic_positional_parameter_definition(parameter, definition);
self.infer_variadic_positional_parameter_definition(
parameter.node(self.module()),
definition,
);
}
DefinitionKind::VariadicKeywordParameter(parameter) => {
self.infer_variadic_keyword_parameter_definition(parameter, definition);
self.infer_variadic_keyword_parameter_definition(
parameter.node(self.module()),
definition,
);
}
DefinitionKind::Parameter(parameter_with_default) => {
self.infer_parameter_definition(parameter_with_default, definition);
self.infer_parameter_definition(
parameter_with_default.node(self.module()),
definition,
);
}
DefinitionKind::WithItem(with_item_definition) => {
self.infer_with_item_definition(with_item_definition, definition);
}
DefinitionKind::MatchPattern(match_pattern) => {
self.infer_match_pattern_definition(
match_pattern.pattern(),
match_pattern.pattern(self.module()),
match_pattern.index(),
definition,
);
@ -1362,13 +1409,13 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_except_handler_definition(except_handler_definition, definition);
}
DefinitionKind::TypeVar(node) => {
self.infer_typevar_definition(node, definition);
self.infer_typevar_definition(node.node(self.module()), definition);
}
DefinitionKind::ParamSpec(node) => {
self.infer_paramspec_definition(node, definition);
self.infer_paramspec_definition(node.node(self.module()), definition);
}
DefinitionKind::TypeVarTuple(node) => {
self.infer_typevartuple_definition(node, definition);
self.infer_typevartuple_definition(node.node(self.module()), definition);
}
}
}
@ -1384,8 +1431,10 @@ impl<'db> TypeInferenceBuilder<'db> {
// implementation to allow this "split" to happen.
match definition.kind(self.db()) {
DefinitionKind::Function(function) => self.infer_function_deferred(function.node()),
DefinitionKind::Class(class) => self.infer_class_deferred(class.node()),
DefinitionKind::Function(function) => {
self.infer_function_deferred(function.node(self.module()));
}
DefinitionKind::Class(class) => self.infer_class_deferred(class.node(self.module())),
_ => {}
}
}
@ -1393,10 +1442,10 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_region_expression(&mut self, expression: Expression<'db>) {
match expression.kind(self.db()) {
ExpressionKind::Normal => {
self.infer_expression_impl(expression.node_ref(self.db()));
self.infer_expression_impl(expression.node_ref(self.db(), self.module()));
}
ExpressionKind::TypeExpression => {
self.infer_type_expression(expression.node_ref(self.db()));
self.infer_type_expression(expression.node_ref(self.db(), self.module()));
}
}
}
@ -1441,7 +1490,7 @@ impl<'db> TypeInferenceBuilder<'db> {
debug_assert!(
binding
.kind(self.db())
.category(self.context.in_stub())
.category(self.context.in_stub(), self.module())
.is_binding()
);
@ -1555,7 +1604,7 @@ impl<'db> TypeInferenceBuilder<'db> {
debug_assert!(
declaration
.kind(self.db())
.category(self.context.in_stub())
.category(self.context.in_stub(), self.module())
.is_declaration()
);
let use_def = self.index.use_def_map(declaration.file_scope(self.db()));
@ -1601,13 +1650,13 @@ impl<'db> TypeInferenceBuilder<'db> {
debug_assert!(
definition
.kind(self.db())
.category(self.context.in_stub())
.category(self.context.in_stub(), self.module())
.is_binding()
);
debug_assert!(
definition
.kind(self.db())
.category(self.context.in_stub())
.category(self.context.in_stub(), self.module())
.is_declaration()
);
@ -1763,7 +1812,7 @@ impl<'db> TypeInferenceBuilder<'db> {
_ => return None,
};
let class_stmt = class_scope.node().as_class()?;
let class_stmt = class_scope.node().as_class(self.module())?;
let class_definition = self.index.expect_single_definition(class_stmt);
binding_type(self.db(), class_definition).into_class_literal()
}
@ -1784,17 +1833,21 @@ impl<'db> TypeInferenceBuilder<'db> {
return false;
};
node_ref.decorator_list.iter().any(|decorator| {
let decorator_type = self.file_expression_type(&decorator.expression);
node_ref
.node(self.module())
.decorator_list
.iter()
.any(|decorator| {
let decorator_type = self.file_expression_type(&decorator.expression);
match decorator_type {
Type::FunctionLiteral(function) => matches!(
function.known(self.db()),
Some(KnownFunction::Overload | KnownFunction::AbstractMethod)
),
_ => false,
}
})
match decorator_type {
Type::FunctionLiteral(function) => matches!(
function.known(self.db()),
Some(KnownFunction::Overload | KnownFunction::AbstractMethod)
),
_ => false,
}
})
}
fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) {
@ -2558,8 +2611,8 @@ impl<'db> TypeInferenceBuilder<'db> {
with_item: &WithItemDefinitionKind<'db>,
definition: Definition<'db>,
) {
let context_expr = with_item.context_expr();
let target = with_item.target();
let context_expr = with_item.context_expr(self.module());
let target = with_item.target(self.module());
let context_expr_ty = self.infer_standalone_expression(context_expr);
@ -2707,12 +2760,12 @@ impl<'db> TypeInferenceBuilder<'db> {
definition: Definition<'db>,
) {
let symbol_ty = self.infer_exception(
except_handler_definition.handled_exceptions(),
except_handler_definition.handled_exceptions(self.module()),
except_handler_definition.is_star(),
);
self.add_binding(
except_handler_definition.node().into(),
except_handler_definition.node(self.module()).into(),
definition,
symbol_ty,
);
@ -3001,7 +3054,7 @@ impl<'db> TypeInferenceBuilder<'db> {
/// `target`.
fn infer_target<F>(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F)
where
F: Fn(&mut TypeInferenceBuilder<'db>, &ast::Expr) -> Type<'db>,
F: Fn(&mut TypeInferenceBuilder<'db, '_>, &ast::Expr) -> Type<'db>,
{
let assigned_ty = match target {
ast::Expr::Name(_) => None,
@ -3542,8 +3595,8 @@ impl<'db> TypeInferenceBuilder<'db> {
assignment: &AssignmentDefinitionKind<'db>,
definition: Definition<'db>,
) {
let value = assignment.value();
let target = assignment.target();
let value = assignment.value(self.module());
let target = assignment.target(self.module());
let value_ty = self.infer_standalone_expression(value);
@ -3625,9 +3678,9 @@ impl<'db> TypeInferenceBuilder<'db> {
assignment: &'db AnnotatedAssignmentDefinitionKind,
definition: Definition<'db>,
) {
let annotation = assignment.annotation();
let target = assignment.target();
let value = assignment.value();
let annotation = assignment.annotation(self.module());
let target = assignment.target(self.module());
let value = assignment.value(self.module());
let mut declared_ty = self.infer_annotation_expression(
annotation,
@ -3858,8 +3911,8 @@ impl<'db> TypeInferenceBuilder<'db> {
for_stmt: &ForStmtDefinitionKind<'db>,
definition: Definition<'db>,
) {
let iterable = for_stmt.iterable();
let target = for_stmt.target();
let iterable = for_stmt.iterable(self.module());
let target = for_stmt.target(self.module());
let iterable_type = self.infer_standalone_expression(iterable);
@ -3967,7 +4020,7 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_import_definition(
&mut self,
node: &ast::StmtImport,
alias: &'db ast::Alias,
alias: &ast::Alias,
definition: Definition<'db>,
) {
let ast::Alias {
@ -4146,7 +4199,7 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_import_from_definition(
&mut self,
import_from: &'db ast::StmtImportFrom,
import_from: &ast::StmtImportFrom,
alias: &ast::Alias,
definition: Definition<'db>,
) {
@ -4804,7 +4857,11 @@ impl<'db> TypeInferenceBuilder<'db> {
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `[... for a.x in not_iterable]
if is_first {
infer_same_file_expression_type(builder.db(), builder.index.expression(iter_expr))
infer_same_file_expression_type(
builder.db(),
builder.index.expression(iter_expr),
builder.module(),
)
} else {
builder.infer_standalone_expression(iter_expr)
}
@ -4820,8 +4877,8 @@ impl<'db> TypeInferenceBuilder<'db> {
comprehension: &ComprehensionDefinitionKind<'db>,
definition: Definition<'db>,
) {
let iterable = comprehension.iterable();
let target = comprehension.target();
let iterable = comprehension.iterable(self.module());
let target = comprehension.target(self.module());
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db(), expression);
@ -5009,8 +5066,10 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Returns `None` if the scope is not function-like, or has no parameters.
fn first_param_type_in_scope(&self, scope: ScopeId) -> Option<Type<'db>> {
let first_param = match scope.node(self.db()) {
NodeWithScopeKind::Function(f) => f.parameters.iter().next(),
NodeWithScopeKind::Lambda(l) => l.parameters.as_ref()?.iter().next(),
NodeWithScopeKind::Function(f) => f.node(self.module()).parameters.iter().next(),
NodeWithScopeKind::Lambda(l) => {
l.node(self.module()).parameters.as_ref()?.iter().next()
}
_ => None,
}?;
@ -5371,6 +5430,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.db(),
self.index,
scope,
self.module(),
) else {
overload.set_return_type(Type::unknown());
BoundSuperError::UnavailableImplicitArguments
@ -5435,7 +5495,11 @@ impl<'db> TypeInferenceBuilder<'db> {
let Some(target) =
assigned_to.as_ref().and_then(|assigned_to| {
match assigned_to.node().targets.as_slice() {
match assigned_to
.node(self.module())
.targets
.as_slice()
{
[ast::Expr::Name(target)] => Some(target),
_ => None,
}
@ -5605,7 +5669,11 @@ impl<'db> TypeInferenceBuilder<'db> {
let containing_assignment =
assigned_to.as_ref().and_then(|assigned_to| {
match assigned_to.node().targets.as_slice() {
match assigned_to
.node(self.module())
.targets
.as_slice()
{
[ast::Expr::Name(target)] => Some(
self.index.expect_single_definition(target),
),
@ -8125,7 +8193,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
/// Annotation expressions.
impl<'db> TypeInferenceBuilder<'db> {
impl<'db> TypeInferenceBuilder<'db, '_> {
/// Infer the type of an annotation expression with the given [`DeferredExpressionState`].
fn infer_annotation_expression(
&mut self,
@ -8314,7 +8382,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
/// Type expressions
impl<'db> TypeInferenceBuilder<'db> {
impl<'db> TypeInferenceBuilder<'db, '_> {
/// Infer the type of a type expression.
fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
let ty = self.infer_type_expression_no_store(expression);
@ -9340,10 +9408,10 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
fn infer_literal_parameter_type<'ast>(
fn infer_literal_parameter_type<'param>(
&mut self,
parameters: &'ast ast::Expr,
) -> Result<Type<'db>, Vec<&'ast ast::Expr>> {
parameters: &'param ast::Expr,
) -> Result<Type<'db>, Vec<&'param ast::Expr>> {
Ok(match parameters {
// TODO handle type aliases
ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => {
@ -9723,6 +9791,7 @@ mod tests {
symbol_name: &str,
) -> Place<'db> {
let file = system_path_to_file(db, file_name).expect("file to exist");
let module = parsed_module(db, file).load(db);
let index = semantic_index(db, file);
let mut file_scope_id = FileScopeId::global();
let mut scope = file_scope_id.to_scope_id(db, file);
@ -9733,7 +9802,7 @@ mod tests {
.unwrap_or_else(|| panic!("scope of {expected_scope_name}"))
.0;
scope = file_scope_id.to_scope_id(db, file);
assert_eq!(scope.name(db), *expected_scope_name);
assert_eq!(scope.name(db, &module), *expected_scope_name);
}
symbol(db, scope, symbol_name).place
@ -10087,7 +10156,7 @@ mod tests {
fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main);
let ast = parsed_module(db, file_main).load(db);
// Get the second statement in `main.py` (x = …) and extract the expression
// node on the right-hand side:
let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value;
@ -10170,7 +10239,7 @@ mod tests {
fn dependency_own_instance_member() -> anyhow::Result<()> {
fn x_rhs_expression(db: &TestDb) -> Expression<'_> {
let file_main = system_path_to_file(db, "/src/main.py").unwrap();
let ast = parsed_module(db, file_main);
let ast = parsed_module(db, file_main).load(db);
// Get the second statement in `main.py` (x = …) and extract the expression
// node on the right-hand side:
let x_rhs_node = &ast.syntax().body[1].as_assign_stmt().unwrap().value;

View file

@ -13,6 +13,7 @@ use crate::types::{
infer_expression_types,
};
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
use ruff_python_stdlib::identifiers::is_identifier;
use itertools::Itertools;
@ -73,7 +74,8 @@ fn all_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), true).finish()
let module = parsed_module(db.upcast(), pattern.file(db)).load(db.upcast());
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), true).finish()
}
#[salsa::tracked(
@ -85,7 +87,9 @@ fn all_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), true).finish()
let module = parsed_module(db.upcast(), expression.file(db)).load(db.upcast());
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), true)
.finish()
}
#[salsa::tracked(
@ -97,7 +101,9 @@ fn all_negative_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Expression(expression), false).finish()
let module = parsed_module(db.upcast(), expression.file(db)).load(db.upcast());
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Expression(expression), false)
.finish()
}
#[salsa::tracked(returns(as_ref))]
@ -105,7 +111,8 @@ fn all_negative_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternPredicate<'db>,
) -> Option<NarrowingConstraints<'db>> {
NarrowingConstraintsBuilder::new(db, PredicateNode::Pattern(pattern), false).finish()
let module = parsed_module(db.upcast(), pattern.file(db)).load(db.upcast());
NarrowingConstraintsBuilder::new(db, &module, PredicateNode::Pattern(pattern), false).finish()
}
#[expect(clippy::ref_option)]
@ -251,16 +258,23 @@ fn expr_name(expr: &ast::Expr) -> Option<&ast::name::Name> {
}
}
struct NarrowingConstraintsBuilder<'db> {
struct NarrowingConstraintsBuilder<'db, 'ast> {
db: &'db dyn Db,
module: &'ast ParsedModuleRef,
predicate: PredicateNode<'db>,
is_positive: bool,
}
impl<'db> NarrowingConstraintsBuilder<'db> {
fn new(db: &'db dyn Db, predicate: PredicateNode<'db>, is_positive: bool) -> Self {
impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
fn new(
db: &'db dyn Db,
module: &'ast ParsedModuleRef,
predicate: PredicateNode<'db>,
is_positive: bool,
) -> Self {
Self {
db,
module,
predicate,
is_positive,
}
@ -289,7 +303,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let expression_node = expression.node_ref(self.db);
let expression_node = expression.node_ref(self.db, self.module);
self.evaluate_expression_node_predicate(expression_node, expression, is_positive)
}
@ -775,7 +789,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
singleton: ast::Singleton,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let symbol = self
.expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id);
let ty = match singleton {
ast::Singleton::None => Type::none(self.db),
@ -790,8 +805,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
cls: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, cls).to_instance(self.db)?;
let symbol = self
.expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?;
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}
@ -801,8 +817,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
subject: Expression<'db>,
value: Expression<'db>,
) -> Option<NarrowingConstraints<'db>> {
let symbol = self.expect_expr_name_symbol(&subject.node_ref(self.db).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, value);
let symbol = self
.expect_expr_name_symbol(&subject.node_ref(self.db, self.module).as_name_expr()?.id);
let ty = infer_same_file_expression_type(self.db, value, self.module);
Some(NarrowingConstraints::from_iter([(symbol, ty)]))
}

View file

@ -1,6 +1,7 @@
use std::borrow::Cow;
use std::cmp::Ordering;
use ruff_db::parsed::ParsedModuleRef;
use rustc_hash::FxHashMap;
use ruff_python_ast::{self as ast, AnyNodeRef};
@ -16,21 +17,22 @@ use super::diagnostic::INVALID_ASSIGNMENT;
use super::{KnownClass, TupleType, UnionType};
/// Unpacks the value expression type to their respective targets.
pub(crate) struct Unpacker<'db> {
context: InferContext<'db>,
pub(crate) struct Unpacker<'db, 'ast> {
context: InferContext<'db, 'ast>,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
targets: FxHashMap<ScopedExpressionId, Type<'db>>,
}
impl<'db> Unpacker<'db> {
impl<'db, 'ast> Unpacker<'db, 'ast> {
pub(crate) fn new(
db: &'db dyn Db,
target_scope: ScopeId<'db>,
value_scope: ScopeId<'db>,
module: &'ast ParsedModuleRef,
) -> Self {
Self {
context: InferContext::new(db, target_scope),
context: InferContext::new(db, target_scope, module),
targets: FxHashMap::default(),
target_scope,
value_scope,
@ -41,6 +43,10 @@ impl<'db> Unpacker<'db> {
self.context.db()
}
fn module(&self) -> &'ast ParsedModuleRef {
self.context.module()
}
/// Unpack the value to the target expression.
pub(crate) fn unpack(&mut self, target: &ast::Expr, value: UnpackValue<'db>) {
debug_assert!(
@ -48,15 +54,16 @@ impl<'db> Unpacker<'db> {
"Unpacking target must be a list or tuple expression"
);
let value_type = infer_expression_types(self.db(), value.expression())
.expression_type(value.scoped_expression_id(self.db(), self.value_scope));
let value_type = infer_expression_types(self.db(), value.expression()).expression_type(
value.scoped_expression_id(self.db(), self.value_scope, self.module()),
);
let value_type = match value.kind() {
UnpackKind::Assign => {
if self.context.in_stub()
&& value
.expression()
.node_ref(self.db())
.node_ref(self.db(), self.module())
.is_ellipsis_literal_expr()
{
Type::unknown()
@ -65,22 +72,34 @@ impl<'db> Unpacker<'db> {
}
}
UnpackKind::Iterable => value_type.try_iterate(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, value_type, value.as_any_node_ref(self.db()));
err.report_diagnostic(
&self.context,
value_type,
value.as_any_node_ref(self.db(), self.module()),
);
err.fallback_element_type(self.db())
}),
UnpackKind::ContextManager => value_type.try_enter(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, value_type, value.as_any_node_ref(self.db()));
err.report_diagnostic(
&self.context,
value_type,
value.as_any_node_ref(self.db(), self.module()),
);
err.fallback_enter_type(self.db())
}),
};
self.unpack_inner(target, value.as_any_node_ref(self.db()), value_type);
self.unpack_inner(
target,
value.as_any_node_ref(self.db(), self.module()),
value_type,
);
}
fn unpack_inner(
&mut self,
target: &ast::Expr,
value_expr: AnyNodeRef<'db>,
value_expr: AnyNodeRef<'_>,
value_ty: Type<'db>,
) {
match target {

View file

@ -1,4 +1,5 @@
use ruff_db::files::File;
use ruff_db::parsed::ParsedModuleRef;
use ruff_python_ast::{self as ast, AnyNodeRef};
use ruff_text_size::{Ranged, TextRange};
@ -37,9 +38,9 @@ pub(crate) struct Unpack<'db> {
/// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target
/// expression is `(a, b)`.
#[no_eq]
#[returns(deref)]
#[tracked]
pub(crate) target: AstNodeRef<ast::Expr>,
#[returns(ref)]
pub(crate) _target: AstNodeRef<ast::Expr>,
/// The ingredient representing the value expression of the unpacking. For example, in
/// `(a, b) = (1, 2)`, the value expression is `(1, 2)`.
@ -49,6 +50,14 @@ pub(crate) struct Unpack<'db> {
}
impl<'db> Unpack<'db> {
pub(crate) fn target<'ast>(
self,
db: &'db dyn Db,
parsed: &'ast ParsedModuleRef,
) -> &'ast ast::Expr {
self._target(db).node(parsed)
}
/// Returns the scope in which the unpack value expression belongs.
///
/// The scope in which the target and value expression belongs to are usually the same
@ -65,8 +74,8 @@ impl<'db> Unpack<'db> {
}
/// Returns the range of the unpack target expression.
pub(crate) fn range(self, db: &'db dyn Db) -> TextRange {
self.target(db).range()
pub(crate) fn range(self, db: &'db dyn Db, module: &ParsedModuleRef) -> TextRange {
self.target(db, module).range()
}
}
@ -94,15 +103,20 @@ impl<'db> UnpackValue<'db> {
self,
db: &'db dyn Db,
scope: ScopeId<'db>,
module: &ParsedModuleRef,
) -> ScopedExpressionId {
self.expression()
.node_ref(db)
.node_ref(db, module)
.scoped_expression_id(db, scope)
}
/// Returns the expression as an [`AnyNodeRef`].
pub(crate) fn as_any_node_ref(self, db: &'db dyn Db) -> AnyNodeRef<'db> {
self.expression().node_ref(db).into()
pub(crate) fn as_any_node_ref<'ast>(
self,
db: &'db dyn Db,
module: &'ast ParsedModuleRef,
) -> AnyNodeRef<'ast> {
self.expression().node_ref(db, module).into()
}
pub(crate) const fn kind(self) -> UnpackKind {